diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index f74eef0480..5a07c1190a 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -489,10 +489,14 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmen } return res, nil } - infos := make([]*queryPb.SegmentInfo, 0) + var segmentInfos []*queryPb.SegmentInfo + + segmentIDs := make(map[int64]struct{}) + for _, segmentID := range in.GetSegmentIDs() { + segmentIDs[segmentID] = struct{}{} + } // get info from historical - // node.historical.replica.printReplica() historicalSegmentInfos, err := node.historical.replica.getSegmentInfosByColID(in.CollectionID) if err != nil { log.Debug("GetSegmentInfo: get historical segmentInfo failed", zap.Int64("collectionID", in.CollectionID), zap.Error(err)) @@ -504,10 +508,9 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmen } return res, nil } - infos = append(infos, historicalSegmentInfos...) + segmentInfos = append(segmentInfos, filterSegmentInfo(historicalSegmentInfos, segmentIDs)...) // get info from streaming - // node.streaming.replica.printReplica() streamingSegmentInfos, err := node.streaming.replica.getSegmentInfosByColID(in.CollectionID) if err != nil { log.Debug("GetSegmentInfo: get streaming segmentInfo failed", zap.Int64("collectionID", in.CollectionID), zap.Error(err)) @@ -519,17 +522,32 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmen } return res, nil } - infos = append(infos, streamingSegmentInfos...) - // log.Debug("GetSegmentInfo: get segment info from query node", zap.Int64("nodeID", node.session.ServerID), zap.Any("segment infos", infos)) + segmentInfos = append(segmentInfos, filterSegmentInfo(streamingSegmentInfos, segmentIDs)...) return &queryPb.GetSegmentInfoResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - Infos: infos, + Infos: segmentInfos, }, nil } +// filterSegmentInfo returns segment info which segment id in segmentIDs map +func filterSegmentInfo(segmentInfos []*queryPb.SegmentInfo, segmentIDs map[int64]struct{}) []*queryPb.SegmentInfo { + if len(segmentIDs) == 0 { + return segmentInfos + } + filtered := make([]*queryPb.SegmentInfo, 0, len(segmentIDs)) + for _, info := range segmentInfos { + _, ok := segmentIDs[info.GetSegmentID()] + if !ok { + continue + } + filtered = append(filtered, info) + } + return filtered +} + // isHealthy checks if QueryNode is healthy func (node *QueryNode) isHealthy() bool { code := node.stateCode.Load().(internalpb.StateCode) diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index 3b0ed7e2d2..77b360ca7d 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -230,10 +230,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var wg sync.WaitGroup - wg.Add(1) t.Run("test GetSegmentInfo", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -242,7 +239,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { MsgType: commonpb.MsgType_WatchQueryChannels, MsgID: rand.Int63(), }, - SegmentIDs: []UniqueID{defaultSegmentID}, + SegmentIDs: []UniqueID{}, CollectionID: defaultCollectionID, } @@ -250,15 +247,19 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) + req.SegmentIDs = []UniqueID{-1} + rsp, err = node.GetSegmentInfo(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) + assert.Equal(t, 0, len(rsp.GetInfos())) + node.UpdateStateCode(internalpb.StateCode_Abnormal) rsp, err = node.GetSegmentInfo(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) }) - wg.Add(1) t.Run("test no collection in historical", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -279,9 +280,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) }) - wg.Add(1) t.Run("test no collection in streaming", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -302,9 +301,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) }) - wg.Add(1) t.Run("test different segment type", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -346,9 +343,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) }) - wg.Add(1) t.Run("test GetSegmentInfo with indexed segment", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -380,9 +375,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) }) - wg.Add(1) t.Run("test GetSegmentInfo without streaming partition", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -391,7 +384,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { MsgType: commonpb.MsgType_WatchQueryChannels, MsgID: rand.Int63(), }, - SegmentIDs: []UniqueID{defaultSegmentID}, + SegmentIDs: []UniqueID{}, CollectionID: defaultCollectionID, } @@ -401,9 +394,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) }) - wg.Add(1) t.Run("test GetSegmentInfo without streaming segment", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -412,7 +403,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { MsgType: commonpb.MsgType_WatchQueryChannels, MsgID: rand.Int63(), }, - SegmentIDs: []UniqueID{defaultSegmentID}, + SegmentIDs: []UniqueID{}, CollectionID: defaultCollectionID, } @@ -422,9 +413,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) }) - wg.Add(1) t.Run("test GetSegmentInfo without historical partition", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -433,7 +422,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { MsgType: commonpb.MsgType_WatchQueryChannels, MsgID: rand.Int63(), }, - SegmentIDs: []UniqueID{defaultSegmentID}, + SegmentIDs: []UniqueID{}, CollectionID: defaultCollectionID, } @@ -443,9 +432,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) }) - wg.Add(1) t.Run("test GetSegmentInfo without historical segment", func(t *testing.T) { - defer wg.Done() node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -454,7 +441,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { MsgType: commonpb.MsgType_WatchQueryChannels, MsgID: rand.Int63(), }, - SegmentIDs: []UniqueID{defaultSegmentID}, + SegmentIDs: []UniqueID{}, CollectionID: defaultCollectionID, } @@ -463,7 +450,6 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) }) - wg.Wait() } func TestImpl_isHealthy(t *testing.T) {