diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index b872f735f8..a17ee178b9 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -3552,11 +3552,7 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue resp.Status = unhealthyStatus() return resp, nil } - segments, err := node.getSegmentsOfCollection(ctx, req.DbName, req.CollectionName) - if err != nil { - resp.Status.Reason = err.Error() - return resp, nil - } + collID, err := globalMetaCache.GetCollectionID(ctx, req.CollectionName) if err != nil { resp.Status.Reason = err.Error() @@ -3570,11 +3566,10 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue SourceID: Params.ProxyCfg.GetNodeID(), }, CollectionID: collID, - SegmentIDs: segments, }) if err != nil { log.Error("Failed to get segment info from QueryCoord", - zap.Int64s("segmentIDs", segments), zap.Error(err)) + zap.Error(err)) resp.Status.Reason = err.Error() return resp, nil } diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index b6331d2b23..f6c84f565c 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -224,10 +224,10 @@ func (c *queryNodeCluster) loadSegments(ctx context.Context, nodeID int64, in *q return fmt.Errorf("loadSegments: can't find QueryNode by nodeID, nodeID = %d", nodeID) } -func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in *querypb.ReleaseSegmentsRequest) error { +func (c *queryNodeCluster) releaseSegments(ctx context.Context, leaderID int64, in *querypb.ReleaseSegmentsRequest) error { c.RLock() var targetNode Node - if node, ok := c.nodes[nodeID]; ok { + if node, ok := c.nodes[leaderID]; ok { targetNode = node } c.RUnlock() @@ -239,14 +239,14 @@ func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in err := targetNode.releaseSegments(ctx, in) if err != nil { - log.Warn("releaseSegments: queryNode release segments error", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error())) + log.Warn("releaseSegments: queryNode release segments error", zap.Int64("leaderID", leaderID), zap.Int64("nodeID", in.NodeID), zap.String("error info", err.Error())) return err } return nil } - return fmt.Errorf("releaseSegments: can't find QueryNode by nodeID, nodeID = %d", nodeID) + return fmt.Errorf("releaseSegments: can't find QueryNode by nodeID, nodeID = %d", leaderID) } func (c *queryNodeCluster) watchDmChannels(ctx context.Context, nodeID int64, in *querypb.WatchDmChannelsRequest) error { @@ -443,18 +443,40 @@ func (c *queryNodeCluster) getSegmentInfoByID(ctx context.Context, segmentID Uni } func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) { - type respTuple struct { res *querypb.GetSegmentInfoResponse err error } + var ( + segmentInfos []*querypb.SegmentInfo + ) + + // Fetch sealed segments from Meta + if len(in.SegmentIDs) > 0 { + for _, segmentID := range in.SegmentIDs { + segment, err := c.clusterMeta.getSegmentInfoByID(segmentID) + if err != nil { + return nil, err + } + + segmentInfos = append(segmentInfos, segment) + } + } else { + allSegments := c.clusterMeta.showSegmentInfos(in.CollectionID, nil) + for _, segment := range allSegments { + if in.CollectionID == 0 || segment.CollectionID == in.CollectionID { + segmentInfos = append(segmentInfos, segment) + } + } + } + + // Fetch growing segments c.RLock() var wg sync.WaitGroup cnt := len(c.nodes) resChan := make(chan respTuple, cnt) wg.Add(cnt) - for _, node := range c.nodes { go func(node Node) { defer wg.Done() @@ -468,13 +490,18 @@ func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSe c.RUnlock() wg.Wait() close(resChan) - var segmentInfos []*querypb.SegmentInfo for tuple := range resChan { if tuple.err != nil { return nil, tuple.err } - segmentInfos = append(segmentInfos, tuple.res.GetInfos()...) + + segments := tuple.res.GetInfos() + for _, segment := range segments { + if segment.SegmentState != commonpb.SegmentState_Sealed { + segmentInfos = append(segmentInfos, segment) + } + } } //TODO::update meta diff --git a/internal/querycoord/impl.go b/internal/querycoord/impl.go index eb77c61fdb..e7cb94e6ca 100644 --- a/internal/querycoord/impl.go +++ b/internal/querycoord/impl.go @@ -897,6 +897,7 @@ func (qc *QueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmen Status: status, }, nil } + for _, info := range segmentInfos { totalNumRows += info.NumRows totalMemSize += info.MemSize diff --git a/internal/querycoord/impl_test.go b/internal/querycoord/impl_test.go index 33773a160e..01c82638cf 100644 --- a/internal/querycoord/impl_test.go +++ b/internal/querycoord/impl_test.go @@ -300,6 +300,10 @@ func TestGrpcTask(t *testing.T) { }) t.Run("Test GetSegmentInfo", func(t *testing.T) { + err := waitLoadCollectionDone(ctx, queryCoord, defaultCollectionID) + assert.NoError(t, err) + time.Sleep(3 * time.Second) + res, err := queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SegmentInfo, diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 1547daab65..54acc02d6f 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -1314,7 +1314,8 @@ func (lst *loadSegmentTask) reschedule(ctx context.Context) ([]task, error) { type releaseSegmentTask struct { *baseTask *querypb.ReleaseSegmentsRequest - cluster Cluster + cluster Cluster + leaderID UniqueID } func (rst *releaseSegmentTask) msgBase() *commonpb.MsgBase { @@ -1354,7 +1355,7 @@ func (rst *releaseSegmentTask) preExecute(context.Context) error { func (rst *releaseSegmentTask) execute(ctx context.Context) error { defer rst.reduceRetryCount() - err := rst.cluster.releaseSegments(rst.ctx, rst.NodeID, rst.ReleaseSegmentsRequest) + err := rst.cluster.releaseSegments(rst.ctx, rst.leaderID, rst.ReleaseSegmentsRequest) if err != nil { log.Warn("releaseSegmentTask: releaseSegment occur error", zap.Int64("taskID", rst.getTaskID())) rst.setResultInfo(err) @@ -2124,6 +2125,53 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { balancedSegmentIDs = lbt.SealedSegmentIDs } + // TODO(yah01): release balanced segments in source nodes + // balancedSegmentSet := make(typeutil.UniqueSet) + // balancedSegmentSet.Insert(balancedSegmentIDs...) + + // for _, nodeID := range lbt.SourceNodeIDs { + // segments := lbt.meta.getSegmentInfosByNode(nodeID) + + // shardSegments := make(map[string][]UniqueID) + // for _, segment := range segments { + // if !balancedSegmentSet.Contain(segment.SegmentID) { + // continue + // } + + // shardSegments[segment.DmChannel] = append(shardSegments[segment.DmChannel], segment.SegmentID) + // } + + // for dmc, segmentIDs := range shardSegments { + // shardLeader, err := getShardLeaderByNodeID(lbt.meta, lbt.replicaID, dmc) + // if err != nil { + // log.Error("failed to get shardLeader", + // zap.Int64("replicaID", lbt.replicaID), + // zap.Int64("nodeID", nodeID), + // zap.String("dmChannel", dmc), + // zap.Error(err)) + // lbt.setResultInfo(err) + + // return err + // } + + // releaseSegmentReq := &querypb.ReleaseSegmentsRequest{ + // Base: &commonpb.MsgBase{ + // MsgType: commonpb.MsgType_ReleaseSegments, + // }, + + // NodeID: nodeID, + // SegmentIDs: segmentIDs, + // } + // baseTask := newBaseTask(ctx, querypb.TriggerCondition_LoadBalance) + // lbt.addChildTask(&releaseSegmentTask{ + // baseTask: baseTask, + // ReleaseSegmentsRequest: releaseSegmentReq, + // cluster: lbt.cluster, + // leaderID: shardLeader, + // }) + // } + // } + col2PartitionIDs := make(map[UniqueID][]UniqueID) par2Segments := make(map[UniqueID][]*querypb.SegmentInfo) for _, segmentID := range balancedSegmentIDs { diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index 284d621259..108919a4f1 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -924,7 +924,9 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta) default: // save new segmentInfo when load segment - segments := make(map[UniqueID]*querypb.SegmentInfo) + var ( + segments = make(map[UniqueID]*querypb.SegmentInfo) + ) for _, childTask := range triggerTask.getChildTask() { if childTask.msgType() == commonpb.MsgType_LoadSegments { @@ -934,27 +936,29 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta) collectionID := loadInfo.CollectionID segmentID := loadInfo.SegmentID - segment, err := meta.getSegmentInfoByID(segmentID) - if err != nil { - segment = &querypb.SegmentInfo{ - SegmentID: segmentID, - CollectionID: loadInfo.CollectionID, - PartitionID: loadInfo.PartitionID, - NodeID: dstNodeID, - DmChannel: loadInfo.InsertChannel, - SegmentState: commonpb.SegmentState_Sealed, - CompactionFrom: loadInfo.CompactionFrom, - ReplicaIds: []UniqueID{req.ReplicaID}, - NodeIds: []UniqueID{dstNodeID}, + segment, saved := segments[segmentID] + if !saved { + segment, err = meta.getSegmentInfoByID(segmentID) + if err != nil { + segment = &querypb.SegmentInfo{ + SegmentID: segmentID, + CollectionID: loadInfo.CollectionID, + PartitionID: loadInfo.PartitionID, + DmChannel: loadInfo.InsertChannel, + SegmentState: commonpb.SegmentState_Sealed, + CompactionFrom: loadInfo.CompactionFrom, + ReplicaIds: []UniqueID{}, + NodeIds: []UniqueID{}, + NumRows: loadInfo.NumOfRows, + } } - } else { - segment.ReplicaIds = append(segment.ReplicaIds, req.ReplicaID) - segment.ReplicaIds = removeFromSlice(segment.GetReplicaIds()) - - segment.NodeIds = append(segment.NodeIds, dstNodeID) - segment.NodeID = dstNodeID } - _, saved := segments[segmentID] + segment.ReplicaIds = append(segment.ReplicaIds, req.ReplicaID) + segment.ReplicaIds = removeFromSlice(segment.GetReplicaIds()) + + segment.NodeIds = append(segment.NodeIds, dstNodeID) + segment.NodeID = dstNodeID + segments[segmentID] = segment if _, ok := segmentInfosToSave[collectionID]; !ok { diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 09913066cd..072f7eaeca 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -128,6 +128,7 @@ func genReleaseSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID i baseTask: baseTask, ReleaseSegmentsRequest: req, cluster: queryCoord.cluster, + leaderID: nodeID, } return releaseSegmentTask } @@ -1100,7 +1101,7 @@ func TestLoadBalanceIndexedSegmentsAfterNodeDown(t *testing.T) { } log.Debug("node still has segments", zap.Int64("nodeID", node1.queryNodeID)) - time.Sleep(200 * time.Millisecond) + time.Sleep(time.Second) } for { diff --git a/internal/querycoord/util.go b/internal/querycoord/util.go index 7d266c932b..7d19423a82 100644 --- a/internal/querycoord/util.go +++ b/internal/querycoord/util.go @@ -218,3 +218,18 @@ func getReplicaAvailableMemory(cluster Cluster, replica *milvuspb.ReplicaInfo) u return availableMemory } + +// func getShardLeaderByNodeID(meta Meta, replicaID UniqueID, dmChannel string) (UniqueID, error) { +// replica, err := meta.getReplicaByID(replicaID) +// if err != nil { +// return 0, err +// } + +// for _, shard := range replica.ShardReplicas { +// if shard.DmChannelName == dmChannel { +// return shard.LeaderID, nil +// } +// } + +// return 0, fmt.Errorf("shard leader not found in replica %v and dm channel %s", replicaID, dmChannel) +// } diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index 4974f1f7d6..17e14e3194 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -641,15 +641,16 @@ def get_segment_distribution(res): from collections import defaultdict segment_distribution = defaultdict(lambda: {"growing": [], "sealed": []}) for r in res: - if r.nodeID not in segment_distribution: - segment_distribution[r.nodeID] = { - "growing": [], - "sealed": [] - } - if r.state == 3: - segment_distribution[r.nodeID]["sealed"].append(r.segmentID) - if r.state == 2: - segment_distribution[r.nodeID]["growing"].append(r.segmentID) + for node_id in r.nodeIds: + if node_id not in segment_distribution: + segment_distribution[node_id] = { + "growing": [], + "sealed": [] + } + if r.state == 3: + segment_distribution[node_id]["sealed"].append(r.segmentID) + if r.state == 2: + segment_distribution[node_id]["growing"].append(r.segmentID) return segment_distribution diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index 3933b6f121..3989800df1 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -9,7 +9,7 @@ allure-pytest==2.7.0 pytest-print==0.2.1 pytest-level==0.1.1 pytest-xdist==2.2.1 -pymilvus==2.1.0.dev61 +pymilvus==2.1.0.dev66 pytest-rerunfailures==9.1.1 git+https://github.com/Projectplace/pytest-tags ndg-httpsclient diff --git a/tests/python_client/testcases/test_utility.py b/tests/python_client/testcases/test_utility.py index 95a3cbe412..9ae15299cd 100644 --- a/tests/python_client/testcases/test_utility.py +++ b/tests/python_client/testcases/test_utility.py @@ -1516,7 +1516,6 @@ class TestUtilityAdvanced(TestcaseBase): assert cnt == nb @pytest.mark.tags(CaseLabel.L3) - @pytest.mark.xfail(reason="need newer SDK") def test_load_balance_normal(self): """ target: test load balance of collection @@ -1558,7 +1557,6 @@ class TestUtilityAdvanced(TestcaseBase): assert set(sealed_segment_ids).issubset(des_sealed_segment_ids) @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.xfail(reason="need newer SDK") def test_load_balance_with_src_node_not_exist(self): """ target: test load balance of collection @@ -1595,7 +1593,6 @@ class TestUtilityAdvanced(TestcaseBase): check_items={ct.err_code: 1, ct.err_msg: "is not exist to balance"}) @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.xfail(reason="need newer SDK") def test_load_balance_with_all_dst_node_not_exist(self): """ target: test load balance of collection @@ -1631,7 +1628,6 @@ class TestUtilityAdvanced(TestcaseBase): check_items={ct.err_code: 1, ct.err_msg: "no available queryNode to allocate"}) @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.xfail(reason="need newer SDK") def test_load_balance_with_one_sealed_segment_id_not_exist(self): """ target: test load balance of collection @@ -1672,7 +1668,6 @@ class TestUtilityAdvanced(TestcaseBase): check_items={ct.err_code: 1, ct.err_msg: "is not exist"}) @pytest.mark.tags(CaseLabel.L3) - @pytest.mark.xfail(reason="need newer SDK") def test_load_balance_in_one_group(self): """ target: test load balance of collection in one group @@ -1720,7 +1715,6 @@ class TestUtilityAdvanced(TestcaseBase): assert set(sealed_segment_ids).issubset(des_sealed_segment_ids) @pytest.mark.tags(CaseLabel.L3) - @pytest.mark.xfail(reason="need newer SDK") def test_load_balance_not_in_one_group(self): """ target: test load balance of collection in one group