diff --git a/tests/python_client/testcases/test_collection.py b/tests/python_client/testcases/test_collection.py index 2c27fb98b6..05910f30de 100644 --- a/tests/python_client/testcases/test_collection.py +++ b/tests/python_client/testcases/test_collection.py @@ -1,5 +1,3 @@ -from functools import reduce - import numpy import pandas as pd import pytest @@ -2227,6 +2225,9 @@ class TestLoadCollection(TestcaseBase): assert collection_w.num_entities == ct.default_nb collection_w.load(replica_number=1) + for seg in self.utility_wrap.get_query_segment_info(collection_w.name)[0]: + assert len(seg.nodeIds) == 1 + collection_w.query(expr=f"{ct.default_int64_field_name} in [0]") loading_progress, _ = self.utility_wrap.loading_progress(collection_w.name) assert loading_progress == {'loading_progress': '100%', 'num_loaded_partitions': 1, 'not_loaded_partitions': []} @@ -2248,11 +2249,12 @@ class TestLoadCollection(TestcaseBase): check_items={'exp_res': [{'int64': 0}]}) # verify loaded segments included 2 replicas and twice num entities - seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) - seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) - num_entities = list(map(lambda seg: seg.num_rows, seg_info)) - assert reduce(lambda x, y: x ^ y, seg_ids) == 0 - assert reduce(lambda x, y: x + y, num_entities) == ct.default_nb * 2 + seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] + num_entities = 0 + for seg in seg_info: + assert len(seg.nodeIds) == 2 + num_entities += seg.num_rows + assert num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.ClusterOnly) def test_load_replica_multi(self): @@ -2277,6 +2279,9 @@ class TestLoadCollection(TestcaseBase): replicas = collection_w.get_replicas()[0] assert len(replicas.groups) == replica_number + for seg in self.utility_wrap.get_query_segment_info(collection_w.name)[0]: + assert len(seg.nodeIds) == replica_number + query_res, _ = collection_w.query(expr=f"{ct.default_int64_field_name} in [0, {tmp_nb}]") assert len(query_res) == 2 search_res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit) @@ -2302,6 +2307,8 @@ class TestLoadCollection(TestcaseBase): assert collection_w.num_entities == ct.default_nb * 2 collection_w.load([partition_w.name], replica_number=2) + for seg in self.utility_wrap.get_query_segment_info(collection_w.name)[0]: + assert len(seg.nodeIds) == 2 # default tag query 0 empty collection_w.query(expr=f"{ct.default_int64_field_name} in [0]", partition_names=[ct.default_tag], check_tasks=CheckTasks.check_query_empty) @@ -2342,20 +2349,18 @@ class TestLoadCollection(TestcaseBase): # verify there are 2 groups (2 replicas) assert len(replicas.groups) == 2 log.debug(replicas) + all_group_nodes = [] for group in replicas.groups: # verify each group have 3 shards assert len(group.shards) == 2 - shard_leaders = [] - # verify one group has 3 querynodes, and one of the querynode isn't shard leader - if len(group.group_nodes) == 3: - for shard in group.shards: - shard_leaders.append(shard.shard_leader) - assert len(shard_leaders) == 2 + all_group_nodes.extend(group.group_nodes) + # verify all groups has 5 querynodes + assert len(all_group_nodes) == 5 # Verify 2 replicas segments loaded seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) - seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) - assert reduce(lambda x, y: x ^ y, seg_ids) == 0 + for seg in seg_info: + assert len(seg.nodeIds) == 2 # verify search successfully res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit) @@ -2393,20 +2398,18 @@ class TestLoadCollection(TestcaseBase): replicas, _ = collection_w.get_replicas() log.debug(replicas) assert len(replicas.groups) == 2 + all_group_nodes = [] for group in replicas.groups: # verify each group have 3 shards assert len(group.shards) == 3 - # verify one group has 2 querynodes, and one of the querynode subscripe 2 dml channel - shard_leaders = [] - if len(group.group_nodes) == 2: - for shard in group.shards: - shard_leaders.append(shard.shard_leader) - assert len(shard_leaders) == 3 and len(set(shard_leaders)) == 2 + all_group_nodes.extend(group.group_nodes) + # verify all groups has 5 querynodes + assert len(all_group_nodes) == 5 # Verify 2 replicas segments loaded seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) - seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) - assert reduce(lambda x, y: x ^ y, seg_ids) == 0 + for seg in seg_info: + assert len(seg.nodeIds) == 2 # Verify search successfully res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit) diff --git a/tests/python_client/testcases/test_partition.py b/tests/python_client/testcases/test_partition.py index 4e8f6f5c5f..db527451d5 100644 --- a/tests/python_client/testcases/test_partition.py +++ b/tests/python_client/testcases/test_partition.py @@ -1,5 +1,3 @@ -from functools import reduce -from os import name import threading import pytest @@ -386,11 +384,12 @@ class TestPartitionParams(TestcaseBase): assert len(two_replicas.groups) == 2 # verify loaded segments included 2 replicas and twice num entities - seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) - seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) - num_entities = list(map(lambda seg: seg.num_rows, seg_info)) - assert reduce(lambda x, y: x ^ y, seg_ids) == 0 - assert reduce(lambda x, y: x + y, num_entities) == ct.default_nb * 2 + seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] + num_entities = 0 + for seg in seg_info: + assert len(seg.nodeIds) == 2 + num_entities += seg.num_rows + assert num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.ClusterOnly) def test_partition_replicas_change_cross_partitions(self): @@ -421,11 +420,12 @@ class TestPartitionParams(TestcaseBase): assert group1_ids.sort() == group2_ids.sort() # verify loaded segments included 2 replicas and 1 partition - seg_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) - seg_ids = list(map(lambda seg: seg.segmentID, seg_info)) - num_entities = list(map(lambda seg: seg.num_rows, seg_info)) - assert reduce(lambda x, y: x ^ y, seg_ids) == 0 - assert reduce(lambda x, y: x + y, num_entities) == ct.default_nb * 2 + seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] + num_entities = 0 + for seg in seg_info: + assert len(seg.nodeIds) == 2 + num_entities += seg.num_rows + assert num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.L1) def test_partition_release(self):