[test] Update load replica case as segment info changed (#17249)

Signed-off-by: ThreadDao <yufen.zong@zilliz.com>
pull/17261/head
ThreadDao 2022-05-28 17:28:04 +08:00 committed by GitHub
parent 731870211a
commit acb07720fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 35 deletions

View File

@ -1,5 +1,3 @@
from functools import reduce
import numpy import numpy
import pandas as pd import pandas as pd
import pytest import pytest
@ -2227,6 +2225,9 @@ class TestLoadCollection(TestcaseBase):
assert collection_w.num_entities == ct.default_nb assert collection_w.num_entities == ct.default_nb
collection_w.load(replica_number=1) collection_w.load(replica_number=1)
for seg in self.utility_wrap.get_query_segment_info(collection_w.name)[0]:
assert len(seg.nodeIds) == 1
collection_w.query(expr=f"{ct.default_int64_field_name} in [0]") collection_w.query(expr=f"{ct.default_int64_field_name} in [0]")
loading_progress, _ = self.utility_wrap.loading_progress(collection_w.name) loading_progress, _ = self.utility_wrap.loading_progress(collection_w.name)
assert loading_progress == {'loading_progress': '100%', 'num_loaded_partitions': 1, 'not_loaded_partitions': []} assert loading_progress == {'loading_progress': '100%', 'num_loaded_partitions': 1, 'not_loaded_partitions': []}
@ -2248,11 +2249,12 @@ class TestLoadCollection(TestcaseBase):
check_items={'exp_res': [{'int64': 0}]}) check_items={'exp_res': [{'int64': 0}]})
# verify loaded segments included 2 replicas and twice num entities # verify loaded segments included 2 replicas and twice num entities
seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) num_entities = 0
num_entities = list(map(lambda seg: seg.num_rows, seg_info)) for seg in seg_info:
assert reduce(lambda x, y: x ^ y, seg_ids) == 0 assert len(seg.nodeIds) == 2
assert reduce(lambda x, y: x + y, num_entities) == ct.default_nb * 2 num_entities += seg.num_rows
assert num_entities == ct.default_nb
@pytest.mark.tags(CaseLabel.ClusterOnly) @pytest.mark.tags(CaseLabel.ClusterOnly)
def test_load_replica_multi(self): def test_load_replica_multi(self):
@ -2277,6 +2279,9 @@ class TestLoadCollection(TestcaseBase):
replicas = collection_w.get_replicas()[0] replicas = collection_w.get_replicas()[0]
assert len(replicas.groups) == replica_number assert len(replicas.groups) == replica_number
for seg in self.utility_wrap.get_query_segment_info(collection_w.name)[0]:
assert len(seg.nodeIds) == replica_number
query_res, _ = collection_w.query(expr=f"{ct.default_int64_field_name} in [0, {tmp_nb}]") query_res, _ = collection_w.query(expr=f"{ct.default_int64_field_name} in [0, {tmp_nb}]")
assert len(query_res) == 2 assert len(query_res) == 2
search_res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit) search_res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit)
@ -2302,6 +2307,8 @@ class TestLoadCollection(TestcaseBase):
assert collection_w.num_entities == ct.default_nb * 2 assert collection_w.num_entities == ct.default_nb * 2
collection_w.load([partition_w.name], replica_number=2) collection_w.load([partition_w.name], replica_number=2)
for seg in self.utility_wrap.get_query_segment_info(collection_w.name)[0]:
assert len(seg.nodeIds) == 2
# default tag query 0 empty # default tag query 0 empty
collection_w.query(expr=f"{ct.default_int64_field_name} in [0]", partition_names=[ct.default_tag], collection_w.query(expr=f"{ct.default_int64_field_name} in [0]", partition_names=[ct.default_tag],
check_tasks=CheckTasks.check_query_empty) check_tasks=CheckTasks.check_query_empty)
@ -2342,20 +2349,18 @@ class TestLoadCollection(TestcaseBase):
# verify there are 2 groups (2 replicas) # verify there are 2 groups (2 replicas)
assert len(replicas.groups) == 2 assert len(replicas.groups) == 2
log.debug(replicas) log.debug(replicas)
all_group_nodes = []
for group in replicas.groups: for group in replicas.groups:
# verify each group have 3 shards # verify each group have 3 shards
assert len(group.shards) == 2 assert len(group.shards) == 2
shard_leaders = [] all_group_nodes.extend(group.group_nodes)
# verify one group has 3 querynodes, and one of the querynode isn't shard leader # verify all groups has 5 querynodes
if len(group.group_nodes) == 3: assert len(all_group_nodes) == 5
for shard in group.shards:
shard_leaders.append(shard.shard_leader)
assert len(shard_leaders) == 2
# Verify 2 replicas segments loaded # Verify 2 replicas segments loaded
seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name)
seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) for seg in seg_info:
assert reduce(lambda x, y: x ^ y, seg_ids) == 0 assert len(seg.nodeIds) == 2
# verify search successfully # verify search successfully
res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit) res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit)
@ -2393,20 +2398,18 @@ class TestLoadCollection(TestcaseBase):
replicas, _ = collection_w.get_replicas() replicas, _ = collection_w.get_replicas()
log.debug(replicas) log.debug(replicas)
assert len(replicas.groups) == 2 assert len(replicas.groups) == 2
all_group_nodes = []
for group in replicas.groups: for group in replicas.groups:
# verify each group have 3 shards # verify each group have 3 shards
assert len(group.shards) == 3 assert len(group.shards) == 3
# verify one group has 2 querynodes, and one of the querynode subscripe 2 dml channel all_group_nodes.extend(group.group_nodes)
shard_leaders = [] # verify all groups has 5 querynodes
if len(group.group_nodes) == 2: assert len(all_group_nodes) == 5
for shard in group.shards:
shard_leaders.append(shard.shard_leader)
assert len(shard_leaders) == 3 and len(set(shard_leaders)) == 2
# Verify 2 replicas segments loaded # Verify 2 replicas segments loaded
seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name)
seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) for seg in seg_info:
assert reduce(lambda x, y: x ^ y, seg_ids) == 0 assert len(seg.nodeIds) == 2
# Verify search successfully # Verify search successfully
res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit) res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit)

View File

@ -1,5 +1,3 @@
from functools import reduce
from os import name
import threading import threading
import pytest import pytest
@ -386,11 +384,12 @@ class TestPartitionParams(TestcaseBase):
assert len(two_replicas.groups) == 2 assert len(two_replicas.groups) == 2
# verify loaded segments included 2 replicas and twice num entities # verify loaded segments included 2 replicas and twice num entities
seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) num_entities = 0
num_entities = list(map(lambda seg: seg.num_rows, seg_info)) for seg in seg_info:
assert reduce(lambda x, y: x ^ y, seg_ids) == 0 assert len(seg.nodeIds) == 2
assert reduce(lambda x, y: x + y, num_entities) == ct.default_nb * 2 num_entities += seg.num_rows
assert num_entities == ct.default_nb
@pytest.mark.tags(CaseLabel.ClusterOnly) @pytest.mark.tags(CaseLabel.ClusterOnly)
def test_partition_replicas_change_cross_partitions(self): def test_partition_replicas_change_cross_partitions(self):
@ -421,11 +420,12 @@ class TestPartitionParams(TestcaseBase):
assert group1_ids.sort() == group2_ids.sort() assert group1_ids.sort() == group2_ids.sort()
# verify loaded segments included 2 replicas and 1 partition # verify loaded segments included 2 replicas and 1 partition
seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) num_entities = 0
num_entities = list(map(lambda seg: seg.num_rows, seg_info)) for seg in seg_info:
assert reduce(lambda x, y: x ^ y, seg_ids) == 0 assert len(seg.nodeIds) == 2
assert reduce(lambda x, y: x + y, num_entities) == ct.default_nb * 2 num_entities += seg.num_rows
assert num_entities == ct.default_nb
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_partition_release(self): def test_partition_release(self):