Fix GetQuerySegmentInfo() returns incorrect result after LoadBalance (#17190)

Signed-off-by: yah01 <yang.cen@zilliz.com>
pull/17214/head
yah01 2022-05-25 15:17:59 +08:00 committed by GitHub
parent 5e1e7a6896
commit de0ba6d495
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 144 additions and 54 deletions

View File

@ -3552,11 +3552,7 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue
resp.Status = unhealthyStatus() resp.Status = unhealthyStatus()
return resp, nil return resp, nil
} }
segments, err := node.getSegmentsOfCollection(ctx, req.DbName, req.CollectionName)
if err != nil {
resp.Status.Reason = err.Error()
return resp, nil
}
collID, err := globalMetaCache.GetCollectionID(ctx, req.CollectionName) collID, err := globalMetaCache.GetCollectionID(ctx, req.CollectionName)
if err != nil { if err != nil {
resp.Status.Reason = err.Error() resp.Status.Reason = err.Error()
@ -3570,11 +3566,10 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue
SourceID: Params.ProxyCfg.GetNodeID(), SourceID: Params.ProxyCfg.GetNodeID(),
}, },
CollectionID: collID, CollectionID: collID,
SegmentIDs: segments,
}) })
if err != nil { if err != nil {
log.Error("Failed to get segment info from QueryCoord", log.Error("Failed to get segment info from QueryCoord",
zap.Int64s("segmentIDs", segments), zap.Error(err)) zap.Error(err))
resp.Status.Reason = err.Error() resp.Status.Reason = err.Error()
return resp, nil return resp, nil
} }

View File

@ -224,10 +224,10 @@ func (c *queryNodeCluster) loadSegments(ctx context.Context, nodeID int64, in *q
return fmt.Errorf("loadSegments: can't find QueryNode by nodeID, nodeID = %d", nodeID) return fmt.Errorf("loadSegments: can't find QueryNode by nodeID, nodeID = %d", nodeID)
} }
func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in *querypb.ReleaseSegmentsRequest) error { func (c *queryNodeCluster) releaseSegments(ctx context.Context, leaderID int64, in *querypb.ReleaseSegmentsRequest) error {
c.RLock() c.RLock()
var targetNode Node var targetNode Node
if node, ok := c.nodes[nodeID]; ok { if node, ok := c.nodes[leaderID]; ok {
targetNode = node targetNode = node
} }
c.RUnlock() c.RUnlock()
@ -239,14 +239,14 @@ func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in
err := targetNode.releaseSegments(ctx, in) err := targetNode.releaseSegments(ctx, in)
if err != nil { if err != nil {
log.Warn("releaseSegments: queryNode release segments error", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error())) log.Warn("releaseSegments: queryNode release segments error", zap.Int64("leaderID", leaderID), zap.Int64("nodeID", in.NodeID), zap.String("error info", err.Error()))
return err return err
} }
return nil return nil
} }
return fmt.Errorf("releaseSegments: can't find QueryNode by nodeID, nodeID = %d", nodeID) return fmt.Errorf("releaseSegments: can't find QueryNode by nodeID, nodeID = %d", leaderID)
} }
func (c *queryNodeCluster) watchDmChannels(ctx context.Context, nodeID int64, in *querypb.WatchDmChannelsRequest) error { func (c *queryNodeCluster) watchDmChannels(ctx context.Context, nodeID int64, in *querypb.WatchDmChannelsRequest) error {
@ -443,18 +443,40 @@ func (c *queryNodeCluster) getSegmentInfoByID(ctx context.Context, segmentID Uni
} }
func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) { func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) {
type respTuple struct { type respTuple struct {
res *querypb.GetSegmentInfoResponse res *querypb.GetSegmentInfoResponse
err error err error
} }
var (
segmentInfos []*querypb.SegmentInfo
)
// Fetch sealed segments from Meta
if len(in.SegmentIDs) > 0 {
for _, segmentID := range in.SegmentIDs {
segment, err := c.clusterMeta.getSegmentInfoByID(segmentID)
if err != nil {
return nil, err
}
segmentInfos = append(segmentInfos, segment)
}
} else {
allSegments := c.clusterMeta.showSegmentInfos(in.CollectionID, nil)
for _, segment := range allSegments {
if in.CollectionID == 0 || segment.CollectionID == in.CollectionID {
segmentInfos = append(segmentInfos, segment)
}
}
}
// Fetch growing segments
c.RLock() c.RLock()
var wg sync.WaitGroup var wg sync.WaitGroup
cnt := len(c.nodes) cnt := len(c.nodes)
resChan := make(chan respTuple, cnt) resChan := make(chan respTuple, cnt)
wg.Add(cnt) wg.Add(cnt)
for _, node := range c.nodes { for _, node := range c.nodes {
go func(node Node) { go func(node Node) {
defer wg.Done() defer wg.Done()
@ -468,13 +490,18 @@ func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSe
c.RUnlock() c.RUnlock()
wg.Wait() wg.Wait()
close(resChan) close(resChan)
var segmentInfos []*querypb.SegmentInfo
for tuple := range resChan { for tuple := range resChan {
if tuple.err != nil { if tuple.err != nil {
return nil, tuple.err return nil, tuple.err
} }
segmentInfos = append(segmentInfos, tuple.res.GetInfos()...)
segments := tuple.res.GetInfos()
for _, segment := range segments {
if segment.SegmentState != commonpb.SegmentState_Sealed {
segmentInfos = append(segmentInfos, segment)
}
}
} }
//TODO::update meta //TODO::update meta

View File

@ -897,6 +897,7 @@ func (qc *QueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmen
Status: status, Status: status,
}, nil }, nil
} }
for _, info := range segmentInfos { for _, info := range segmentInfos {
totalNumRows += info.NumRows totalNumRows += info.NumRows
totalMemSize += info.MemSize totalMemSize += info.MemSize

View File

@ -300,6 +300,10 @@ func TestGrpcTask(t *testing.T) {
}) })
t.Run("Test GetSegmentInfo", func(t *testing.T) { t.Run("Test GetSegmentInfo", func(t *testing.T) {
err := waitLoadCollectionDone(ctx, queryCoord, defaultCollectionID)
assert.NoError(t, err)
time.Sleep(3 * time.Second)
res, err := queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{ res, err := queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo, MsgType: commonpb.MsgType_SegmentInfo,

View File

@ -1315,6 +1315,7 @@ type releaseSegmentTask struct {
*baseTask *baseTask
*querypb.ReleaseSegmentsRequest *querypb.ReleaseSegmentsRequest
cluster Cluster cluster Cluster
leaderID UniqueID
} }
func (rst *releaseSegmentTask) msgBase() *commonpb.MsgBase { func (rst *releaseSegmentTask) msgBase() *commonpb.MsgBase {
@ -1354,7 +1355,7 @@ func (rst *releaseSegmentTask) preExecute(context.Context) error {
func (rst *releaseSegmentTask) execute(ctx context.Context) error { func (rst *releaseSegmentTask) execute(ctx context.Context) error {
defer rst.reduceRetryCount() defer rst.reduceRetryCount()
err := rst.cluster.releaseSegments(rst.ctx, rst.NodeID, rst.ReleaseSegmentsRequest) err := rst.cluster.releaseSegments(rst.ctx, rst.leaderID, rst.ReleaseSegmentsRequest)
if err != nil { if err != nil {
log.Warn("releaseSegmentTask: releaseSegment occur error", zap.Int64("taskID", rst.getTaskID())) log.Warn("releaseSegmentTask: releaseSegment occur error", zap.Int64("taskID", rst.getTaskID()))
rst.setResultInfo(err) rst.setResultInfo(err)
@ -2124,6 +2125,53 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error {
balancedSegmentIDs = lbt.SealedSegmentIDs balancedSegmentIDs = lbt.SealedSegmentIDs
} }
// TODO(yah01): release balanced segments in source nodes
// balancedSegmentSet := make(typeutil.UniqueSet)
// balancedSegmentSet.Insert(balancedSegmentIDs...)
// for _, nodeID := range lbt.SourceNodeIDs {
// segments := lbt.meta.getSegmentInfosByNode(nodeID)
// shardSegments := make(map[string][]UniqueID)
// for _, segment := range segments {
// if !balancedSegmentSet.Contain(segment.SegmentID) {
// continue
// }
// shardSegments[segment.DmChannel] = append(shardSegments[segment.DmChannel], segment.SegmentID)
// }
// for dmc, segmentIDs := range shardSegments {
// shardLeader, err := getShardLeaderByNodeID(lbt.meta, lbt.replicaID, dmc)
// if err != nil {
// log.Error("failed to get shardLeader",
// zap.Int64("replicaID", lbt.replicaID),
// zap.Int64("nodeID", nodeID),
// zap.String("dmChannel", dmc),
// zap.Error(err))
// lbt.setResultInfo(err)
// return err
// }
// releaseSegmentReq := &querypb.ReleaseSegmentsRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_ReleaseSegments,
// },
// NodeID: nodeID,
// SegmentIDs: segmentIDs,
// }
// baseTask := newBaseTask(ctx, querypb.TriggerCondition_LoadBalance)
// lbt.addChildTask(&releaseSegmentTask{
// baseTask: baseTask,
// ReleaseSegmentsRequest: releaseSegmentReq,
// cluster: lbt.cluster,
// leaderID: shardLeader,
// })
// }
// }
col2PartitionIDs := make(map[UniqueID][]UniqueID) col2PartitionIDs := make(map[UniqueID][]UniqueID)
par2Segments := make(map[UniqueID][]*querypb.SegmentInfo) par2Segments := make(map[UniqueID][]*querypb.SegmentInfo)
for _, segmentID := range balancedSegmentIDs { for _, segmentID := range balancedSegmentIDs {

View File

@ -924,7 +924,9 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta)
default: default:
// save new segmentInfo when load segment // save new segmentInfo when load segment
segments := make(map[UniqueID]*querypb.SegmentInfo) var (
segments = make(map[UniqueID]*querypb.SegmentInfo)
)
for _, childTask := range triggerTask.getChildTask() { for _, childTask := range triggerTask.getChildTask() {
if childTask.msgType() == commonpb.MsgType_LoadSegments { if childTask.msgType() == commonpb.MsgType_LoadSegments {
@ -934,27 +936,29 @@ func updateSegmentInfoFromTask(ctx context.Context, triggerTask task, meta Meta)
collectionID := loadInfo.CollectionID collectionID := loadInfo.CollectionID
segmentID := loadInfo.SegmentID segmentID := loadInfo.SegmentID
segment, err := meta.getSegmentInfoByID(segmentID) segment, saved := segments[segmentID]
if !saved {
segment, err = meta.getSegmentInfoByID(segmentID)
if err != nil { if err != nil {
segment = &querypb.SegmentInfo{ segment = &querypb.SegmentInfo{
SegmentID: segmentID, SegmentID: segmentID,
CollectionID: loadInfo.CollectionID, CollectionID: loadInfo.CollectionID,
PartitionID: loadInfo.PartitionID, PartitionID: loadInfo.PartitionID,
NodeID: dstNodeID,
DmChannel: loadInfo.InsertChannel, DmChannel: loadInfo.InsertChannel,
SegmentState: commonpb.SegmentState_Sealed, SegmentState: commonpb.SegmentState_Sealed,
CompactionFrom: loadInfo.CompactionFrom, CompactionFrom: loadInfo.CompactionFrom,
ReplicaIds: []UniqueID{req.ReplicaID}, ReplicaIds: []UniqueID{},
NodeIds: []UniqueID{dstNodeID}, NodeIds: []UniqueID{},
NumRows: loadInfo.NumOfRows,
}
}
} }
} else {
segment.ReplicaIds = append(segment.ReplicaIds, req.ReplicaID) segment.ReplicaIds = append(segment.ReplicaIds, req.ReplicaID)
segment.ReplicaIds = removeFromSlice(segment.GetReplicaIds()) segment.ReplicaIds = removeFromSlice(segment.GetReplicaIds())
segment.NodeIds = append(segment.NodeIds, dstNodeID) segment.NodeIds = append(segment.NodeIds, dstNodeID)
segment.NodeID = dstNodeID segment.NodeID = dstNodeID
}
_, saved := segments[segmentID]
segments[segmentID] = segment segments[segmentID] = segment
if _, ok := segmentInfosToSave[collectionID]; !ok { if _, ok := segmentInfosToSave[collectionID]; !ok {

View File

@ -128,6 +128,7 @@ func genReleaseSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID i
baseTask: baseTask, baseTask: baseTask,
ReleaseSegmentsRequest: req, ReleaseSegmentsRequest: req,
cluster: queryCoord.cluster, cluster: queryCoord.cluster,
leaderID: nodeID,
} }
return releaseSegmentTask return releaseSegmentTask
} }
@ -1100,7 +1101,7 @@ func TestLoadBalanceIndexedSegmentsAfterNodeDown(t *testing.T) {
} }
log.Debug("node still has segments", log.Debug("node still has segments",
zap.Int64("nodeID", node1.queryNodeID)) zap.Int64("nodeID", node1.queryNodeID))
time.Sleep(200 * time.Millisecond) time.Sleep(time.Second)
} }
for { for {

View File

@ -218,3 +218,18 @@ func getReplicaAvailableMemory(cluster Cluster, replica *milvuspb.ReplicaInfo) u
return availableMemory return availableMemory
} }
// func getShardLeaderByNodeID(meta Meta, replicaID UniqueID, dmChannel string) (UniqueID, error) {
// replica, err := meta.getReplicaByID(replicaID)
// if err != nil {
// return 0, err
// }
// for _, shard := range replica.ShardReplicas {
// if shard.DmChannelName == dmChannel {
// return shard.LeaderID, nil
// }
// }
// return 0, fmt.Errorf("shard leader not found in replica %v and dm channel %s", replicaID, dmChannel)
// }

View File

@ -641,15 +641,16 @@ def get_segment_distribution(res):
from collections import defaultdict from collections import defaultdict
segment_distribution = defaultdict(lambda: {"growing": [], "sealed": []}) segment_distribution = defaultdict(lambda: {"growing": [], "sealed": []})
for r in res: for r in res:
if r.nodeID not in segment_distribution: for node_id in r.nodeIds:
segment_distribution[r.nodeID] = { if node_id not in segment_distribution:
segment_distribution[node_id] = {
"growing": [], "growing": [],
"sealed": [] "sealed": []
} }
if r.state == 3: if r.state == 3:
segment_distribution[r.nodeID]["sealed"].append(r.segmentID) segment_distribution[node_id]["sealed"].append(r.segmentID)
if r.state == 2: if r.state == 2:
segment_distribution[r.nodeID]["growing"].append(r.segmentID) segment_distribution[node_id]["growing"].append(r.segmentID)
return segment_distribution return segment_distribution

View File

@ -9,7 +9,7 @@ allure-pytest==2.7.0
pytest-print==0.2.1 pytest-print==0.2.1
pytest-level==0.1.1 pytest-level==0.1.1
pytest-xdist==2.2.1 pytest-xdist==2.2.1
pymilvus==2.1.0.dev61 pymilvus==2.1.0.dev66
pytest-rerunfailures==9.1.1 pytest-rerunfailures==9.1.1
git+https://github.com/Projectplace/pytest-tags git+https://github.com/Projectplace/pytest-tags
ndg-httpsclient ndg-httpsclient

View File

@ -1516,7 +1516,6 @@ class TestUtilityAdvanced(TestcaseBase):
assert cnt == nb assert cnt == nb
@pytest.mark.tags(CaseLabel.L3) @pytest.mark.tags(CaseLabel.L3)
@pytest.mark.xfail(reason="need newer SDK")
def test_load_balance_normal(self): def test_load_balance_normal(self):
""" """
target: test load balance of collection target: test load balance of collection
@ -1558,7 +1557,6 @@ class TestUtilityAdvanced(TestcaseBase):
assert set(sealed_segment_ids).issubset(des_sealed_segment_ids) assert set(sealed_segment_ids).issubset(des_sealed_segment_ids)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="need newer SDK")
def test_load_balance_with_src_node_not_exist(self): def test_load_balance_with_src_node_not_exist(self):
""" """
target: test load balance of collection target: test load balance of collection
@ -1595,7 +1593,6 @@ class TestUtilityAdvanced(TestcaseBase):
check_items={ct.err_code: 1, ct.err_msg: "is not exist to balance"}) check_items={ct.err_code: 1, ct.err_msg: "is not exist to balance"})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="need newer SDK")
def test_load_balance_with_all_dst_node_not_exist(self): def test_load_balance_with_all_dst_node_not_exist(self):
""" """
target: test load balance of collection target: test load balance of collection
@ -1631,7 +1628,6 @@ class TestUtilityAdvanced(TestcaseBase):
check_items={ct.err_code: 1, ct.err_msg: "no available queryNode to allocate"}) check_items={ct.err_code: 1, ct.err_msg: "no available queryNode to allocate"})
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="need newer SDK")
def test_load_balance_with_one_sealed_segment_id_not_exist(self): def test_load_balance_with_one_sealed_segment_id_not_exist(self):
""" """
target: test load balance of collection target: test load balance of collection
@ -1672,7 +1668,6 @@ class TestUtilityAdvanced(TestcaseBase):
check_items={ct.err_code: 1, ct.err_msg: "is not exist"}) check_items={ct.err_code: 1, ct.err_msg: "is not exist"})
@pytest.mark.tags(CaseLabel.L3) @pytest.mark.tags(CaseLabel.L3)
@pytest.mark.xfail(reason="need newer SDK")
def test_load_balance_in_one_group(self): def test_load_balance_in_one_group(self):
""" """
target: test load balance of collection in one group target: test load balance of collection in one group
@ -1720,7 +1715,6 @@ class TestUtilityAdvanced(TestcaseBase):
assert set(sealed_segment_ids).issubset(des_sealed_segment_ids) assert set(sealed_segment_ids).issubset(des_sealed_segment_ids)
@pytest.mark.tags(CaseLabel.L3) @pytest.mark.tags(CaseLabel.L3)
@pytest.mark.xfail(reason="need newer SDK")
def test_load_balance_not_in_one_group(self): def test_load_balance_not_in_one_group(self):
""" """
target: test load balance of collection in one group target: test load balance of collection in one group