Add loadbalance testcases (#12180)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
pull/12228/head
zhuwenxing 2021-11-23 11:57:15 +08:00 committed by GitHub
parent 3be9442c0f
commit 17eaffb790
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 0 deletions

View File

@ -207,3 +207,4 @@ class CaseLabel:
L1 = "L1"
L2 = "L2"
L3 = "L3"
Loadbalance = "Loadbalance" # loadbalance testcases which need to be run in multi querynodes

View File

@ -7,6 +7,7 @@ from utils.util_log import test_log as log
from common import common_func as cf
from common import common_type as ct
from common.common_type import CaseLabel, CheckTasks
from common.milvus_sys import MilvusSys
prefix = "utility"
default_schema = cf.gen_default_collection_schema()
@ -18,6 +19,25 @@ default_nb = ct.default_nb
num_loaded_entities = "num_loaded_entities"
num_total_entities = "num_total_entities"
def get_segment_distribution(res):
"""
Get segment distribution
"""
from collections import defaultdict
segment_distribution = defaultdict(lambda: {"growing": [], "sealed": []})
for r in res:
if r.nodeID not in segment_distribution:
segment_distribution[r.nodeID] = {
"growing": [],
"sealed": []
}
if r.state == 3:
segment_distribution[r.nodeID]["sealed"].append(r.segmentID)
if r.state == 2:
segment_distribution[r.nodeID]["growing"].append(r.segmentID)
return segment_distribution
class TestUtilityParams(TestcaseBase):
""" Test case of index interface """
@ -1380,3 +1400,44 @@ class TestUtilityAdvanced(TestcaseBase):
for r in res:
cnt += r.num_rows
assert cnt == nb
@pytest.mark.tags(CaseLabel.Loadbalance)
def test_load_balance_normal(self):
"""
target: test load balance of collection
method: init a collection and load balance
expected: sealed_segment_ids is subset of des_sealed_segment_ids
"""
# init a collection
c_name = cf.gen_unique_str(prefix)
collection_w = self.init_collection_wrap(name=c_name)
ms = MilvusSys()
nb = 3000
df = cf.gen_default_dataframe_data(nb)
collection_w.insert(df)
# get sealed segments
collection_w.num_entities
# get growing segments
collection_w.insert(df)
collection_w.load()
# prepare load balance params
res, _ = self.utility_wrap.get_query_segment_info(c_name)
segment_distribution = get_segment_distribution(res)
all_querynodes = [node["identifier"]for node in ms.query_nodes]
assert len(all_querynodes) > 1
all_querynodes = sorted(all_querynodes,
key=lambda x: len(segment_distribution[x]["sealed"]) \
if x in segment_distribution else 0, reverse=True)
src_node_id = all_querynodes[0]
des_node_ids = all_querynodes[1:]
sealed_segment_ids = segment_distribution[src_node_id]["sealed"]
# load balance
self.utility_wrap.load_balance(src_node_id,des_node_ids,sealed_segment_ids)
# get segments distribution after load balance
res, _ = self.utility_wrap.get_query_segment_info(c_name)
segment_distribution = get_segment_distribution(res)
des_sealed_segment_ids = []
for des_node_id in des_node_ids:
des_sealed_segment_ids += segment_distribution[des_node_id]["sealed"]
# assert sealed_segment_ids is subset of des_sealed_segment_ids
assert set(sealed_segment_ids).issubset(des_sealed_segment_ids)