[skip ci]Update get_segment_distribution function (#12814)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
pull/12831/head
zhuwenxing 2021-12-06 21:04:14 +08:00 committed by GitHub
parent 67cab4e915
commit 05e00b5113
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 22 deletions

View File

@ -528,3 +528,23 @@ def _check_primary_keys(primary_keys, nb):
if primary_keys[i] >= primary_keys[i + 1]:
return False
return True
def get_segment_distribution(res):
"""
Get segment distribution
"""
from collections import defaultdict
segment_distribution = defaultdict(lambda: {"growing": [], "sealed": []})
for r in res:
if r.nodeID not in segment_distribution:
segment_distribution[r.nodeID] = {
"growing": [],
"sealed": []
}
if r.state == 3:
segment_distribution[r.nodeID]["sealed"].append(r.segmentID)
if r.state == 2:
segment_distribution[r.nodeID]["growing"].append(r.segmentID)
return segment_distribution

View File

@ -20,26 +20,6 @@ num_loaded_entities = "num_loaded_entities"
num_total_entities = "num_total_entities"
def get_segment_distribution(res):
"""
Get segment distribution
"""
from collections import defaultdict
segment_distribution = defaultdict(lambda: {"growing": [], "sealed": []})
for r in res:
if r.nodeID not in segment_distribution:
segment_distribution[r.nodeID] = {
"growing": [],
"sealed": []
}
if r.state == 3:
segment_distribution[r.nodeID]["sealed"].append(r.segmentID)
if r.state == 2:
segment_distribution[r.nodeID]["growing"].append(r.segmentID)
return segment_distribution
class TestUtilityParams(TestcaseBase):
""" Test case of index interface """
@ -1487,7 +1467,7 @@ class TestUtilityAdvanced(TestcaseBase):
collection_w.load()
# prepare load balance params
res, _ = self.utility_wrap.get_query_segment_info(c_name)
segment_distribution = get_segment_distribution(res)
segment_distribution = cf.get_segment_distribution(res)
all_querynodes = [node["identifier"] for node in ms.query_nodes]
assert len(all_querynodes) > 1
all_querynodes = sorted(all_querynodes,
@ -1500,7 +1480,7 @@ class TestUtilityAdvanced(TestcaseBase):
self.utility_wrap.load_balance(src_node_id, des_node_ids, sealed_segment_ids)
# get segments distribution after load balance
res, _ = self.utility_wrap.get_query_segment_info(c_name)
segment_distribution = get_segment_distribution(res)
segment_distribution = cf.get_segment_distribution(res)
des_sealed_segment_ids = []
for des_node_id in des_node_ids:
des_sealed_segment_ids += segment_distribution[des_node_id]["sealed"]