Add cluster unit tests

Signed-off-by: sunby <bingyi.sun@zilliz.com>
pull/4973/head^2
sunby 2021-04-08 15:08:34 +08:00 committed by yefu.chen
parent 1d7195e036
commit 6e70ce3f66
3 changed files with 87 additions and 23 deletions

View File

@ -25,7 +25,7 @@ type dataNode struct {
channelNum int
}
type dataNodeCluster struct {
mu sync.RWMutex
sync.RWMutex
finishCh chan struct{}
nodes []*dataNode
}
@ -42,8 +42,8 @@ func newDataNodeCluster(finishCh chan struct{}) *dataNodeCluster {
}
func (c *dataNodeCluster) Register(dataNode *dataNode) {
c.mu.Lock()
defer c.mu.Unlock()
c.Lock()
defer c.Unlock()
if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) {
c.nodes = append(c.nodes, dataNode)
if len(c.nodes) == Params.DataNodeNum {
@ -62,23 +62,25 @@ func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool {
}
func (c *dataNodeCluster) GetNumOfNodes() int {
c.RLock()
defer c.RUnlock()
return len(c.nodes)
}
func (c *dataNodeCluster) GetNodeIDs() []int64 {
c.mu.RLock()
defer c.mu.RUnlock()
ret := make([]int64, len(c.nodes))
for i, node := range c.nodes {
ret[i] = node.id
c.RLock()
defer c.RUnlock()
ret := make([]int64, 0, len(c.nodes))
for _, node := range c.nodes {
ret = append(ret, node.id)
}
return ret
}
func (c *dataNodeCluster) WatchInsertChannels(channels []string) {
ctx := context.TODO()
c.mu.Lock()
defer c.mu.Unlock()
c.Lock()
defer c.Unlock()
var groups [][]string
if len(channels) < len(c.nodes) {
groups = make([][]string, len(channels))
@ -108,8 +110,8 @@ func (c *dataNodeCluster) WatchInsertChannels(channels []string) {
}
func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.ComponentInfo, error) {
c.mu.RLock()
defer c.mu.RUnlock()
c.RLock()
defer c.RUnlock()
ret := make([]*internalpb.ComponentInfo, 0)
for _, node := range c.nodes {
states, err := node.client.GetComponentStates(ctx)
@ -124,8 +126,8 @@ func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.
func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) {
ctx := context.TODO()
c.mu.RLock()
defer c.mu.RUnlock()
c.Lock()
defer c.Unlock()
for _, node := range c.nodes {
if _, err := node.client.FlushSegments(ctx, request); err != nil {
log.Error("flush segment err", zap.Stringer("dataNode", node), zap.Error(err))
@ -135,6 +137,8 @@ func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) {
}
func (c *dataNodeCluster) ShutDownClients() {
c.Lock()
defer c.Unlock()
for _, node := range c.nodes {
if err := node.client.Stop(); err != nil {
log.Error("stop client error", zap.Stringer("dataNode", node), zap.Error(err))
@ -145,8 +149,8 @@ func (c *dataNodeCluster) ShutDownClients() {
// Clear only for test
func (c *dataNodeCluster) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.Lock()
defer c.Unlock()
c.finishCh = make(chan struct{})
c.nodes = make([]*dataNode, 0)
}

View File

@ -4,8 +4,52 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"golang.org/x/net/context"
)
func TestDataNodeClusterRegister(t *testing.T) {
Params.Init()
Params.DataNodeNum = 3
ch := make(chan struct{})
cluster := newDataNodeCluster(ch)
ids := make([]int64, 0, Params.DataNodeNum)
for i := 0; i < Params.DataNodeNum; i++ {
c := newMockDataNodeClient(int64(i))
err := c.Init()
assert.Nil(t, err)
err = c.Start()
assert.Nil(t, err)
cluster.Register(&dataNode{
id: int64(i),
address: struct {
ip string
port int64
}{"localhost", int64(9999 + i)},
client: c,
channelNum: 0,
})
ids = append(ids, int64(i))
}
_, ok := <-ch
assert.False(t, ok)
assert.EqualValues(t, Params.DataNodeNum, cluster.GetNumOfNodes())
assert.EqualValues(t, ids, cluster.GetNodeIDs())
states, err := cluster.GetDataNodeStates(context.TODO())
assert.Nil(t, err)
assert.EqualValues(t, Params.DataNodeNum, len(states))
for _, s := range states {
assert.EqualValues(t, internalpb.StateCode_Healthy, s.StateCode)
}
cluster.ShutDownClients()
states, err = cluster.GetDataNodeStates(context.TODO())
assert.Nil(t, err)
assert.EqualValues(t, Params.DataNodeNum, len(states))
for _, s := range states {
assert.EqualValues(t, internalpb.StateCode_Abnormal, s.StateCode)
}
}
func TestWatchChannels(t *testing.T) {
Params.Init()
Params.DataNodeNum = 3
@ -23,13 +67,18 @@ func TestWatchChannels(t *testing.T) {
cluster := newDataNodeCluster(make(chan struct{}))
for _, c := range cases {
for i := 0; i < Params.DataNodeNum; i++ {
c := newMockDataNodeClient(int64(i))
err := c.Init()
assert.Nil(t, err)
err = c.Start()
assert.Nil(t, err)
cluster.Register(&dataNode{
id: int64(i),
address: struct {
ip string
port int64
}{"localhost", int64(9999 + i)},
client: newMockDataNodeClient(),
client: c,
channelNum: 0,
})
}

View File

@ -53,6 +53,15 @@ func newTestSchema() *schemapb.CollectionSchema {
}
type mockDataNodeClient struct {
id int64
state internalpb.StateCode
}
func newMockDataNodeClient(id int64) *mockDataNodeClient {
return &mockDataNodeClient{
id: id,
state: internalpb.StateCode_Initializing,
}
}
func (c *mockDataNodeClient) Init() error {
@ -60,22 +69,23 @@ func (c *mockDataNodeClient) Init() error {
}
func (c *mockDataNodeClient) Start() error {
c.state = internalpb.StateCode_Healthy
return nil
}
func (c *mockDataNodeClient) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
//TODO
return nil, nil
return &internalpb.ComponentStates{
State: &internalpb.ComponentInfo{
NodeID: c.id,
StateCode: c.state,
},
}, nil
}
func (c *mockDataNodeClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil
}
func newMockDataNodeClient() *mockDataNodeClient {
return &mockDataNodeClient{}
}
func (c *mockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) {
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}
@ -85,5 +95,6 @@ func (c *mockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.Flush
}
func (c *mockDataNodeClient) Stop() error {
c.state = internalpb.StateCode_Abnormal
return nil
}