mirror of https://github.com/milvus-io/milvus.git
Fix GetQuerySegmentInfo() returns incorrect result after LoadBalance (#17190)
Signed-off-by: yah01 <yang.cen@zilliz.com>pull/17214/head
parent
5e1e7a6896
commit
de0ba6d495
|
@ -3552,11 +3552,7 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue
|
|||
resp.Status = unhealthyStatus()
|
||||
return resp, nil
|
||||
}
|
||||
segments, err := node.getSegmentsOfCollection(ctx, req.DbName, req.CollectionName)
|
||||
if err != nil {
|
||||
resp.Status.Reason = err.Error()
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, req.CollectionName)
|
||||
if err != nil {
|
||||
resp.Status.Reason = err.Error()
|
||||
|
@ -3570,11 +3566,10 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue
|
|||
SourceID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
CollectionID: collID,
|
||||
SegmentIDs: segments,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("Failed to get segment info from QueryCoord",
|
||||
zap.Int64s("segmentIDs", segments), zap.Error(err))
|
||||
zap.Error(err))
|
||||
resp.Status.Reason = err.Error()
|
||||
return resp, nil
|
||||
}
|
||||
|
|
|
@ -224,10 +224,10 @@ func (c *queryNodeCluster) loadSegments(ctx context.Context, nodeID int64, in *q
|
|||
return fmt.Errorf("loadSegments: can't find QueryNode by nodeID, nodeID = %d", nodeID)
|
||||
}
|
||||
|
||||
func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in *querypb.ReleaseSegmentsRequest) error {
|
||||
func (c *queryNodeCluster) releaseSegments(ctx context.Context, leaderID int64, in *querypb.ReleaseSegmentsRequest) error {
|
||||
c.RLock()
|
||||
var targetNode Node
|
||||
if node, ok := c.nodes[nodeID]; ok {
|
||||
if node, ok := c.nodes[leaderID]; ok {
|
||||
targetNode = node
|
||||
}
|
||||
c.RUnlock()
|
||||
|
@ -239,14 +239,14 @@ func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in
|
|||
|
||||
err := targetNode.releaseSegments(ctx, in)
|
||||
if err != nil {
|
||||
log.Warn("releaseSegments: queryNode release segments error", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error()))
|
||||
log.Warn("releaseSegments: queryNode release segments error", zap.Int64("leaderID", leaderID), zap.Int64("nodeID", in.NodeID), zap.String("error info", err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("releaseSegments: can't find QueryNode by nodeID, nodeID = %d", nodeID)
|
||||
return fmt.Errorf("releaseSegments: can't find QueryNode by nodeID, nodeID = %d", leaderID)
|
||||
}
|
||||
|
||||
func (c *queryNodeCluster) watchDmChannels(ctx context.Context, nodeID int64, in *querypb.WatchDmChannelsRequest) error {
|
||||
|
@ -443,18 +443,40 @@ func (c *queryNodeCluster) getSegmentInfoByID(ctx context.Context, segmentID Uni
|
|||
}
|
||||
|
||||
func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) {
|
||||
|
||||
type respTuple struct {
|
||||
res *querypb.GetSegmentInfoResponse
|
||||
err error
|
||||
}
|
||||
|
||||
var (
|
||||
segmentInfos []*querypb.SegmentInfo
|
||||
)
|
||||
|
||||
// Fetch sealed segments from Meta
|
||||
if len(in.SegmentIDs) > 0 {
|
||||
for _, segmentID := range in.SegmentIDs {
|
||||
segment, err := c.clusterMeta.getSegmentInfoByID(segmentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
segmentInfos = append(segmentInfos, segment)
|
||||
}
|
||||
} else {
|
||||
allSegments := c.clusterMeta.showSegmentInfos(in.CollectionID, nil)
|
||||
for _, segment := range allSegments {
|
||||
if in.CollectionID == 0 || segment.CollectionID == in.CollectionID {
|
||||
segmentInfos = append(segmentInfos, segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch growing segments
|
||||
c.RLock()
|
||||
var wg sync.WaitGroup
|
||||
cnt := len(c.nodes)
|
||||
resChan := make(chan respTuple, cnt)
|
||||
wg.Add(cnt)
|
||||
|
||||
for _, node := range c.nodes {
|
||||
go func(node Node) {
|
||||
defer wg.Done()
|
||||
|
@ -468,13 +490,18 @@ func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSe
|
|||
c.RUnlock()
|
||||
wg.Wait()
|
||||
close(resChan)
|
||||
var segmentInfos []*querypb.SegmentInfo
|
||||
|
||||
for tuple := range resChan {
|
||||
if tuple.err != nil {
|
||||
return nil, tuple.err
|
||||
}
|
||||
segmentInfos = append(segmentInfos, tuple.res.GetInfos()...)
|
||||
|
||||
segments := tuple.res.GetInfos()
|
||||
for _, segment := range segments {
|
||||
if segment.SegmentState != commonpb.SegmentState_Sealed {
|
||||
segmentInfos = append(segmentInfos, segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//TODO::update meta
|
||||
|
|
|
@ -897,6 +897,7 @@ func (qc *QueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmen
|
|||
Status: status,
|
||||
}, nil
|
||||
}
|
||||
|
||||
for _, info := range segmentInfos {
|
||||
totalNumRows += info.NumRows
|
||||
totalMemSize += info.MemSize
|
||||
|
|
|
@ -300,6 +300,10 @@ func TestGrpcTask(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Test GetSegmentInfo", func(t *testing.T) {
|
||||
err := waitLoadCollectionDone(ctx, queryCoord, defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
res, err := queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_SegmentInfo,
|
||||
|
|
|
@ -1315,6 +1315,7 @@ type releaseSegmentTask struct {
|
|||
*baseTask
|
||||
*querypb.ReleaseSegmentsRequest
|
||||
cluster Cluster
|
||||
leaderID UniqueID
|
||||
}
|
||||
|
||||
func (rst *releaseSegmentTask) msgBase() *commonpb.MsgBase {
|
||||
|
@ -1354,7 +1355,7 @@ func (rst *releaseSegmentTask) preExecute(context.Context) error {
|
|||
func (rst *releaseSegmentTask) execute(ctx context.Context) error {
|
||||
defer rst.reduceRetryCount()
|
||||
|
||||
err := rst.cluster.releaseSegments(rst.ctx, rst.NodeID, rst.ReleaseSegmentsRequest)
|
||||
err := rst.cluster.releaseSegments(rst.ctx, rst.leaderID, rst.ReleaseSegmentsRequest)
|
||||
if err != nil {
|
||||
log.Warn("releaseSegmentTask: releaseSegment occur error", zap.Int64("taskID", rst.getTaskID()))
|
||||
rst.setResultInfo(err)
|
||||
|
@ -2124,6 +2125,53 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error {
|
|||
balancedSegmentIDs = lbt.SealedSegmentIDs
|
||||
}
|
||||
|
||||
// TODO(yah01): release balanced segments in source nodes
|
||||
// balancedSegmentSet := make(typeutil.UniqueSet)
|
||||
// balancedSegmentSet.Insert(balancedSegmentIDs...)
|
||||
|
||||
// for _, nodeID := range lbt.SourceNodeIDs {
|
||||
// segments := lbt.meta.getSegmentInfosByNode(nodeID)
|
||||
|
||||
// shardSegments := make(map[string][]UniqueID)
|
||||
// for _, segment := range segments {
|
||||
// if !balancedSegmentSet.Contain(segment.SegmentID) {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// shardSegments[segment.DmChannel] = append(shardSegments[segment.DmChannel], segment.SegmentID)
|
||||
// }
|
||||
|
||||
// for dmc, segmentIDs := range shardSegments {
|
||||
// shardLeader, err := getShardLeaderByNodeID(lbt.meta, lbt.replicaID, dmc)
|
||||
// if err != nil {
|
||||
// log.Error("failed to get shardLeader",
|
||||
// zap.Int64("replicaID", lbt.replicaID),
|
||||
// zap.Int64("nodeID", nodeID),
|
||||
// zap.String("dmChannel", dmc),
|
||||
// zap.Error(err))
|
||||
// lbt.setResultInfo(err)
|
||||
|
||||
// return err
|
||||
// }
|
||||
|
||||
// releaseSegmentReq := &querypb.ReleaseSegmentsRequest{
|
||||
// Base: &commonpb.MsgBase{
|
||||
// MsgType: commonpb.MsgType_ReleaseSegments,
|
||||
// },
|
||||
|
||||
// NodeID: nodeID,
|
||||
// SegmentIDs: segmentIDs,
|
||||
// }
|
||||
// baseTask := newBaseTask(ctx, querypb.TriggerCondition_LoadBalance)
|
||||
// lbt.addChildTask(&releaseSegmentTask{
|
||||
// baseTask: baseTask,
|
||||
// ReleaseSegmentsRequest: releaseSegmentReq,
|
||||
// cluster: lbt.cluster,
|
||||
// leaderID: shardLeader,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
col2PartitionIDs := make(map[UniqueID][]UniqueID)
|
||||
par2Segments := make(map[UniqueID][]*querypb.SegmentInfo)
|
||||
for _, segmentID := range balancedSegmentIDs {
|
||||
|
|
|
@ -924,7 +924,9 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta)
|
|||
|
||||
default:
|
||||
// save new segmentInfo when load segment
|
||||
segments := make(map[UniqueID]*querypb.SegmentInfo)
|
||||
var (
|
||||
segments = make(map[UniqueID]*querypb.SegmentInfo)
|
||||
)
|
||||
|
||||
for _, childTask := range triggerTask.getChildTask() {
|
||||
if childTask.msgType() == commonpb.MsgType_LoadSegments {
|
||||
|
@ -934,27 +936,29 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta)
|
|||
collectionID := loadInfo.CollectionID
|
||||
segmentID := loadInfo.SegmentID
|
||||
|
||||
segment, err := meta.getSegmentInfoByID(segmentID)
|
||||
segment, saved := segments[segmentID]
|
||||
if !saved {
|
||||
segment, err = meta.getSegmentInfoByID(segmentID)
|
||||
if err != nil {
|
||||
segment = &querypb.SegmentInfo{
|
||||
SegmentID: segmentID,
|
||||
CollectionID: loadInfo.CollectionID,
|
||||
PartitionID: loadInfo.PartitionID,
|
||||
NodeID: dstNodeID,
|
||||
DmChannel: loadInfo.InsertChannel,
|
||||
SegmentState: commonpb.SegmentState_Sealed,
|
||||
CompactionFrom: loadInfo.CompactionFrom,
|
||||
ReplicaIds: []UniqueID{req.ReplicaID},
|
||||
NodeIds: []UniqueID{dstNodeID},
|
||||
ReplicaIds: []UniqueID{},
|
||||
NodeIds: []UniqueID{},
|
||||
NumRows: loadInfo.NumOfRows,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
segment.ReplicaIds = append(segment.ReplicaIds, req.ReplicaID)
|
||||
segment.ReplicaIds = removeFromSlice(segment.GetReplicaIds())
|
||||
|
||||
segment.NodeIds = append(segment.NodeIds, dstNodeID)
|
||||
segment.NodeID = dstNodeID
|
||||
}
|
||||
_, saved := segments[segmentID]
|
||||
|
||||
segments[segmentID] = segment
|
||||
|
||||
if _, ok := segmentInfosToSave[collectionID]; !ok {
|
||||
|
|
|
@ -128,6 +128,7 @@ func genReleaseSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID i
|
|||
baseTask: baseTask,
|
||||
ReleaseSegmentsRequest: req,
|
||||
cluster: queryCoord.cluster,
|
||||
leaderID: nodeID,
|
||||
}
|
||||
return releaseSegmentTask
|
||||
}
|
||||
|
@ -1100,7 +1101,7 @@ func TestLoadBalanceIndexedSegmentsAfterNodeDown(t *testing.T) {
|
|||
}
|
||||
log.Debug("node still has segments",
|
||||
zap.Int64("nodeID", node1.queryNodeID))
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
for {
|
||||
|
|
|
@ -218,3 +218,18 @@ func getReplicaAvailableMemory(cluster Cluster, replica *milvuspb.ReplicaInfo) u
|
|||
|
||||
return availableMemory
|
||||
}
|
||||
|
||||
// func getShardLeaderByNodeID(meta Meta, replicaID UniqueID, dmChannel string) (UniqueID, error) {
|
||||
// replica, err := meta.getReplicaByID(replicaID)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
|
||||
// for _, shard := range replica.ShardReplicas {
|
||||
// if shard.DmChannelName == dmChannel {
|
||||
// return shard.LeaderID, nil
|
||||
// }
|
||||
// }
|
||||
|
||||
// return 0, fmt.Errorf("shard leader not found in replica %v and dm channel %s", replicaID, dmChannel)
|
||||
// }
|
||||
|
|
|
@ -641,15 +641,16 @@ def get_segment_distribution(res):
|
|||
from collections import defaultdict
|
||||
segment_distribution = defaultdict(lambda: {"growing": [], "sealed": []})
|
||||
for r in res:
|
||||
if r.nodeID not in segment_distribution:
|
||||
segment_distribution[r.nodeID] = {
|
||||
for node_id in r.nodeIds:
|
||||
if node_id not in segment_distribution:
|
||||
segment_distribution[node_id] = {
|
||||
"growing": [],
|
||||
"sealed": []
|
||||
}
|
||||
if r.state == 3:
|
||||
segment_distribution[r.nodeID]["sealed"].append(r.segmentID)
|
||||
segment_distribution[node_id]["sealed"].append(r.segmentID)
|
||||
if r.state == 2:
|
||||
segment_distribution[r.nodeID]["growing"].append(r.segmentID)
|
||||
segment_distribution[node_id]["growing"].append(r.segmentID)
|
||||
|
||||
return segment_distribution
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ allure-pytest==2.7.0
|
|||
pytest-print==0.2.1
|
||||
pytest-level==0.1.1
|
||||
pytest-xdist==2.2.1
|
||||
pymilvus==2.1.0.dev61
|
||||
pymilvus==2.1.0.dev66
|
||||
pytest-rerunfailures==9.1.1
|
||||
git+https://github.com/Projectplace/pytest-tags
|
||||
ndg-httpsclient
|
||||
|
|
|
@ -1516,7 +1516,6 @@ class TestUtilityAdvanced(TestcaseBase):
|
|||
assert cnt == nb
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L3)
|
||||
@pytest.mark.xfail(reason="need newer SDK")
|
||||
def test_load_balance_normal(self):
|
||||
"""
|
||||
target: test load balance of collection
|
||||
|
@ -1558,7 +1557,6 @@ class TestUtilityAdvanced(TestcaseBase):
|
|||
assert set(sealed_segment_ids).issubset(des_sealed_segment_ids)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="need newer SDK")
|
||||
def test_load_balance_with_src_node_not_exist(self):
|
||||
"""
|
||||
target: test load balance of collection
|
||||
|
@ -1595,7 +1593,6 @@ class TestUtilityAdvanced(TestcaseBase):
|
|||
check_items={ct.err_code: 1, ct.err_msg: "is not exist to balance"})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="need newer SDK")
|
||||
def test_load_balance_with_all_dst_node_not_exist(self):
|
||||
"""
|
||||
target: test load balance of collection
|
||||
|
@ -1631,7 +1628,6 @@ class TestUtilityAdvanced(TestcaseBase):
|
|||
check_items={ct.err_code: 1, ct.err_msg: "no available queryNode to allocate"})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="need newer SDK")
|
||||
def test_load_balance_with_one_sealed_segment_id_not_exist(self):
|
||||
"""
|
||||
target: test load balance of collection
|
||||
|
@ -1672,7 +1668,6 @@ class TestUtilityAdvanced(TestcaseBase):
|
|||
check_items={ct.err_code: 1, ct.err_msg: "is not exist"})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L3)
|
||||
@pytest.mark.xfail(reason="need newer SDK")
|
||||
def test_load_balance_in_one_group(self):
|
||||
"""
|
||||
target: test load balance of collection in one group
|
||||
|
@ -1720,7 +1715,6 @@ class TestUtilityAdvanced(TestcaseBase):
|
|||
assert set(sealed_segment_ids).issubset(des_sealed_segment_ids)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L3)
|
||||
@pytest.mark.xfail(reason="need newer SDK")
|
||||
def test_load_balance_not_in_one_group(self):
|
||||
"""
|
||||
target: test load balance of collection in one group
|
||||
|
|
Loading…
Reference in New Issue