test: support multiple data/querynodes in integration test (#30618)

issue: https://github.com/milvus-io/milvus/issues/29507

Signed-off-by: yiwangdr <yiwangdr@gmail.com>
pull/30713/head
yiwangdr 2024-02-20 19:54:53 -08:00 committed by GitHub
parent 1346b57433
commit c6665c2a4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 1025 additions and 211 deletions

View File

@ -22,13 +22,15 @@ type coordBroker struct {
*dataCoordBroker
}
func NewCoordBroker(rc types.RootCoordClient, dc types.DataCoordClient) Broker {
func NewCoordBroker(rc types.RootCoordClient, dc types.DataCoordClient, serverID int64) Broker {
return &coordBroker{
rootCoordBroker: &rootCoordBroker{
client: rc,
client: rc,
serverID: serverID,
},
dataCoordBroker: &dataCoordBroker{
client: dc,
client: dc,
serverID: serverID,
},
}
}

View File

@ -14,18 +14,18 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type dataCoordBroker struct {
client types.DataCoordClient
client types.DataCoordClient
serverID int64
}
func (dc *dataCoordBroker) AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]typeutil.UniqueID, error) {
req := &datapb.AssignSegmentIDRequest{
NodeID: paramtable.GetNodeID(),
NodeID: dc.serverID,
PeerRole: typeutil.ProxyRole,
SegmentIDRequests: reqs,
}
@ -48,7 +48,7 @@ func (dc *dataCoordBroker) ReportTimeTick(ctx context.Context, msgs []*msgpb.Dat
req := &datapb.ReportDataNodeTtMsgsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(dc.serverID),
),
Msgs: msgs,
}
@ -69,7 +69,7 @@ func (dc *dataCoordBroker) GetSegmentInfo(ctx context.Context, segmentIDs []int6
infoResp, err := dc.client.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_SegmentInfo),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(dc.serverID),
),
SegmentIDs: segmentIDs,
IncludeUnHealthy: true,
@ -96,7 +96,7 @@ func (dc *dataCoordBroker) UpdateChannelCheckpoint(ctx context.Context, channelN
req := &datapb.UpdateChannelCheckpointRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(dc.serverID),
),
VChannel: channelName,
Position: cp,

View File

@ -33,7 +33,7 @@ func (s *dataCoordSuite) SetupSuite() {
func (s *dataCoordSuite) SetupTest() {
s.dc = mocks.NewMockDataCoordClient(s.T())
s.broker = NewCoordBroker(nil, s.dc)
s.broker = NewCoordBroker(nil, s.dc, 1)
}
func (s *dataCoordSuite) resetMock() {

View File

@ -13,12 +13,12 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type rootCoordBroker struct {
client types.RootCoordClient
client types.RootCoordClient
serverID int64
}
func (rc *rootCoordBroker) DescribeCollection(ctx context.Context, collectionID typeutil.UniqueID, timestamp typeutil.Timestamp) (*milvuspb.DescribeCollectionResponse, error) {
@ -29,7 +29,7 @@ func (rc *rootCoordBroker) DescribeCollection(ctx context.Context, collectionID
req := &milvuspb.DescribeCollectionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(rc.serverID),
),
// please do not specify the collection name alone after database feature.
CollectionID: collectionID,
@ -89,7 +89,7 @@ func (rc *rootCoordBroker) AllocTimestamp(ctx context.Context, num uint32) (uint
req := &rootcoordpb.AllocTimestampRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(rc.serverID),
),
Count: num,
}

View File

@ -33,7 +33,7 @@ func (s *rootCoordSuite) SetupSuite() {
func (s *rootCoordSuite) SetupTest() {
s.rc = mocks.NewMockRootCoordClient(s.T())
s.broker = NewCoordBroker(s.rc, nil)
s.broker = NewCoordBroker(s.rc, nil, 1)
}
func (s *rootCoordSuite) resetMock() {

View File

@ -26,12 +26,12 @@ import (
"math/rand"
"os"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/cockroachdb/errors"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -84,6 +84,7 @@ var Params *paramtable.ComponentParam = paramtable.Get()
// `segmentCache` stores all flushing and flushed segments.
type DataNode struct {
ctx context.Context
serverID int64
cancel context.CancelFunc
Role string
stateCode atomic.Value // commonpb.StateCode_Initializing
@ -127,7 +128,7 @@ type DataNode struct {
}
// NewDataNode will return a DataNode with abnormal state.
func NewDataNode(ctx context.Context, factory dependency.Factory) *DataNode {
func NewDataNode(ctx context.Context, factory dependency.Factory, serverID int64) *DataNode {
rand.Seed(time.Now().UnixNano())
ctx2, cancel2 := context.WithCancel(ctx)
node := &DataNode{
@ -138,6 +139,7 @@ func NewDataNode(ctx context.Context, factory dependency.Factory) *DataNode {
rootCoord: nil,
dataCoord: nil,
factory: factory,
serverID: serverID,
segmentCache: newCache(),
compactionExecutor: newCompactionExecutor(),
@ -189,9 +191,10 @@ func (node *DataNode) SetDataCoordClient(ds types.DataCoordClient) error {
// Register register datanode to etcd
func (node *DataNode) Register() error {
log.Debug("node begin to register to etcd", zap.String("serverName", node.session.ServerName), zap.Int64("ServerID", node.session.ServerID))
node.session.Register()
metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.DataNodeRole).Inc()
metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.DataNodeRole).Inc()
log.Info("DataNode Register Finished")
// Start liveness check
node.session.LivenessCheck(node.ctx, func() {
@ -199,7 +202,7 @@ func (node *DataNode) Register() error {
if err := node.Stop(); err != nil {
log.Fatal("failed to stop server", zap.Error(err))
}
metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.DataNodeRole).Dec()
metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.DataNodeRole).Dec()
// manually send signal to starter goroutine
if node.session.TriggerKill {
if p, err := os.FindProcess(os.Getpid()); err == nil {
@ -232,6 +235,10 @@ func (node *DataNode) initRateCollector() error {
return nil
}
func (node *DataNode) GetNodeID() int64 {
return node.serverID
}
func (node *DataNode) Init() error {
var initError error
node.initOnce.Do(func() {
@ -244,24 +251,24 @@ func (node *DataNode) Init() error {
return
}
node.broker = broker.NewCoordBroker(node.rootCoord, node.dataCoord)
node.broker = broker.NewCoordBroker(node.rootCoord, node.dataCoord, node.GetNodeID())
err := node.initRateCollector()
if err != nil {
log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", paramtable.GetNodeID()), zap.Error(err))
log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", node.GetNodeID()), zap.Error(err))
initError = err
return
}
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID()))
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", node.GetNodeID()))
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.DataNodeRole, paramtable.GetNodeID())
log.Info("DataNode server init dispatcher client done", zap.Int64("node ID", paramtable.GetNodeID()))
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.DataNodeRole, node.GetNodeID())
log.Info("DataNode server init dispatcher client done", zap.Int64("node ID", node.GetNodeID()))
alloc, err := allocator.New(context.Background(), node.rootCoord, paramtable.GetNodeID())
alloc, err := allocator.New(context.Background(), node.rootCoord, node.GetNodeID())
if err != nil {
log.Error("failed to create id allocator",
zap.Error(err),
zap.String("role", typeutil.DataNodeRole), zap.Int64("DataNode ID", paramtable.GetNodeID()))
zap.String("role", typeutil.DataNodeRole), zap.Int64("DataNode ID", node.GetNodeID()))
initError = err
return
}
@ -292,7 +299,7 @@ func (node *DataNode) Init() error {
node.channelCheckpointUpdater = newChannelCheckpointUpdater(node)
log.Info("init datanode done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", node.address))
log.Info("init datanode done", zap.Int64("nodeID", node.GetNodeID()), zap.String("Address", node.address))
})
return initError
}
@ -354,7 +361,7 @@ func (node *DataNode) Start() error {
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO),
commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(node.GetNodeID()),
),
Count: 1,
})

View File

@ -40,7 +40,6 @@ import (
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -351,7 +350,7 @@ func getServiceWithChannel(initCtx context.Context, node *DataNode, info *datapb
resendTTCh = make(chan resendTTMsg, 100)
)
node.writeBufferManager.Register(channelName, metacache, storageV2Cache, writebuffer.WithMetaWriter(syncmgr.BrokerMetaWriter(node.broker)), writebuffer.WithIDAllocator(node.allocator))
node.writeBufferManager.Register(channelName, metacache, storageV2Cache, writebuffer.WithMetaWriter(syncmgr.BrokerMetaWriter(node.broker, config.serverID)), writebuffer.WithIDAllocator(node.allocator))
ctx, cancel := context.WithCancel(node.ctx)
ds := &dataSyncService{
ctx: ctx,
@ -410,7 +409,7 @@ func getServiceWithChannel(initCtx context.Context, node *DataNode, info *datapb
}
m.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()})
metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(config.serverID)).Inc()
log.Info("datanode AsProducer", zap.String("TimeTickChannelName", Params.CommonCfg.DataCoordTimeTick.GetValue()))
m.EnableProduce(true)

View File

@ -33,7 +33,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/logutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
const retryWatchInterval = 20 * time.Second
@ -93,7 +92,7 @@ func (node *DataNode) StartWatchChannels(ctx context.Context) {
// serves the corner case for etcd connection lost and missing some events
func (node *DataNode) checkWatchedList() error {
// REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name}
prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", paramtable.GetNodeID()))
prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.serverID))
keys, values, err := node.watchKv.LoadWithPrefix(prefix)
if err != nil {
return err

View File

@ -62,6 +62,7 @@ func (fm *fgManagerImpl) AddFlowgraph(ds *dataSyncService) {
func (fm *fgManagerImpl) AddandStartWithEtcdTickler(dn *DataNode, vchan *datapb.VchannelInfo, schema *schemapb.CollectionSchema, tickler *etcdTickler) error {
log := log.With(zap.String("channel", vchan.GetChannelName()))
log.Warn(fmt.Sprintf("debug AddandStartWithEtcdTickler %d", dn.GetNodeID()))
if fm.flowgraphs.Contain(vchan.GetChannelName()) {
log.Warn("try to add an existed DataSyncService")
return nil

View File

@ -83,7 +83,7 @@ var segID2SegInfo = map[int64]*datapb.SegmentInfo{
func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode {
factory := dependency.NewDefaultFactory(true)
node := NewDataNode(ctx, factory)
node := NewDataNode(ctx, factory, 1)
node.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}})
node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID())

View File

@ -94,13 +94,13 @@ func (node *DataNode) GetComponentStates(ctx context.Context, req *milvuspb.GetC
// So if receiving calls to flush segment A, DataNode should guarantee the segment to be flushed.
func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) {
metrics.DataNodeFlushReqCounter.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(node.GetNodeID()),
metrics.TotalLabel).Inc()
log := log.Ctx(ctx)
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.FlushSegments failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode.FlushSegments failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return merr.Status(err), nil
}
@ -111,6 +111,7 @@ func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmen
zap.Int64("serverID", serverID),
)
log.Info(fmt.Sprintf("debug by FlushSegments:%v:%v", serverID, req.GetBase().GetTargetID()))
return merr.Status(merr.WrapErrNodeNotMatch(req.GetBase().GetTargetID(), serverID)), nil
}
@ -133,7 +134,7 @@ func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmen
log.Info("sending segments to WriteBuffer Manager")
metrics.DataNodeFlushReqCounter.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(node.GetNodeID()),
metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
@ -166,7 +167,7 @@ func (node *DataNode) GetStatisticsChannel(ctx context.Context, req *internalpb.
func (node *DataNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
log.Debug("DataNode.ShowConfigurations", zap.String("pattern", req.Pattern))
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.ShowConfigurations failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode.ShowConfigurations failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return &internalpb.ShowConfigurationsResponse{
Status: merr.Status(err),
@ -191,7 +192,7 @@ func (node *DataNode) ShowConfigurations(ctx context.Context, req *internalpb.Sh
// GetMetrics return datanode metrics
func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.GetMetrics failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode.GetMetrics failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return &milvuspb.GetMetricsResponse{
Status: merr.Status(err),
@ -201,7 +202,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe
metricType, err := metricsinfo.ParseMetricType(req.Request)
if err != nil {
log.Warn("DataNode.GetMetrics failed to parse metric type",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeID", node.GetNodeID()),
zap.String("req", req.Request),
zap.Error(err))
@ -213,7 +214,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe
if metricType == metricsinfo.SystemInfoMetrics {
systemInfoMetrics, err := node.getSystemInfoMetrics(ctx, req)
if err != nil {
log.Warn("DataNode GetMetrics failed", zap.Int64("nodeID", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode GetMetrics failed", zap.Int64("nodeID", node.GetNodeID()), zap.Error(err))
return &milvuspb.GetMetricsResponse{
Status: merr.Status(err),
}, nil
@ -223,7 +224,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe
}
log.RatedWarn(60, "DataNode.GetMetrics failed, request metric type is not implemented yet",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeID", node.GetNodeID()),
zap.String("req", req.Request),
zap.String("metric_type", metricType))
@ -237,7 +238,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe
func (node *DataNode) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(zap.Int64("planID", req.GetPlanID()))
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.Compaction failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode.Compaction failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return merr.Status(err), nil
}
@ -307,7 +308,7 @@ func (node *DataNode) Compaction(ctx context.Context, req *datapb.CompactionPlan
// return status of all compaction plans
func (node *DataNode) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) {
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.GetCompactionState failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode.GetCompactionState failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return &datapb.CompactionStateResponse{
Status: merr.Status(err),
}, nil
@ -330,7 +331,7 @@ func (node *DataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegments
)
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.SyncSegments failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode.SyncSegments failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return merr.Status(err), nil
}
@ -366,7 +367,7 @@ func (node *DataNode) NotifyChannelOperation(ctx context.Context, req *datapb.Ch
zap.Int("operation count", len(req.GetInfos())))
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.NotifyChannelOperation failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode.NotifyChannelOperation failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return merr.Status(err), nil
}
@ -389,7 +390,7 @@ func (node *DataNode) CheckChannelOperationProgress(ctx context.Context, req *da
log.Info("DataNode receives CheckChannelOperationProgress")
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
log.Warn("DataNode.CheckChannelOperationProgress failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err))
log.Warn("DataNode.CheckChannelOperationProgress failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err))
return &datapb.ChannelOperationProgressResponse{
Status: merr.Status(err),
}, nil
@ -406,7 +407,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest)
zap.String("database name", req.GetImportTask().GetDatabaseName()),
zap.Strings("channel names", req.GetImportTask().GetChannelNames()),
zap.Int64s("working dataNodes", req.WorkingNodes),
zap.Int64("node ID", paramtable.GetNodeID()),
zap.Int64("node ID", node.GetNodeID()),
}
log.Info("DataNode receive import request", logFields...)
defer func() {
@ -416,7 +417,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest)
importResult := &rootcoordpb.ImportResult{
Status: merr.Success(),
TaskId: req.GetImportTask().TaskId,
DatanodeId: paramtable.GetNodeID(),
DatanodeId: node.GetNodeID(),
State: commonpb.ImportState_ImportStarted,
Segments: make([]int64, 0),
AutoIds: make([]int64, 0),
@ -513,7 +514,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest)
}
func (node *DataNode) FlushChannels(ctx context.Context, req *datapb.FlushChannelsRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(zap.Int64("nodeId", paramtable.GetNodeID()),
log := log.Ctx(ctx).With(zap.Int64("nodeId", node.GetNodeID()),
zap.Time("flushTs", tsoutil.PhysicalTime(req.GetFlushTs())),
zap.Strings("channels", req.GetChannels()))
@ -557,7 +558,7 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor
return nil
}, retry.Attempts(getFlowGraphServiceAttempts))
if err != nil {
logFields = append(logFields, zap.Int64("node ID", paramtable.GetNodeID()))
logFields = append(logFields, zap.Int64("node ID", node.GetNodeID()))
log.Error("channel not found in current DataNode", logFields...)
return &datapb.AddImportSegmentResponse{
Status: &commonpb.Status{
@ -660,7 +661,7 @@ func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil
importResult := &rootcoordpb.ImportResult{
Status: merr.Success(),
TaskId: req.GetImportTask().TaskId,
DatanodeId: paramtable.GetNodeID(),
DatanodeId: node.GetNodeID(),
State: commonpb.ImportState_ImportStarted,
Segments: []int64{segmentID},
AutoIds: make([]int64, 0),
@ -732,7 +733,7 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo
err := node.broker.SaveImportSegment(context.Background(), &datapb.SaveImportSegmentRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithTimeStamp(ts), // Pass current timestamp downstream.
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(node.GetNodeID()),
),
SegmentId: segmentID,
ChannelName: targetChName,
@ -742,7 +743,7 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo
SaveBinlogPathReq: &datapb.SaveBinlogPathsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(node.GetNodeID()),
),
SegmentID: segmentID,
CollectionID: req.GetImportTask().GetCollectionId(),

View File

@ -12,7 +12,6 @@ import (
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
)
@ -63,7 +62,7 @@ func (u *mqStatsUpdater) send(ts Timestamp, segmentIDs []int64) error {
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt),
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(u.config.serverID),
),
ChannelName: u.config.vChannelName,
Timestamp: ts,

View File

@ -13,7 +13,6 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/retry"
)
@ -25,14 +24,16 @@ type MetaWriter interface {
}
type brokerMetaWriter struct {
broker broker.Broker
opts []retry.Option
broker broker.Broker
opts []retry.Option
serverID int64
}
func BrokerMetaWriter(broker broker.Broker, opts ...retry.Option) MetaWriter {
func BrokerMetaWriter(broker broker.Broker, serverID int64, opts ...retry.Option) MetaWriter {
return &brokerMetaWriter{
broker: broker,
opts: opts,
broker: broker,
serverID: serverID,
opts: opts,
}
}
@ -82,7 +83,7 @@ func (b *brokerMetaWriter) UpdateSync(pack *SyncTask) error {
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(0),
commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(b.serverID),
),
SegmentID: pack.segmentID,
CollectionID: pack.collectionID,
@ -165,7 +166,7 @@ func (b *brokerMetaWriter) UpdateSyncV2(pack *SyncTaskV2) error {
req := &datapb.SaveBinlogPathsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(b.serverID),
),
SegmentID: pack.segmentID,
CollectionID: pack.collectionID,
@ -214,7 +215,7 @@ func (b *brokerMetaWriter) DropChannel(channelName string) error {
err := retry.Do(context.Background(), func() error {
status, err := b.broker.DropVirtualChannel(context.Background(), &datapb.DropVirtualChannelRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithSourceID(paramtable.GetNodeID()),
commonpbutil.WithSourceID(b.serverID),
),
ChannelName: channelName,
})

View File

@ -30,7 +30,7 @@ func (s *MetaWriterSuite) SetupSuite() {
func (s *MetaWriterSuite) SetupTest() {
s.broker = broker.NewMockBroker(s.T())
s.metacache = metacache.NewMockMetaCache(s.T())
s.writer = BrokerMetaWriter(s.broker, retry.Attempts(1))
s.writer = BrokerMetaWriter(s.broker, 1, retry.Attempts(1))
}
func (s *MetaWriterSuite) TestNormalSave() {

View File

@ -160,7 +160,7 @@ func (s *SyncManagerSuite) TestSubmit() {
manager, err := NewSyncManager(s.chunkManager, s.allocator)
s.NoError(err)
task := s.getSuiteSyncTask()
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithTimeRange(50, 100)
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
@ -192,7 +192,7 @@ func (s *SyncManagerSuite) TestCompacted() {
manager, err := NewSyncManager(s.chunkManager, s.allocator)
s.NoError(err)
task := s.getSuiteSyncTask()
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithTimeRange(50, 100)
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
@ -235,7 +235,7 @@ func (s *SyncManagerSuite) TestBlock() {
go func() {
task := s.getSuiteSyncTask()
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithTimeRange(50, 100)
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,

View File

@ -190,7 +190,7 @@ func (s *SyncTaskSuite) TestRunNormal() {
s.Run("without_data", func() {
task := s.getSuiteSyncTask()
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithTimeRange(50, 100)
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
@ -205,7 +205,7 @@ func (s *SyncTaskSuite) TestRunNormal() {
s.Run("with_insert_delete_cp", func() {
task := s.getSuiteSyncTask()
task.WithTimeRange(50, 100)
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
MsgID: []byte{1, 2, 3, 4},
@ -223,7 +223,7 @@ func (s *SyncTaskSuite) TestRunNormal() {
s.Run("with_statslog", func() {
task := s.getSuiteSyncTask()
task.WithTimeRange(50, 100)
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
MsgID: []byte{1, 2, 3, 4},
@ -246,7 +246,7 @@ func (s *SyncTaskSuite) TestRunNormal() {
s.Run("with_delta_data", func() {
task := s.getSuiteSyncTask()
task.WithTimeRange(50, 100)
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
MsgID: []byte{1, 2, 3, 4},
@ -278,7 +278,7 @@ func (s *SyncTaskSuite) TestRunL0Segment() {
Value: []byte("test_data"),
}
task.WithTimeRange(50, 100)
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
MsgID: []byte{1, 2, 3, 4},
@ -315,7 +315,7 @@ func (s *SyncTaskSuite) TestCompactToNull() {
s.metacache.EXPECT().GetSegmentByID(s.segmentID).Return(seg, true)
task := s.getSuiteSyncTask()
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithTimeRange(50, 100)
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
@ -379,7 +379,7 @@ func (s *SyncTaskSuite) TestRunError() {
s.broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(errors.New("mocked"))
task := s.getSuiteSyncTask()
task.WithMetaWriter(BrokerMetaWriter(s.broker, retry.Attempts(1)))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1, retry.Attempts(1)))
task.WithTimeRange(50, 100)
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,

View File

@ -221,7 +221,7 @@ func (s *SyncTaskSuiteV2) TestRunNormal() {
s.Run("without_insert_delete", func() {
task := s.getSuiteSyncTask()
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithTimeRange(50, 100)
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
@ -236,7 +236,7 @@ func (s *SyncTaskSuiteV2) TestRunNormal() {
s.Run("with_insert_delete_cp", func() {
task := s.getSuiteSyncTask()
task.WithTimeRange(50, 100)
task.WithMetaWriter(BrokerMetaWriter(s.broker))
task.WithMetaWriter(BrokerMetaWriter(s.broker, 1))
task.WithCheckpoint(&msgpb.MsgPosition{
ChannelName: s.channelName,
MsgID: []byte{1, 2, 3, 4},

View File

@ -43,10 +43,11 @@ type Client struct {
grpcClient grpcclient.GrpcClient[datapb.DataNodeClient]
sess *sessionutil.Session
addr string
serverID int64
}
// NewClient creates a client for DataNode.
func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) {
func NewClient(ctx context.Context, addr string, serverID int64) (*Client, error) {
if addr == "" {
return nil, fmt.Errorf("address is empty")
}
@ -61,12 +62,13 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error)
addr: addr,
grpcClient: grpcclient.NewClientBase[datapb.DataNodeClient](config, "milvus.proto.data.DataNode"),
sess: sess,
serverID: serverID,
}
// node shall specify node id
client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.DataNodeRole, nodeID))
client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.DataNodeRole, serverID))
client.grpcClient.SetGetAddrFunc(client.getAddr)
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetNodeID(nodeID)
client.grpcClient.SetNodeID(serverID)
client.grpcClient.SetSession(sess)
return client, nil
@ -120,7 +122,7 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.serverID))
return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) {
return client.WatchDmChannels(ctx, req)
})
@ -142,7 +144,7 @@ func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsReq
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.serverID))
return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) {
return client.FlushSegments(ctx, req)
})
@ -153,7 +155,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.serverID))
return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*internalpb.ShowConfigurationsResponse, error) {
return client.ShowConfigurations(ctx, req)
})
@ -164,7 +166,7 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.serverID))
return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*milvuspb.GetMetricsResponse, error) {
return client.GetMetrics(ctx, req)
})
@ -181,7 +183,7 @@ func (c *Client) GetCompactionState(ctx context.Context, req *datapb.CompactionS
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.serverID))
return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*datapb.CompactionStateResponse, error) {
return client.GetCompactionState(ctx, req)
})
@ -192,7 +194,7 @@ func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest, opts
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.serverID))
return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) {
return client.Import(ctx, req)
})
@ -202,7 +204,7 @@ func (c *Client) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegme
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.serverID))
return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*datapb.ResendSegmentStatsResponse, error) {
return client.ResendSegmentStats(ctx, req)
})
@ -213,7 +215,7 @@ func (c *Client) AddImportSegment(ctx context.Context, req *datapb.AddImportSegm
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.serverID))
return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*datapb.AddImportSegmentResponse, error) {
return client.AddImportSegment(ctx, req)
})

View File

@ -90,7 +90,8 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
},
}
s.datanode = dn.NewDataNode(s.ctx, s.factory)
s.serverID.Store(paramtable.GetNodeID())
s.datanode = dn.NewDataNode(s.ctx, s.factory, s.serverID.Load())
return s, nil
}
@ -246,6 +247,7 @@ func (s *Server) init() error {
s.SetEtcdClient(s.etcdCli)
s.datanode.SetAddress(Params.GetAddress())
log.Info("DataNode address", zap.String("address", Params.IP+":"+strconv.Itoa(Params.Port.GetAsInt())))
log.Info("DataNode serverID", zap.Int64("serverID", s.serverID.Load()))
err = s.startGrpc()
if err != nil {

View File

@ -91,6 +91,10 @@ func (m *MockDataNode) GetAddress() string {
return ""
}
func (m *MockDataNode) GetNodeID() int64 {
return 2
}
func (m *MockDataNode) SetRootCoordClient(rc types.RootCoordClient) error {
return m.err
}

View File

@ -41,6 +41,7 @@ type Client struct {
grpcClient grpcclient.GrpcClient[querypb.QueryNodeClient]
addr string
sess *sessionutil.Session
nodeID int64
}
// NewClient creates a new QueryNode client.
@ -59,6 +60,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error)
addr: addr,
grpcClient: grpcclient.NewClientBase[querypb.QueryNodeClient](config, "milvus.proto.query.QueryNode"),
sess: sess,
nodeID: nodeID,
}
// node shall specify node id
client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.QueryNodeRole, nodeID))
@ -122,7 +124,7 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChanne
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.WatchDmChannels(ctx, req)
})
@ -133,7 +135,7 @@ func (c *Client) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannel
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.UnsubDmChannel(ctx, req)
})
@ -144,7 +146,7 @@ func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequ
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.LoadSegments(ctx, req)
})
@ -155,7 +157,7 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.ReleaseCollection(ctx, req)
})
@ -166,7 +168,7 @@ func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.LoadPartitions(ctx, req)
})
@ -177,7 +179,7 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.ReleasePartitions(ctx, req)
})
@ -188,7 +190,7 @@ func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmen
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.ReleaseSegments(ctx, req)
})
@ -253,7 +255,7 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.GetSegmentInfoResponse, error) {
return client.GetSegmentInfo(ctx, req)
})
@ -264,7 +266,7 @@ func (c *Client) SyncReplicaSegments(ctx context.Context, req *querypb.SyncRepli
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.SyncReplicaSegments(ctx, req)
})
@ -275,7 +277,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*internalpb.ShowConfigurationsResponse, error) {
return client.ShowConfigurations(ctx, req)
})
@ -286,7 +288,7 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*milvuspb.GetMetricsResponse, error) {
return client.GetMetrics(ctx, req)
})
@ -302,7 +304,7 @@ func (c *Client) GetDataDistribution(ctx context.Context, req *querypb.GetDataDi
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.GetDataDistributionResponse, error) {
return client.GetDataDistribution(ctx, req)
})
@ -312,7 +314,7 @@ func (c *Client) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
commonpbutil.FillMsgBaseFromClient(c.nodeID))
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.SyncDistribution(ctx, req)
})
@ -323,7 +325,7 @@ func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...gr
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()),
commonpbutil.FillMsgBaseFromClient(c.nodeID),
)
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) {
return client.Delete(ctx, req)

View File

@ -132,6 +132,7 @@ func (s *Server) init() error {
log.Error("QueryNode init error: ", zap.Error(err))
return err
}
s.serverID.Store(s.querynode.GetNodeID())
return nil
}

View File

@ -91,6 +91,7 @@ func Test_NewServer(t *testing.T) {
mockQN.EXPECT().SetAddress(mock.Anything).Maybe()
mockQN.EXPECT().UpdateStateCode(mock.Anything).Maybe()
mockQN.EXPECT().Init().Return(nil).Maybe()
mockQN.EXPECT().GetNodeID().Return(2).Maybe()
server.querynode = mockQN
t.Run("Run", func(t *testing.T) {
@ -285,6 +286,7 @@ func Test_Run(t *testing.T) {
mockQN.EXPECT().SetAddress(mock.Anything).Maybe()
mockQN.EXPECT().UpdateStateCode(mock.Anything).Maybe()
mockQN.EXPECT().Init().Return(nil).Maybe()
mockQN.EXPECT().GetNodeID().Return(2).Maybe()
server.querynode = mockQN
err = server.Run()
assert.Error(t, err)

View File

@ -568,6 +568,47 @@ func (_c *MockDataNode_GetMetrics_Call) RunAndReturn(run func(context.Context, *
return _c
}
// GetNodeID provides a mock function with given fields:
func (_m *MockDataNode) GetNodeID() int64 {
ret := _m.Called()
var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockDataNode_GetNodeID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeID'
type MockDataNode_GetNodeID_Call struct {
*mock.Call
}
// GetNodeID is a helper method to define mock.On call
func (_e *MockDataNode_Expecter) GetNodeID() *MockDataNode_GetNodeID_Call {
return &MockDataNode_GetNodeID_Call{Call: _e.mock.On("GetNodeID")}
}
func (_c *MockDataNode_GetNodeID_Call) Run(run func()) *MockDataNode_GetNodeID_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockDataNode_GetNodeID_Call) Return(_a0 int64) *MockDataNode_GetNodeID_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockDataNode_GetNodeID_Call) RunAndReturn(run func() int64) *MockDataNode_GetNodeID_Call {
_c.Call.Return(run)
return _c
}
// GetStateCode provides a mock function with given fields:
func (_m *MockDataNode) GetStateCode() commonpb.StateCode {
ret := _m.Called()

View File

@ -291,6 +291,47 @@ func (_c *MockQueryNode_GetMetrics_Call) RunAndReturn(run func(context.Context,
return _c
}
// GetNodeID provides a mock function with given fields:
func (_m *MockQueryNode) GetNodeID() int64 {
ret := _m.Called()
var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockQueryNode_GetNodeID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeID'
type MockQueryNode_GetNodeID_Call struct {
*mock.Call
}
// GetNodeID is a helper method to define mock.On call
func (_e *MockQueryNode_Expecter) GetNodeID() *MockQueryNode_GetNodeID_Call {
return &MockQueryNode_GetNodeID_Call{Call: _e.mock.On("GetNodeID")}
}
func (_c *MockQueryNode_GetNodeID_Call) Run(run func()) *MockQueryNode_GetNodeID_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockQueryNode_GetNodeID_Call) Return(_a0 int64) *MockQueryNode_GetNodeID_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockQueryNode_GetNodeID_Call) RunAndReturn(run func() int64) *MockQueryNode_GetNodeID_Call {
_c.Call.Return(run)
return _c
}
// GetSegmentInfo provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNode) GetSegmentInfo(_a0 context.Context, _a1 *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
ret := _m.Called(_a0, _a1)

View File

@ -37,7 +37,6 @@ import (
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
@ -184,10 +183,10 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque
)
var err error
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc()
defer func() {
if err != nil {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc()
}
}()
@ -244,13 +243,13 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque
))
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.Leader).Inc()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.Leader).Inc()
return resp, nil
}
func (node *QueryNode) queryChannelStream(ctx context.Context, req *querypb.QueryRequest, channel string, srv streamrpc.QueryStreamServer) error {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc()
msgID := req.Req.Base.GetMsgID()
log := log.Ctx(ctx).With(
zap.Int64("msgID", msgID),
@ -262,7 +261,7 @@ func (node *QueryNode) queryChannelStream(ctx context.Context, req *querypb.Quer
var err error
defer func() {
if err != nil {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc()
}
}()
@ -344,10 +343,10 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
defer node.lifetime.Done()
var err error
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
defer func() {
if err != nil {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.Leader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.Leader).Inc()
}
}()
@ -394,10 +393,10 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
// update metric to prometheus
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk()))
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(req.Req.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(req.Req.GetTopk()))
return resp, nil
}
@ -415,10 +414,10 @@ func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.Hyb
defer node.lifetime.Done()
var err error
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
defer func() {
if err != nil {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc()
}
}()
@ -449,11 +448,11 @@ func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.Hyb
// update metric to prometheus
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
for _, searchReq := range req.GetReq().GetReqs() {
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetTopk()))
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(searchReq.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(searchReq.GetTopk()))
}
return result, nil
}

View File

@ -114,7 +114,7 @@ func getQuotaMetrics(node *QueryNode) (*metricsinfo.QueryNodeQuotaMetrics, error
return seg.MemSize()
})
totalGrowingSize += size
metrics.QueryNodeEntitiesSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeEntitiesSize.WithLabelValues(fmt.Sprint(node.GetNodeID()),
fmt.Sprint(collection), segments.SegmentTypeGrowing.String()).Set(float64(size))
}
@ -126,7 +126,7 @@ func getQuotaMetrics(node *QueryNode) (*metricsinfo.QueryNodeQuotaMetrics, error
size := lo.SumBy(segs, func(seg segments.Segment) int64 {
return seg.MemSize()
})
metrics.QueryNodeEntitiesSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryNodeEntitiesSize.WithLabelValues(fmt.Sprint(node.GetNodeID()),
fmt.Sprint(collection), segments.SegmentTypeSealed.String()).Set(float64(size))
}
@ -148,7 +148,7 @@ func getQuotaMetrics(node *QueryNode) (*metricsinfo.QueryNodeQuotaMetrics, error
QueryQueue: qqms,
GrowingSegmentsSize: totalGrowingSize,
Effect: metricsinfo.NodeEffect{
NodeID: paramtable.GetNodeID(),
NodeID: node.GetNodeID(),
CollectionIDs: collections.Collect(),
},
}, nil
@ -163,7 +163,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest,
if err != nil {
return &milvuspb.GetMetricsResponse{
Status: merr.Status(err),
ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()),
ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, node.GetNodeID()),
}, nil
}
hardwareInfos := metricsinfo.HardwareMetrics{
@ -179,7 +179,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest,
nodeInfos := metricsinfo.QueryNodeInfos{
BaseComponentInfos: metricsinfo.BaseComponentInfos{
Name: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
Name: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, node.GetNodeID()),
HardwareInfos: hardwareInfos,
SystemInfo: metricsinfo.DeployMetrics{},
CreatedTime: paramtable.GetCreateTime().String(),
@ -199,13 +199,13 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest,
return &milvuspb.GetMetricsResponse{
Status: merr.Status(err),
Response: "",
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, node.GetNodeID()),
}, nil
}
return &milvuspb.GetMetricsResponse{
Status: merr.Success(),
Response: resp,
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, node.GetNodeID()),
}, nil
}

View File

@ -105,6 +105,7 @@ type QueryNode struct {
subscribingChannels *typeutil.ConcurrentSet[string]
unsubscribingChannels *typeutil.ConcurrentSet[string]
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator]
serverID int64
// segment loader
loader segments.Loader
@ -156,7 +157,8 @@ func (node *QueryNode) initSession() error {
node.session.Init(typeutil.QueryNodeRole, node.address, false, true)
sessionutil.SaveServerInfo(typeutil.QueryNodeRole, node.session.ServerID)
paramtable.SetNodeID(node.session.ServerID)
log.Info("QueryNode init session", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("node address", node.session.Address))
node.serverID = node.session.ServerID
log.Info("QueryNode init session", zap.Int64("nodeID", node.GetNodeID()), zap.String("node address", node.session.Address))
return nil
}
@ -164,13 +166,13 @@ func (node *QueryNode) initSession() error {
func (node *QueryNode) Register() error {
node.session.Register()
// start liveness check
metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.QueryNodeRole).Inc()
metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.QueryNodeRole).Inc()
node.session.LivenessCheck(node.ctx, func() {
log.Error("Query Node disconnected from etcd, process will exit", zap.Int64("Server Id", paramtable.GetNodeID()))
log.Error("Query Node disconnected from etcd, process will exit", zap.Int64("Server Id", node.GetNodeID()))
if err := node.Stop(); err != nil {
log.Fatal("failed to stop server", zap.Error(err))
}
metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.QueryNodeRole).Dec()
metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.QueryNodeRole).Dec()
// manually send signal to starter goroutine
if node.session.TriggerKill {
if p, err := os.FindProcess(os.Getpid()); err == nil {
@ -263,6 +265,10 @@ func getIndexEngineVersion() (minimal, current int32) {
return int32(cMinimal), int32(cCurrent)
}
func (node *QueryNode) GetNodeID() int64 {
return node.serverID
}
func (node *QueryNode) CloseSegcore() {
// safe stop
initcore.CleanRemoteChunkManager()
@ -301,7 +307,7 @@ func (node *QueryNode) Init() error {
initError = err
return
}
metrics.QueryNodeDiskUsedSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(localUsedSize / 1024 / 1024))
metrics.QueryNodeDiskUsedSize.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(float64(localUsedSize / 1024 / 1024))
node.chunkManager, err = node.factory.NewPersistentStorageChunkManager(node.ctx)
if err != nil {
@ -317,7 +323,7 @@ func (node *QueryNode) Init() error {
log.Info("queryNode init scheduler", zap.String("policy", schedulePolicy))
node.clusterManager = cluster.NewWorkerManager(func(ctx context.Context, nodeID int64) (cluster.Worker, error) {
if nodeID == paramtable.GetNodeID() {
if nodeID == node.GetNodeID() {
return NewLocalWorker(node), nil
}
@ -350,7 +356,7 @@ func (node *QueryNode) Init() error {
} else {
node.loader = segments.NewLoader(node.manager, node.chunkManager)
}
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, paramtable.GetNodeID())
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, node.GetNodeID())
// init pipeline manager
node.pipelineManager = pipeline.NewManager(node.manager, node.tSafeManager, node.dispClient, node.delegators)
@ -373,7 +379,7 @@ func (node *QueryNode) Init() error {
}
log.Info("query node init successfully",
zap.Int64("queryNodeID", paramtable.GetNodeID()),
zap.Int64("queryNodeID", node.GetNodeID()),
zap.String("Address", node.address),
)
})
@ -392,9 +398,9 @@ func (node *QueryNode) Start() error {
mmapEnabled := len(mmapDirPath) > 0
node.UpdateStateCode(commonpb.StateCode_Healthy)
registry.GetInMemoryResolver().RegisterQueryNode(paramtable.GetNodeID(), node)
registry.GetInMemoryResolver().RegisterQueryNode(node.GetNodeID(), node)
log.Info("query node start successfully",
zap.Int64("queryNodeID", paramtable.GetNodeID()),
zap.Int64("queryNodeID", node.GetNodeID()),
zap.String("Address", node.address),
zap.Bool("mmapEnabled", mmapEnabled),
)
@ -432,7 +438,7 @@ func (node *QueryNode) Stop() error {
select {
case <-timeoutCh:
log.Warn("migrate data timed out", zap.Int64("ServerID", paramtable.GetNodeID()),
log.Warn("migrate data timed out", zap.Int64("ServerID", node.GetNodeID()),
zap.Int64s("sealedSegments", lo.Map(sealedSegments, func(s segments.Segment, i int) int64 {
return s.ID()
})),
@ -444,14 +450,14 @@ func (node *QueryNode) Stop() error {
break outer
case <-time.After(time.Second):
metrics.StoppingBalanceSegmentNum.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(len(sealedSegments)))
metrics.StoppingBalanceChannelNum.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(channelNum))
metrics.StoppingBalanceSegmentNum.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(float64(len(sealedSegments)))
metrics.StoppingBalanceChannelNum.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(float64(channelNum))
}
}
metrics.StoppingBalanceNodeNum.WithLabelValues().Set(0)
metrics.StoppingBalanceSegmentNum.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(0)
metrics.StoppingBalanceChannelNum.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(0)
metrics.StoppingBalanceSegmentNum.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(0)
metrics.StoppingBalanceChannelNum.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(0)
}
node.UpdateStateCode(commonpb.StateCode_Abnormal)

View File

@ -66,7 +66,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context, req *milvuspb.Get
log.Debug("QueryNode current state", zap.Int64("NodeID", nodeID), zap.String("StateCode", code.String()))
if node.session != nil && node.session.Registered() {
nodeID = paramtable.GetNodeID()
nodeID = node.GetNodeID()
}
info := &milvuspb.ComponentInfo{
NodeID: nodeID,
@ -112,7 +112,7 @@ func (node *QueryNode) GetStatistics(ctx context.Context, req *querypb.GetStatis
}
defer node.lifetime.Done()
err := merr.CheckTargetID(req.GetReq().GetBase())
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &internalpb.GetStatisticsResponse{
@ -200,7 +200,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.String("channel", channel.GetChannelName()),
zap.Int64("currentNodeID", paramtable.GetNodeID()),
zap.Int64("currentNodeID", node.GetNodeID()),
)
log.Info("received watch channel request",
@ -214,7 +214,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(req.GetBase()); err != nil {
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
@ -347,7 +347,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.String("channel", req.GetChannelName()),
zap.Int64("currentNodeID", paramtable.GetNodeID()),
zap.Int64("currentNodeID", node.GetNodeID()),
)
log.Info("received unsubscribe channel request")
@ -359,7 +359,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(req.GetBase()); err != nil {
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
@ -412,7 +412,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen
zap.Int64("partitionID", segment.GetPartitionID()),
zap.String("shard", segment.GetInsertChannel()),
zap.Int64("segmentID", segment.GetSegmentID()),
zap.Int64("currentNodeID", paramtable.GetNodeID()),
zap.Int64("currentNodeID", node.GetNodeID()),
)
log.Info("received load segments request",
@ -426,7 +426,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(req.GetBase()); err != nil {
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
@ -529,7 +529,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release
zap.Int64("collectionID", req.GetCollectionID()),
zap.String("shard", req.GetShard()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
zap.Int64("currentNodeID", paramtable.GetNodeID()),
zap.Int64("currentNodeID", node.GetNodeID()),
)
log.Info("received release segment request",
@ -544,7 +544,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(req.GetBase()); err != nil {
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
@ -630,8 +630,8 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen
DmChannel: segment.Shard(),
PartitionID: segment.Partition(),
CollectionID: segment.Collection(),
NodeID: paramtable.GetNodeID(),
NodeIds: []int64{paramtable.GetNodeID()},
NodeID: node.GetNodeID(),
NodeIds: []int64{node.GetNodeID()},
MemSize: segment.MemSize(),
NumRows: segment.InsertCount(),
IndexName: indexName,
@ -669,10 +669,10 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
}
defer node.lifetime.Done()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.FromLeader).Inc()
defer func() {
if !merr.Ok(resp.GetStatus()) {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.FromLeader).Inc()
}
}()
@ -693,7 +693,7 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
return resp, nil
}
task := tasks.NewSearchTask(searchCtx, collection, node.manager, req)
task := tasks.NewSearchTask(searchCtx, collection, node.manager, req, node.serverID)
if err := node.scheduler.Add(task); err != nil {
log.Warn("failed to search channel", zap.Error(err))
resp.Status = merr.Status(err)
@ -713,8 +713,8 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
))
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
resp = task.Result()
resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
@ -750,7 +750,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
}
defer node.lifetime.Done()
err := merr.CheckTargetID(req.GetReq().GetBase())
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &internalpb.SearchResults{
@ -807,12 +807,12 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
return resp, nil
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).
Observe(float64(reduceLatency.Milliseconds()))
collector.Rate.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq()))
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.SearchLabel).
Add(float64(proto.Size(req)))
if result.GetCostAggregation() != nil {
@ -836,19 +836,19 @@ func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSear
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
return &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
SourceID: node.GetNodeID(),
},
Status: merr.Status(err),
}, nil
}
defer node.lifetime.Done()
err := merr.CheckTargetID(req.GetReq().GetBase())
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
SourceID: node.GetNodeID(),
},
Status: merr.Status(err),
}, nil
@ -856,7 +856,7 @@ func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSear
resp := &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
SourceID: node.GetNodeID(),
},
Status: merr.Success(),
}
@ -916,11 +916,11 @@ func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSear
resp.ChannelsMvcc = channelsMvcc
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards).
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards).
Observe(float64(reduceLatency.Milliseconds()))
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.HybridSearchLabel).
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.HybridSearchLabel).
Add(float64(proto.Size(req)))
if resp.GetCostAggregation() != nil {
@ -950,10 +950,10 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ
}
defer node.lifetime.Done()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc()
defer func() {
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc()
}
}()
@ -995,8 +995,8 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ
// TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
result := task.Result()
result.GetCostAggregation().ResponseTime = latency.Milliseconds()
result.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ()
@ -1031,7 +1031,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
}
defer node.lifetime.Done()
err := merr.CheckTargetID(req.GetReq().GetBase())
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &internalpb.RetrieveResults{
@ -1080,12 +1080,12 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
}, nil
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards).
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards).
Observe(float64(reduceLatency.Milliseconds()))
if !req.FromShardLeader {
collector.Rate.Add(metricsinfo.NQPerSecond, 1)
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
}
if ret.GetCostAggregation() != nil {
@ -1116,7 +1116,7 @@ func (node *QueryNode) QueryStream(req *querypb.QueryRequest, srv querypb.QueryN
}
defer node.lifetime.Done()
err := merr.CheckTargetID(req.GetReq().GetBase())
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return err
@ -1151,7 +1151,7 @@ func (node *QueryNode) QueryStream(req *querypb.QueryRequest, srv querypb.QueryN
}
collector.Rate.Add(metricsinfo.NQPerSecond, 1)
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
return nil
}
@ -1170,10 +1170,10 @@ func (node *QueryNode) QueryStreamSegments(req *querypb.QueryRequest, srv queryp
)
resp := &internalpb.RetrieveResults{}
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc()
defer func() {
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc()
}
}()
@ -1207,8 +1207,8 @@ func (node *QueryNode) QueryStreamSegments(req *querypb.QueryRequest, srv queryp
// TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
return nil
}
@ -1221,7 +1221,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
log.Warn("QueryNode.ShowConfigurations failed",
zap.Int64("nodeId", paramtable.GetNodeID()),
zap.Int64("nodeId", node.GetNodeID()),
zap.String("req", req.Pattern),
zap.Error(err))
@ -1251,7 +1251,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S
func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
log.Warn("QueryNode.GetMetrics failed",
zap.Int64("nodeId", paramtable.GetNodeID()),
zap.Int64("nodeId", node.GetNodeID()),
zap.String("req", req.Request),
zap.Error(err))
@ -1265,7 +1265,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
metricType, err := metricsinfo.ParseMetricType(req.Request)
if err != nil {
log.Warn("QueryNode.GetMetrics failed to parse metric type",
zap.Int64("nodeId", paramtable.GetNodeID()),
zap.Int64("nodeId", node.GetNodeID()),
zap.String("req", req.Request),
zap.Error(err))
@ -1278,7 +1278,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
queryNodeMetrics, err := getSystemInfoMetrics(ctx, req, node)
if err != nil {
log.Warn("QueryNode.GetMetrics failed",
zap.Int64("nodeId", paramtable.GetNodeID()),
zap.Int64("nodeId", node.GetNodeID()),
zap.String("req", req.Request),
zap.String("metricType", metricType),
zap.Error(err))
@ -1287,7 +1287,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
}, nil
}
log.RatedDebug(50, "QueryNode.GetMetrics",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeID", node.GetNodeID()),
zap.String("req", req.Request),
zap.String("metricType", metricType),
zap.Any("queryNodeMetrics", queryNodeMetrics))
@ -1296,7 +1296,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
}
log.Debug("QueryNode.GetMetrics failed, request metric type is not implemented yet",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeID", node.GetNodeID()),
zap.String("req", req.Request),
zap.String("metricType", metricType))
@ -1308,7 +1308,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) {
log := log.Ctx(ctx).With(
zap.Int64("msgID", req.GetBase().GetMsgID()),
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Int64("nodeID", node.GetNodeID()),
)
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
log.Warn("QueryNode.GetDataDistribution failed",
@ -1321,7 +1321,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(req.GetBase()); err != nil {
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return &querypb.GetDataDistributionResponse{
Status: merr.Status(err),
}, nil
@ -1393,7 +1393,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
return &querypb.GetDataDistributionResponse{
Status: merr.Success(),
NodeID: paramtable.GetNodeID(),
NodeID: node.GetNodeID(),
Segments: segmentVersionInfos,
Channels: channelVersionInfos,
LeaderViews: leaderViews,
@ -1402,7 +1402,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()),
zap.String("channel", req.GetChannel()), zap.Int64("currentNodeID", paramtable.GetNodeID()))
zap.String("channel", req.GetChannel()), zap.Int64("currentNodeID", node.GetNodeID()))
// check node healthy
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
return merr.Status(err), nil
@ -1410,7 +1410,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(req.GetBase()); err != nil {
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
@ -1510,7 +1510,7 @@ func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) (
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(req.GetBase()); err != nil {
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}

View File

@ -48,6 +48,7 @@ type SearchTask struct {
originNqs []int64
others []*SearchTask
notifier chan error
serverID int64
tr *timerecord.TimeRecorder
scheduleSpan trace.Span
@ -57,6 +58,7 @@ func NewSearchTask(ctx context.Context,
collection *segments.Collection,
manager *segments.Manager,
req *querypb.SearchRequest,
serverID int64,
) *SearchTask {
ctx, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "schedule")
return &SearchTask{
@ -74,6 +76,7 @@ func NewSearchTask(ctx context.Context,
notifier: make(chan error, 1),
tr: timerecord.NewTimeRecorderWithTrace(ctx, "searchTask"),
scheduleSpan: span,
serverID: serverID,
}
}
@ -83,13 +86,17 @@ func (t *SearchTask) Username() string {
return t.req.Req.GetUsername()
}
func (t *SearchTask) GetNodeID() int64 {
return t.serverID
}
func (t *SearchTask) IsGpuIndex() bool {
return t.collection.IsGpuIndex()
}
func (t *SearchTask) PreExecute() error {
// Update task wait time metric before execute
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)
nodeID := strconv.FormatInt(t.GetNodeID(), 10)
inQueueDuration := t.tr.ElapseSpan()
// Update in queue metric for prometheus.
@ -180,7 +187,7 @@ func (t *SearchTask) Execute() error {
task.result = &internalpb.SearchResults{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
SourceID: t.GetNodeID(),
},
Status: merr.Success(),
MetricType: metricType,
@ -211,7 +218,7 @@ func (t *SearchTask) Execute() error {
}
defer segments.DeleteSearchResultDataBlobs(blobs)
metrics.QueryNodeReduceLatency.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(t.GetNodeID()),
metrics.SearchLabel,
metrics.ReduceSegments).
Observe(float64(tr.RecordSpan().Milliseconds()))
@ -234,7 +241,7 @@ func (t *SearchTask) Execute() error {
task.result = &internalpb.SearchResults{
Base: &commonpb.MsgBase{
SourceID: paramtable.GetNodeID(),
SourceID: t.GetNodeID(),
},
Status: merr.Success(),
MetricType: metricType,
@ -294,9 +301,9 @@ func (t *SearchTask) Merge(other *SearchTask) bool {
func (t *SearchTask) Done(err error) {
if !t.merged {
metrics.QueryNodeSearchGroupSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.groupSize))
metrics.QueryNodeSearchGroupNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.nq))
metrics.QueryNodeSearchGroupTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.topk))
metrics.QueryNodeSearchGroupSize.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.groupSize))
metrics.QueryNodeSearchGroupNQ.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.nq))
metrics.QueryNodeSearchGroupTopK.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.topk))
}
t.notifier <- err
for _, other := range t.others {

View File

@ -75,6 +75,7 @@ type DataNodeComponent interface {
SetAddress(address string)
GetAddress() string
GetNodeID() int64
// SetEtcdClient set etcd client for DataNode
SetEtcdClient(etcdClient *clientv3.Client)
@ -283,6 +284,7 @@ type QueryNodeComponent interface {
SetAddress(address string)
GetAddress() string
GetNodeID() int64
// SetEtcdClient set etcd client for QueryNode
SetEtcdClient(etcdClient *clientv3.Client)

View File

@ -293,9 +293,9 @@ func AnalyzeState(role string, nodeID int64, state *milvuspb.ComponentStates) er
return nil
}
func CheckTargetID(msg *commonpb.MsgBase) error {
if msg.GetTargetID() != paramtable.GetNodeID() {
return WrapErrNodeNotMatch(paramtable.GetNodeID(), msg.GetTargetID())
func CheckTargetID(actualNodeID int64, msg *commonpb.MsgBase) error {
if msg.GetTargetID() != actualNodeID {
return WrapErrNodeNotMatch(actualNodeID, msg.GetTargetID())
}
return nil

View File

@ -0,0 +1,309 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datanode
import (
"context"
"fmt"
"math/rand"
"strconv"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
type DataNodeSuite struct {
integration.MiniClusterSuite
maxGoRoutineNum int
dim int
numCollections int
rowsPerCollection int
waitTimeInSec time.Duration
prefix string
}
func (s *DataNodeSuite) setupParam() {
s.maxGoRoutineNum = 100
s.dim = 128
s.numCollections = 2
s.rowsPerCollection = 100
s.waitTimeInSec = time.Second * 1
}
func (s *DataNodeSuite) loadCollection(collectionName string) {
c := s.Cluster
dbName := ""
schema := integration.ConstructSchema(collectionName, s.dim, true)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
err = merr.Error(createCollectionStatus)
s.NoError(err)
showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.True(merr.Ok(showCollectionsResp.GetStatus()))
batchSize := 500000
for start := 0; start < s.rowsPerCollection; start += batchSize {
rowNum := batchSize
if start+batchSize > s.rowsPerCollection {
rowNum = s.rowsPerCollection - start
}
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, s.dim)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.True(merr.Ok(insertResult.GetStatus()))
}
log.Info("=========================Data insertion finished=========================")
// flush
flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
s.WaitForFlush(context.TODO(), ids, flushTs, dbName, collectionName)
log.Info("=========================Data flush finished=========================")
// create index
createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
s.NoError(err)
s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField)
log.Info("=========================Index created=========================")
// load
loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
err = merr.Error(loadStatus)
s.NoError(err)
s.WaitForLoad(context.TODO(), collectionName)
log.Info("=========================Collection loaded=========================")
}
func (s *DataNodeSuite) checkCollections() bool {
req := &milvuspb.ShowCollectionsRequest{
DbName: "",
TimeStamp: 0, // means now
}
resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req)
s.NoError(err)
s.Equal(len(resp.CollectionIds), s.numCollections)
notLoaded := 0
loaded := 0
for _, name := range resp.CollectionNames {
loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{
DbName: "",
CollectionName: name,
})
s.NoError(err)
if loadProgress.GetProgress() != int64(100) {
notLoaded++
} else {
loaded++
}
}
log.Info(fmt.Sprintf("loading status: %d/%d", loaded, len(resp.GetCollectionNames())))
return notLoaded == 0
}
func (s *DataNodeSuite) search(collectionName string) {
c := s.Cluster
var err error
// Query
queryReq := &milvuspb.QueryRequest{
Base: nil,
CollectionName: collectionName,
PartitionNames: nil,
Expr: "",
OutputFields: []string{"count(*)"},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
queryResult, err := c.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
s.Equal(len(queryResult.FieldsData), 1)
numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]
s.Equal(numEntities, int64(s.rowsPerCollection))
// Search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
radius := 10
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(context.TODO(), searchReq)
err = merr.Error(searchResult.GetStatus())
s.NoError(err)
}
func (s *DataNodeSuite) insertBatchCollections(prefix string, collectionBatchSize, idxStart int, wg *sync.WaitGroup) {
for idx := 0; idx < collectionBatchSize; idx++ {
collectionName := prefix + "_" + strconv.Itoa(idxStart+idx)
s.loadCollection(collectionName)
}
wg.Done()
}
func (s *DataNodeSuite) setupData() {
// Add the second data node
s.Cluster.AddDataNode()
goRoutineNum := s.maxGoRoutineNum
if goRoutineNum > s.numCollections {
goRoutineNum = s.numCollections
}
collectionBatchSize := s.numCollections / goRoutineNum
log.Info(fmt.Sprintf("=========================test with dim=%d, s.rowsPerCollection=%d, s.numCollections=%d, goRoutineNum=%d==================", s.dim, s.rowsPerCollection, s.numCollections, goRoutineNum))
log.Info("=========================Start to inject data=========================")
s.prefix = "TestDataNodeUtil" + funcutil.GenRandomStr()
searchName := s.prefix + "_0"
wg := sync.WaitGroup{}
for idx := 0; idx < goRoutineNum; idx++ {
wg.Add(1)
go s.insertBatchCollections(s.prefix, collectionBatchSize, idx*collectionBatchSize, &wg)
}
wg.Wait()
log.Info("=========================Data injection finished=========================")
s.checkCollections()
log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName))
s.search(searchName)
log.Info("=========================Search finished=========================")
time.Sleep(s.waitTimeInSec)
s.checkCollections()
log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName))
s.search(searchName)
log.Info("=========================Search2 finished=========================")
s.checkAllCollectionsReady()
}
func (s *DataNodeSuite) checkAllCollectionsReady() {
goRoutineNum := s.maxGoRoutineNum
if goRoutineNum > s.numCollections {
goRoutineNum = s.numCollections
}
collectionBatchSize := s.numCollections / goRoutineNum
for i := 0; i < goRoutineNum; i++ {
for idx := 0; idx < collectionBatchSize; idx++ {
collectionName := s.prefix + "_" + strconv.Itoa(i*collectionBatchSize+idx)
s.search(collectionName)
queryReq := &milvuspb.QueryRequest{
CollectionName: collectionName,
Expr: "",
OutputFields: []string{"count(*)"},
}
_, err := s.Cluster.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
}
}
}
func (s *DataNodeSuite) checkQNRestarts(idx int) {
// Stop all data nodes
s.Cluster.StopAllDataNodes()
// Add new data nodes.
qn1 := s.Cluster.AddDataNode()
qn2 := s.Cluster.AddDataNode()
time.Sleep(s.waitTimeInSec)
cn := fmt.Sprintf("new_collection_r_%d", idx)
s.loadCollection(cn)
s.search(cn)
// Randomly stop one data node.
if rand.Intn(2) == 0 {
qn1.Stop()
} else {
qn2.Stop()
}
time.Sleep(s.waitTimeInSec)
cn = fmt.Sprintf("new_collection_x_%d", idx)
s.loadCollection(cn)
s.search(cn)
}
func (s *DataNodeSuite) TestSwapQN() {
s.setupParam()
s.setupData()
// Test case with new data nodes added
s.Cluster.AddDataNode()
s.Cluster.AddDataNode()
time.Sleep(s.waitTimeInSec)
cn := "new_collection_a"
s.loadCollection(cn)
s.search(cn)
// Test case with all data nodes replaced
for idx := 0; idx < 5; idx++ {
s.checkQNRestarts(idx)
}
}
func TestDataNodeUtil(t *testing.T) {
suite.Run(t, new(DataNodeSuite))
}

View File

@ -26,6 +26,7 @@ import (
"github.com/cockroachdb/errors"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -118,13 +119,20 @@ type MiniClusterV2 struct {
IndexNode *grpcindexnode.Server
MetaWatcher MetaWatcher
ptmu sync.Mutex
querynodes []*grpcquerynode.Server
qnid atomic.Int64
datanodes []*grpcdatanode.Server
dnid atomic.Int64
}
type OptionV2 func(cluster *MiniClusterV2)
func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2, error) {
cluster := &MiniClusterV2{
ctx: ctx,
ctx: ctx,
qnid: *atomic.NewInt64(10000),
dnid: *atomic.NewInt64(20000),
}
paramtable.Init()
cluster.params = DefaultParams()
@ -238,6 +246,62 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2,
return cluster, nil
}
func (cluster *MiniClusterV2) AddQueryNode() *grpcquerynode.Server {
cluster.ptmu.Lock()
defer cluster.ptmu.Unlock()
cluster.qnid.Inc()
id := cluster.qnid.Load()
oid := paramtable.GetNodeID()
log.Info(fmt.Sprintf("adding extra querynode with id:%d", id))
paramtable.SetNodeID(id)
node, err := grpcquerynode.NewServer(context.TODO(), cluster.factory)
if err != nil {
return nil
}
err = node.Run()
if err != nil {
return nil
}
paramtable.SetNodeID(oid)
req := &milvuspb.GetComponentStatesRequest{}
resp, err := node.GetComponentStates(context.TODO(), req)
if err != nil {
return nil
}
log.Info(fmt.Sprintf("querynode %d ComponentStates:%v", id, resp))
cluster.querynodes = append(cluster.querynodes, node)
return node
}
func (cluster *MiniClusterV2) AddDataNode() *grpcdatanode.Server {
cluster.ptmu.Lock()
defer cluster.ptmu.Unlock()
cluster.qnid.Inc()
id := cluster.qnid.Load()
oid := paramtable.GetNodeID()
log.Info(fmt.Sprintf("adding extra datanode with id:%d", id))
paramtable.SetNodeID(id)
node, err := grpcdatanode.NewServer(context.TODO(), cluster.factory)
if err != nil {
return nil
}
err = node.Run()
if err != nil {
return nil
}
paramtable.SetNodeID(oid)
req := &milvuspb.GetComponentStatesRequest{}
resp, err := node.GetComponentStates(context.TODO(), req)
if err != nil {
return nil
}
log.Info(fmt.Sprintf("datanode %d ComponentStates:%v", id, resp))
cluster.datanodes = append(cluster.datanodes, node)
return node
}
func (cluster *MiniClusterV2) Start() error {
log.Info("mini cluster start")
err := cluster.RootCoord.Run()
@ -301,10 +365,8 @@ func (cluster *MiniClusterV2) Stop() error {
cluster.Proxy.Stop()
log.Info("mini cluster proxy stopped")
cluster.DataNode.Stop()
log.Info("mini cluster dataNode stopped")
cluster.QueryNode.Stop()
log.Info("mini cluster queryNode stopped")
cluster.StopAllDataNodes()
cluster.StopAllQueryNodes()
cluster.IndexNode.Stop()
log.Info("mini cluster indexNode stopped")
@ -323,6 +385,26 @@ func (cluster *MiniClusterV2) Stop() error {
return nil
}
func (cluster *MiniClusterV2) StopAllQueryNodes() {
cluster.QueryNode.Stop()
log.Info("mini cluster main queryNode stopped")
numExtraQN := len(cluster.querynodes)
for _, node := range cluster.querynodes {
node.Stop()
}
log.Info(fmt.Sprintf("mini cluster stoped %d extra querynode", numExtraQN))
}
func (cluster *MiniClusterV2) StopAllDataNodes() {
cluster.DataNode.Stop()
log.Info("mini cluster main dataNode stopped")
numExtraQN := len(cluster.datanodes)
for _, node := range cluster.datanodes {
node.Stop()
}
log.Info(fmt.Sprintf("mini cluster stoped %d extra datanode", numExtraQN))
}
func (cluster *MiniClusterV2) GetContext() context.Context {
return cluster.ctx
}

View File

@ -0,0 +1,305 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynode
import (
"context"
"fmt"
"strconv"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
type QueryNodeSuite struct {
integration.MiniClusterSuite
maxGoRoutineNum int
dim int
numCollections int
rowsPerCollection int
waitTimeInSec time.Duration
prefix string
}
func (s *QueryNodeSuite) setupParam() {
s.maxGoRoutineNum = 100
s.dim = 128
s.numCollections = 2
s.rowsPerCollection = 100
s.waitTimeInSec = time.Second * 10
}
func (s *QueryNodeSuite) loadCollection(collectionName string, dim int) {
c := s.Cluster
dbName := ""
schema := integration.ConstructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
err = merr.Error(createCollectionStatus)
s.NoError(err)
showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.True(merr.Ok(showCollectionsResp.GetStatus()))
batchSize := 500000
for start := 0; start < s.rowsPerCollection; start += batchSize {
rowNum := batchSize
if start+batchSize > s.rowsPerCollection {
rowNum = s.rowsPerCollection - start
}
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.True(merr.Ok(insertResult.GetStatus()))
}
log.Info("=========================Data insertion finished=========================")
// flush
flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
s.WaitForFlush(context.TODO(), ids, flushTs, dbName, collectionName)
log.Info("=========================Data flush finished=========================")
// create index
createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
s.NoError(err)
s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField)
log.Info("=========================Index created=========================")
// load
loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
err = merr.Error(loadStatus)
s.NoError(err)
s.WaitForLoad(context.TODO(), collectionName)
log.Info("=========================Collection loaded=========================")
}
func (s *QueryNodeSuite) checkCollections() bool {
req := &milvuspb.ShowCollectionsRequest{
DbName: "",
TimeStamp: 0, // means now
}
resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req)
s.NoError(err)
s.Equal(len(resp.CollectionIds), s.numCollections)
notLoaded := 0
loaded := 0
for _, name := range resp.CollectionNames {
loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{
DbName: "",
CollectionName: name,
})
s.NoError(err)
if loadProgress.GetProgress() != int64(100) {
notLoaded++
} else {
loaded++
}
}
log.Info(fmt.Sprintf("loading status: %d/%d", loaded, len(resp.GetCollectionNames())))
return notLoaded == 0
}
func (s *QueryNodeSuite) search(collectionName string, dim int) {
c := s.Cluster
var err error
// Query
queryReq := &milvuspb.QueryRequest{
Base: nil,
CollectionName: collectionName,
PartitionNames: nil,
Expr: "",
OutputFields: []string{"count(*)"},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
queryResult, err := c.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
s.Equal(len(queryResult.FieldsData), 1)
numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]
s.Equal(numEntities, int64(s.rowsPerCollection))
// Search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
radius := 10
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(context.TODO(), searchReq)
err = merr.Error(searchResult.GetStatus())
s.NoError(err)
}
func (s *QueryNodeSuite) insertBatchCollections(prefix string, collectionBatchSize, idxStart, dim int, wg *sync.WaitGroup) {
for idx := 0; idx < collectionBatchSize; idx++ {
collectionName := s.prefix + "_" + strconv.Itoa(idxStart+idx)
s.loadCollection(collectionName, dim)
}
wg.Done()
}
func (s *QueryNodeSuite) setupData() {
// Add the second query node
s.Cluster.AddQueryNode()
goRoutineNum := s.maxGoRoutineNum
if goRoutineNum > s.numCollections {
goRoutineNum = s.numCollections
}
collectionBatchSize := s.numCollections / goRoutineNum
log.Info(fmt.Sprintf("=========================test with s.dim=%d, s.rowsPerCollection=%d, s.numCollections=%d, goRoutineNum=%d==================", s.dim, s.rowsPerCollection, s.numCollections, goRoutineNum))
log.Info("=========================Start to inject data=========================")
s.prefix = "TestQueryNodeUtil" + funcutil.GenRandomStr()
searchName := s.prefix + "_0"
wg := sync.WaitGroup{}
for idx := 0; idx < goRoutineNum; idx++ {
wg.Add(1)
go s.insertBatchCollections(s.prefix, collectionBatchSize, idx*collectionBatchSize, s.dim, &wg)
}
wg.Wait()
log.Info("=========================Data injection finished=========================")
s.checkCollections()
log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName))
s.search(searchName, s.dim)
log.Info("=========================Search finished=========================")
time.Sleep(s.waitTimeInSec)
s.checkCollections()
log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName))
s.search(searchName, s.dim)
log.Info("=========================Search2 finished=========================")
s.checkAllCollectionsReady()
}
func (s *QueryNodeSuite) checkAllCollectionsReady() {
goRoutineNum := s.maxGoRoutineNum
if goRoutineNum > s.numCollections {
goRoutineNum = s.numCollections
}
collectionBatchSize := s.numCollections / goRoutineNum
for i := 0; i < goRoutineNum; i++ {
for idx := 0; idx < collectionBatchSize; idx++ {
collectionName := s.prefix + "_" + strconv.Itoa(i*collectionBatchSize+idx)
s.search(collectionName, s.dim)
queryReq := &milvuspb.QueryRequest{
CollectionName: collectionName,
Expr: "",
OutputFields: []string{"count(*)"},
}
_, err := s.Cluster.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
}
}
}
func (s *QueryNodeSuite) checkQNRestarts() {
// Stop all query nodes
s.Cluster.StopAllQueryNodes()
// Add new Query nodes.
s.Cluster.AddQueryNode()
s.Cluster.AddQueryNode()
time.Sleep(s.waitTimeInSec)
for i := 0; i < 1000; i++ {
time.Sleep(s.waitTimeInSec)
if s.checkCollections() {
break
}
}
s.checkAllCollectionsReady()
}
func (s *QueryNodeSuite) TestSwapQN() {
s.setupParam()
s.setupData()
// Test case with one query node stopped
s.Cluster.QueryNode.Stop()
time.Sleep(s.waitTimeInSec)
s.checkAllCollectionsReady()
// Test case with new Query nodes added
s.Cluster.AddQueryNode()
s.Cluster.AddQueryNode()
time.Sleep(s.waitTimeInSec)
s.checkAllCollectionsReady()
// Test case with all query nodes replaced
for idx := 0; idx < 2; idx++ {
s.checkQNRestarts()
}
}
func TestQueryNodeUtil(t *testing.T) {
suite.Run(t, new(QueryNodeSuite))
}