move tcp to cluster
parent
3dc688cff2
commit
1228de4e7c
|
@ -0,0 +1,59 @@
|
|||
package cluster
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/fatih/pool"
|
||||
"github.com/influxdb/influxdb/meta"
|
||||
)
|
||||
|
||||
type clientPool struct {
|
||||
mu sync.RWMutex
|
||||
pool map[*meta.NodeInfo]pool.Pool
|
||||
}
|
||||
|
||||
func newClientPool() *clientPool {
|
||||
return &clientPool{
|
||||
pool: make(map[*meta.NodeInfo]pool.Pool),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientPool) setPool(n *meta.NodeInfo, p pool.Pool) {
|
||||
c.mu.Lock()
|
||||
c.pool[n] = p
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *clientPool) getPool(n *meta.NodeInfo) (pool.Pool, bool) {
|
||||
c.mu.Lock()
|
||||
p, ok := c.pool[n]
|
||||
c.mu.Unlock()
|
||||
return p, ok
|
||||
}
|
||||
|
||||
func (c *clientPool) size() int {
|
||||
c.mu.RLock()
|
||||
var size int
|
||||
for _, p := range c.pool {
|
||||
size += p.Len()
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
return size
|
||||
}
|
||||
|
||||
func (c *clientPool) conn(n *meta.NodeInfo) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
conn, err := c.pool[n].Get()
|
||||
c.mu.Unlock()
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (c *clientPool) close() error {
|
||||
c.mu.Lock()
|
||||
for _, p := range c.pool {
|
||||
p.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package tcp
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
@ -9,7 +9,6 @@ import (
|
|||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/influxdb/influxdb/cluster"
|
||||
"github.com/influxdb/influxdb/tsdb"
|
||||
)
|
||||
|
||||
|
@ -171,7 +170,7 @@ func (s *Server) writeShardRequest(conn net.Conn) error {
|
|||
return err
|
||||
}
|
||||
|
||||
var wsr cluster.WriteShardRequest
|
||||
var wsr WriteShardRequest
|
||||
if err := wsr.UnmarshalBinary(message); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -185,7 +184,7 @@ func (s *Server) writeShardResponse(conn net.Conn, e error) {
|
|||
return
|
||||
}
|
||||
|
||||
var wsr cluster.WriteShardResponse
|
||||
var wsr WriteShardResponse
|
||||
if e != nil {
|
||||
wsr.SetCode(1)
|
||||
wsr.SetMessage(e.Error())
|
|
@ -1,4 +1,4 @@
|
|||
package tcp_test
|
||||
package cluster_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -6,10 +6,22 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/influxdb/influxdb/tcp"
|
||||
"github.com/influxdb/influxdb/cluster"
|
||||
"github.com/influxdb/influxdb/meta"
|
||||
"github.com/influxdb/influxdb/tsdb"
|
||||
)
|
||||
|
||||
type metaStore struct {
|
||||
host string
|
||||
}
|
||||
|
||||
func (m *metaStore) Node(nodeID uint64) (*meta.NodeInfo, error) {
|
||||
return &meta.NodeInfo{
|
||||
ID: nodeID,
|
||||
Host: m.host,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
writeShardFunc func(shardID uint64, points []tsdb.Point) error
|
||||
}
|
||||
|
@ -62,7 +74,7 @@ func (testServer) ResponseN(n int) ([]*serverResponse, error) {
|
|||
func TestServer_Close_ErrServerClosed(t *testing.T) {
|
||||
var (
|
||||
ts testServer
|
||||
s = tcp.NewServer(ts)
|
||||
s = cluster.NewServer(ts)
|
||||
)
|
||||
|
||||
// Start on a random port
|
||||
|
@ -75,7 +87,7 @@ func TestServer_Close_ErrServerClosed(t *testing.T) {
|
|||
s.Close()
|
||||
|
||||
// Try to close it again
|
||||
if err := s.Close(); err != tcp.ErrServerClosed {
|
||||
if err := s.Close(); err != cluster.ErrServerClosed {
|
||||
t.Fatalf("expected an error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -83,13 +95,13 @@ func TestServer_Close_ErrServerClosed(t *testing.T) {
|
|||
func TestServer_Close_ErrBindAddressRequired(t *testing.T) {
|
||||
var (
|
||||
ts testServer
|
||||
s = tcp.NewServer(ts)
|
||||
s = cluster.NewServer(ts)
|
||||
)
|
||||
|
||||
// Start on a random port
|
||||
_, e := s.ListenAndServe("")
|
||||
if e == nil {
|
||||
t.Fatalf("exprected error %s, got nil.", tcp.ErrBindAddressRequired)
|
||||
t.Fatalf("exprected error %s, got nil.", cluster.ErrBindAddressRequired)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -97,7 +109,7 @@ func TestServer_Close_ErrBindAddressRequired(t *testing.T) {
|
|||
func TestServer_WriteShardRequestSuccess(t *testing.T) {
|
||||
var (
|
||||
ts = newTestServer(writeShardSuccess)
|
||||
s = tcp.NewServer(ts)
|
||||
s = cluster.NewServer(ts)
|
||||
)
|
||||
// Close the server
|
||||
defer s.Close()
|
||||
|
@ -108,21 +120,22 @@ func TestServer_WriteShardRequestSuccess(t *testing.T) {
|
|||
t.Fatalf("err does not match. expected %v, got %v", nil, e)
|
||||
}
|
||||
|
||||
client := tcp.NewClient()
|
||||
writer := cluster.NewWriter(&metaStore{host: host})
|
||||
|
||||
now := time.Now()
|
||||
|
||||
shardID := uint64(1)
|
||||
ownerID := uint64(2)
|
||||
var points []tsdb.Point
|
||||
points = append(points, tsdb.NewPoint(
|
||||
"cpu", tsdb.Tags{"host": "server01"}, map[string]interface{}{"value": int64(100)}, now,
|
||||
))
|
||||
|
||||
if err := client.WriteShard(host, shardID, points); err != nil {
|
||||
if err := writer.Write(shardID, ownerID, points); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := client.Close(); err != nil {
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -162,7 +175,7 @@ func TestServer_WriteShardRequestSuccess(t *testing.T) {
|
|||
func TestServer_WriteShardRequestMultipleSuccess(t *testing.T) {
|
||||
var (
|
||||
ts = newTestServer(writeShardSuccess)
|
||||
s = tcp.NewServer(ts)
|
||||
s = cluster.NewServer(ts)
|
||||
)
|
||||
// Close the server
|
||||
defer s.Close()
|
||||
|
@ -173,17 +186,18 @@ func TestServer_WriteShardRequestMultipleSuccess(t *testing.T) {
|
|||
t.Fatalf("err does not match. expected %v, got %v", nil, e)
|
||||
}
|
||||
|
||||
client := tcp.NewClient()
|
||||
writer := cluster.NewWriter(&metaStore{host: host})
|
||||
|
||||
now := time.Now()
|
||||
|
||||
shardID := uint64(1)
|
||||
ownerID := uint64(2)
|
||||
var points []tsdb.Point
|
||||
points = append(points, tsdb.NewPoint(
|
||||
"cpu", tsdb.Tags{"host": "server01"}, map[string]interface{}{"value": int64(100)}, now,
|
||||
))
|
||||
|
||||
if err := client.WriteShard(host, shardID, points); err != nil {
|
||||
if err := writer.Write(shardID, ownerID, points); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -193,11 +207,11 @@ func TestServer_WriteShardRequestMultipleSuccess(t *testing.T) {
|
|||
"cpu", tsdb.Tags{"host": "server01"}, map[string]interface{}{"value": int64(100)}, now,
|
||||
))
|
||||
|
||||
if err := client.WriteShard(host, shardID, points[1:]); err != nil {
|
||||
if err := writer.Write(shardID, ownerID, points[1:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := client.Close(); err != nil {
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -237,7 +251,7 @@ func TestServer_WriteShardRequestMultipleSuccess(t *testing.T) {
|
|||
func TestServer_WriteShardRequestFail(t *testing.T) {
|
||||
var (
|
||||
ts = newTestServer(writeShardFail)
|
||||
s = tcp.NewServer(ts)
|
||||
s = cluster.NewServer(ts)
|
||||
)
|
||||
// Close the server
|
||||
defer s.Close()
|
||||
|
@ -248,16 +262,17 @@ func TestServer_WriteShardRequestFail(t *testing.T) {
|
|||
t.Fatalf("err does not match. expected %v, got %v", nil, e)
|
||||
}
|
||||
|
||||
client := tcp.NewClient()
|
||||
writer := cluster.NewWriter(&metaStore{host: host})
|
||||
now := time.Now()
|
||||
|
||||
shardID := uint64(1)
|
||||
ownerID := uint64(2)
|
||||
var points []tsdb.Point
|
||||
points = append(points, tsdb.NewPoint(
|
||||
"cpu", tsdb.Tags{"host": "server01"}, map[string]interface{}{"value": int64(100)}, now,
|
||||
))
|
||||
|
||||
if err, exp := client.WriteShard(host, shardID, points), "error code 1: failed to write"; err == nil || err.Error() != exp {
|
||||
if err, exp := writer.Write(shardID, ownerID, points), "error code 1: failed to write"; err == nil || err.Error() != exp {
|
||||
t.Fatalf("expected error %s, got %v", exp, err)
|
||||
}
|
||||
}
|
|
@ -1,78 +1,81 @@
|
|||
package tcp
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/influxdb/influxdb/cluster"
|
||||
"github.com/fatih/pool"
|
||||
"github.com/influxdb/influxdb/meta"
|
||||
"github.com/influxdb/influxdb/tsdb"
|
||||
)
|
||||
|
||||
const (
|
||||
writeShardRequestMessage byte = iota + 1
|
||||
writeShardResponseMessage
|
||||
)
|
||||
|
||||
const maxConnections = 500
|
||||
|
||||
var errMaxConnectionsExceeded = fmt.Errorf("can not exceed max connections of %d", maxConnections)
|
||||
|
||||
type clientConn struct {
|
||||
client *Client
|
||||
addr string
|
||||
type metaStore interface {
|
||||
Node(id uint64) (ni *meta.NodeInfo, err error)
|
||||
}
|
||||
|
||||
func newClientConn(addr string, c *Client) *clientConn {
|
||||
return &clientConn{
|
||||
addr: addr,
|
||||
client: c,
|
||||
type connFactory struct {
|
||||
nodeInfo *meta.NodeInfo
|
||||
clientPool interface {
|
||||
size() int
|
||||
}
|
||||
}
|
||||
func (c *clientConn) dial() (net.Conn, error) {
|
||||
if c.client.poolSize() > maxConnections {
|
||||
|
||||
func (c *connFactory) dial() (net.Conn, error) {
|
||||
if c.clientPool.size() > maxConnections {
|
||||
return nil, errMaxConnectionsExceeded
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", c.addr)
|
||||
conn, err := net.Dial("tcp", c.nodeInfo.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
pool *connectionPool
|
||||
type Writer struct {
|
||||
pool *clientPool
|
||||
metaStore metaStore
|
||||
}
|
||||
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
pool: newConnectionPool(),
|
||||
func NewWriter(m metaStore) *Writer {
|
||||
return &Writer{
|
||||
pool: newClientPool(),
|
||||
metaStore: m,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) poolSize() int {
|
||||
if c.pool == nil {
|
||||
return 0
|
||||
func (c *Writer) dial(nodeID uint64) (net.Conn, error) {
|
||||
nodeInfo, err := c.metaStore.Node(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.pool.size()
|
||||
}
|
||||
|
||||
func (c *Client) dial(addr string) (net.Conn, error) {
|
||||
addr = strings.ToLower(addr)
|
||||
// if we don't have a connection pool for that addr yet, create one
|
||||
_, ok := c.pool.getPool(addr)
|
||||
_, ok := c.pool.getPool(nodeInfo)
|
||||
if !ok {
|
||||
conn := newClientConn(addr, c)
|
||||
p, err := NewChannelPool(1, 3, conn.dial)
|
||||
factory := &connFactory{nodeInfo: nodeInfo, clientPool: c.pool}
|
||||
p, err := pool.NewChannelPool(1, 3, factory.dial)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.pool.setPool(addr, p)
|
||||
c.pool.setPool(nodeInfo, p)
|
||||
}
|
||||
return c.pool.conn(addr)
|
||||
return c.pool.conn(nodeInfo)
|
||||
}
|
||||
|
||||
func (c *Client) WriteShard(addr string, shardID uint64, points []tsdb.Point) error {
|
||||
conn, err := c.dial(addr)
|
||||
func (w *Writer) Write(shardID, ownerID uint64, points []tsdb.Point) error {
|
||||
conn, err := w.dial(ownerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -85,7 +88,7 @@ func (c *Client) WriteShard(addr string, shardID uint64, points []tsdb.Point) er
|
|||
return err
|
||||
}
|
||||
|
||||
var request cluster.WriteShardRequest
|
||||
var request WriteShardRequest
|
||||
request.SetShardID(shardID)
|
||||
request.AddPoints(points)
|
||||
|
||||
|
@ -121,7 +124,7 @@ func (c *Client) WriteShard(addr string, shardID uint64, points []tsdb.Point) er
|
|||
return err
|
||||
}
|
||||
|
||||
var response cluster.WriteShardResponse
|
||||
var response WriteShardResponse
|
||||
if err := response.UnmarshalBinary(message); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -132,11 +135,10 @@ func (c *Client) WriteShard(addr string, shardID uint64, points []tsdb.Point) er
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
if c.pool == nil {
|
||||
func (w *Writer) Close() error {
|
||||
if w.pool == nil {
|
||||
return fmt.Errorf("client already closed")
|
||||
}
|
||||
c.pool = nil
|
||||
w.pool = nil
|
||||
return nil
|
||||
}
|
131
tcp/channel.go
131
tcp/channel.go
|
@ -1,131 +0,0 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// channelPool implements the Pool interface based on buffered channels.
|
||||
type channelPool struct {
|
||||
// storage for our net.Conn connections
|
||||
mu sync.Mutex
|
||||
conns chan net.Conn
|
||||
|
||||
// net.Conn generator
|
||||
factory Factory
|
||||
}
|
||||
|
||||
// Factory is a function to create new connections.
|
||||
type Factory func() (net.Conn, error)
|
||||
|
||||
// NewChannelPool returns a new pool based on buffered channels with an initial
|
||||
// capacity and maximum capacity. Factory is used when initial capacity is
|
||||
// greater than zero to fill the pool. A zero initialCap doesn't fill the Pool
|
||||
// until a new Get() is called. During a Get(), If there is no new connection
|
||||
// available in the pool, a new connection will be created via the Factory()
|
||||
// method.
|
||||
func NewChannelPool(initialCap, maxCap int, factory Factory) (Pool, error) {
|
||||
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
|
||||
return nil, errors.New("invalid capacity settings")
|
||||
}
|
||||
|
||||
c := &channelPool{
|
||||
conns: make(chan net.Conn, maxCap),
|
||||
factory: factory,
|
||||
}
|
||||
|
||||
// create initial connections, if something goes wrong,
|
||||
// just close the pool error out.
|
||||
for i := 0; i < initialCap; i++ {
|
||||
conn, err := factory()
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, fmt.Errorf("factory is not able to fill the pool: %s", err)
|
||||
}
|
||||
c.conns <- conn
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *channelPool) getConns() chan net.Conn {
|
||||
c.mu.Lock()
|
||||
conns := c.conns
|
||||
c.mu.Unlock()
|
||||
return conns
|
||||
}
|
||||
|
||||
// Get implements the Pool interfaces Get() method. If there is no new
|
||||
// connection available in the pool, a new connection will be created via the
|
||||
// Factory() method.
|
||||
func (c *channelPool) Get() (net.Conn, error) {
|
||||
conns := c.getConns()
|
||||
if conns == nil {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
// wrap our connections with out custom net.Conn implementation (wrapConn
|
||||
// method) that puts the connection back to the pool if it's closed.
|
||||
select {
|
||||
case conn := <-conns:
|
||||
if conn == nil {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
return c.wrapConn(conn), nil
|
||||
default:
|
||||
conn, err := c.factory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.wrapConn(conn), nil
|
||||
}
|
||||
}
|
||||
|
||||
// put puts the connection back to the pool. If the pool is full or closed,
|
||||
// conn is simply closed. A nil conn will be rejected.
|
||||
func (c *channelPool) put(conn net.Conn) error {
|
||||
if conn == nil {
|
||||
return errors.New("connection is nil. rejecting")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conns == nil {
|
||||
// pool is closed, close passed connection
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
// put the resource back into the pool. If the pool is full, this will
|
||||
// block and the default case will be executed.
|
||||
select {
|
||||
case c.conns <- conn:
|
||||
return nil
|
||||
default:
|
||||
// pool is full, close passed connection
|
||||
return conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channelPool) Close() {
|
||||
c.mu.Lock()
|
||||
conns := c.conns
|
||||
c.conns = nil
|
||||
c.factory = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
if conns == nil {
|
||||
return
|
||||
}
|
||||
|
||||
close(conns)
|
||||
for conn := range conns {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channelPool) Len() int { return len(c.getConns()) }
|
|
@ -1,247 +0,0 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
InitialCap = 5
|
||||
MaximumCap = 30
|
||||
network = "tcp"
|
||||
address = "127.0.0.1:7777"
|
||||
factory = func() (net.Conn, error) { return net.Dial(network, address) }
|
||||
)
|
||||
|
||||
func init() {
|
||||
// used for factory function
|
||||
go simpleTCPServer()
|
||||
time.Sleep(time.Millisecond * 300) // wait until tcp server has been settled
|
||||
|
||||
rand.Seed(time.Now().UTC().UnixNano())
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
_, err := newChannelPool()
|
||||
if err != nil {
|
||||
t.Errorf("New error: %s", err)
|
||||
}
|
||||
}
|
||||
func TestPool_Get_Impl(t *testing.T) {
|
||||
p, _ := newChannelPool()
|
||||
defer p.Close()
|
||||
|
||||
conn, err := p.Get()
|
||||
if err != nil {
|
||||
t.Errorf("Get error: %s", err)
|
||||
}
|
||||
|
||||
_, ok := conn.(poolConn)
|
||||
if !ok {
|
||||
t.Errorf("Conn is not of type poolConn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Get(t *testing.T) {
|
||||
p, _ := newChannelPool()
|
||||
defer p.Close()
|
||||
|
||||
_, err := p.Get()
|
||||
if err != nil {
|
||||
t.Errorf("Get error: %s", err)
|
||||
}
|
||||
|
||||
// after one get, current capacity should be lowered by one.
|
||||
if p.Len() != (InitialCap - 1) {
|
||||
t.Errorf("Get error. Expecting %d, got %d",
|
||||
(InitialCap - 1), p.Len())
|
||||
}
|
||||
|
||||
// get them all
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < (InitialCap - 1); i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := p.Get()
|
||||
if err != nil {
|
||||
t.Errorf("Get error: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if p.Len() != 0 {
|
||||
t.Errorf("Get error. Expecting %d, got %d",
|
||||
(InitialCap - 1), p.Len())
|
||||
}
|
||||
|
||||
_, err = p.Get()
|
||||
if err != nil {
|
||||
t.Errorf("Get error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Put(t *testing.T) {
|
||||
p, err := NewChannelPool(0, 30, factory)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
// get/create from the pool
|
||||
conns := make([]net.Conn, MaximumCap)
|
||||
for i := 0; i < MaximumCap; i++ {
|
||||
conn, _ := p.Get()
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// now put them all back
|
||||
for _, conn := range conns {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
if p.Len() != MaximumCap {
|
||||
t.Errorf("Put error len. Expecting %d, got %d",
|
||||
1, p.Len())
|
||||
}
|
||||
|
||||
conn, _ := p.Get()
|
||||
p.Close() // close pool
|
||||
|
||||
conn.Close() // try to put into a full pool
|
||||
if p.Len() != 0 {
|
||||
t.Errorf("Put error. Closed pool shouldn't allow to put connections.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_UsedCapacity(t *testing.T) {
|
||||
p, _ := newChannelPool()
|
||||
defer p.Close()
|
||||
|
||||
if p.Len() != InitialCap {
|
||||
t.Errorf("InitialCap error. Expecting %d, got %d",
|
||||
InitialCap, p.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Close(t *testing.T) {
|
||||
p, _ := newChannelPool()
|
||||
|
||||
// now close it and test all cases we are expecting.
|
||||
p.Close()
|
||||
|
||||
c := p.(*channelPool)
|
||||
|
||||
if c.conns != nil {
|
||||
t.Errorf("Close error, conns channel should be nil")
|
||||
}
|
||||
|
||||
if c.factory != nil {
|
||||
t.Errorf("Close error, factory should be nil")
|
||||
}
|
||||
|
||||
_, err := p.Get()
|
||||
if err == nil {
|
||||
t.Errorf("Close error, get conn should return an error")
|
||||
}
|
||||
|
||||
if p.Len() != 0 {
|
||||
t.Errorf("Close error used capacity. Expecting 0, got %d", p.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolConcurrent(t *testing.T) {
|
||||
p, _ := newChannelPool()
|
||||
pipe := make(chan net.Conn, 0)
|
||||
|
||||
go func() {
|
||||
p.Close()
|
||||
}()
|
||||
|
||||
for i := 0; i < MaximumCap; i++ {
|
||||
go func() {
|
||||
conn, _ := p.Get()
|
||||
|
||||
pipe <- conn
|
||||
}()
|
||||
|
||||
go func() {
|
||||
conn := <-pipe
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWriteRead(t *testing.T) {
|
||||
p, _ := NewChannelPool(0, 30, factory)
|
||||
|
||||
conn, _ := p.Get()
|
||||
|
||||
msg := "hello"
|
||||
_, err := conn.Write([]byte(msg))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolConcurrent2(t *testing.T) {
|
||||
p, _ := NewChannelPool(0, 30, factory)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
conn, _ := p.Get()
|
||||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
|
||||
conn.Close()
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
conn, _ := p.Get()
|
||||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
|
||||
conn.Close()
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func newChannelPool() (Pool, error) {
|
||||
return NewChannelPool(InitialCap, MaximumCap, factory)
|
||||
}
|
||||
|
||||
func simpleTCPServer() {
|
||||
l, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
buffer := make([]byte, 256)
|
||||
conn.Read(buffer)
|
||||
}()
|
||||
}
|
||||
}
|
27
tcp/conn.go
27
tcp/conn.go
|
@ -1,27 +0,0 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
// poolConn is a wrapper around net.Conn to modify the the behavior of
|
||||
// net.Conn's Close() method.
|
||||
type poolConn struct {
|
||||
net.Conn
|
||||
c *channelPool
|
||||
}
|
||||
|
||||
// Close() puts the given connects back to the pool instead of closing it.
|
||||
func (p poolConn) Close() error {
|
||||
spew.Dump("I'm back on the queue!")
|
||||
return p.c.put(p.Conn)
|
||||
}
|
||||
|
||||
// newConn wraps a standard net.Conn to a poolConn net.Conn.
|
||||
func (c *channelPool) wrapConn(conn net.Conn) net.Conn {
|
||||
p := poolConn{c: c}
|
||||
p.Conn = conn
|
||||
return p
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConn_Impl(t *testing.T) {
|
||||
var _ net.Conn = new(poolConn)
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type connectionPool struct {
|
||||
mu sync.RWMutex
|
||||
pool map[string]Pool
|
||||
}
|
||||
|
||||
func newConnectionPool() *connectionPool {
|
||||
return &connectionPool{
|
||||
pool: make(map[string]Pool),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionPool) setPool(addr string, p Pool) {
|
||||
log.Println("setting pool")
|
||||
c.mu.Lock()
|
||||
c.pool[addr] = p
|
||||
c.mu.Unlock()
|
||||
log.Println("setting pool complete")
|
||||
}
|
||||
|
||||
func (c *connectionPool) getPool(addr string) (Pool, bool) {
|
||||
log.Println("getting pool")
|
||||
c.mu.Lock()
|
||||
p, ok := c.pool[addr]
|
||||
c.mu.Unlock()
|
||||
log.Println("getting pool complete")
|
||||
return p, ok
|
||||
}
|
||||
|
||||
func (c *connectionPool) size() int {
|
||||
log.Println("getting pool size")
|
||||
c.mu.RLock()
|
||||
var size int
|
||||
for _, p := range c.pool {
|
||||
size += p.Len()
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
log.Println("getting pool size complete")
|
||||
return size
|
||||
}
|
||||
|
||||
func (c *connectionPool) conn(addr string) (net.Conn, error) {
|
||||
log.Println("getting connection")
|
||||
c.mu.Lock()
|
||||
conn, err := c.pool[addr].Get()
|
||||
c.mu.Unlock()
|
||||
log.Println("getting connection complete")
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (c *connectionPool) close() error {
|
||||
log.Println("closing")
|
||||
c.mu.Lock()
|
||||
for _, p := range c.pool {
|
||||
p.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
package tcp
|
||||
|
||||
const (
|
||||
writeShardRequestMessage byte = iota + 1
|
||||
writeShardResponseMessage
|
||||
)
|
28
tcp/pool.go
28
tcp/pool.go
|
@ -1,28 +0,0 @@
|
|||
// Design is based heavily (or exactly) on the https://github.com/fatih/pool package
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrClosed is the error resulting if the pool is closed via pool.Close().
|
||||
ErrClosed = errors.New("pool is closed")
|
||||
)
|
||||
|
||||
// Pool interface describes a pool implementation. A pool should have maximum
|
||||
// capacity. An ideal pool is threadsafe and easy to use.
|
||||
type Pool interface {
|
||||
// Get returns a new connection from the pool. Closing the connections puts
|
||||
// it back to the Pool. Closing it when the pool is destroyed or full will
|
||||
// be counted as an error.
|
||||
Get() (net.Conn, error)
|
||||
|
||||
// Close closes the pool and all its connections. After Close() the pool is
|
||||
// no longer usable.
|
||||
Close()
|
||||
|
||||
// Len returns the current number of connections of the pool.
|
||||
Len() int
|
||||
}
|
Loading…
Reference in New Issue