Integration test framework (#21283)

Signed-off-by: wayblink <anyang.wang@zilliz.com>
pull/21679/head
wayblink 2023-01-12 19:49:40 +08:00 committed by GitHub
parent 76c0292bca
commit 6a722396bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 3102 additions and 527 deletions

View File

@ -121,7 +121,7 @@ func TestServer_CreateIndex(t *testing.T) {
Value: "DISKANN",
},
}
s.indexNodeManager = NewNodeManager(ctx)
s.indexNodeManager = NewNodeManager(ctx, defaultIndexNodeCreatorFunc)
resp, err := s.CreateIndex(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode())

View File

@ -24,7 +24,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
grpcindexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/metrics"
@ -34,19 +33,21 @@ import (
// IndexNodeManager is used to manage the client of IndexNode.
type IndexNodeManager struct {
nodeClients map[UniqueID]types.IndexNode
stoppingNodes map[UniqueID]struct{}
lock sync.RWMutex
ctx context.Context
nodeClients map[UniqueID]types.IndexNode
stoppingNodes map[UniqueID]struct{}
lock sync.RWMutex
ctx context.Context
indexNodeCreator indexNodeCreatorFunc
}
// NewNodeManager is used to create a new IndexNodeManager.
func NewNodeManager(ctx context.Context) *IndexNodeManager {
func NewNodeManager(ctx context.Context, indexNodeCreator indexNodeCreatorFunc) *IndexNodeManager {
return &IndexNodeManager{
nodeClients: make(map[UniqueID]types.IndexNode),
stoppingNodes: make(map[UniqueID]struct{}),
lock: sync.RWMutex{},
ctx: ctx,
nodeClients: make(map[UniqueID]types.IndexNode),
stoppingNodes: make(map[UniqueID]struct{}),
lock: sync.RWMutex{},
ctx: ctx,
indexNodeCreator: indexNodeCreator,
}
}
@ -84,7 +85,7 @@ func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error {
err error
)
nodeClient, err = grpcindexnodeclient.NewClient(context.TODO(), address, Params.DataCoordCfg.WithCredential.GetAsBool())
nodeClient, err = nm.indexNodeCreator(context.TODO(), address)
if err != nil {
log.Error("create IndexNode client fail", zap.Error(err))
return err

View File

@ -31,7 +31,7 @@ import (
)
func TestIndexNodeManager_AddNode(t *testing.T) {
nm := NewNodeManager(context.Background())
nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc)
nodeID, client := nm.PeekClient(&model.SegmentIndex{})
assert.Equal(t, int64(-1), nodeID)
assert.Nil(t, client)
@ -255,7 +255,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
}
func TestNodeManager_StoppingNode(t *testing.T) {
nm := NewNodeManager(context.Background())
nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc)
err := nm.AddNode(1, "indexnode-1")
assert.NoError(t, err)
assert.Equal(t, 1, len(nm.GetAllClients()))

View File

@ -229,6 +229,18 @@ func newMockDataNodeClient(id int64, ch chan interface{}) (*mockDataNodeClient,
}, nil
}
type mockIndexNodeClient struct {
id int64
state commonpb.StateCode
}
func newMockIndexNodeClient(id int64) (*mockIndexNodeClient, error) {
return &mockIndexNodeClient{
id: id,
state: commonpb.StateCode_Initializing,
}, nil
}
func (c *mockDataNodeClient) Init() error {
return nil
}
@ -417,7 +429,7 @@ func (m *mockRootCoordService) GetStatisticsChannel(ctx context.Context) (*milvu
panic("not implemented") // TODO: Implement
}
//DDL request
// DDL request
func (m *mockRootCoordService) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
@ -489,7 +501,7 @@ func (m *mockRootCoordService) ShowPartitionsInternal(ctx context.Context, req *
return m.ShowPartitions(ctx, req)
}
//global timestamp allocator
// global timestamp allocator
func (m *mockRootCoordService) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) {
if m.state != commonpb.StateCode_Healthy {
return &rootcoordpb.AllocTimestampResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil
@ -523,7 +535,7 @@ func (m *mockRootCoordService) AllocID(ctx context.Context, req *rootcoordpb.All
}, nil
}
//segment
// segment
func (m *mockRootCoordService) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) {
panic("not implemented") // TODO: Implement
}

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
datanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client"
indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client"
rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
@ -79,6 +80,7 @@ type (
)
type dataNodeCreatorFunc func(ctx context.Context, addr string) (types.DataNode, error)
type indexNodeCreatorFunc func(ctx context.Context, addr string) (types.IndexNode, error)
type rootCoordCreatorFunc func(ctx context.Context, metaRootPath string, etcdClient *clientv3.Client) (types.RootCoord, error)
// makes sure Server implements `DataCoord`
@ -131,6 +133,7 @@ type Server struct {
activateFunc func()
dataNodeCreator dataNodeCreatorFunc
indexNodeCreator indexNodeCreatorFunc
rootCoordClientCreator rootCoordCreatorFunc
//indexCoord types.IndexCoord
@ -153,36 +156,36 @@ func defaultServerHelper() ServerHelper {
// Option utility function signature to set DataCoord server attributes
type Option func(svr *Server)
// SetRootCoordCreator returns an `Option` setting RootCoord creator with provided parameter
func SetRootCoordCreator(creator rootCoordCreatorFunc) Option {
// WithRootCoordCreator returns an `Option` setting RootCoord creator with provided parameter
func WithRootCoordCreator(creator rootCoordCreatorFunc) Option {
return func(svr *Server) {
svr.rootCoordClientCreator = creator
}
}
// SetServerHelper returns an `Option` setting ServerHelp with provided parameter
func SetServerHelper(helper ServerHelper) Option {
// WithServerHelper returns an `Option` setting ServerHelp with provided parameter
func WithServerHelper(helper ServerHelper) Option {
return func(svr *Server) {
svr.helper = helper
}
}
// SetCluster returns an `Option` setting Cluster with provided parameter
func SetCluster(cluster *Cluster) Option {
// WithCluster returns an `Option` setting Cluster with provided parameter
func WithCluster(cluster *Cluster) Option {
return func(svr *Server) {
svr.cluster = cluster
}
}
// SetDataNodeCreator returns an `Option` setting DataNode create function
func SetDataNodeCreator(creator dataNodeCreatorFunc) Option {
// WithDataNodeCreator returns an `Option` setting DataNode create function
func WithDataNodeCreator(creator dataNodeCreatorFunc) Option {
return func(svr *Server) {
svr.dataNodeCreator = creator
}
}
// SetSegmentManager returns an Option to set SegmentManager
func SetSegmentManager(manager Manager) Option {
// WithSegmentManager returns an Option to set SegmentManager
func WithSegmentManager(manager Manager) Option {
return func(svr *Server) {
svr.segmentManager = manager
}
@ -199,6 +202,7 @@ func CreateServer(ctx context.Context, factory dependency.Factory, opts ...Optio
buildIndexCh: make(chan UniqueID, 1024),
notifyIndexChan: make(chan UniqueID),
dataNodeCreator: defaultDataNodeCreatorFunc,
indexNodeCreator: defaultIndexNodeCreatorFunc,
rootCoordClientCreator: defaultRootCoordCreatorFunc,
helper: defaultServerHelper(),
metricsCacheManager: metricsinfo.NewMetricsCacheManager(),
@ -215,6 +219,10 @@ func defaultDataNodeCreatorFunc(ctx context.Context, addr string) (types.DataNod
return datanodeclient.NewClient(ctx, addr)
}
func defaultIndexNodeCreatorFunc(ctx context.Context, addr string) (types.IndexNode, error) {
return indexnodeclient.NewClient(context.TODO(), addr, Params.DataCoordCfg.WithCredential.GetAsBool())
}
func defaultRootCoordCreatorFunc(ctx context.Context, metaRootPath string, client *clientv3.Client) (types.RootCoord, error) {
return rootcoordclient.NewClient(ctx, metaRootPath, client)
}
@ -374,6 +382,18 @@ func (s *Server) SetEtcdClient(client *clientv3.Client) {
s.etcdCli = client
}
func (s *Server) SetRootCoord(rootCoord types.RootCoord) {
s.rootCoordClient = rootCoord
}
func (s *Server) SetDataNodeCreator(f func(context.Context, string) (types.DataNode, error)) {
s.dataNodeCreator = f
}
func (s *Server) SetIndexNodeCreator(f func(context.Context, string) (types.IndexNode, error)) {
s.indexNodeCreator = f
}
func (s *Server) createCompactionHandler() {
s.compactionHandler = newCompactionPlanHandler(s.sessionManager, s.channelManager, s.meta, s.allocator, s.flushCh)
}
@ -465,6 +485,9 @@ func (s *Server) initSegmentManager() {
}
func (s *Server) initMeta(chunkManager storage.ChunkManager) error {
if s.meta != nil {
return nil
}
etcdKV := etcdkv.NewEtcdKV(s.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue())
s.kvClient = etcdKV
@ -488,7 +511,7 @@ func (s *Server) initIndexBuilder(manager storage.ChunkManager) {
func (s *Server) initIndexNodeManager() {
if s.indexNodeManager == nil {
s.indexNodeManager = NewNodeManager(s.ctx)
s.indexNodeManager = NewNodeManager(s.ctx, s.indexNodeCreator)
}
}
@ -834,8 +857,10 @@ func (s *Server) handleFlushingSegments(ctx context.Context) {
func (s *Server) initRootCoordClient() error {
var err error
if s.rootCoordClient, err = s.rootCoordClientCreator(s.ctx, Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli); err != nil {
return err
if s.rootCoordClient == nil {
if s.rootCoordClient, err = s.rootCoordClientCreator(s.ctx, Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli); err != nil {
return err
}
}
if err = s.rootCoordClient.Init(); err != nil {
return err

View File

@ -40,6 +40,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mq/msgstream"
@ -1302,7 +1303,7 @@ func TestSaveBinlogPaths(t *testing.T) {
/*
t.Run("test save dropped segment and remove channel", func(t *testing.T) {
spyCh := make(chan struct{}, 1)
svr := newTestServer(t, nil, SetSegmentManager(&spySegmentManager{spyCh: spyCh}))
svr := newTestServer(t, nil, WithSegmentManager(&spySegmentManager{spyCh: spyCh}))
defer closeTestServer(t, svr)
svr.meta.AddCollection(&collectionInfo{ID: 1})
@ -1333,7 +1334,7 @@ func TestSaveBinlogPaths(t *testing.T) {
func TestDropVirtualChannel(t *testing.T) {
t.Run("normal DropVirtualChannel", func(t *testing.T) {
spyCh := make(chan struct{}, 1)
svr := newTestServer(t, nil, SetSegmentManager(&spySegmentManager{spyCh: spyCh}))
svr := newTestServer(t, nil, WithSegmentManager(&spySegmentManager{spyCh: spyCh}))
defer closeTestServer(t, svr)
@ -1668,7 +1669,7 @@ func TestDataNodeTtChannel(t *testing.T) {
helper := ServerHelper{
eventAfterHandleDataNodeTt: func() { ch <- struct{}{} },
}
svr := newTestServer(t, nil, SetServerHelper(helper))
svr := newTestServer(t, nil, WithServerHelper(helper))
defer closeTestServer(t, svr)
svr.meta.AddCollection(&collectionInfo{
@ -2835,13 +2836,13 @@ func TestOptions(t *testing.T) {
kv.Close()
}()
t.Run("SetRootCoordCreator", func(t *testing.T) {
t.Run("WithRootCoordCreator", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
var crt rootCoordCreatorFunc = func(ctx context.Context, metaRoot string, etcdClient *clientv3.Client) (types.RootCoord, error) {
return nil, errors.New("dummy")
}
opt := SetRootCoordCreator(crt)
opt := WithRootCoordCreator(crt)
assert.NotNil(t, opt)
svr.rootCoordClientCreator = nil
opt(svr)
@ -2850,7 +2851,7 @@ func TestOptions(t *testing.T) {
assert.NotNil(t, crt)
assert.NotNil(t, svr.rootCoordClientCreator)
})
t.Run("SetCluster", func(t *testing.T) {
t.Run("WithCluster", func(t *testing.T) {
defer kv.RemoveWithPrefix("")
sessionManager := NewSessionManager()
@ -2859,17 +2860,17 @@ func TestOptions(t *testing.T) {
cluster := NewCluster(sessionManager, channelManager)
assert.Nil(t, err)
opt := SetCluster(cluster)
opt := WithCluster(cluster)
assert.NotNil(t, opt)
svr := newTestServer(t, nil, opt)
defer closeTestServer(t, svr)
assert.Same(t, cluster, svr.cluster)
})
t.Run("SetDataNodeCreator", func(t *testing.T) {
t.Run("WithDataNodeCreator", func(t *testing.T) {
var target int64
var val = rand.Int63()
opt := SetDataNodeCreator(func(context.Context, string) (types.DataNode, error) {
opt := WithDataNodeCreator(func(context.Context, string) (types.DataNode, error) {
target = val
return nil, nil
})
@ -2918,7 +2919,7 @@ func TestHandleSessionEvent(t *testing.T) {
assert.Nil(t, err)
defer cluster.Close()
svr := newTestServer(t, nil, SetCluster(cluster))
svr := newTestServer(t, nil, WithCluster(cluster))
defer closeTestServer(t, svr)
t.Run("handle events", func(t *testing.T) {
// None event
@ -3779,6 +3780,7 @@ func testDataCoordBase(t *testing.T, opts ...Option) *Server {
paramtable.Get().Save(Params.CommonCfg.DataCoordTimeTick.Key, Params.CommonCfg.DataCoordTimeTick.GetValue()+strconv.Itoa(rand.Int()))
factory := dependency.NewDefaultFactory(true)
ctx := context.Background()
etcdCli, err := etcd.GetEtcdClient(
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
@ -3789,17 +3791,18 @@ func testDataCoordBase(t *testing.T, opts ...Option) *Server {
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
assert.Nil(t, err)
sessKey := path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot)
_, err = etcdCli.Delete(context.Background(), sessKey, clientv3.WithPrefix())
_, err = etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix())
assert.Nil(t, err)
svr := CreateServer(context.TODO(), factory, opts...)
svr := CreateServer(ctx, factory, opts...)
svr.SetEtcdClient(etcdCli)
svr.dataNodeCreator = func(ctx context.Context, addr string) (types.DataNode, error) {
svr.SetDataNodeCreator(func(ctx context.Context, addr string) (types.DataNode, error) {
return newMockDataNodeClient(0, nil)
}
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
return newMockRootCoordService(), nil
}
})
svr.SetIndexNodeCreator(func(ctx context.Context, addr string) (types.IndexNode, error) {
return indexnode.NewMockIndexNodeComponent(ctx)
})
svr.SetRootCoord(newMockRootCoordService())
err = svr.Init()
assert.Nil(t, err)

View File

@ -117,6 +117,9 @@ type DataNode struct {
rootCoord types.RootCoord
dataCoord types.DataCoord
//call once
initOnce sync.Once
sessionMu sync.Mutex // to fix data race
session *sessionutil.Session
watchKv kv.MetaKv
chunkManager storage.ChunkManager
@ -153,6 +156,10 @@ func (node *DataNode) SetAddress(address string) {
node.address = address
}
func (node *DataNode) GetAddress() string {
return node.address
}
// SetEtcdClient sets etcd client for DataNode
func (node *DataNode) SetEtcdClient(etcdCli *clientv3.Client) {
node.etcdCli = etcdCli
@ -186,7 +193,7 @@ func (node *DataNode) Register() error {
// Start liveness check
go node.session.LivenessCheck(node.ctx, func() {
log.Error("Data Node disconnected from etcd, process will exit", zap.Int64("Server Id", node.session.ServerID))
log.Error("Data Node disconnected from etcd, process will exit", zap.Int64("Server Id", node.GetSession().ServerID))
if err := node.Stop(); err != nil {
log.Fatal("failed to stop server", zap.Error(err))
}
@ -222,37 +229,42 @@ func (node *DataNode) initRateCollector() error {
return nil
}
// Init function does nothing now.
func (node *DataNode) Init() error {
log.Info("DataNode server initializing",
zap.String("TimeTickChannelName", Params.CommonCfg.DataCoordTimeTick.GetValue()),
)
if err := node.initSession(); err != nil {
log.Error("DataNode server init session failed", zap.Error(err))
return err
}
var initError error
node.initOnce.Do(func() {
logutil.Logger(node.ctx).Info("DataNode server initializing",
zap.String("TimeTickChannelName", Params.CommonCfg.DataCoordTimeTick.GetValue()),
)
if err := node.initSession(); err != nil {
log.Error("DataNode server init session failed", zap.Error(err))
initError = err
return
}
err := node.initRateCollector()
if err != nil {
log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", paramtable.GetNodeID()), zap.Error(err))
return err
}
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID()))
err := node.initRateCollector()
if err != nil {
log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", paramtable.GetNodeID()), zap.Error(err))
initError = err
return
}
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID()))
idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, paramtable.GetNodeID())
if err != nil {
log.Error("failed to create id allocator",
zap.Error(err),
zap.String("role", typeutil.DataNodeRole), zap.Int64("DataNode ID", paramtable.GetNodeID()))
return err
}
node.rowIDAllocator = idAllocator
idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, paramtable.GetNodeID())
if err != nil {
log.Error("failed to create id allocator",
zap.Error(err),
zap.String("role", typeutil.DataNodeRole), zap.Int64("DataNode ID", paramtable.GetNodeID()))
initError = err
return
}
node.rowIDAllocator = idAllocator
node.factory.Init(Params)
log.Info("DataNode server init succeeded",
zap.String("MsgChannelSubName", Params.CommonCfg.DataNodeSubName.GetValue()))
node.factory.Init(Params)
log.Info("DataNode server init succeeded",
zap.String("MsgChannelSubName", Params.CommonCfg.DataNodeSubName.GetValue()))
return nil
})
return initError
}
// StartWatchChannels start loop to watch channel allocation status via kv(etcd for now)
@ -260,7 +272,8 @@ func (node *DataNode) StartWatchChannels(ctx context.Context) {
defer logutil.LogPanic()
// REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name}
// TODO, this is risky, we'd better watch etcd with revision rather simply a path
watchPrefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", paramtable.GetNodeID()))
watchPrefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetSession().ServerID))
log.Info("Start watch channel", zap.String("prefix", watchPrefix))
evtChan := node.watchKv.WatchWithPrefix(watchPrefix)
// after watch, first check all exists nodes first
err := node.checkWatchedList()
@ -412,7 +425,7 @@ func (node *DataNode) handlePutEvent(watchInfo *datapb.ChannelWatchInfo, version
return fmt.Errorf("fail to marshal watchInfo with state, vChanName: %s, state: %s ,err: %w", vChanName, watchInfo.State.String(), err)
}
key := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", paramtable.GetNodeID()), vChanName)
key := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetSession().ServerID), vChanName)
success, err := node.watchKv.CompareVersionAndSwap(key, version, string(v))
// etcd error, retrying
@ -558,3 +571,17 @@ func (node *DataNode) Stop() error {
return nil
}
// to fix data race
func (node *DataNode) SetSession(session *sessionutil.Session) {
node.sessionMu.Lock()
defer node.sessionMu.Unlock()
node.session = session
}
// to fix data race
func (node *DataNode) GetSession() *sessionutil.Session {
node.sessionMu.Lock()
defer node.sessionMu.Unlock()
return node.session
}

View File

@ -102,6 +102,9 @@ func TestDataNode(t *testing.T) {
assert.Nil(t, err)
err = node.Start()
assert.Nil(t, err)
assert.Empty(t, node.GetAddress())
node.SetAddress("address")
assert.Equal(t, "address", node.GetAddress())
defer node.Stop()
node.chunkManager = storage.NewLocalChunkManager(storage.RootPath("/tmp/milvus_test/datanode"))
@ -155,7 +158,7 @@ func TestDataNode(t *testing.T) {
t.Run("Test getSystemInfoMetrics", func(t *testing.T) {
emptyNode := &DataNode{}
emptyNode.session = &sessionutil.Session{ServerID: 1}
emptyNode.SetSession(&sessionutil.Session{ServerID: 1})
emptyNode.flowgraphManager = newFlowgraphManager()
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics)
@ -170,7 +173,7 @@ func TestDataNode(t *testing.T) {
t.Run("Test getSystemInfoMetrics with quotaMetric error", func(t *testing.T) {
emptyNode := &DataNode{}
emptyNode.session = &sessionutil.Session{ServerID: 1}
emptyNode.SetSession(&sessionutil.Session{ServerID: 1})
emptyNode.flowgraphManager = newFlowgraphManager()
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics)

View File

@ -61,6 +61,7 @@ type dataSyncService struct {
chunkManager storage.ChunkManager
compactor *compactionExecutor // reference to compaction executor
serverID int64
stopOnce sync.Once
flushListener chan *segmentFlushPack // chan to listen flush event
}
@ -77,6 +78,7 @@ func newDataSyncService(ctx context.Context,
flushingSegCache *Cache,
chunkManager storage.ChunkManager,
compactor *compactionExecutor,
serverID int64,
) (*dataSyncService, error) {
if channel == nil {
@ -108,6 +110,7 @@ func newDataSyncService(ctx context.Context,
flushingSegCache: flushingSegCache,
chunkManager: chunkManager,
compactor: compactor,
serverID: serverID,
}
if err := service.initNodes(vchan); err != nil {
@ -127,7 +130,7 @@ type nodeConfig struct {
vChannelName string
channel Channel // Channel info
allocator allocatorInterface
serverID int64
// defaults
parallelConfig
}
@ -280,6 +283,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo) erro
allocator: dsService.idAllocator,
parallelConfig: newParallelConfig(),
serverID: dsService.serverID,
}
var dmStreamNode Node

View File

@ -172,6 +172,7 @@ func TestDataSyncService_newDataSyncService(te *testing.T) {
newCache(),
cm,
newCompactionExecutor(),
0,
)
if !test.isValidCase {
@ -269,7 +270,7 @@ func TestDataSyncService_Start(t *testing.T) {
},
}
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor())
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
assert.Nil(t, err)
sync.flushListener = make(chan *segmentFlushPack)
@ -424,7 +425,7 @@ func TestDataSyncService_Close(t *testing.T) {
},
}
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor())
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
assert.Nil(t, err)
sync.flushListener = make(chan *segmentFlushPack, 10)

View File

@ -681,7 +681,7 @@ func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *De
commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt),
commonpbutil.WithMsgID(0),
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(config.serverID),
),
ChannelName: config.vChannelName,
Timestamp: ts,

View File

@ -48,7 +48,7 @@ func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo
var alloc allocatorInterface = newAllocator(dn.rootCoord)
dataSyncService, err := newDataSyncService(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel,
alloc, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor)
alloc, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, dn.GetSession().ServerID)
if err != nil {
log.Warn("new data sync service fail", zap.String("vChannelName", vchan.GetChannelName()), zap.Error(err))
return err

View File

@ -801,7 +801,7 @@ func flushNotifyFunc(dsService *dataSyncService, opts ...retry.Option) notifyMet
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(0),
commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(dsService.serverID),
),
SegmentID: pack.segmentID,
CollectionID: dsService.collectionID,

View File

@ -94,7 +94,7 @@ func (node *DataNode) getSystemInfoMetrics(ctx context.Context, req *milvuspb.Ge
CreatedTime: paramtable.GetCreateTime().String(),
UpdatedTime: paramtable.GetUpdateTime().String(),
Type: typeutil.DataNodeRole,
ID: node.session.ServerID,
ID: node.GetSession().ServerID,
},
SystemConfigurations: metricsinfo.DataNodeConfiguration{
FlushInsertBufferSize: Params.DataNodeCfg.FlushInsertBufferSize.GetAsInt64(),

View File

@ -37,6 +37,7 @@ import (
s "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
@ -79,6 +80,7 @@ var emptyFlushAndDropFunc flushAndDropFunc = func(_ []*segmentFlushPack) {}
func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode {
factory := dependency.NewDefaultFactory(true)
node := NewDataNode(ctx, factory)
node.SetSession(&sessionutil.Session{ServerID: 1})
rc := &RootCoordFactory{
ID: 0,

View File

@ -65,8 +65,8 @@ func (node *DataNode) WatchDmChannels(ctx context.Context, in *datapb.WatchDmCha
func (node *DataNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
log.Debug("DataNode current state", zap.Any("State", node.stateCode.Load()))
nodeID := common.NotRegisteredID
if node.session != nil && node.session.Registered() {
nodeID = node.session.ServerID
if node.GetSession() != nil && node.session.Registered() {
nodeID = node.GetSession().ServerID
}
states := &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
@ -100,14 +100,15 @@ func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmen
return errStatus, nil
}
if req.GetBase().GetTargetID() != node.session.ServerID {
serverID := node.GetSession().ServerID
if req.GetBase().GetTargetID() != serverID {
log.Warn("flush segment target id not matched",
zap.Int64("targetID", req.GetBase().GetTargetID()),
zap.Int64("serverID", node.session.ServerID),
zap.Int64("serverID", serverID),
)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), node.session.ServerID),
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), serverID),
}
return status, nil
}
@ -814,7 +815,7 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo
commonpbutil.WithMsgType(0),
commonpbutil.WithMsgID(0),
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(node.session.ServerID),
),
SegmentID: segmentID,
CollectionID: req.GetImportTask().GetCollectionId(),

View File

@ -123,7 +123,7 @@ func (s *DataNodeServicesSuite) TestGetComponentStates() {
s.Assert().Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
s.Assert().Equal(common.NotRegisteredID, resp.State.NodeID)
s.node.session = &sessionutil.Session{}
s.node.SetSession(&sessionutil.Session{})
s.node.session.UpdateRegistered(true)
resp, err = s.node.GetComponentStates(context.Background())
s.Assert().NoError(err)
@ -203,7 +203,7 @@ func (s *DataNodeServicesSuite) TestFlushSegments() {
req := &datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.node.session.ServerID,
TargetID: s.node.GetSession().ServerID,
},
DbID: 0,
CollectionID: 1,
@ -277,7 +277,7 @@ func (s *DataNodeServicesSuite) TestFlushSegments() {
req = &datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.node.session.ServerID,
TargetID: s.node.GetSession().ServerID,
},
DbID: 0,
CollectionID: 1,
@ -290,7 +290,7 @@ func (s *DataNodeServicesSuite) TestFlushSegments() {
req = &datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.node.session.ServerID,
TargetID: s.node.GetSession().ServerID,
},
DbID: 0,
CollectionID: 1,
@ -314,7 +314,7 @@ func (s *DataNodeServicesSuite) TestShowConfigurations() {
//test closed server
node := &DataNode{}
node.session = &sessionutil.Session{ServerID: 1}
node.SetSession(&sessionutil.Session{ServerID: 1})
node.stateCode.Store(commonpb.StateCode_Abnormal)
resp, err := node.ShowConfigurations(s.ctx, req)
@ -331,7 +331,7 @@ func (s *DataNodeServicesSuite) TestShowConfigurations() {
func (s *DataNodeServicesSuite) TestGetMetrics() {
node := &DataNode{}
node.session = &sessionutil.Session{ServerID: 1}
node.SetSession(&sessionutil.Session{ServerID: 1})
node.flowgraphManager = newFlowgraphManager()
// server is closed
node.stateCode.Store(commonpb.StateCode_Abnormal)

View File

@ -32,7 +32,6 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
)
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockDataCoord struct {
types.DataCoord
@ -102,6 +101,15 @@ func (*MockDataCoord) SetAddress(address string) {
func (m *MockDataCoord) SetEtcdClient(etcdClient *clientv3.Client) {
}
func (m *MockDataCoord) SetRootCoord(rootCoord types.RootCoord) {
}
func (m *MockDataCoord) SetDataNodeCreator(func(context.Context, string) (types.DataNode, error)) {
}
func (m *MockDataCoord) SetIndexNodeCreator(func(context.Context, string) (types.IndexNode, error)) {
}
func (m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return m.states, m.err
}
@ -264,7 +272,6 @@ func (m *MockDataCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexReq
return m.dropIndexResp, m.err
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func Test_NewServer(t *testing.T) {
paramtable.Init()
ctx := context.Background()

View File

@ -42,6 +42,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/tracer"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
@ -260,7 +261,7 @@ func (s *Server) init() error {
log.Error("failed to start RootCoord client", zap.Error(err))
panic(err)
}
if err = funcutil.WaitForComponentHealthy(ctx, rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil {
if err = componentutil.WaitForComponentHealthy(ctx, rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil {
log.Error("failed to wait for RootCoord client to be ready", zap.Error(err))
panic(err)
}
@ -286,7 +287,7 @@ func (s *Server) init() error {
log.Error("failed to start DataCoord client", zap.Error(err))
panic(err)
}
if err = funcutil.WaitForComponentInitOrHealthy(ctx, dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil {
if err = componentutil.WaitForComponentInitOrHealthy(ctx, dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil {
log.Error("failed to wait for DataCoord client to be ready", zap.Error(err))
panic(err)
}

View File

@ -84,6 +84,10 @@ func (m *MockDataNode) GetStateCode() commonpb.StateCode {
func (m *MockDataNode) SetAddress(address string) {
}
func (m *MockDataNode) GetAddress() string {
return ""
}
func (m *MockDataNode) SetRootCoord(rc types.RootCoord) error {
return m.err
}

View File

@ -259,11 +259,7 @@ func (s *Server) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsReq
// NewServer create a new IndexNode grpc server.
func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) {
ctx1, cancel := context.WithCancel(ctx)
node, err := indexnode.NewIndexNode(ctx1, factory)
if err != nil {
defer cancel()
return nil, err
}
node := indexnode.NewIndexNode(ctx1, factory)
return &Server{
loopCtx: ctx1,

View File

@ -49,6 +49,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
@ -372,7 +373,7 @@ func (s *Server) init() error {
log.Debug("init RootCoord client for Proxy done")
log.Debug("Proxy wait for RootCoord to be healthy")
if err := funcutil.WaitForComponentHealthy(s.ctx, s.rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil {
if err := componentutil.WaitForComponentHealthy(s.ctx, s.rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil {
log.Warn("Proxy failed to wait for RootCoord to be healthy", zap.Error(err))
return err
}
@ -401,7 +402,7 @@ func (s *Server) init() error {
log.Debug("init DataCoord client for Proxy done")
log.Debug("Proxy wait for DataCoord to be healthy")
if err := funcutil.WaitForComponentHealthy(s.ctx, s.dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil {
if err := componentutil.WaitForComponentHealthy(s.ctx, s.dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil {
log.Warn("Proxy failed to wait for DataCoord to be healthy", zap.Error(err))
return err
}
@ -430,7 +431,7 @@ func (s *Server) init() error {
log.Debug("init QueryCoord client for Proxy done")
log.Debug("Proxy wait for QueryCoord to be healthy")
if err := funcutil.WaitForComponentHealthy(s.ctx, s.queryCoordClient, "QueryCoord", 1000000, time.Millisecond*200); err != nil {
if err := componentutil.WaitForComponentHealthy(s.ctx, s.queryCoordClient, "QueryCoord", 1000000, time.Millisecond*200); err != nil {
log.Warn("Proxy failed to wait for QueryCoord to be healthy", zap.Error(err))
return err
}

View File

@ -797,6 +797,10 @@ func (m *MockProxy) SetQueryCoordClient(queryCoord types.QueryCoord) {
}
func (m *MockProxy) SetQueryNodeCreator(func(ctx context.Context, addr string) (types.QueryNode, error)) {
}
func (m *MockProxy) GetRateLimiter() (types.Limiter, error) {
return nil, nil
}
@ -808,6 +812,10 @@ func (m *MockProxy) UpdateStateCode(stateCode commonpb.StateCode) {
func (m *MockProxy) SetAddress(address string) {
}
func (m *MockProxy) GetAddress() string {
return ""
}
func (m *MockProxy) SetEtcdClient(etcdClient *clientv3.Client) {
}

View File

@ -40,6 +40,7 @@ import (
qc "github.com/milvus-io/milvus/internal/querycoordv2"
"github.com/milvus-io/milvus/internal/tracer"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
@ -132,25 +133,25 @@ func (s *Server) init() error {
if s.rootCoord == nil {
s.rootCoord, err = rcc.NewClient(s.loopCtx, qc.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli)
if err != nil {
log.Debug("QueryCoord try to new RootCoord client failed", zap.Error(err))
log.Error("QueryCoord try to new RootCoord client failed", zap.Error(err))
panic(err)
}
}
if err = s.rootCoord.Init(); err != nil {
log.Debug("QueryCoord RootCoordClient Init failed", zap.Error(err))
log.Error("QueryCoord RootCoordClient Init failed", zap.Error(err))
panic(err)
}
if err = s.rootCoord.Start(); err != nil {
log.Debug("QueryCoord RootCoordClient Start failed", zap.Error(err))
log.Error("QueryCoord RootCoordClient Start failed", zap.Error(err))
panic(err)
}
// wait for master init or healthy
log.Debug("QueryCoord try to wait for RootCoord ready")
err = funcutil.WaitForComponentHealthy(s.loopCtx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200)
err = componentutil.WaitForComponentHealthy(s.loopCtx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200)
if err != nil {
log.Debug("QueryCoord wait for RootCoord ready failed", zap.Error(err))
log.Error("QueryCoord wait for RootCoord ready failed", zap.Error(err))
panic(err)
}
@ -163,23 +164,23 @@ func (s *Server) init() error {
if s.dataCoord == nil {
s.dataCoord, err = dcc.NewClient(s.loopCtx, qc.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli)
if err != nil {
log.Debug("QueryCoord try to new DataCoord client failed", zap.Error(err))
log.Error("QueryCoord try to new DataCoord client failed", zap.Error(err))
panic(err)
}
}
if err = s.dataCoord.Init(); err != nil {
log.Debug("QueryCoord DataCoordClient Init failed", zap.Error(err))
log.Error("QueryCoord DataCoordClient Init failed", zap.Error(err))
panic(err)
}
if err = s.dataCoord.Start(); err != nil {
log.Debug("QueryCoord DataCoordClient Start failed", zap.Error(err))
log.Error("QueryCoord DataCoordClient Start failed", zap.Error(err))
panic(err)
}
log.Debug("QueryCoord try to wait for DataCoord ready")
err = funcutil.WaitForComponentHealthy(s.loopCtx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200)
err = componentutil.WaitForComponentHealthy(s.loopCtx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200)
if err != nil {
log.Debug("QueryCoord wait for DataCoord ready failed", zap.Error(err))
log.Error("QueryCoord wait for DataCoord ready failed", zap.Error(err))
panic(err)
}
if err := s.SetDataCoord(s.dataCoord); err != nil {

View File

@ -34,7 +34,7 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
)
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockQueryCoord struct {
states *milvuspb.ComponentStates
status *commonpb.Status
@ -88,6 +88,9 @@ func (m *MockQueryCoord) SetDataCoord(types.DataCoord) error {
return nil
}
func (m *MockQueryCoord) SetQueryNodeCreator(func(ctx context.Context, addr string) (types.QueryNode, error)) {
}
func (m *MockQueryCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
log.Debug("MockQueryCoord::WaitForComponentStates")
return m.states, m.err
@ -159,7 +162,7 @@ func (m *MockQueryCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHea
}, m.err
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockRootCoord struct {
types.RootCoord
initErr error
@ -192,7 +195,7 @@ func (m *MockRootCoord) GetComponentStates(ctx context.Context) (*milvuspb.Compo
}, nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockDataCoord struct {
types.DataCoord
initErr error
@ -225,7 +228,7 @@ func (m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.Compo
}, nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func TestMain(m *testing.M) {
paramtable.Init()
code := m.Run()

View File

@ -34,7 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
)
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockQueryNode struct {
states *milvuspb.ComponentStates
status *commonpb.Status
@ -128,6 +128,10 @@ func (m *MockQueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Sy
func (m *MockQueryNode) SetAddress(address string) {
}
func (m *MockQueryNode) GetAddress() string {
return ""
}
func (m *MockQueryNode) SetEtcdClient(client *clientv3.Client) {
}

View File

@ -170,19 +170,35 @@ func (s *Server) init() error {
if s.newDataCoordClient != nil {
log.Debug("RootCoord start to create DataCoord client")
dataCoord := s.newDataCoordClient(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli)
if err := s.rootCoord.SetDataCoord(s.ctx, dataCoord); err != nil {
s.dataCoord = dataCoord
if err = s.dataCoord.Init(); err != nil {
log.Error("RootCoord DataCoordClient Init failed", zap.Error(err))
panic(err)
}
if err = s.dataCoord.Start(); err != nil {
log.Error("RootCoord DataCoordClient Start failed", zap.Error(err))
panic(err)
}
if err := s.rootCoord.SetDataCoord(dataCoord); err != nil {
panic(err)
}
s.dataCoord = dataCoord
}
if s.newQueryCoordClient != nil {
log.Debug("RootCoord start to create QueryCoord client")
queryCoord := s.newQueryCoordClient(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli)
s.queryCoord = queryCoord
if err := s.queryCoord.Init(); err != nil {
log.Error("RootCoord QueryCoordClient Init failed", zap.Error(err))
panic(err)
}
if err := s.queryCoord.Start(); err != nil {
log.Error("RootCoord QueryCoordClient Start failed", zap.Error(err))
panic(err)
}
if err := s.rootCoord.SetQueryCoord(queryCoord); err != nil {
panic(err)
}
s.queryCoord = queryCoord
}
return s.rootCoord.Init()

View File

@ -18,6 +18,7 @@ package grpcrootcoord
import (
"context"
"errors"
"fmt"
"math/rand"
"path"
@ -35,7 +36,6 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
)
@ -66,7 +66,7 @@ func (m *mockCore) SetAddress(address string) {
func (m *mockCore) SetEtcdClient(etcdClient *clientv3.Client) {
}
func (m *mockCore) SetDataCoord(context.Context, types.DataCoord) error {
func (m *mockCore) SetDataCoord(types.DataCoord) error {
return nil
}
@ -74,6 +74,9 @@ func (m *mockCore) SetQueryCoord(types.QueryCoord) error {
return nil
}
func (m *mockCore) SetProxyCreator(func(ctx context.Context, addr string) (types.Proxy, error)) {
}
func (m *mockCore) Register() error {
return nil
}
@ -92,13 +95,15 @@ func (m *mockCore) Stop() error {
type mockDataCoord struct {
types.DataCoord
initErr error
startErr error
}
func (m *mockDataCoord) Init() error {
return nil
return m.initErr
}
func (m *mockDataCoord) Start() error {
return nil
return m.startErr
}
func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return &milvuspb.ComponentStates{
@ -119,19 +124,21 @@ func (m *mockDataCoord) Stop() error {
return fmt.Errorf("stop error")
}
type mockQuery struct {
type mockQueryCoord struct {
types.QueryCoord
initErr error
startErr error
}
func (m *mockQuery) Init() error {
return nil
func (m *mockQueryCoord) Init() error {
return m.initErr
}
func (m *mockQuery) Start() error {
return nil
func (m *mockQueryCoord) Start() error {
return m.startErr
}
func (m *mockQuery) Stop() error {
func (m *mockQueryCoord) Stop() error {
return fmt.Errorf("stop error")
}
@ -154,7 +161,7 @@ func TestRun(t *testing.T) {
return &mockDataCoord{}
}
svr.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord {
return &mockQuery{}
return &mockQueryCoord{}
}
paramtable.Get().Save(rcServerConfig.Port.Key, fmt.Sprintf("%d", rand.Int()%100+10000))
@ -192,19 +199,66 @@ func TestRun(t *testing.T) {
}
func initEtcd(etcdEndpoints []string) (*clientv3.Client, error) {
var etcdCli *clientv3.Client
connectEtcdFn := func() error {
etcd, err := clientv3.New(clientv3.Config{Endpoints: etcdEndpoints, DialTimeout: 5 * time.Second})
if err != nil {
return err
}
etcdCli = etcd
return nil
func TestServerRun_DataCoordClientInitErr(t *testing.T) {
paramtable.Init()
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, server)
server.newDataCoordClient = func(string, *clientv3.Client) types.DataCoord {
return &mockDataCoord{initErr: errors.New("mock datacoord init error")}
}
err := retry.Do(context.TODO(), connectEtcdFn, retry.Attempts(100))
if err != nil {
return nil, err
}
return etcdCli, nil
assert.Panics(t, func() { server.Run() })
err = server.Stop()
assert.Nil(t, err)
}
func TestServerRun_DataCoordClientStartErr(t *testing.T) {
paramtable.Init()
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, server)
server.newDataCoordClient = func(string, *clientv3.Client) types.DataCoord {
return &mockDataCoord{startErr: errors.New("mock datacoord start error")}
}
assert.Panics(t, func() { server.Run() })
err = server.Stop()
assert.Nil(t, err)
}
func TestServerRun_QueryCoordClientInitErr(t *testing.T) {
paramtable.Init()
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, server)
server.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord {
return &mockQueryCoord{initErr: errors.New("mock querycoord init error")}
}
assert.Panics(t, func() { server.Run() })
err = server.Stop()
assert.Nil(t, err)
}
func TestServer_QueryCoordClientStartErr(t *testing.T) {
paramtable.Init()
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, server)
server.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord {
return &mockQueryCoord{startErr: errors.New("mock querycoord start error")}
}
assert.Panics(t, func() { server.Run() })
err = server.Stop()
assert.Nil(t, err)
}

View File

@ -97,7 +97,7 @@ type IndexNode struct {
}
// NewIndexNode creates a new IndexNode component.
func NewIndexNode(ctx context.Context, factory dependency.Factory) (*IndexNode, error) {
func NewIndexNode(ctx context.Context, factory dependency.Factory) *IndexNode {
log.Debug("New IndexNode ...")
rand.Seed(time.Now().UnixNano())
ctx1, cancel := context.WithCancel(ctx)
@ -109,13 +109,10 @@ func NewIndexNode(ctx context.Context, factory dependency.Factory) (*IndexNode,
tasks: map[taskKey]*taskInfo{},
}
b.UpdateStateCode(commonpb.StateCode_Abnormal)
sc, err := NewTaskScheduler(b.loopCtx)
if err != nil {
return nil, err
}
sc := NewTaskScheduler(b.loopCtx)
b.sched = sc
return b, nil
return b
}
// Register register index node at etcd.
@ -349,3 +346,7 @@ func (i *IndexNode) ShowConfigurations(ctx context.Context, req *internalpb.Show
func (i *IndexNode) SetAddress(address string) {
i.address = address
}
func (i *IndexNode) GetAddress() string {
return i.address
}

View File

@ -18,10 +18,7 @@ func NewMockIndexNodeComponent(ctx context.Context) (types.IndexNodeComponent, e
chunkMgr: &mockChunkmgr{},
}
node, err := NewIndexNode(ctx, factory)
if err != nil {
return nil, err
}
node := NewIndexNode(ctx, factory)
startEmbedEtcd()
etcdCli := getEtcdClient()

View File

@ -183,6 +183,11 @@ func (m *Mock) Register() error {
func (m *Mock) SetAddress(address string) {
m.CallSetAddress(address)
}
func (m *Mock) GetAddress() string {
return ""
}
func (m *Mock) SetEtcdClient(etcdClient *clientv3.Client) {
}
@ -209,7 +214,7 @@ func (m *Mock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest)
return m.CallGetMetrics(ctx, req)
}
//ShowConfigurations returns the configurations of Mock indexNode matching req.Pattern
// ShowConfigurations returns the configurations of Mock indexNode matching req.Pattern
func (m *Mock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
return m.CallShowConfigurations(ctx, req)
}

View File

@ -461,8 +461,7 @@ func TestComponentState(t *testing.T) {
ctx = context.TODO()
)
Params.Init()
in, err := NewIndexNode(ctx, factory)
assert.Nil(t, err)
in := NewIndexNode(ctx, factory)
in.SetEtcdClient(getEtcdClient())
state, err := in.GetComponentStates(ctx)
assert.Nil(t, err)
@ -497,8 +496,7 @@ func TestGetTimeTickChannel(t *testing.T) {
ctx = context.TODO()
)
Params.Init()
in, err := NewIndexNode(ctx, factory)
assert.Nil(t, err)
in := NewIndexNode(ctx, factory)
ret, err := in.GetTimeTickChannel(ctx)
assert.Nil(t, err)
assert.Equal(t, ret.Status.ErrorCode, commonpb.ErrorCode_Success)
@ -512,8 +510,7 @@ func TestGetStatisticChannel(t *testing.T) {
ctx = context.TODO()
)
Params.Init()
in, err := NewIndexNode(ctx, factory)
assert.Nil(t, err)
in := NewIndexNode(ctx, factory)
ret, err := in.GetStatisticsChannel(ctx)
assert.Nil(t, err)
@ -528,8 +525,7 @@ func TestIndexTaskWhenStoppingNode(t *testing.T) {
ctx = context.TODO()
)
Params.Init()
in, err := NewIndexNode(ctx, factory)
assert.Nil(t, err)
in := NewIndexNode(ctx, factory)
in.loadOrStoreTask("cluster-1", 1, &taskInfo{
state: commonpb.IndexState_InProgress,
@ -555,6 +551,19 @@ func TestIndexTaskWhenStoppingNode(t *testing.T) {
}
}
func TestGetSetAddress(t *testing.T) {
var (
factory = &mockFactory{
chunkMgr: &mockChunkmgr{},
}
ctx = context.TODO()
)
Params.Init()
in := NewIndexNode(ctx, factory)
in.SetAddress("address")
assert.Equal(t, "address", in.GetAddress())
}
func TestInitErr(t *testing.T) {
// var (
// factory = &mockFactory{}

View File

@ -172,7 +172,7 @@ type TaskScheduler struct {
}
// NewTaskScheduler creates a new task scheduler of indexing tasks.
func NewTaskScheduler(ctx context.Context) (*TaskScheduler, error) {
func NewTaskScheduler(ctx context.Context) *TaskScheduler {
ctx1, cancel := context.WithCancel(ctx)
s := &TaskScheduler{
ctx: ctx1,
@ -181,7 +181,7 @@ func NewTaskScheduler(ctx context.Context) (*TaskScheduler, error) {
}
s.IndexBuildQueue = NewIndexBuildTaskQueue(s)
return s, nil
return s
}
func (sched *TaskScheduler) scheduleIndexBuildTask() []task {

View File

@ -157,8 +157,7 @@ func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expect
func TestIndexTaskScheduler(t *testing.T) {
Params.Init()
scheduler, err := NewTaskScheduler(context.TODO())
assert.Nil(t, err)
scheduler := NewTaskScheduler(context.TODO())
scheduler.Start()
tasks := make([]task, 0)
@ -188,15 +187,14 @@ func TestIndexTaskScheduler(t *testing.T) {
assert.Equal(t, tasks[len(tasks)-1].GetState(), tasks[len(tasks)-1].(*fakeTask).expectedState)
assert.Equal(t, tasks[len(tasks)-1].Ctx().(*stagectx).curstate, fakeTaskState(fakeTaskSavedIndexes))
scheduler, err = NewTaskScheduler(context.TODO())
assert.Nil(t, err)
scheduler = NewTaskScheduler(context.TODO())
tasks = make([]task, 0, 1024)
for i := 0; i < 1024; i++ {
tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished))
assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(tasks[len(tasks)-1]))
}
failTask := newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished)
err = scheduler.IndexBuildQueue.Enqueue(failTask)
err := scheduler.IndexBuildQueue.Enqueue(failTask)
assert.Error(t, err)
failTask.Reset()

View File

@ -438,6 +438,10 @@ func (node *Proxy) SetAddress(address string) {
node.address = address
}
func (node *Proxy) GetAddress() string {
return node.address
}
// SetEtcdClient sets etcd client for proxy.
func (node *Proxy) SetEtcdClient(client *clientv3.Client) {
node.etcdCli = client
@ -458,6 +462,10 @@ func (node *Proxy) SetQueryCoordClient(cli types.QueryCoord) {
node.queryCoord = cli
}
func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string) (types.QueryNode, error)) {
node.shardMgr.clientCreator = f
}
// GetRateLimiter returns the rateLimiter in Proxy.
func (node *Proxy) GetRateLimiter() (types.Limiter, error) {
if node.multiRateLimiter == nil {

View File

@ -58,6 +58,7 @@ import (
"github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/tracer"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/distance"
@ -474,6 +475,7 @@ func TestProxy(t *testing.T) {
var p paramtable.GrpcServerConfig
p.Init(typeutil.ProxyRole, &base)
testServer.Proxy.SetAddress(p.GetAddress())
assert.Equal(t, p.GetAddress(), testServer.Proxy.GetAddress())
go testServer.startGrpc(ctx, &wg, &p)
assert.NoError(t, testServer.waitForGrpcReady())
@ -482,7 +484,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
err = rootCoordClient.Init()
assert.NoError(t, err)
err = funcutil.WaitForComponentHealthy(ctx, rootCoordClient, typeutil.RootCoordRole, attempts, sleepDuration)
err = componentutil.WaitForComponentHealthy(ctx, rootCoordClient, typeutil.RootCoordRole, attempts, sleepDuration)
assert.NoError(t, err)
proxy.SetRootCoordClient(rootCoordClient)
log.Info("Proxy set root coordinator client")
@ -491,7 +493,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
err = dataCoordClient.Init()
assert.NoError(t, err)
err = funcutil.WaitForComponentHealthy(ctx, dataCoordClient, typeutil.DataCoordRole, attempts, sleepDuration)
err = componentutil.WaitForComponentHealthy(ctx, dataCoordClient, typeutil.DataCoordRole, attempts, sleepDuration)
assert.NoError(t, err)
proxy.SetDataCoordClient(dataCoordClient)
log.Info("Proxy set data coordinator client")
@ -500,9 +502,10 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
err = queryCoordClient.Init()
assert.NoError(t, err)
err = funcutil.WaitForComponentHealthy(ctx, queryCoordClient, typeutil.QueryCoordRole, attempts, sleepDuration)
err = componentutil.WaitForComponentHealthy(ctx, queryCoordClient, typeutil.QueryCoordRole, attempts, sleepDuration)
assert.NoError(t, err)
proxy.SetQueryCoordClient(queryCoordClient)
proxy.SetQueryNodeCreator(defaultQueryNodeClientCreator)
log.Info("Proxy set query coordinator client")
proxy.UpdateStateCode(commonpb.StateCode_Initializing)

View File

@ -312,6 +312,7 @@ func (sa *segIDAssigner) syncSegments() (bool, error) {
sa.segReqs = nil
log.Debug("syncSegments call dataCoord.AssignSegmentID", zap.String("request", req.String()))
resp, err := sa.dataCoord.AssignSegmentID(context.Background(), req)
if err != nil {

View File

@ -106,7 +106,7 @@ func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
return func(s *shardClientMgr) { s.clientCreator = creator }
}
func defaultShardClientCreator(ctx context.Context, addr string) (types.QueryNode, error) {
func defaultQueryNodeClientCreator(ctx context.Context, addr string) (types.QueryNode, error) {
return qnClient.NewClient(ctx, addr)
}
@ -117,7 +117,7 @@ func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgr {
sync.RWMutex
data map[UniqueID]*shardClient
}{data: make(map[UniqueID]*shardClient)},
clientCreator: defaultShardClientCreator,
clientCreator: defaultQueryNodeClientCreator,
}
for _, opt := range options {
opt(s)

View File

@ -85,8 +85,9 @@ type Server struct {
broker meta.Broker
// Session
cluster session.Cluster
nodeMgr *session.NodeManager
cluster session.Cluster
nodeMgr *session.NodeManager
queryNodeCreator session.QueryNodeCreator
// Schedulers
jobScheduler *job.Scheduler
@ -117,6 +118,7 @@ func NewQueryCoord(ctx context.Context) (*Server, error) {
cancel: cancel,
}
server.UpdateStateCode(commonpb.StateCode_Abnormal)
server.queryNodeCreator = session.DefaultQueryNodeCreator
return server, nil
}
@ -182,7 +184,7 @@ func (s *Server) Init() error {
// Init session
log.Info("init session")
s.nodeMgr = session.NewNodeManager()
s.cluster = session.NewCluster(s.nodeMgr)
s.cluster = session.NewCluster(s.nodeMgr, s.queryNodeCreator)
// Init schedulers
log.Info("init schedulers")
@ -479,6 +481,10 @@ func (s *Server) SetDataCoord(dataCoord types.DataCoord) error {
return nil
}
func (s *Server) SetQueryNodeCreator(f func(ctx context.Context, addr string) (types.QueryNode, error)) {
s.queryNodeCreator = f
}
func (s *Server) recover() error {
// Recover target managers
group, ctx := errgroup.WithContext(s.ctx)

View File

@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/mocks"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable"
@ -429,7 +430,7 @@ func newQueryCoord() (*Server, error) {
return nil, err
}
server.SetEtcdClient(etcdCli)
server.SetQueryNodeCreator(session.DefaultQueryNodeCreator)
err = server.Init()
return server, err
}

View File

@ -29,6 +29,7 @@ import (
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"go.uber.org/zap"
)
@ -68,9 +69,15 @@ type QueryCluster struct {
stopOnce sync.Once
}
func NewCluster(nodeManager *NodeManager) *QueryCluster {
type QueryNodeCreator func(ctx context.Context, addr string) (types.QueryNode, error)
func DefaultQueryNodeCreator(ctx context.Context, addr string) (types.QueryNode, error) {
return grpcquerynodeclient.NewClient(ctx, addr)
}
func NewCluster(nodeManager *NodeManager, queryNodeCreator QueryNodeCreator) *QueryCluster {
c := &QueryCluster{
clients: newClients(),
clients: newClients(queryNodeCreator),
nodeManager: nodeManager,
ch: make(chan struct{}),
}
@ -112,7 +119,7 @@ func (c *QueryCluster) updateLoop() {
func (c *QueryCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.LoadSegmentsRequest)
req.Base.TargetID = nodeID
status, err = cli.LoadSegments(ctx, req)
@ -126,7 +133,7 @@ func (c *QueryCluster) LoadSegments(ctx context.Context, nodeID int64, req *quer
func (c *QueryCluster) WatchDmChannels(ctx context.Context, nodeID int64, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.WatchDmChannelsRequest)
req.Base.TargetID = nodeID
status, err = cli.WatchDmChannels(ctx, req)
@ -140,7 +147,7 @@ func (c *QueryCluster) WatchDmChannels(ctx context.Context, nodeID int64, req *q
func (c *QueryCluster) UnsubDmChannel(ctx context.Context, nodeID int64, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.UnsubDmChannelRequest)
req.Base.TargetID = nodeID
status, err = cli.UnsubDmChannel(ctx, req)
@ -154,7 +161,7 @@ func (c *QueryCluster) UnsubDmChannel(ctx context.Context, nodeID int64, req *qu
func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.ReleaseSegmentsRequest)
req.Base.TargetID = nodeID
status, err = cli.ReleaseSegments(ctx, req)
@ -168,7 +175,7 @@ func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *q
func (c *QueryCluster) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) {
var resp *querypb.GetDataDistributionResponse
var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.GetDataDistributionRequest)
req.Base = &commonpb.MsgBase{
TargetID: nodeID,
@ -186,7 +193,7 @@ func (c *QueryCluster) GetMetrics(ctx context.Context, nodeID int64, req *milvus
resp *milvuspb.GetMetricsResponse
err error
)
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
resp, err = cli.GetMetrics(ctx, req)
})
if err1 != nil {
@ -200,7 +207,7 @@ func (c *QueryCluster) SyncDistribution(ctx context.Context, nodeID int64, req *
resp *commonpb.Status
err error
)
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.SyncDistributionRequest)
req.Base.TargetID = nodeID
resp, err = cli.SyncDistribution(ctx, req)
@ -216,18 +223,16 @@ func (c *QueryCluster) GetComponentStates(ctx context.Context, nodeID int64) (*m
resp *milvuspb.ComponentStates
err error
)
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
resp, err = cli.GetComponentStates(ctx)
})
if err1 != nil {
return nil, err1
}
return resp, err
}
func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli *grpcquerynodeclient.Client)) error {
func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli types.QueryNode)) error {
node := c.nodeManager.Get(nodeID)
if node == nil {
return WrapErrNodeNotFound(nodeID)
@ -244,7 +249,8 @@ func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli *grpc
type clients struct {
sync.RWMutex
clients map[int64]*grpcquerynodeclient.Client // nodeID -> client
clients map[int64]types.QueryNode // nodeID -> client
queryNodeCreator QueryNodeCreator
}
func (c *clients) getAllNodeIDs() []int64 {
@ -258,15 +264,15 @@ func (c *clients) getAllNodeIDs() []int64 {
return ret
}
func (c *clients) getOrCreate(ctx context.Context, node *NodeInfo) (*grpcquerynodeclient.Client, error) {
func (c *clients) getOrCreate(ctx context.Context, node *NodeInfo) (types.QueryNode, error) {
if cli := c.get(node.ID()); cli != nil {
return cli, nil
}
return c.create(node)
}
func createNewClient(ctx context.Context, addr string) (*grpcquerynodeclient.Client, error) {
newCli, err := grpcquerynodeclient.NewClient(ctx, addr)
func createNewClient(ctx context.Context, addr string, queryNodeCreator QueryNodeCreator) (types.QueryNode, error) {
newCli, err := queryNodeCreator(ctx, addr)
if err != nil {
return nil, err
}
@ -279,13 +285,13 @@ func createNewClient(ctx context.Context, addr string) (*grpcquerynodeclient.Cli
return newCli, nil
}
func (c *clients) create(node *NodeInfo) (*grpcquerynodeclient.Client, error) {
func (c *clients) create(node *NodeInfo) (types.QueryNode, error) {
c.Lock()
defer c.Unlock()
if cli, ok := c.clients[node.ID()]; ok {
return cli, nil
}
cli, err := createNewClient(context.Background(), node.Addr())
cli, err := createNewClient(context.Background(), node.Addr(), c.queryNodeCreator)
if err != nil {
return nil, err
}
@ -293,7 +299,7 @@ func (c *clients) create(node *NodeInfo) (*grpcquerynodeclient.Client, error) {
return cli, nil
}
func (c *clients) get(nodeID int64) *grpcquerynodeclient.Client {
func (c *clients) get(nodeID int64) types.QueryNode {
c.RLock()
defer c.RUnlock()
return c.clients[nodeID]
@ -320,6 +326,9 @@ func (c *clients) closeAll() {
}
}
func newClients() *clients {
return &clients{clients: make(map[int64]*grpcquerynodeclient.Client)}
func newClients(queryNodeCreator QueryNodeCreator) *clients {
return &clients{
clients: make(map[int64]types.QueryNode),
queryNodeCreator: queryNodeCreator,
}
}

View File

@ -90,7 +90,7 @@ func (suite *ClusterTestSuite) setupCluster() {
node := NewNodeInfo(int64(i), lis.Addr().String())
suite.nodeManager.Add(node)
}
suite.cluster = NewCluster(suite.nodeManager)
suite.cluster = NewCluster(suite.nodeManager, DefaultQueryNodeCreator)
}
func (suite *ClusterTestSuite) createTestServers() []querypb.QueryNodeServer {

View File

@ -41,7 +41,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
@ -75,7 +75,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon
}
nodeID := common.NotRegisteredID
if node.session != nil && node.session.Registered() {
nodeID = paramtable.GetNodeID()
nodeID = node.GetSession().ServerID
}
info := &milvuspb.ComponentInfo{
NodeID: nodeID,
@ -172,7 +172,7 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
}
if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
failRet.Status.Reason = msgQueryNodeIsUnhealthy(node.GetSession().ServerID)
return failRet, nil
}
node.wg.Add(1)
@ -299,9 +299,10 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
// WatchDmChannels create consumers on dmChannels to receive Incremental datawhich is the important part of real-time query
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
// check node healthy
if !node.isHealthy() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
@ -312,17 +313,17 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
defer node.wg.Done()
// check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() {
if in.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), paramtable.GetNodeID()),
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), nodeID),
}
return status, nil
}
log := log.With(
zap.Int64("collectionID", in.GetCollectionID()),
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeID", nodeID),
zap.Strings("channels", lo.Map(in.GetInfos(), func(info *datapb.VchannelInfo, _ int) string {
return info.GetChannelName()
})),
@ -390,8 +391,9 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
// check node healthy
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
@ -402,10 +404,10 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
defer node.wg.Done()
// check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() {
if req.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()),
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID),
}
return status, nil
}
@ -451,9 +453,10 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
// LoadSegments load historical data into query node, historical data can be vector data or index
func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
// check node healthy
if !node.isHealthy() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
@ -464,10 +467,10 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
defer node.wg.Done()
// check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() {
if in.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), paramtable.GetNodeID()),
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), nodeID),
}
return status, nil
}
@ -496,7 +499,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
startTs := time.Now()
log.Info("loadSegmentsTask init", zap.Int64("collectionID", in.CollectionID),
zap.Int64s("segmentIDs", segmentIDs),
zap.Int64("nodeID", paramtable.GetNodeID()))
zap.Int64("nodeID", nodeID))
// TODO remove concurrent load segment for now, unless we solve the memory issue
log.Info("loadSegmentsTask start ", zap.Int64("collectionID", in.CollectionID),
@ -512,7 +515,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
return status, nil
}
log.Info("loadSegmentsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", paramtable.GetNodeID()))
log.Info("loadSegmentsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", nodeID))
waitFunc := func() (*commonpb.Status, error) {
err = task.WaitToFinish()
@ -527,7 +530,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
log.Warn(err.Error())
return status, nil
}
log.Info("loadSegmentsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", paramtable.GetNodeID()))
log.Info("loadSegmentsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", nodeID))
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
@ -539,7 +542,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
// ReleaseCollection clears all data related to this collection on the querynode
func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
@ -587,7 +590,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
// ReleasePartitions clears all data related to this partition on the querynode
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
@ -634,8 +637,9 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
// ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID
func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
@ -646,10 +650,10 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
defer node.wg.Done()
// check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() {
if in.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), paramtable.GetNodeID()),
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), nodeID),
}
return status, nil
}
@ -684,7 +688,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
// GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ...
func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
res := &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -732,13 +736,14 @@ func filterSegmentInfo(segmentInfos []*querypb.SegmentInfo, segmentIDs map[int64
// Search performs replica search tasks.
func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
if !node.IsStandAlone && req.GetReq().GetBase().GetTargetID() != paramtable.GetNodeID() {
nodeID := node.GetSession().ServerID
if !node.IsStandAlone && req.GetReq().GetBase().GetTargetID() != nodeID {
return &internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s",
paramtable.GetNodeID(),
common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), paramtable.GetNodeID())),
nodeID,
common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), nodeID)),
},
}, nil
}
@ -807,13 +812,14 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
tr.CtxElapse(ctx, "search done in all shards")
rateCol.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq()))
rateCol.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Add(float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(nodeID, 10), metrics.SearchLabel).Add(float64(proto.Size(req)))
}
return ret, nil
}
func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.SearchRequest, dmlChannel string) (*internalpb.SearchResults, error) {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel).Inc()
nodeID := node.GetSession().ServerID
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.TotalLabel).Inc()
failRet := &internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -822,11 +828,11 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
defer func() {
if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc()
}
}()
if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID)
return failRet, nil
}
node.wg.Add(1)
@ -876,13 +882,13 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
tr.CtxElapse(ctx, fmt.Sprintf("do subsearch done, vChannel = %s, segmentIDs = %v", dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(historicalTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(historicalTask.reduceDur.Milliseconds()))
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.SuccessLabel).Inc()
return historicalTask.Ret, nil
}
@ -923,9 +929,9 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
if err != nil {
return err
}
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
streamingResult = streamingTask.Ret
return nil
@ -951,16 +957,17 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
tr.CtxElapse(ctx, fmt.Sprintf("do reduce done in shard cluster, vChannel = %s, segmentIDs = %v", dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk()))
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.Leader).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.SuccessLabel).Inc()
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(nodeID)).Observe(float64(req.Req.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(nodeID)).Observe(float64(req.Req.GetTopk()))
return ret, nil
}
func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.QueryRequest, dmlChannel string) (*internalpb.RetrieveResults, error) {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel).Inc()
nodeID := node.GetSession().ServerID
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.TotalLabel).Inc()
failRet := &internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -972,11 +979,11 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
defer func() {
if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc()
}
}()
if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID)
return failRet, nil
}
node.wg.Add(1)
@ -1030,13 +1037,13 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(queryTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(queryTask.reduceDur.Milliseconds()))
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.SuccessLabel).Inc()
return queryTask.Ret, nil
}
@ -1067,9 +1074,9 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
if err != nil {
return err
}
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
streamingResult = streamingTask.Ret
return nil
@ -1101,8 +1108,8 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.SuccessLabel).Inc()
return ret, nil
}
@ -1115,13 +1122,14 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()),
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
if req.GetReq().GetBase().GetTargetID() != paramtable.GetNodeID() {
nodeID := node.GetSession().ServerID
if req.GetReq().GetBase().GetTargetID() != nodeID {
return &internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s",
paramtable.GetNodeID(),
common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), paramtable.GetNodeID())),
nodeID,
common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), nodeID)),
},
}, nil
}
@ -1185,7 +1193,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
if !req.FromShardLeader {
rateCol.Add(metricsinfo.NQPerSecond, 1)
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(nodeID, 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
}
return ret, nil
}
@ -1195,7 +1203,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
if !node.isHealthy() {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()),
Reason: msgQueryNodeIsUnhealthy(node.GetSession().ServerID),
}, nil
}
node.wg.Add(1)
@ -1219,16 +1227,17 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
// ShowConfigurations returns the configurations of queryNode matching req.Pattern
func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() {
log.Warn("QueryNode.ShowConfigurations failed",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeId", nodeID),
zap.String("req", req.Pattern),
zap.Error(errQueryNodeIsUnhealthy(paramtable.GetNodeID())))
zap.Error(errQueryNodeIsUnhealthy(nodeID)))
return &internalpb.ShowConfigurationsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()),
Reason: msgQueryNodeIsUnhealthy(nodeID),
},
Configuations: nil,
}, nil
@ -1256,16 +1265,17 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S
// GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ...
func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() {
log.Ctx(ctx).Warn("QueryNode.GetMetrics failed",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeId", nodeID),
zap.String("req", req.Request),
zap.Error(errQueryNodeIsUnhealthy(paramtable.GetNodeID())))
zap.Error(errQueryNodeIsUnhealthy(nodeID)))
return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()),
Reason: msgQueryNodeIsUnhealthy(nodeID),
},
Response: "",
}, nil
@ -1276,7 +1286,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
metricType, err := metricsinfo.ParseMetricType(req.Request)
if err != nil {
log.Ctx(ctx).Warn("QueryNode.GetMetrics failed to parse metric type",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeId", nodeID),
zap.String("req", req.Request),
zap.Error(err))
@ -1292,7 +1302,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
queryNodeMetrics, err := getSystemInfoMetrics(ctx, req, node)
if err != nil {
log.Ctx(ctx).Warn("QueryNode.GetMetrics failed",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeId", nodeID),
zap.String("req", req.Request),
zap.String("metricType", metricType),
zap.Error(err))
@ -1307,7 +1317,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
}
log.Ctx(ctx).RatedDebug(60, "QueryNode.GetMetrics failed, request metric type is not implemented yet",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeId", nodeID),
zap.String("req", req.Request),
zap.String("metricType", metricType))
@ -1321,18 +1331,19 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
}
func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) {
nodeID := node.GetSession().ServerID
log := log.With(
zap.Int64("msg-id", req.GetBase().GetMsgID()),
zap.Int64("node-id", paramtable.GetNodeID()),
zap.Int64("node-id", nodeID),
)
if !node.isHealthyOrStopping() {
log.Warn("QueryNode.GetMetrics failed",
zap.Error(errQueryNodeIsUnhealthy(paramtable.GetNodeID())))
zap.Error(errQueryNodeIsUnhealthy(nodeID)))
return &querypb.GetDataDistributionResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()),
Reason: msgQueryNodeIsUnhealthy(nodeID),
},
}, nil
}
@ -1340,12 +1351,12 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
defer node.wg.Done()
// check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() {
if req.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s",
paramtable.GetNodeID(),
common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID())),
nodeID,
common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID)),
}
return &querypb.GetDataDistributionResponse{Status: status}, nil
}
@ -1407,7 +1418,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
return &querypb.GetDataDistributionResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
NodeID: paramtable.GetNodeID(),
NodeID: nodeID,
Segments: segmentVersionInfos,
Channels: channelVersionInfos,
LeaderViews: leaderViews,
@ -1416,9 +1427,10 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel()))
nodeID := node.GetSession().ServerID
// check node healthy
if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID())
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
@ -1429,11 +1441,11 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
defer node.wg.Done()
// check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() {
log.Warn("failed to do match target id when sync ", zap.Int64("expect", req.GetBase().GetTargetID()), zap.Int64("actual", node.session.ServerID))
if req.GetBase().GetTargetID() != nodeID {
log.Warn("failed to do match target id when sync ", zap.Int64("expect", req.GetBase().GetTargetID()), zap.Int64("actual", nodeID))
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()),
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID),
}
return status, nil
}
@ -1476,3 +1488,17 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
Reason: "",
}, nil
}
// to fix data race
func (node *QueryNode) SetSession(session *sessionutil.Session) {
node.sessionMu.Lock()
defer node.sessionMu.Unlock()
node.session = session
}
// to fix data race
func (node *QueryNode) GetSession() *sessionutil.Session {
node.sessionMu.Lock()
defer node.sessionMu.Unlock()
return node.session
}

View File

@ -100,7 +100,7 @@ func TestImpl_WatchDmChannels(t *testing.T) {
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(),
TargetID: node.session.ServerID,
TargetID: node.GetSession().ServerID,
},
NodeID: 0,
CollectionID: defaultCollectionID,
@ -187,7 +187,7 @@ func TestImpl_WatchDmChannels(t *testing.T) {
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(),
TargetID: node.session.ServerID,
TargetID: node.GetSession().ServerID,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
@ -218,7 +218,7 @@ func TestImpl_UnsubDmChannel(t *testing.T) {
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(),
TargetID: node.session.ServerID,
TargetID: node.GetSession().ServerID,
},
NodeID: 0,
CollectionID: defaultCollectionID,
@ -241,7 +241,7 @@ func TestImpl_UnsubDmChannel(t *testing.T) {
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_UnsubDmChannel,
MsgID: rand.Int63(),
TargetID: node.session.ServerID,
TargetID: node.GetSession().ServerID,
},
NodeID: 0,
CollectionID: defaultCollectionID,
@ -299,7 +299,7 @@ func TestImpl_LoadSegments(t *testing.T) {
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
TargetID: node.session.ServerID,
TargetID: node.GetSession().ServerID,
},
DstNodeID: 0,
Schema: schema,
@ -540,11 +540,11 @@ func TestImpl_ShowConfigurations(t *testing.T) {
t.Run("test ShowConfigurations", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli)
node.SetSession(sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli))
pattern := "Cache"
req := &internalpb.ShowConfigurationsRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.GetSession().ServerID),
Pattern: pattern,
}
@ -556,12 +556,12 @@ func TestImpl_ShowConfigurations(t *testing.T) {
t.Run("test ShowConfigurations node failed", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli)
node.SetSession(sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli))
node.UpdateStateCode(commonpb.StateCode_Abnormal)
pattern := "Cache"
req := &internalpb.ShowConfigurationsRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.GetSession().ServerID),
Pattern: pattern,
}
@ -592,7 +592,7 @@ func TestImpl_GetMetrics(t *testing.T) {
defer wg.Done()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli)
node.SetSession(sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli))
metricReq := make(map[string]string)
metricReq[metricsinfo.MetricTypeKey] = "system_info"
@ -644,7 +644,7 @@ func TestImpl_ReleaseSegments(t *testing.T) {
defer node.Stop()
req := &queryPb.ReleaseSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.GetSession().ServerID),
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID},
@ -669,7 +669,7 @@ func TestImpl_ReleaseSegments(t *testing.T) {
defer node.Stop()
req := &queryPb.ReleaseSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.GetSession().ServerID),
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID},
@ -704,7 +704,7 @@ func TestImpl_ReleaseSegments(t *testing.T) {
defer node.Stop()
req := &queryPb.ReleaseSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.GetSession().ServerID),
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID},
@ -725,7 +725,7 @@ func TestImpl_ReleaseSegments(t *testing.T) {
assert.NoError(t, err)
req := &queryPb.ReleaseSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.GetSession().ServerID),
CollectionID: defaultCollectionID,
}
@ -1056,7 +1056,7 @@ func TestSyncDistribution(t *testing.T) {
defer node.Stop()
resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{TargetID: node.session.ServerID},
Base: &commonpb.MsgBase{TargetID: node.GetSession().ServerID},
CollectionID: defaultCollectionID,
Channel: defaultDMLChannel,
Actions: []*querypb.SyncAction{
@ -1086,7 +1086,7 @@ func TestSyncDistribution(t *testing.T) {
cs.SetupFirstVersion()
resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{TargetID: node.session.ServerID},
Base: &commonpb.MsgBase{TargetID: node.GetSession().ServerID},
CollectionID: defaultCollectionID,
Channel: defaultDMLChannel,
Actions: []*querypb.SyncAction{
@ -1109,7 +1109,7 @@ func TestSyncDistribution(t *testing.T) {
assert.Equal(t, segmentStateLoaded, segment.state)
assert.EqualValues(t, 1, segment.version)
resp, err = node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{TargetID: node.session.ServerID},
Base: &commonpb.MsgBase{TargetID: node.GetSession().ServerID},
CollectionID: defaultCollectionID,
Channel: defaultDMLChannel,
Actions: []*querypb.SyncAction{

View File

@ -38,7 +38,7 @@ type ImplUtilsSuite struct {
func (s *ImplUtilsSuite) SetupSuite() {
s.querynode = newQueryNodeMock()
client := v3client.New(embedetcdServer.Server)
s.querynode.session = sessionutil.NewSession(context.Background(), "milvus_ut/sessions", client)
s.querynode.SetSession(sessionutil.NewSession(context.Background(), "milvus_ut/sessions", client))
s.querynode.UpdateStateCode(commonpb.StateCode_Healthy)
s.querynode.ShardClusterService = newShardClusterService(client, s.querynode.session, s.querynode)
@ -52,8 +52,8 @@ func (s *ImplUtilsSuite) SetupTest() {
nodeEvent := []nodeEvent{
{
nodeID: s.querynode.session.ServerID,
nodeAddr: s.querynode.session.ServerName,
nodeID: s.querynode.GetSession().ServerID,
nodeAddr: s.querynode.GetSession().ServerName,
isLeader: true,
},
}
@ -75,9 +75,9 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.Run("normal transfer load", func() {
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
TargetID: s.querynode.GetSession().ServerID,
},
DstNodeID: s.querynode.session.ServerID,
DstNodeID: s.querynode.GetSession().ServerID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID,
@ -95,9 +95,9 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.Run("transfer non-exist channel load", func() {
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
TargetID: s.querynode.GetSession().ServerID,
},
DstNodeID: s.querynode.session.ServerID,
DstNodeID: s.querynode.GetSession().ServerID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID,
@ -115,9 +115,9 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.Run("transfer empty load segments", func() {
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
TargetID: s.querynode.GetSession().ServerID,
},
DstNodeID: s.querynode.session.ServerID,
DstNodeID: s.querynode.GetSession().ServerID,
Infos: []*querypb.SegmentLoadInfo{},
})
@ -141,7 +141,7 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
TargetID: s.querynode.GetSession().ServerID,
},
DstNodeID: 100,
Infos: []*querypb.SegmentLoadInfo{
@ -197,12 +197,12 @@ func (s *ImplUtilsSuite) TestTransferRelease() {
s.Run("normal transfer release", func() {
status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
TargetID: s.querynode.GetSession().ServerID,
},
SegmentIDs: []int64{},
Scope: querypb.DataScope_All,
Shard: defaultChannelName,
NodeID: s.querynode.session.ServerID,
NodeID: s.querynode.GetSession().ServerID,
})
s.NoError(err)
@ -212,12 +212,12 @@ func (s *ImplUtilsSuite) TestTransferRelease() {
s.Run("transfer non-exist channel release", func() {
status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
TargetID: s.querynode.GetSession().ServerID,
},
SegmentIDs: []int64{},
Scope: querypb.DataScope_All,
Shard: "invalid_channel",
NodeID: s.querynode.session.ServerID,
NodeID: s.querynode.GetSession().ServerID,
})
s.NoError(err)
@ -239,7 +239,7 @@ func (s *ImplUtilsSuite) TestTransferRelease() {
status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID,
TargetID: s.querynode.GetSession().ServerID,
},
SegmentIDs: []int64{},
Scope: querypb.DataScope_All,

View File

@ -89,7 +89,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed)
req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{
@ -117,7 +117,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed)
req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{
@ -210,7 +210,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
CollectionID: defaultCollectionID,
}
req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
DeltaPositions: []*internalpb.MsgPosition{
@ -305,7 +305,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed)
req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{
@ -360,7 +360,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
segmentID1 := defaultSegmentID
segmentID2 := defaultSegmentID + 1
req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{
@ -428,7 +428,7 @@ func TestTask_loadSegmentsTaskLoadDelta(t *testing.T) {
CollectionID: defaultCollectionID,
}
loadReq := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
DeltaPositions: []*internalpb.MsgPosition{
@ -458,7 +458,7 @@ func TestTask_loadSegmentsTaskLoadDelta(t *testing.T) {
// load second segments with same channel
loadReq = &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{

View File

@ -81,7 +81,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest,
}, nil
}
hardwareInfos := metricsinfo.HardwareMetrics{
IP: node.session.Address,
IP: node.GetSession().Address,
CPUCoreCount: hardware.GetCPUNum(),
CPUCoreUsage: hardware.GetCPUUsage(),
Memory: totalMem,
@ -99,7 +99,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest,
CreatedTime: paramtable.GetCreateTime().String(),
UpdatedTime: paramtable.GetUpdateTime().String(),
Type: typeutil.QueryNodeRole,
ID: node.session.ServerID,
ID: node.GetSession().ServerID,
},
SystemConfigurations: metricsinfo.QueryNodeConfiguration{
SimdType: Params.CommonCfg.SimdType.GetValue(),

View File

@ -48,10 +48,10 @@ func TestGetSystemInfoMetrics(t *testing.T) {
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
assert.NoError(t, err)
defer etcdCli.Close()
node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli)
node.SetSession(sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli))
req := &milvuspb.GetMetricsRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.session.ServerID),
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.GetSession().ServerID),
}
resp, err := getSystemInfoMetrics(ctx, req, node)
assert.NoError(t, err)

View File

@ -1741,7 +1741,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
}
// init shard cluster service
node.ShardClusterService = newShardClusterService(node.etcdCli, node.session, node)
node.ShardClusterService = newShardClusterService(node.etcdCli, node.GetSession(), node)
node.queryShardService, err = newQueryShardService(node.queryNodeLoopCtx,
node.metaReplica, node.tSafeReplica,

View File

@ -110,8 +110,9 @@ type QueryNode struct {
factory dependency.Factory
scheduler *taskScheduler
session *sessionutil.Session
eventCh <-chan *sessionutil.SessionEvent
sessionMu sync.Mutex
session *sessionutil.Session
eventCh <-chan *sessionutil.SessionEvent
vectorStorage storage.ChunkManager
etcdKV *etcdkv.EtcdKV
@ -393,3 +394,7 @@ func (node *QueryNode) SetEtcdClient(client *clientv3.Client) {
func (node *QueryNode) SetAddress(address string) {
node.address = address
}
func (node *QueryNode) GetAddress() string {
return node.address
}

View File

@ -197,6 +197,9 @@ func TestQueryNode_init(t *testing.T) {
node.SetEtcdClient(etcdcli)
err = node.Init()
assert.Nil(t, err)
assert.Empty(t, node.GetAddress())
node.SetAddress("address")
assert.Equal(t, "address", node.GetAddress())
}
func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, error) {

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/types"
@ -34,7 +35,21 @@ import (
"github.com/milvus-io/milvus/internal/util/sessionutil"
)
type proxyCreator func(sess *sessionutil.Session) (types.Proxy, error)
type proxyCreator func(ctx context.Context, addr string) (types.Proxy, error)
func DefaultProxyCreator(ctx context.Context, addr string) (types.Proxy, error) {
cli, err := grpcproxyclient.NewClient(ctx, addr)
if err != nil {
return nil, err
}
if err := cli.Init(); err != nil {
return nil, err
}
if err := cli.Start(); err != nil {
return nil, err
}
return cli, nil
}
type proxyClientManager struct {
creator proxyCreator
@ -85,7 +100,7 @@ func (p *proxyClientManager) GetProxyCount() int {
}
func (p *proxyClientManager) connect(session *sessionutil.Session) {
pc, err := p.creator(session)
pc, err := p.creator(context.Background(), session.Address)
if err != nil {
log.Warn("failed to create proxy client", zap.String("address", session.Address), zap.Int64("serverID", session.ServerID), zap.Error(err))
return

View File

@ -117,7 +117,7 @@ func TestProxyClientManager_GetProxyClients(t *testing.T) {
defer cli.Close()
assert.Nil(t, err)
core.etcdCli = cli
core.proxyCreator = func(se *sessionutil.Session) (types.Proxy, error) {
core.proxyCreator = func(ctx context.Context, addr string) (types.Proxy, error) {
return nil, errors.New("failed")
}
@ -149,7 +149,7 @@ func TestProxyClientManager_AddProxyClient(t *testing.T) {
defer cli.Close()
core.etcdCli = cli
core.proxyCreator = func(se *sessionutil.Session) (types.Proxy, error) {
core.proxyCreator = func(ctx context.Context, addr string) (types.Proxy, error) {
return nil, errors.New("failed")
}

View File

@ -36,7 +36,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
pnc "github.com/milvus-io/milvus/internal/distributed/proxy/client"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
@ -150,19 +149,7 @@ func NewCore(c context.Context, factory dependency.Factory) (*Core, error) {
}
core.UpdateStateCode(commonpb.StateCode_Abnormal)
core.proxyCreator = func(se *sessionutil.Session) (types.Proxy, error) {
cli, err := pnc.NewClient(c, se.Address)
if err != nil {
return nil, err
}
if err := cli.Init(); err != nil {
return nil, err
}
if err := cli.Start(); err != nil {
return nil, err
}
return cli, nil
}
core.SetProxyCreator(DefaultProxyCreator)
return core, nil
}
@ -263,23 +250,21 @@ func (c *Core) tsLoop() {
}
}
func (c *Core) SetDataCoord(ctx context.Context, s types.DataCoord) error {
if err := s.Init(); err != nil {
return err
}
if err := s.Start(); err != nil {
return err
func (c *Core) SetProxyCreator(f func(ctx context.Context, addr string) (types.Proxy, error)) {
c.proxyCreator = f
}
func (c *Core) SetDataCoord(s types.DataCoord) error {
if s == nil {
return errors.New("null DataCoord interface")
}
c.dataCoord = s
return nil
}
func (c *Core) SetQueryCoord(s types.QueryCoord) error {
if err := s.Init(); err != nil {
return err
}
if err := s.Start(); err != nil {
return err
if s == nil {
return errors.New("null QueryCoord interface")
}
c.queryCoord = s
return nil

View File

@ -52,6 +52,8 @@ type Component interface {
GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error)
GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error)
Register() error
//SetAddress(address string)
//GetAddress() string
}
// DataNode is the interface `datanode` package implements
@ -112,6 +114,7 @@ type DataNodeComponent interface {
GetStateCode() commonpb.StateCode
SetAddress(address string)
GetAddress() string
// SetEtcdClient set etcd client for DataNode
SetEtcdClient(etcdClient *clientv3.Client)
@ -370,6 +373,14 @@ type DataCoordComponent interface {
// SetEtcdClient set EtcdClient for DataCoord
// `etcdClient` is a client of etcd
SetEtcdClient(etcdClient *clientv3.Client)
SetRootCoord(rootCoord RootCoord)
// SetDataNodeCreator set DataNode client creator func for DataCoord
SetDataNodeCreator(func(context.Context, string) (DataNode, error))
//SetIndexNodeCreator set Index client creator func for DataCoord
SetIndexNodeCreator(func(context.Context, string) (IndexNode, error))
}
// IndexNode is the interface `indexnode` package implements
@ -406,7 +417,7 @@ type IndexNodeComponent interface {
IndexNode
SetAddress(address string)
GetAddress() string
// SetEtcdClient set etcd client for IndexNodeComponent
SetEtcdClient(etcdClient *clientv3.Client)
@ -763,10 +774,9 @@ type RootCoordComponent interface {
// SetDataCoord set DataCoord for RootCoord
// `dataCoord` is a client of data coordinator.
// `ctx` is the context pass to DataCoord api.
//
// Always return nil.
SetDataCoord(ctx context.Context, dataCoord DataCoord) error
SetDataCoord(dataCoord DataCoord) error
// SetQueryCoord set QueryCoord for RootCoord
// `queryCoord` is a client of query coordinator.
@ -774,6 +784,9 @@ type RootCoordComponent interface {
// Always return nil.
SetQueryCoord(queryCoord QueryCoord) error
// SetProxyCreator set Proxy client creator func for RootCoord
SetProxyCreator(func(ctx context.Context, addr string) (Proxy, error))
// GetMetrics notifies RootCoordComponent to collect metrics for specified component
GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)
}
@ -826,6 +839,7 @@ type ProxyComponent interface {
Proxy
SetAddress(address string)
GetAddress() string
// SetEtcdClient set EtcdClient for Proxy
// `etcdClient` is a client of etcd
SetEtcdClient(etcdClient *clientv3.Client)
@ -846,6 +860,9 @@ type ProxyComponent interface {
// `queryCoord` is a client of query coordinator.
SetQueryCoordClient(queryCoord QueryCoord)
// SetQueryNodeCreator set QueryNode client creator func for Proxy
SetQueryNodeCreator(func(ctx context.Context, addr string) (QueryNode, error))
// GetRateLimiter returns the rateLimiter in Proxy
GetRateLimiter() (Limiter, error)
@ -1326,6 +1343,7 @@ type QueryNodeComponent interface {
UpdateStateCode(stateCode commonpb.StateCode)
SetAddress(address string)
GetAddress() string
// SetEtcdClient set etcd client for QueryNode
SetEtcdClient(etcdClient *clientv3.Client)
@ -1385,4 +1403,7 @@ type QueryCoordComponent interface {
// Return nil in status:
// The rootCoord is not nil.
SetRootCoord(rootCoord RootCoord) error
// SetQueryNodeCreator set QueryNode client creator func for QueryCoord
SetQueryNodeCreator(func(ctx context.Context, addr string) (QueryNode, error))
}

View File

@ -0,0 +1,73 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package componentutil
import (
"context"
"errors"
"fmt"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/retry"
)
// WaitForComponentStates wait for component's state to be one of the specific states
func WaitForComponentStates(ctx context.Context, service types.Component, serviceName string, states []commonpb.StateCode, attempts uint, sleep time.Duration) error {
checkFunc := func() error {
resp, err := service.GetComponentStates(ctx)
if err != nil {
return err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(resp.Status.Reason)
}
meet := false
for _, state := range states {
if resp.State.StateCode == state {
meet = true
break
}
}
if !meet {
return fmt.Errorf(
"WaitForComponentStates, not meet, %s current state: %s",
serviceName,
resp.State.StateCode.String())
}
return nil
}
return retry.Do(ctx, checkFunc, retry.Attempts(attempts), retry.Sleep(sleep))
}
// WaitForComponentInitOrHealthy wait for component's state to be initializing or healthy
func WaitForComponentInitOrHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}, attempts, sleep)
}
// WaitForComponentInit wait for component's state to be initializing
func WaitForComponentInit(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing}, attempts, sleep)
}
// WaitForComponentHealthy wait for component's state to be healthy
func WaitForComponentHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Healthy}, attempts, sleep)
}

View File

@ -0,0 +1,133 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package componentutil
import (
"context"
"errors"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
)
type MockComponent struct {
compState *milvuspb.ComponentStates
strResp *milvuspb.StringResponse
compErr error
}
func (mc *MockComponent) SetCompState(state *milvuspb.ComponentStates) {
mc.compState = state
}
func (mc *MockComponent) SetStrResp(resp *milvuspb.StringResponse) {
mc.strResp = resp
}
func (mc *MockComponent) Init() error {
return nil
}
func (mc *MockComponent) Start() error {
return nil
}
func (mc *MockComponent) Stop() error {
return nil
}
func (mc *MockComponent) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return mc.compState, mc.compErr
}
func (mc *MockComponent) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return mc.strResp, nil
}
func (mc *MockComponent) Register() error {
return nil
}
func buildMockComponent(code commonpb.StateCode) *MockComponent {
mc := &MockComponent{
compState: &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: "role",
StateCode: code,
},
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
},
strResp: nil,
compErr: nil,
}
return mc
}
func Test_WaitForComponentInitOrHealthy(t *testing.T) {
mc := &MockComponent{
compState: nil,
strResp: nil,
compErr: errors.New("error"),
}
err := WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
assert.NotNil(t, err)
mc = &MockComponent{
compState: &milvuspb.ComponentStates{
State: nil,
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
},
strResp: nil,
compErr: nil,
}
err = WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
assert.NotNil(t, err)
validCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if funcutil.SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}
func Test_WaitForComponentInit(t *testing.T) {
validCodes := []commonpb.StateCode{commonpb.StateCode_Initializing}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentInit(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if funcutil.SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}

View File

@ -39,7 +39,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/retry"
)
@ -67,51 +66,6 @@ func GetLocalIP() string {
return "127.0.0.1"
}
// WaitForComponentStates wait for component's state to be one of the specific states
func WaitForComponentStates(ctx context.Context, service types.Component, serviceName string, states []commonpb.StateCode, attempts uint, sleep time.Duration) error {
checkFunc := func() error {
resp, err := service.GetComponentStates(ctx)
if err != nil {
return err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(resp.Status.Reason)
}
meet := false
for _, state := range states {
if resp.State.StateCode == state {
meet = true
break
}
}
if !meet {
return fmt.Errorf(
"WaitForComponentStates, not meet, %s current state: %s",
serviceName,
resp.State.StateCode.String())
}
return nil
}
return retry.Do(ctx, checkFunc, retry.Attempts(attempts), retry.Sleep(sleep))
}
// WaitForComponentInitOrHealthy wait for component's state to be initializing or healthy
func WaitForComponentInitOrHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}, attempts, sleep)
}
// WaitForComponentInit wait for component's state to be initializing
func WaitForComponentInit(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing}, attempts, sleep)
}
// WaitForComponentHealthy wait for component's state to be healthy
func WaitForComponentHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Healthy}, attempts, sleep)
}
// JSONToMap parse the jsonic index parameters to map
func JSONToMap(mStr string) (map[string]string, error) {
buffer := make(map[string]any)

View File

@ -35,62 +35,6 @@ import (
grpcStatus "google.golang.org/grpc/status"
)
type MockComponent struct {
compState *milvuspb.ComponentStates
strResp *milvuspb.StringResponse
compErr error
}
func (mc *MockComponent) SetCompState(state *milvuspb.ComponentStates) {
mc.compState = state
}
func (mc *MockComponent) SetStrResp(resp *milvuspb.StringResponse) {
mc.strResp = resp
}
func (mc *MockComponent) Init() error {
return nil
}
func (mc *MockComponent) Start() error {
return nil
}
func (mc *MockComponent) Stop() error {
return nil
}
func (mc *MockComponent) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return mc.compState, mc.compErr
}
func (mc *MockComponent) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return mc.strResp, nil
}
func (mc *MockComponent) Register() error {
return nil
}
func buildMockComponent(code commonpb.StateCode) *MockComponent {
mc := &MockComponent{
compState: &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: "role",
StateCode: code,
},
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
},
strResp: nil,
compErr: nil,
}
return mc
}
func Test_CheckGrpcReady(t *testing.T) {
errChan := make(chan error)
@ -112,68 +56,6 @@ func Test_GetLocalIP(t *testing.T) {
assert.NotZero(t, len(ip))
}
func Test_WaitForComponentInitOrHealthy(t *testing.T) {
mc := &MockComponent{
compState: nil,
strResp: nil,
compErr: errors.New("error"),
}
err := WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
assert.NotNil(t, err)
mc = &MockComponent{
compState: &milvuspb.ComponentStates{
State: nil,
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
},
strResp: nil,
compErr: nil,
}
err = WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
assert.NotNil(t, err)
validCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}
func Test_WaitForComponentInit(t *testing.T) {
validCodes := []commonpb.StateCode{commonpb.StateCode_Initializing}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentInit(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}
func Test_WaitForComponentHealthy(t *testing.T) {
validCodes := []commonpb.StateCode{commonpb.StateCode_Healthy}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}
func Test_ParseIndexParamsMap(t *testing.T) {
num := 10
keys := make([]string, 0)

View File

@ -83,6 +83,8 @@ type ComponentParam struct {
DataNodeGrpcClientCfg GrpcClientConfig
IndexCoordGrpcClientCfg GrpcClientConfig
IndexNodeGrpcClientCfg GrpcClientConfig
IntegrationTestCfg integrationTestConfig
}
// InitOnce initialize once
@ -126,6 +128,8 @@ func (p *ComponentParam) Init() {
p.DataCoordGrpcClientCfg.Init(typeutil.DataCoordRole, &p.BaseTable)
p.DataNodeGrpcClientCfg.Init(typeutil.DataNodeRole, &p.BaseTable)
p.IndexNodeGrpcClientCfg.Init(typeutil.IndexNodeRole, &p.BaseTable)
p.IntegrationTestCfg.init(&p.BaseTable)
}
func (p *ComponentParam) RocksmqEnable() bool {
@ -1732,3 +1736,17 @@ func (p *indexNodeConfig) init(base *BaseTable) {
}
p.GracefulStopTimeout.Init(base.mgr)
}
type integrationTestConfig struct {
IntegrationMode ParamItem `refreshable:"false"`
}
func (p *integrationTestConfig) init(base *BaseTable) {
p.IntegrationMode = ParamItem{
Key: "integration.test.mode",
Version: "2.2.0",
DefaultValue: "false",
PanicIfEmpty: true,
}
p.IntegrationMode.Init(base.mgr)
}

View File

@ -209,6 +209,11 @@ func NewSession(ctx context.Context, metaRoot string, client *clientv3.Client, o
reuseNodeID: true,
}
// integration test create cluster with different nodeId in one process
if paramtable.Get().IntegrationTestCfg.IntegrationMode.GetAsBool() {
session.reuseNodeID = false
}
session.apply(opts...)
session.UpdateRegistered(false)

View File

@ -709,3 +709,28 @@ func TestSession_apply(t *testing.T) {
assert.Equal(t, int64(100), session.sessionTTL)
assert.Equal(t, int64(200), session.sessionRetryTimes)
}
func TestIntegrationMode(t *testing.T) {
ctx := context.Background()
params := paramtable.Get()
params.Init()
params.Save(params.IntegrationTestCfg.IntegrationMode.Key, "true")
endpoints := params.GetWithDefault("etcd.endpoints", paramtable.DefaultEtcdEndpoints)
metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot)
etcdEndpoints := strings.Split(endpoints, ",")
etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints)
require.NoError(t, err)
etcdKV := etcdkv.NewEtcdKV(etcdCli, metaRoot)
err = etcdKV.RemoveWithPrefix("")
assert.NoError(t, err)
s1 := NewSession(ctx, metaRoot, etcdCli)
assert.Equal(t, false, s1.reuseNodeID)
s2 := NewSession(ctx, metaRoot, etcdCli)
assert.Equal(t, false, s2.reuseNodeID)
s1.Init("inittest1", "testAddr1", false, false)
s1.Init("inittest2", "testAddr2", false, false)
assert.NotEqual(t, s1.ServerID, s2.ServerID)
}

View File

@ -0,0 +1,372 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"math/rand"
"strconv"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func TestHelloMilvus(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
prefix := "TestHelloMilvus"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionStatus, err := c.proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: 2,
})
assert.NoError(t, err)
if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason()))
}
assert.Equal(t, createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
assert.NoError(t, err)
assert.Equal(t, showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
insertResult, err := c.proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
assert.Equal(t, insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
assert.NoError(t, err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
assert.NotEmpty(t, segmentIDs)
segments, err := c.metaWatcher.ShowSegments()
assert.NoError(t, err)
assert.NotEmpty(t, segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
})
if err != nil {
panic(errors.New("GetFlushState failed"))
return false
}
return resp.GetFlushed()
}
for !flushed() {
// respect context deadline/cancel
select {
case <-ctx.Done():
panic(errors.New("deadline exceeded"))
default:
}
time.Sleep(500 * time.Millisecond)
}
}
// create index
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: strconv.Itoa(10),
},
},
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// load
loadStatus, err := c.proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
for {
select {
case <-ctx.Done():
errors.New("context deadline exceeded")
default:
}
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
// search
expr := fmt.Sprintf("%s > 0", "int64")
nq := 10
topk := 10
roundDecimal := -1
nprobe := 10
searchReq := constructSearchRequest("", collectionName, expr,
floatVecField, nq, dim, nprobe, topk, roundDecimal)
searchResult, err := c.proxy.Search(ctx, searchReq)
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
}
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
log.Info("TestHelloMilvus succeed")
}
const (
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
LimitKey = "limit"
)
func constructSearchRequest(
dbName, collectionName string,
expr string,
floatVecField string,
nq, dim, nprobe, topk, roundDecimal int,
) *milvuspb.SearchRequest {
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
if err != nil {
panic(err)
}
plg := constructPlaceholderGroup(nq, dim)
plgBs, err := proto.Marshal(plg)
if err != nil {
panic(err)
}
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: string(b),
},
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: strconv.Itoa(topk),
},
{
Key: RoundDecimalKey,
Value: strconv.Itoa(roundDecimal),
},
},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
func constructPlaceholderGroup(
nq, dim int,
) *commonpb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
if err != nil {
panic(err)
}
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
return &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: values,
},
},
}
}
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: generateFloatVectors(numRows, dim),
},
},
},
},
}
}
func generateFloatVectors(numRows, dim int) []float32 {
total := numRows * dim
ret := make([]float32, 0, total)
for i := 0; i < total; i++ {
ret = append(ret, rand.Float32())
}
return ret
}
func generateHashKeys(numRows int) []uint32 {
ret := make([]uint32, 0, numRows)
for i := 0; i < numRows; i++ {
ret = append(ret, rand.Uint32())
}
return ret
}

View File

@ -0,0 +1,141 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"context"
"encoding/json"
"fmt"
"path"
"sort"
"time"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/util/sessionutil"
clientv3 "go.etcd.io/etcd/client/v3"
)
// MetaWatcher to observe meta data of milvus cluster
type MetaWatcher interface {
ShowSessions() ([]*sessionutil.Session, error)
ShowSegments() ([]*datapb.SegmentInfo, error)
ShowReplicas() ([]*milvuspb.ReplicaInfo, error)
}
type EtcdMetaWatcher struct {
MetaWatcher
rootPath string
etcdCli *clientv3.Client
}
func (watcher *EtcdMetaWatcher) ShowSessions() ([]*sessionutil.Session, error) {
metaPath := watcher.rootPath + "/meta/session"
return listSessionsByPrefix(watcher.etcdCli, metaPath)
}
func (watcher *EtcdMetaWatcher) ShowSegments() ([]*datapb.SegmentInfo, error) {
metaBasePath := path.Join(watcher.rootPath, "/meta/datacoord-meta/s/")
return listSegments(watcher.etcdCli, metaBasePath, func(s *datapb.SegmentInfo) bool {
return true
})
}
func (watcher *EtcdMetaWatcher) ShowReplicas() ([]*milvuspb.ReplicaInfo, error) {
metaBasePath := path.Join(watcher.rootPath, "/meta/querycoord-replica/")
return listReplicas(watcher.etcdCli, metaBasePath)
}
//=================== Below largely copied from birdwatcher ========================
// listSessions returns all session
func listSessionsByPrefix(cli *clientv3.Client, prefix string) ([]*sessionutil.Session, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
resp, err := cli.Get(ctx, prefix, clientv3.WithPrefix())
if err != nil {
return nil, err
}
sessions := make([]*sessionutil.Session, 0, len(resp.Kvs))
for _, kv := range resp.Kvs {
session := &sessionutil.Session{}
err := json.Unmarshal(kv.Value, session)
if err != nil {
continue
}
sessions = append(sessions, session)
}
return sessions, nil
}
func listSegments(cli *clientv3.Client, prefix string, filter func(*datapb.SegmentInfo) bool) ([]*datapb.SegmentInfo, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
resp, err := cli.Get(ctx, prefix, clientv3.WithPrefix())
if err != nil {
return nil, err
}
segments := make([]*datapb.SegmentInfo, 0, len(resp.Kvs))
for _, kv := range resp.Kvs {
info := &datapb.SegmentInfo{}
err = proto.Unmarshal(kv.Value, info)
if err != nil {
continue
}
if filter == nil || filter(info) {
segments = append(segments, info)
}
}
sort.Slice(segments, func(i, j int) bool {
return segments[i].GetID() < segments[j].GetID()
})
return segments, nil
}
func listReplicas(cli *clientv3.Client, prefix string) ([]*milvuspb.ReplicaInfo, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
resp, err := cli.Get(ctx, prefix, clientv3.WithPrefix())
if err != nil {
return nil, err
}
replicas := make([]*milvuspb.ReplicaInfo, 0, len(resp.Kvs))
for _, kv := range resp.Kvs {
replica := &milvuspb.ReplicaInfo{}
if err != proto.Unmarshal(kv.Value, replica) {
continue
}
replicas = append(replicas, replica)
}
return replicas, nil
}
func PrettyReplica(replica *milvuspb.ReplicaInfo) string {
res := fmt.Sprintf("ReplicaID: %d CollectionID: %d\n", replica.ReplicaID, replica.CollectionID)
for _, shardReplica := range replica.ShardReplicas {
res = res + fmt.Sprintf("Channel %s leader %d\n", shardReplica.DmChannelName, shardReplica.LeaderID)
}
res = res + fmt.Sprintf("Nodes:%v\n", replica.NodeIds)
return res
}

View File

@ -0,0 +1,334 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"context"
"errors"
"strconv"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func TestShowSessions(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
sessions, err := c.metaWatcher.ShowSessions()
assert.NoError(t, err)
assert.NotEmpty(t, sessions)
for _, session := range sessions {
log.Info("ShowSessions result", zap.String("session", session.String()))
}
}
func TestShowSegments(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
prefix := "TestShowSegments"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionStatus, err := c.proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: 2,
})
assert.NoError(t, err)
assert.Equal(t, createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
assert.NoError(t, err)
assert.Equal(t, showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
insertResult, err := c.proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
assert.Equal(t, insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
segments, err := c.metaWatcher.ShowSegments()
assert.NoError(t, err)
assert.NotEmpty(t, segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
}
func TestShowReplicas(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
prefix := "TestShowReplicas"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionStatus, err := c.proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: 2,
})
assert.NoError(t, err)
if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason()))
}
assert.Equal(t, createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
assert.NoError(t, err)
assert.Equal(t, showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
insertResult, err := c.proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
assert.Equal(t, insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
assert.NoError(t, err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
assert.NotEmpty(t, segmentIDs)
segments, err := c.metaWatcher.ShowSegments()
assert.NoError(t, err)
assert.NotEmpty(t, segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
})
if err != nil {
panic(errors.New("GetFlushState failed"))
return false
}
return resp.GetFlushed()
}
for !flushed() {
// respect context deadline/cancel
select {
case <-ctx.Done():
panic(errors.New("deadline exceeded"))
default:
}
time.Sleep(500 * time.Millisecond)
}
}
// create index
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: strconv.Itoa(10),
},
},
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// load
loadStatus, err := c.proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
for {
select {
case <-ctx.Done():
errors.New("context deadline exceeded")
default:
}
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
replicas, err := c.metaWatcher.ShowReplicas()
assert.NoError(t, err)
assert.NotEmpty(t, replicas)
for _, replica := range replicas {
log.Info("ShowReplicas result", zap.String("replica", PrettyReplica(replica)))
}
log.Info("TestShowReplicas succeed")
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,187 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/datanode"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/querynode"
"github.com/stretchr/testify/assert"
)
func TestAddRemoveDataNode(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
datanode := datanode.NewDataNode(ctx, c.factory)
datanode.SetEtcdClient(c.etcdCli)
//datanode := c.CreateDefaultDataNode()
err = c.AddDataNode(datanode)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.DataNodeNum)
assert.Equal(t, 2, len(c.dataNodes))
err = c.RemoveDataNode(datanode)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.DataNodeNum)
assert.Equal(t, 1, len(c.dataNodes))
// add default node and remove randomly
err = c.AddDataNode(nil)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.DataNodeNum)
assert.Equal(t, 2, len(c.dataNodes))
err = c.RemoveDataNode(nil)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.DataNodeNum)
assert.Equal(t, 1, len(c.dataNodes))
}
func TestAddRemoveQueryNode(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
queryNode := querynode.NewQueryNode(ctx, c.factory)
queryNode.SetEtcdClient(c.etcdCli)
//queryNode := c.CreateDefaultQueryNode()
err = c.AddQueryNode(queryNode)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 2, len(c.queryNodes))
err = c.RemoveQueryNode(queryNode)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 1, len(c.queryNodes))
// add default node and remove randomly
err = c.AddQueryNode(nil)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 2, len(c.queryNodes))
err = c.RemoveQueryNode(nil)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 1, len(c.queryNodes))
}
func TestAddRemoveIndexNode(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
indexNode := indexnode.NewIndexNode(ctx, c.factory)
indexNode.SetEtcdClient(c.etcdCli)
//indexNode := c.CreateDefaultIndexNode()
err = c.AddIndexNode(indexNode)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 2, len(c.indexNodes))
err = c.RemoveIndexNode(indexNode)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 1, len(c.indexNodes))
// add default node and remove randomly
err = c.AddIndexNode(nil)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 2, len(c.indexNodes))
err = c.RemoveIndexNode(nil)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 1, len(c.indexNodes))
}
func TestUpdateClusterSize(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
err = c.UpdateClusterSize(ClusterConfig{
QueryNodeNum: -1,
DataNodeNum: -1,
IndexNodeNum: -1,
})
assert.Error(t, err)
err = c.UpdateClusterSize(ClusterConfig{
QueryNodeNum: 2,
DataNodeNum: 2,
IndexNodeNum: 2,
})
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.DataNodeNum)
assert.Equal(t, 2, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 2, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 2, len(c.dataNodes))
assert.Equal(t, 2, len(c.queryNodes))
assert.Equal(t, 2, len(c.indexNodes))
err = c.UpdateClusterSize(ClusterConfig{
DataNodeNum: 3,
QueryNodeNum: 2,
IndexNodeNum: 1,
})
assert.NoError(t, err)
assert.Equal(t, 3, c.clusterConfig.DataNodeNum)
assert.Equal(t, 2, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 1, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 3, len(c.dataNodes))
assert.Equal(t, 2, len(c.queryNodes))
assert.Equal(t, 1, len(c.indexNodes))
}