mirror of https://github.com/milvus-io/milvus.git
parent
1d7195e036
commit
6e70ce3f66
|
@ -25,7 +25,7 @@ type dataNode struct {
|
|||
channelNum int
|
||||
}
|
||||
type dataNodeCluster struct {
|
||||
mu sync.RWMutex
|
||||
sync.RWMutex
|
||||
finishCh chan struct{}
|
||||
nodes []*dataNode
|
||||
}
|
||||
|
@ -42,8 +42,8 @@ func newDataNodeCluster(finishCh chan struct{}) *dataNodeCluster {
|
|||
}
|
||||
|
||||
func (c *dataNodeCluster) Register(dataNode *dataNode) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) {
|
||||
c.nodes = append(c.nodes, dataNode)
|
||||
if len(c.nodes) == Params.DataNodeNum {
|
||||
|
@ -62,23 +62,25 @@ func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool {
|
|||
}
|
||||
|
||||
func (c *dataNodeCluster) GetNumOfNodes() int {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
return len(c.nodes)
|
||||
}
|
||||
|
||||
func (c *dataNodeCluster) GetNodeIDs() []int64 {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
ret := make([]int64, len(c.nodes))
|
||||
for i, node := range c.nodes {
|
||||
ret[i] = node.id
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
ret := make([]int64, 0, len(c.nodes))
|
||||
for _, node := range c.nodes {
|
||||
ret = append(ret, node.id)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *dataNodeCluster) WatchInsertChannels(channels []string) {
|
||||
ctx := context.TODO()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
var groups [][]string
|
||||
if len(channels) < len(c.nodes) {
|
||||
groups = make([][]string, len(channels))
|
||||
|
@ -108,8 +110,8 @@ func (c *dataNodeCluster) WatchInsertChannels(channels []string) {
|
|||
}
|
||||
|
||||
func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.ComponentInfo, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
ret := make([]*internalpb.ComponentInfo, 0)
|
||||
for _, node := range c.nodes {
|
||||
states, err := node.client.GetComponentStates(ctx)
|
||||
|
@ -124,8 +126,8 @@ func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.
|
|||
|
||||
func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) {
|
||||
ctx := context.TODO()
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for _, node := range c.nodes {
|
||||
if _, err := node.client.FlushSegments(ctx, request); err != nil {
|
||||
log.Error("flush segment err", zap.Stringer("dataNode", node), zap.Error(err))
|
||||
|
@ -135,6 +137,8 @@ func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) {
|
|||
}
|
||||
|
||||
func (c *dataNodeCluster) ShutDownClients() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for _, node := range c.nodes {
|
||||
if err := node.client.Stop(); err != nil {
|
||||
log.Error("stop client error", zap.Stringer("dataNode", node), zap.Error(err))
|
||||
|
@ -145,8 +149,8 @@ func (c *dataNodeCluster) ShutDownClients() {
|
|||
|
||||
// Clear only for test
|
||||
func (c *dataNodeCluster) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.finishCh = make(chan struct{})
|
||||
c.nodes = make([]*dataNode, 0)
|
||||
}
|
||||
|
|
|
@ -4,8 +4,52 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestDataNodeClusterRegister(t *testing.T) {
|
||||
Params.Init()
|
||||
Params.DataNodeNum = 3
|
||||
ch := make(chan struct{})
|
||||
cluster := newDataNodeCluster(ch)
|
||||
ids := make([]int64, 0, Params.DataNodeNum)
|
||||
for i := 0; i < Params.DataNodeNum; i++ {
|
||||
c := newMockDataNodeClient(int64(i))
|
||||
err := c.Init()
|
||||
assert.Nil(t, err)
|
||||
err = c.Start()
|
||||
assert.Nil(t, err)
|
||||
cluster.Register(&dataNode{
|
||||
id: int64(i),
|
||||
address: struct {
|
||||
ip string
|
||||
port int64
|
||||
}{"localhost", int64(9999 + i)},
|
||||
client: c,
|
||||
channelNum: 0,
|
||||
})
|
||||
ids = append(ids, int64(i))
|
||||
}
|
||||
_, ok := <-ch
|
||||
assert.False(t, ok)
|
||||
assert.EqualValues(t, Params.DataNodeNum, cluster.GetNumOfNodes())
|
||||
assert.EqualValues(t, ids, cluster.GetNodeIDs())
|
||||
states, err := cluster.GetDataNodeStates(context.TODO())
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, Params.DataNodeNum, len(states))
|
||||
for _, s := range states {
|
||||
assert.EqualValues(t, internalpb.StateCode_Healthy, s.StateCode)
|
||||
}
|
||||
cluster.ShutDownClients()
|
||||
states, err = cluster.GetDataNodeStates(context.TODO())
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, Params.DataNodeNum, len(states))
|
||||
for _, s := range states {
|
||||
assert.EqualValues(t, internalpb.StateCode_Abnormal, s.StateCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchChannels(t *testing.T) {
|
||||
Params.Init()
|
||||
Params.DataNodeNum = 3
|
||||
|
@ -23,13 +67,18 @@ func TestWatchChannels(t *testing.T) {
|
|||
cluster := newDataNodeCluster(make(chan struct{}))
|
||||
for _, c := range cases {
|
||||
for i := 0; i < Params.DataNodeNum; i++ {
|
||||
c := newMockDataNodeClient(int64(i))
|
||||
err := c.Init()
|
||||
assert.Nil(t, err)
|
||||
err = c.Start()
|
||||
assert.Nil(t, err)
|
||||
cluster.Register(&dataNode{
|
||||
id: int64(i),
|
||||
address: struct {
|
||||
ip string
|
||||
port int64
|
||||
}{"localhost", int64(9999 + i)},
|
||||
client: newMockDataNodeClient(),
|
||||
client: c,
|
||||
channelNum: 0,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -53,6 +53,15 @@ func newTestSchema() *schemapb.CollectionSchema {
|
|||
}
|
||||
|
||||
type mockDataNodeClient struct {
|
||||
id int64
|
||||
state internalpb.StateCode
|
||||
}
|
||||
|
||||
func newMockDataNodeClient(id int64) *mockDataNodeClient {
|
||||
return &mockDataNodeClient{
|
||||
id: id,
|
||||
state: internalpb.StateCode_Initializing,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockDataNodeClient) Init() error {
|
||||
|
@ -60,22 +69,23 @@ func (c *mockDataNodeClient) Init() error {
|
|||
}
|
||||
|
||||
func (c *mockDataNodeClient) Start() error {
|
||||
c.state = internalpb.StateCode_Healthy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockDataNodeClient) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
//TODO
|
||||
return nil, nil
|
||||
return &internalpb.ComponentStates{
|
||||
State: &internalpb.ComponentInfo{
|
||||
NodeID: c.id,
|
||||
StateCode: c.state,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *mockDataNodeClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func newMockDataNodeClient() *mockDataNodeClient {
|
||||
return &mockDataNodeClient{}
|
||||
}
|
||||
|
||||
func (c *mockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
|
||||
}
|
||||
|
@ -85,5 +95,6 @@ func (c *mockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.Flush
|
|||
}
|
||||
|
||||
func (c *mockDataNodeClient) Stop() error {
|
||||
c.state = internalpb.StateCode_Abnormal
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue