diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index ad2762a20f..8efcd714a5 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -320,11 +320,11 @@ func (node *DataNode) Start() error { } connectEtcdFn := func() error { - etcdClient, err := clientv3.New(clientv3.Config{Endpoints: Params.EtcdEndpoints}) + etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) if err != nil { return err } - node.kvClient = etcdkv.NewEtcdKV(etcdClient, Params.MetaRootPath) + node.kvClient = etcdKV return nil } err = retry.Do(node.ctx, connectEtcdFn, retry.Attempts(ConnectEtcdMaxRetryTime)) diff --git a/internal/datanode/data_node_test.go b/internal/datanode/data_node_test.go index c9917274b3..9d5e99ad9c 100644 --- a/internal/datanode/data_node_test.go +++ b/internal/datanode/data_node_test.go @@ -24,7 +24,6 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.etcd.io/etcd/clientv3" "go.uber.org/zap" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" @@ -333,59 +332,56 @@ func TestWatchChannel(t *testing.T) { defer cancel() t.Run("test watch channel", func(t *testing.T) { - client, err := clientv3.New(clientv3.Config{Endpoints: Params.EtcdEndpoints}) - assert.Nil(t, err) - if assert.NotNil(t, client) { - kv := etcdkv.NewEtcdKV(client, Params.MetaRootPath) - ch := fmt.Sprintf("datanode-etcd-test-channel_%d", rand.Int31()) - path := fmt.Sprintf("channel/%d/%s", node.NodeID, ch) - c := make(chan struct{}) - go func() { - ec := kv.WatchWithPrefix(fmt.Sprintf("channel/%d", node.NodeID)) - cnt := 0 - for { - evt := <-ec - for _, event := range evt.Events { - if strings.Contains(string(event.Kv.Key), ch) { - cnt++ - } - } - if cnt >= 2 { - break + kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) + require.NoError(t, err) + ch := fmt.Sprintf("datanode-etcd-test-channel_%d", rand.Int31()) + path := fmt.Sprintf("channel/%d/%s", node.NodeID, ch) + c := make(chan struct{}) + go func() { + ec := kv.WatchWithPrefix(fmt.Sprintf("channel/%d", node.NodeID)) + cnt := 0 + for { + evt := <-ec + for _, event := range evt.Events { + if strings.Contains(string(event.Kv.Key), ch) { + cnt++ } } - c <- struct{}{} - }() - - vchan := &datapb.VchannelInfo{ - CollectionID: 1, - ChannelName: ch, - UnflushedSegments: []*datapb.SegmentInfo{}, + if cnt >= 2 { + break + } } - info := &datapb.ChannelWatchInfo{ - State: datapb.ChannelWatchState_Uncomplete, - Vchan: vchan, - } - val, err := proto.Marshal(info) - assert.Nil(t, err) - err = kv.Save(path, string(val)) - assert.Nil(t, err) + c <- struct{}{} + }() - <-c - node.chanMut.RLock() - _, has := node.vchan2SyncService[ch] - node.chanMut.RUnlock() - assert.True(t, has) - - kv.RemoveWithPrefix(fmt.Sprintf("channel/%d", node.NodeID)) - //TODO there is not way to sync Release done, use sleep for now - time.Sleep(100 * time.Millisecond) - - node.chanMut.RLock() - _, has = node.vchan2SyncService[ch] - node.chanMut.RUnlock() - assert.False(t, has) + vchan := &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: ch, + UnflushedSegments: []*datapb.SegmentInfo{}, } + info := &datapb.ChannelWatchInfo{ + State: datapb.ChannelWatchState_Uncomplete, + Vchan: vchan, + } + val, err := proto.Marshal(info) + assert.Nil(t, err) + err = kv.Save(path, string(val)) + assert.Nil(t, err) + + <-c + node.chanMut.RLock() + _, has := node.vchan2SyncService[ch] + node.chanMut.RUnlock() + assert.True(t, has) + + kv.RemoveWithPrefix(fmt.Sprintf("channel/%d", node.NodeID)) + //TODO there is not way to sync Release done, use sleep for now + time.Sleep(100 * time.Millisecond) + + node.chanMut.RLock() + _, has = node.vchan2SyncService[ch] + node.chanMut.RUnlock() + assert.False(t, has) }) }