// Copyright (C) 2019-2020 Zilliz. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software distributed under the License // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. package datacoord import ( "context" "errors" "fmt" "testing" "github.com/milvus-io/milvus/internal/kv" memkv "github.com/milvus-io/milvus/internal/kv/mem" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/stretchr/testify/assert" ) type SpyClusterStore struct { *NodesInfo ch chan interface{} } func (s *SpyClusterStore) SetNode(nodeID UniqueID, node *NodeInfo) { s.NodesInfo.SetNode(nodeID, node) s.ch <- struct{}{} } func (s *SpyClusterStore) DeleteNode(nodeID UniqueID) { s.NodesInfo.DeleteNode(nodeID) s.ch <- struct{}{} } func spyWatchPolicy(ch chan interface{}) channelAssignPolicy { return func(cluster []*NodeInfo, channel string, collectionID UniqueID) []*NodeInfo { for _, node := range cluster { for _, c := range node.Info.GetChannels() { if c.GetName() == channel && c.GetCollectionID() == collectionID { ch <- struct{}{} return nil } } } ret := make([]*NodeInfo, 0) c := &datapb.ChannelStatus{ Name: channel, State: datapb.ChannelWatchState_Uncomplete, CollectionID: collectionID, } n := cluster[0].Clone(AddChannels([]*datapb.ChannelStatus{c})) ret = append(ret, n) return ret } } // a mock kv that always fail when LoadWithPrefix type loadPrefixFailKv struct { kv.TxnKV } // LoadWithPrefix override behavior func (kv *loadPrefixFailKv) LoadWithPrefix(key string) ([]string, []string, error) { return []string{}, []string{}, errors.New("mocked fail") } func TestClusterCreate(t *testing.T) { ch := make(chan interface{}) memKv := memkv.NewMemoryKV() spyClusterStore := &SpyClusterStore{ NodesInfo: NewNodesInfo(), ch: ch, } cluster, err := NewCluster(context.TODO(), memKv, spyClusterStore, dummyPosProvider{}) assert.Nil(t, err) defer cluster.Close() addr := "localhost:8080" info := &datapb.DataNodeInfo{ Address: addr, Version: 1, Channels: []*datapb.ChannelStatus{}, } nodes := []*NodeInfo{NewNodeInfo(context.TODO(), info)} cluster.Startup(nodes) <-ch dataNodes := cluster.GetNodes() assert.EqualValues(t, 1, len(dataNodes)) assert.EqualValues(t, "localhost:8080", dataNodes[0].Info.GetAddress()) t.Run("loadKv Fails", func(t *testing.T) { fkv := &loadPrefixFailKv{TxnKV: memKv} cluster, err := NewCluster(context.TODO(), fkv, spyClusterStore, dummyPosProvider{}) assert.NotNil(t, err) assert.Nil(t, cluster) }) } func TestRegister(t *testing.T) { registerPolicy := newEmptyRegisterPolicy() ch := make(chan interface{}) kv := memkv.NewMemoryKV() spyClusterStore := &SpyClusterStore{ NodesInfo: NewNodesInfo(), ch: ch, } cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{}, withRegisterPolicy(registerPolicy)) assert.Nil(t, err) defer cluster.Close() addr := "localhost:8080" cluster.Startup(nil) info := &datapb.DataNodeInfo{ Address: addr, Version: 1, Channels: []*datapb.ChannelStatus{}, } node := NewNodeInfo(context.TODO(), info) cluster.Register(node) <-ch dataNodes := cluster.GetNodes() assert.EqualValues(t, 1, len(dataNodes)) assert.EqualValues(t, "localhost:8080", dataNodes[0].Info.GetAddress()) } func TestUnregister(t *testing.T) { t.Run("remove node after unregister", func(t *testing.T) { unregisterPolicy := newEmptyUnregisterPolicy() ch := make(chan interface{}) kv := memkv.NewMemoryKV() spyClusterStore := &SpyClusterStore{ NodesInfo: NewNodesInfo(), ch: ch, } cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{}, withUnregistorPolicy(unregisterPolicy)) assert.Nil(t, err) defer cluster.Close() addr := "localhost:8080" info := &datapb.DataNodeInfo{ Address: addr, Version: 1, Channels: []*datapb.ChannelStatus{}, } nodes := []*NodeInfo{NewNodeInfo(context.TODO(), info)} cluster.Startup(nodes) <-ch dataNodes := cluster.GetNodes() assert.EqualValues(t, 1, len(dataNodes)) assert.EqualValues(t, "localhost:8080", dataNodes[0].Info.GetAddress()) cluster.UnRegister(nodes[0]) <-ch dataNodes = cluster.GetNodes() assert.EqualValues(t, 0, len(dataNodes)) }) t.Run("move channels to online nodes after unregister", func(t *testing.T) { ch := make(chan interface{}) kv := memkv.NewMemoryKV() spyClusterStore := &SpyClusterStore{ NodesInfo: NewNodesInfo(), ch: ch, } cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{}) assert.Nil(t, err) defer cluster.Close() ch1 := &datapb.ChannelStatus{ Name: "ch_1", State: datapb.ChannelWatchState_Uncomplete, CollectionID: 100, } nodeInfo1 := &datapb.DataNodeInfo{ Address: "localhost:8080", Version: 1, Channels: []*datapb.ChannelStatus{ch1}, } nodeInfo2 := &datapb.DataNodeInfo{ Address: "localhost:8081", Version: 2, Channels: []*datapb.ChannelStatus{}, } node1 := NewNodeInfo(context.TODO(), nodeInfo1) node2 := NewNodeInfo(context.TODO(), nodeInfo2) cli1, err := newMockDataNodeClient(1, make(chan interface{})) assert.Nil(t, err) cli2, err := newMockDataNodeClient(2, make(chan interface{})) assert.Nil(t, err) node1.client = cli1 node2.client = cli2 nodes := []*NodeInfo{node1, node2} cluster.Startup(nodes) <-ch <-ch dataNodes := cluster.GetNodes() assert.EqualValues(t, 2, len(dataNodes)) for _, node := range dataNodes { if node.Info.GetVersion() == 1 { cluster.UnRegister(node) <-ch <-ch break } } dataNodes = cluster.GetNodes() assert.EqualValues(t, 1, len(dataNodes)) assert.EqualValues(t, 2, dataNodes[0].Info.GetVersion()) assert.EqualValues(t, ch1.Name, dataNodes[0].Info.GetChannels()[0].Name) }) t.Run("remove all channels after unregsiter", func(t *testing.T) { ch := make(chan interface{}, 10) kv := memkv.NewMemoryKV() spyClusterStore := &SpyClusterStore{ NodesInfo: NewNodesInfo(), ch: ch, } cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{}) assert.Nil(t, err) defer cluster.Close() chstatus := &datapb.ChannelStatus{ Name: "ch_1", State: datapb.ChannelWatchState_Uncomplete, CollectionID: 100, } nodeInfo := &datapb.DataNodeInfo{ Address: "localhost:8080", Version: 1, Channels: []*datapb.ChannelStatus{chstatus}, } node := NewNodeInfo(context.TODO(), nodeInfo) cli, err := newMockDataNodeClient(1, make(chan interface{})) assert.Nil(t, err) node.client = cli cluster.Startup([]*NodeInfo{node}) <-ch cluster.UnRegister(node) <-ch spyClusterStore2 := &SpyClusterStore{ NodesInfo: NewNodesInfo(), ch: ch, } cluster2, err := NewCluster(context.TODO(), kv, spyClusterStore2, dummyPosProvider{}) <-ch assert.Nil(t, err) nodes := cluster2.GetNodes() assert.EqualValues(t, 1, len(nodes)) assert.EqualValues(t, 1, nodes[0].Info.GetVersion()) assert.EqualValues(t, 0, len(nodes[0].Info.GetChannels())) }) } func TestWatchIfNeeded(t *testing.T) { ch := make(chan interface{}) kv := memkv.NewMemoryKV() spyClusterStore := &SpyClusterStore{ NodesInfo: NewNodesInfo(), ch: ch, } pch := make(chan interface{}) spyPolicy := spyWatchPolicy(pch) cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{}, withAssignPolicy(spyPolicy)) assert.Nil(t, err) defer cluster.Close() addr := "localhost:8080" info := &datapb.DataNodeInfo{ Address: addr, Version: 1, Channels: []*datapb.ChannelStatus{}, } node := NewNodeInfo(context.TODO(), info) node.client, err = newMockDataNodeClient(1, make(chan interface{})) assert.Nil(t, err) nodes := []*NodeInfo{node} cluster.Startup(nodes) fmt.Println("11111") <-ch chName := "ch1" cluster.Watch(chName, 0) fmt.Println("222") <-ch dataNodes := cluster.GetNodes() assert.EqualValues(t, 1, len(dataNodes[0].Info.GetChannels())) assert.EqualValues(t, chName, dataNodes[0].Info.Channels[0].Name) cluster.Watch(chName, 0) <-pch }