261 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			261 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Go
		
	
	
// Copyright 2009 The Go9p Authors.  All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package go9p
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"log"
 | 
						|
	"net"
 | 
						|
)
 | 
						|
 | 
						|
func (srv *Srv) NewConn(c net.Conn) {
 | 
						|
	conn := new(Conn)
 | 
						|
	conn.Srv = srv
 | 
						|
	conn.Msize = srv.Msize
 | 
						|
	conn.Dotu = srv.Dotu
 | 
						|
	conn.Debuglevel = srv.Debuglevel
 | 
						|
	conn.conn = c
 | 
						|
	conn.fidpool = make(map[uint32]*SrvFid)
 | 
						|
	conn.reqs = make(map[uint16]*SrvReq)
 | 
						|
	conn.reqout = make(chan *SrvReq, srv.Maxpend)
 | 
						|
	conn.done = make(chan bool)
 | 
						|
	conn.rchan = make(chan *Fcall, 64)
 | 
						|
 | 
						|
	srv.Lock()
 | 
						|
	if srv.conns == nil {
 | 
						|
		srv.conns = make(map[*Conn]*Conn)
 | 
						|
	}
 | 
						|
	srv.conns[conn] = conn
 | 
						|
	srv.Unlock()
 | 
						|
 | 
						|
	conn.Id = c.RemoteAddr().String()
 | 
						|
	if op, ok := (conn.Srv.ops).(ConnOps); ok {
 | 
						|
		op.ConnOpened(conn)
 | 
						|
	}
 | 
						|
 | 
						|
	if sop, ok := (interface{}(conn)).(StatsOps); ok {
 | 
						|
		sop.statsRegister()
 | 
						|
	}
 | 
						|
 | 
						|
	go conn.recv()
 | 
						|
	go conn.send()
 | 
						|
}
 | 
						|
 | 
						|
func (conn *Conn) close() {
 | 
						|
	conn.done <- true
 | 
						|
	conn.Srv.Lock()
 | 
						|
	delete(conn.Srv.conns, conn)
 | 
						|
	conn.Srv.Unlock()
 | 
						|
 | 
						|
	if sop, ok := (interface{}(conn)).(StatsOps); ok {
 | 
						|
		sop.statsUnregister()
 | 
						|
	}
 | 
						|
	if op, ok := (conn.Srv.ops).(ConnOps); ok {
 | 
						|
		op.ConnClosed(conn)
 | 
						|
	}
 | 
						|
 | 
						|
	/* call FidDestroy for all remaining fids */
 | 
						|
	if op, ok := (conn.Srv.ops).(SrvFidOps); ok {
 | 
						|
		for _, fid := range conn.fidpool {
 | 
						|
			op.FidDestroy(fid)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (conn *Conn) recv() {
 | 
						|
	var err error
 | 
						|
	var n int
 | 
						|
 | 
						|
	buf := make([]byte, conn.Msize*8)
 | 
						|
	pos := 0
 | 
						|
	for {
 | 
						|
		if len(buf) < int(conn.Msize) {
 | 
						|
			b := make([]byte, conn.Msize*8)
 | 
						|
			copy(b, buf[0:pos])
 | 
						|
			buf = b
 | 
						|
			b = nil
 | 
						|
		}
 | 
						|
 | 
						|
		n, err = conn.conn.Read(buf[pos:])
 | 
						|
		if err != nil || n == 0 {
 | 
						|
			conn.close()
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
		pos += n
 | 
						|
		for pos > 4 {
 | 
						|
			sz, _ := Gint32(buf)
 | 
						|
			if sz > conn.Msize {
 | 
						|
				log.Println("bad client connection: ", conn.conn.RemoteAddr())
 | 
						|
				conn.conn.Close()
 | 
						|
				conn.close()
 | 
						|
				return
 | 
						|
			}
 | 
						|
			if pos < int(sz) {
 | 
						|
				if len(buf) < int(sz) {
 | 
						|
					b := make([]byte, conn.Msize*8)
 | 
						|
					copy(b, buf[0:pos])
 | 
						|
					buf = b
 | 
						|
					b = nil
 | 
						|
				}
 | 
						|
 | 
						|
				break
 | 
						|
			}
 | 
						|
			fc, err, fcsize := Unpack(buf, conn.Dotu)
 | 
						|
			if err != nil {
 | 
						|
				log.Println(fmt.Sprintf("invalid packet : %v %v", err, buf))
 | 
						|
				conn.conn.Close()
 | 
						|
				conn.close()
 | 
						|
				return
 | 
						|
			}
 | 
						|
 | 
						|
			tag := fc.Tag
 | 
						|
			req := new(SrvReq)
 | 
						|
			select {
 | 
						|
			case req.Rc = <-conn.rchan:
 | 
						|
				break
 | 
						|
			default:
 | 
						|
				req.Rc = NewFcall(conn.Msize)
 | 
						|
			}
 | 
						|
 | 
						|
			req.Conn = conn
 | 
						|
			req.Tc = fc
 | 
						|
			//			req.Rc = rc
 | 
						|
			if conn.Debuglevel > 0 {
 | 
						|
				conn.logFcall(req.Tc)
 | 
						|
				if conn.Debuglevel&DbgPrintPackets != 0 {
 | 
						|
					log.Println(">->", conn.Id, fmt.Sprint(req.Tc.Pkt))
 | 
						|
				}
 | 
						|
 | 
						|
				if conn.Debuglevel&DbgPrintFcalls != 0 {
 | 
						|
					log.Println(">>>", conn.Id, req.Tc.String())
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			conn.Lock()
 | 
						|
			conn.nreqs++
 | 
						|
			conn.tsz += uint64(fc.Size)
 | 
						|
			conn.npend++
 | 
						|
			if conn.npend > conn.maxpend {
 | 
						|
				conn.maxpend = conn.npend
 | 
						|
			}
 | 
						|
 | 
						|
			req.next = conn.reqs[tag]
 | 
						|
			conn.reqs[tag] = req
 | 
						|
			process := req.next == nil
 | 
						|
			if req.next != nil {
 | 
						|
				req.next.prev = req
 | 
						|
			}
 | 
						|
			conn.Unlock()
 | 
						|
			if process {
 | 
						|
				// Tversion may change some attributes of the
 | 
						|
				// connection, so we block on it. Otherwise,
 | 
						|
				// we may loop back to reading and that is a race.
 | 
						|
				// This fix brought to you by the race detector.
 | 
						|
				if req.Tc.Type == Tversion {
 | 
						|
					req.process()
 | 
						|
				} else {
 | 
						|
					go req.process()
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			buf = buf[fcsize:]
 | 
						|
			pos -= fcsize
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
func (conn *Conn) send() {
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case <-conn.done:
 | 
						|
			return
 | 
						|
 | 
						|
		case req := <-conn.reqout:
 | 
						|
			SetTag(req.Rc, req.Tc.Tag)
 | 
						|
			conn.Lock()
 | 
						|
			conn.rsz += uint64(req.Rc.Size)
 | 
						|
			conn.npend--
 | 
						|
			conn.Unlock()
 | 
						|
			if conn.Debuglevel > 0 {
 | 
						|
				conn.logFcall(req.Rc)
 | 
						|
				if conn.Debuglevel&DbgPrintPackets != 0 {
 | 
						|
					log.Println("<-<", conn.Id, fmt.Sprint(req.Rc.Pkt))
 | 
						|
				}
 | 
						|
 | 
						|
				if conn.Debuglevel&DbgPrintFcalls != 0 {
 | 
						|
					log.Println("<<<", conn.Id, req.Rc.String())
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			for buf := req.Rc.Pkt; len(buf) > 0; {
 | 
						|
				n, err := conn.conn.Write(buf)
 | 
						|
				if err != nil {
 | 
						|
					/* just close the socket, will get signal on conn.done */
 | 
						|
					log.Println("error while writing")
 | 
						|
					conn.conn.Close()
 | 
						|
					break
 | 
						|
				}
 | 
						|
 | 
						|
				buf = buf[n:]
 | 
						|
			}
 | 
						|
 | 
						|
			select {
 | 
						|
			case conn.rchan <- req.Rc:
 | 
						|
				break
 | 
						|
			default:
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (conn *Conn) RemoteAddr() net.Addr {
 | 
						|
	return conn.conn.RemoteAddr()
 | 
						|
}
 | 
						|
 | 
						|
func (conn *Conn) LocalAddr() net.Addr {
 | 
						|
	return conn.conn.LocalAddr()
 | 
						|
}
 | 
						|
 | 
						|
func (conn *Conn) logFcall(fc *Fcall) {
 | 
						|
	if conn.Debuglevel&DbgLogPackets != 0 {
 | 
						|
		pkt := make([]byte, len(fc.Pkt))
 | 
						|
		copy(pkt, fc.Pkt)
 | 
						|
		conn.Srv.Log.Log(pkt, conn, DbgLogPackets)
 | 
						|
	}
 | 
						|
 | 
						|
	if conn.Debuglevel&DbgLogFcalls != 0 {
 | 
						|
		f := new(Fcall)
 | 
						|
		*f = *fc
 | 
						|
		f.Pkt = nil
 | 
						|
		conn.Srv.Log.Log(f, conn, DbgLogFcalls)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (srv *Srv) StartNetListener(ntype, addr string) error {
 | 
						|
	l, err := net.Listen(ntype, addr)
 | 
						|
	if err != nil {
 | 
						|
		return &Error{err.Error(), EIO}
 | 
						|
	}
 | 
						|
 | 
						|
	return srv.StartListener(l)
 | 
						|
}
 | 
						|
 | 
						|
// Start listening on the specified network and address for incoming
 | 
						|
// connections. Once a connection is established, create a new Conn
 | 
						|
// value, read messages from the socket, send them to the specified
 | 
						|
// server, and send back responses received from the server.
 | 
						|
func (srv *Srv) StartListener(l net.Listener) error {
 | 
						|
	for {
 | 
						|
		c, err := l.Accept()
 | 
						|
		if err != nil {
 | 
						|
			return &Error{err.Error(), EIO}
 | 
						|
		}
 | 
						|
 | 
						|
		srv.NewConn(c)
 | 
						|
	}
 | 
						|
}
 |