minikube/third_party/go9p/srv_conn.go

261 lines
5.1 KiB
Go

// Copyright 2009 The Go9p Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package go9p
import (
"fmt"
"log"
"net"
)
func (srv *Srv) NewConn(c net.Conn) {
conn := new(Conn)
conn.Srv = srv
conn.Msize = srv.Msize
conn.Dotu = srv.Dotu
conn.Debuglevel = srv.Debuglevel
conn.conn = c
conn.fidpool = make(map[uint32]*SrvFid)
conn.reqs = make(map[uint16]*SrvReq)
conn.reqout = make(chan *SrvReq, srv.Maxpend)
conn.done = make(chan bool)
conn.rchan = make(chan *Fcall, 64)
srv.Lock()
if srv.conns == nil {
srv.conns = make(map[*Conn]*Conn)
}
srv.conns[conn] = conn
srv.Unlock()
conn.Id = c.RemoteAddr().String()
if op, ok := (conn.Srv.ops).(ConnOps); ok {
op.ConnOpened(conn)
}
if sop, ok := (interface{}(conn)).(StatsOps); ok {
sop.statsRegister()
}
go conn.recv()
go conn.send()
}
func (conn *Conn) close() {
conn.done <- true
conn.Srv.Lock()
delete(conn.Srv.conns, conn)
conn.Srv.Unlock()
if sop, ok := (interface{}(conn)).(StatsOps); ok {
sop.statsUnregister()
}
if op, ok := (conn.Srv.ops).(ConnOps); ok {
op.ConnClosed(conn)
}
/* call FidDestroy for all remaining fids */
if op, ok := (conn.Srv.ops).(SrvFidOps); ok {
for _, fid := range conn.fidpool {
op.FidDestroy(fid)
}
}
}
func (conn *Conn) recv() {
var err error
var n int
buf := make([]byte, conn.Msize*8)
pos := 0
for {
if len(buf) < int(conn.Msize) {
b := make([]byte, conn.Msize*8)
copy(b, buf[0:pos])
buf = b
b = nil
}
n, err = conn.conn.Read(buf[pos:])
if err != nil || n == 0 {
conn.close()
return
}
pos += n
for pos > 4 {
sz, _ := Gint32(buf)
if sz > conn.Msize {
log.Println("bad client connection: ", conn.conn.RemoteAddr())
conn.conn.Close()
conn.close()
return
}
if pos < int(sz) {
if len(buf) < int(sz) {
b := make([]byte, conn.Msize*8)
copy(b, buf[0:pos])
buf = b
b = nil
}
break
}
fc, err, fcsize := Unpack(buf, conn.Dotu)
if err != nil {
log.Printf("invalid packet : %v %v\n", err, buf)
conn.conn.Close()
conn.close()
return
}
tag := fc.Tag
req := new(SrvReq)
select {
case req.Rc = <-conn.rchan:
break
default:
req.Rc = NewFcall(conn.Msize)
}
req.Conn = conn
req.Tc = fc
// req.Rc = rc
if conn.Debuglevel > 0 {
conn.logFcall(req.Tc)
if conn.Debuglevel&DbgPrintPackets != 0 {
log.Println(">->", conn.Id, fmt.Sprint(req.Tc.Pkt))
}
if conn.Debuglevel&DbgPrintFcalls != 0 {
log.Println(">>>", conn.Id, req.Tc.String())
}
}
conn.Lock()
conn.nreqs++
conn.tsz += uint64(fc.Size)
conn.npend++
if conn.npend > conn.maxpend {
conn.maxpend = conn.npend
}
req.next = conn.reqs[tag]
conn.reqs[tag] = req
process := req.next == nil
if req.next != nil {
req.next.prev = req
}
conn.Unlock()
if process {
// Tversion may change some attributes of the
// connection, so we block on it. Otherwise,
// we may loop back to reading and that is a race.
// This fix brought to you by the race detector.
if req.Tc.Type == Tversion {
req.process()
} else {
go req.process()
}
}
buf = buf[fcsize:]
pos -= fcsize
}
}
}
func (conn *Conn) send() {
for {
select {
case <-conn.done:
return
case req := <-conn.reqout:
SetTag(req.Rc, req.Tc.Tag)
conn.Lock()
conn.rsz += uint64(req.Rc.Size)
conn.npend--
conn.Unlock()
if conn.Debuglevel > 0 {
conn.logFcall(req.Rc)
if conn.Debuglevel&DbgPrintPackets != 0 {
log.Println("<-<", conn.Id, fmt.Sprint(req.Rc.Pkt))
}
if conn.Debuglevel&DbgPrintFcalls != 0 {
log.Println("<<<", conn.Id, req.Rc.String())
}
}
for buf := req.Rc.Pkt; len(buf) > 0; {
n, err := conn.conn.Write(buf)
if err != nil {
/* just close the socket, will get signal on conn.done */
log.Println("error while writing")
conn.conn.Close()
break
}
buf = buf[n:]
}
select {
case conn.rchan <- req.Rc:
break
default:
}
}
}
}
func (conn *Conn) RemoteAddr() net.Addr {
return conn.conn.RemoteAddr()
}
func (conn *Conn) LocalAddr() net.Addr {
return conn.conn.LocalAddr()
}
func (conn *Conn) logFcall(fc *Fcall) {
if conn.Debuglevel&DbgLogPackets != 0 {
pkt := make([]byte, len(fc.Pkt))
copy(pkt, fc.Pkt)
conn.Srv.Log.Log(pkt, conn, DbgLogPackets)
}
if conn.Debuglevel&DbgLogFcalls != 0 {
f := new(Fcall)
*f = *fc
f.Pkt = nil
conn.Srv.Log.Log(f, conn, DbgLogFcalls)
}
}
func (srv *Srv) StartNetListener(ntype, addr string) error {
l, err := net.Listen(ntype, addr)
if err != nil {
return &Error{err.Error(), EIO}
}
return srv.StartListener(l)
}
// Start listening on the specified network and address for incoming
// connections. Once a connection is established, create a new Conn
// value, read messages from the socket, send them to the specified
// server, and send back responses received from the server.
func (srv *Srv) StartListener(l net.Listener) error {
for {
c, err := l.Accept()
if err != nil {
return &Error{err.Error(), EIO}
}
srv.NewConn(c)
}
}