diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index a50a7248e7..239b016b37 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -1233,12 +1233,10 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi case querypb.SyncType_Remove: shardCluster.forceRemoveSegment(action.GetSegmentID()) case querypb.SyncType_Set: - shardCluster.updateSegment(shardSegmentInfo{ - segmentID: action.GetSegmentID(), - partitionID: action.GetPartitionID(), - nodeID: action.GetNodeID(), - state: segmentStateLoaded, - }) + shardCluster.SyncSegments([]*querypb.ReplicaSegmentsInfo{ + {NodeId: action.GetNodeID(), PartitionId: action.GetPartitionID(), SegmentIds: []int64{action.GetSegmentID()}}, + }, segmentStateLoaded) + default: return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index 52a1d31017..aad8067a56 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -763,3 +763,78 @@ func TestImpl_SyncReplicaSegments(t *testing.T) { }) } + +func TestSyncDistribution(t *testing.T) { + t.Run("QueryNode not healthy", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + node, err := genSimpleQueryNode(ctx) + defer node.Stop() + assert.NoError(t, err) + + node.UpdateStateCode(internalpb.StateCode_Abnormal) + + resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + }) + + t.Run("Sync non-exist channel", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + node, err := genSimpleQueryNode(ctx) + defer node.Stop() + assert.NoError(t, err) + + resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{ + CollectionID: defaultCollectionID, + Channel: defaultDMLChannel, + Actions: []*querypb.SyncAction{ + { + Type: querypb.SyncType_Set, + PartitionID: defaultPartitionID, + SegmentID: defaultSegmentID, + NodeID: 99, + }, + }, + }) + + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + }) + + t.Run("Normal sync segments", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + node, err := genSimpleQueryNode(ctx) + defer node.Stop() + assert.NoError(t, err) + + node.ShardClusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel) + + resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{ + CollectionID: defaultCollectionID, + Channel: defaultDMLChannel, + Actions: []*querypb.SyncAction{ + { + Type: querypb.SyncType_Set, + PartitionID: defaultPartitionID, + SegmentID: defaultSegmentID, + NodeID: 99, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + + cs, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel) + require.True(t, ok) + segment, ok := cs.getSegment(defaultSegmentID) + require.True(t, ok) + assert.Equal(t, common.InvalidNodeID, segment.nodeID) + assert.Equal(t, defaultPartitionID, segment.partitionID) + assert.Equal(t, segmentStateLoaded, segment.state) + + }) + +}