261 lines
5.1 KiB
Go
261 lines
5.1 KiB
Go
// Copyright 2009 The Go9p Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package go9p
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
)
|
|
|
|
func (srv *Srv) NewConn(c net.Conn) {
|
|
conn := new(Conn)
|
|
conn.Srv = srv
|
|
conn.Msize = srv.Msize
|
|
conn.Dotu = srv.Dotu
|
|
conn.Debuglevel = srv.Debuglevel
|
|
conn.conn = c
|
|
conn.fidpool = make(map[uint32]*SrvFid)
|
|
conn.reqs = make(map[uint16]*SrvReq)
|
|
conn.reqout = make(chan *SrvReq, srv.Maxpend)
|
|
conn.done = make(chan bool)
|
|
conn.rchan = make(chan *Fcall, 64)
|
|
|
|
srv.Lock()
|
|
if srv.conns == nil {
|
|
srv.conns = make(map[*Conn]*Conn)
|
|
}
|
|
srv.conns[conn] = conn
|
|
srv.Unlock()
|
|
|
|
conn.Id = c.RemoteAddr().String()
|
|
if op, ok := (conn.Srv.ops).(ConnOps); ok {
|
|
op.ConnOpened(conn)
|
|
}
|
|
|
|
if sop, ok := (interface{}(conn)).(StatsOps); ok {
|
|
sop.statsRegister()
|
|
}
|
|
|
|
go conn.recv()
|
|
go conn.send()
|
|
}
|
|
|
|
func (conn *Conn) close() {
|
|
conn.done <- true
|
|
conn.Srv.Lock()
|
|
delete(conn.Srv.conns, conn)
|
|
conn.Srv.Unlock()
|
|
|
|
if sop, ok := (interface{}(conn)).(StatsOps); ok {
|
|
sop.statsUnregister()
|
|
}
|
|
if op, ok := (conn.Srv.ops).(ConnOps); ok {
|
|
op.ConnClosed(conn)
|
|
}
|
|
|
|
/* call FidDestroy for all remaining fids */
|
|
if op, ok := (conn.Srv.ops).(SrvFidOps); ok {
|
|
for _, fid := range conn.fidpool {
|
|
op.FidDestroy(fid)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (conn *Conn) recv() {
|
|
var err error
|
|
var n int
|
|
|
|
buf := make([]byte, conn.Msize*8)
|
|
pos := 0
|
|
for {
|
|
if len(buf) < int(conn.Msize) {
|
|
b := make([]byte, conn.Msize*8)
|
|
copy(b, buf[0:pos])
|
|
buf = b
|
|
b = nil
|
|
}
|
|
|
|
n, err = conn.conn.Read(buf[pos:])
|
|
if err != nil || n == 0 {
|
|
conn.close()
|
|
return
|
|
}
|
|
|
|
pos += n
|
|
for pos > 4 {
|
|
sz, _ := Gint32(buf)
|
|
if sz > conn.Msize {
|
|
log.Println("bad client connection: ", conn.conn.RemoteAddr())
|
|
conn.conn.Close()
|
|
conn.close()
|
|
return
|
|
}
|
|
if pos < int(sz) {
|
|
if len(buf) < int(sz) {
|
|
b := make([]byte, conn.Msize*8)
|
|
copy(b, buf[0:pos])
|
|
buf = b
|
|
b = nil
|
|
}
|
|
|
|
break
|
|
}
|
|
fc, err, fcsize := Unpack(buf, conn.Dotu)
|
|
if err != nil {
|
|
log.Println(fmt.Sprintf("invalid packet : %v %v", err, buf))
|
|
conn.conn.Close()
|
|
conn.close()
|
|
return
|
|
}
|
|
|
|
tag := fc.Tag
|
|
req := new(SrvReq)
|
|
select {
|
|
case req.Rc = <-conn.rchan:
|
|
break
|
|
default:
|
|
req.Rc = NewFcall(conn.Msize)
|
|
}
|
|
|
|
req.Conn = conn
|
|
req.Tc = fc
|
|
// req.Rc = rc
|
|
if conn.Debuglevel > 0 {
|
|
conn.logFcall(req.Tc)
|
|
if conn.Debuglevel&DbgPrintPackets != 0 {
|
|
log.Println(">->", conn.Id, fmt.Sprint(req.Tc.Pkt))
|
|
}
|
|
|
|
if conn.Debuglevel&DbgPrintFcalls != 0 {
|
|
log.Println(">>>", conn.Id, req.Tc.String())
|
|
}
|
|
}
|
|
|
|
conn.Lock()
|
|
conn.nreqs++
|
|
conn.tsz += uint64(fc.Size)
|
|
conn.npend++
|
|
if conn.npend > conn.maxpend {
|
|
conn.maxpend = conn.npend
|
|
}
|
|
|
|
req.next = conn.reqs[tag]
|
|
conn.reqs[tag] = req
|
|
process := req.next == nil
|
|
if req.next != nil {
|
|
req.next.prev = req
|
|
}
|
|
conn.Unlock()
|
|
if process {
|
|
// Tversion may change some attributes of the
|
|
// connection, so we block on it. Otherwise,
|
|
// we may loop back to reading and that is a race.
|
|
// This fix brought to you by the race detector.
|
|
if req.Tc.Type == Tversion {
|
|
req.process()
|
|
} else {
|
|
go req.process()
|
|
}
|
|
}
|
|
|
|
buf = buf[fcsize:]
|
|
pos -= fcsize
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func (conn *Conn) send() {
|
|
for {
|
|
select {
|
|
case <-conn.done:
|
|
return
|
|
|
|
case req := <-conn.reqout:
|
|
SetTag(req.Rc, req.Tc.Tag)
|
|
conn.Lock()
|
|
conn.rsz += uint64(req.Rc.Size)
|
|
conn.npend--
|
|
conn.Unlock()
|
|
if conn.Debuglevel > 0 {
|
|
conn.logFcall(req.Rc)
|
|
if conn.Debuglevel&DbgPrintPackets != 0 {
|
|
log.Println("<-<", conn.Id, fmt.Sprint(req.Rc.Pkt))
|
|
}
|
|
|
|
if conn.Debuglevel&DbgPrintFcalls != 0 {
|
|
log.Println("<<<", conn.Id, req.Rc.String())
|
|
}
|
|
}
|
|
|
|
for buf := req.Rc.Pkt; len(buf) > 0; {
|
|
n, err := conn.conn.Write(buf)
|
|
if err != nil {
|
|
/* just close the socket, will get signal on conn.done */
|
|
log.Println("error while writing")
|
|
conn.conn.Close()
|
|
break
|
|
}
|
|
|
|
buf = buf[n:]
|
|
}
|
|
|
|
select {
|
|
case conn.rchan <- req.Rc:
|
|
break
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (conn *Conn) RemoteAddr() net.Addr {
|
|
return conn.conn.RemoteAddr()
|
|
}
|
|
|
|
func (conn *Conn) LocalAddr() net.Addr {
|
|
return conn.conn.LocalAddr()
|
|
}
|
|
|
|
func (conn *Conn) logFcall(fc *Fcall) {
|
|
if conn.Debuglevel&DbgLogPackets != 0 {
|
|
pkt := make([]byte, len(fc.Pkt))
|
|
copy(pkt, fc.Pkt)
|
|
conn.Srv.Log.Log(pkt, conn, DbgLogPackets)
|
|
}
|
|
|
|
if conn.Debuglevel&DbgLogFcalls != 0 {
|
|
f := new(Fcall)
|
|
*f = *fc
|
|
f.Pkt = nil
|
|
conn.Srv.Log.Log(f, conn, DbgLogFcalls)
|
|
}
|
|
}
|
|
|
|
func (srv *Srv) StartNetListener(ntype, addr string) error {
|
|
l, err := net.Listen(ntype, addr)
|
|
if err != nil {
|
|
return &Error{err.Error(), EIO}
|
|
}
|
|
|
|
return srv.StartListener(l)
|
|
}
|
|
|
|
// Start listening on the specified network and address for incoming
|
|
// connections. Once a connection is established, create a new Conn
|
|
// value, read messages from the socket, send them to the specified
|
|
// server, and send back responses received from the server.
|
|
func (srv *Srv) StartListener(l net.Listener) error {
|
|
for {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
return &Error{err.Error(), EIO}
|
|
}
|
|
|
|
srv.NewConn(c)
|
|
}
|
|
}
|