diff --git a/tcp/mux.go b/tcp/mux.go index a23eb6570c..124500671d 100644 --- a/tcp/mux.go +++ b/tcp/mux.go @@ -22,6 +22,8 @@ type Mux struct { ln net.Listener m map[byte]*listener + defaultListener *listener + wg sync.WaitGroup // The amount of time to wait for the first header byte. @@ -31,6 +33,26 @@ type Mux struct { Logger *log.Logger } +type replayConn struct { + net.Conn + firstByte byte + readFirstbyte bool +} + +func (rc *replayConn) Read(b []byte) (int, error) { + if rc.readFirstbyte { + return rc.Conn.Read(b) + } + + if len(b) == 0 { + return 0, nil + } + + b[0] = rc.firstByte + rc.readFirstbyte = true + return 1, nil +} + // NewMux returns a new instance of Mux for ln. func NewMux() *Mux { return &Mux{ @@ -61,6 +83,11 @@ func (mux *Mux) Serve(ln net.Listener) error { for _, ln := range mux.m { close(ln.c) } + + if mux.defaultListener != nil { + close(mux.defaultListener.c) + } + return err } @@ -97,9 +124,17 @@ func (mux *Mux) handleConn(conn net.Conn) { // Retrieve handler based on first byte. handler := mux.m[typ[0]] if handler == nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: handler not registered: %d. Connection from %s closed", typ[0], conn.RemoteAddr()) - return + if mux.defaultListener == nil { + conn.Close() + mux.Logger.Printf("tcp.Mux: handler not registered: %d. Connection from %s closed", typ[0], conn.RemoteAddr()) + return + } + + conn = &replayConn{ + Conn: conn, + firstByte: typ[0], + } + handler = mux.defaultListener } // Send connection to handler. The handler is responsible for closing the connection. @@ -133,6 +168,25 @@ func (mux *Mux) Listen(header byte) net.Listener { return ln } +// DefaultListener() will return a net.Listener that will pass-through any +// connections with non-registered values for the first byte of the connection. +// The connections returned from this listener's Accept() method will replay the +// first byte of the connection as a short first Read(). +// +// This can be used to pass to an HTTP server, so long as there are no conflicts +// with registsered listener bytes and the first character of the HTTP request: +// 71 ('G') for GET, etc. +func (mux *Mux) DefaultListener() net.Listener { + if mux.defaultListener == nil { + mux.defaultListener = &listener{ + c: make(chan net.Conn), + mux: mux, + } + } + + return mux.defaultListener +} + // listener is a receiver for connections received by Mux. type listener struct { c chan net.Conn