From c8fa7026aa9caa8188113ca7a368d277a05168b3 Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Thu, 7 Oct 2021 22:38:43 +0800 Subject: [PATCH] Remove import star of test file (#9425) Signed-off-by: ThreadDao --- .../testcases/entity/test_insert.py | 159 +++++++++--------- 1 file changed, 83 insertions(+), 76 deletions(-) diff --git a/tests/python_client/testcases/entity/test_insert.py b/tests/python_client/testcases/entity/test_insert.py index f833b04895..9c3e1d68fb 100644 --- a/tests/python_client/testcases/entity/test_insert.py +++ b/tests/python_client/testcases/entity/test_insert.py @@ -1,18 +1,25 @@ +import copy +import logging +import threading + import pytest from pymilvus import DataType, ParamError, BaseException -from utils.utils import * +from utils import utils as ut from common.constants import default_entity, default_entities, default_binary_entity, default_binary_entities, \ default_fields from common.common_type import CaseLabel ADD_TIMEOUT = 60 uid = "test_insert" -field_name = default_float_vec_field_name -binary_field_name = default_binary_vec_field_name +field_name = ut.default_float_vec_field_name +binary_field_name = ut.default_binary_vec_field_name +default_nb = ut.default_nb +row_count = ut.row_count +default_tag = ut.default_tag default_single_query = { "bool": { "must": [ - {"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type": "L2", + {"vector": {field_name: {"topk": 10, "query": ut.gen_vectors(1, ut.default_dim), "metric_type": "L2", "params": {"nprobe": 10}}}} ] } @@ -28,25 +35,25 @@ class TestInsertBase: @pytest.fixture( scope="function", - params=gen_simple_index() + params=ut.gen_simple_index() ) def get_simple_index(self, request, connect): # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): + if request.param["index_type"] in ut.index_cpu_not_support(): pytest.skip("CPU not support index_type: ivf_sq8h") logging.getLogger().info(request.param) return request.param @pytest.fixture( scope="function", - params=gen_single_filter_fields() + params=ut.gen_single_filter_fields() ) def get_filter_field(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_single_vector_fields() + params=ut.gen_single_vector_fields() ) def get_vector_field(self, request): yield request.param @@ -81,7 +88,7 @@ class TestInsertBase: method: insert entity into a random named collection expected: raise a BaseException """ - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) with pytest.raises(BaseException) as e: connect.insert(collection_name, default_entities) @@ -136,7 +143,7 @@ class TestInsertBase: connect.create_index(collection, field_name, get_simple_index) if get_simple_index["index_type"] != "FLAT": index = connect.describe_index(collection, "") - create_target_index(get_simple_index, field_name) + ut.create_target_index(get_simple_index, field_name) assert index == get_simple_index @pytest.mark.timeout(ADD_TIMEOUT) @@ -152,7 +159,7 @@ class TestInsertBase: assert len(result.primary_keys) == default_nb if get_simple_index["index_type"] != "FLAT": index = connect.describe_index(collection, "") - create_target_index(get_simple_index, field_name) + ut.create_target_index(get_simple_index, field_name) assert index == get_simple_index @pytest.mark.timeout(ADD_TIMEOUT) @@ -167,18 +174,18 @@ class TestInsertBase: connect.flush([collection]) connect.load_collection(collection) res = connect.search(collection, default_single_query) - assert len(res[0]) == default_top_k + assert len(res[0]) == ut.default_top_k @pytest.mark.tags(CaseLabel.L2) def _test_insert_segment_row_count(self, connect, collection): - nb = default_segment_row_limit + 1 - result = connect.insert(collection, gen_entities(nb)) + nb = ut.default_segment_row_limit + 1 + result = connect.insert(collection, ut.gen_entities(nb)) connect.flush([collection]) assert len(result.primary_keys) == nb stats = connect.get_collection_stats(collection) assert len(stats['partitions'][0]['segments']) == 2 for segment in stats['partitions'][0]['segments']: - assert segment['row_count'] in [default_segment_row_limit, 1] + assert segment['row_count'] in [ut.default_segment_row_limit, 1] @pytest.fixture( scope="function", @@ -200,7 +207,7 @@ class TestInsertBase: """ nb = insert_count ids = [i for i in range(nb)] - entities = gen_entities(nb) + entities = ut.gen_entities(nb) entities[0]["values"] = ids result = connect.insert(id_collection, entities) connect.flush([id_collection]) @@ -219,7 +226,7 @@ class TestInsertBase: """ nb = insert_count ids = [1 for i in range(nb)] - entities = gen_entities(nb) + entities = ut.gen_entities(nb) entities[0]["values"] = ids result = connect.insert(id_collection, entities) connect.flush([id_collection]) @@ -239,14 +246,14 @@ class TestInsertBase: nb = 5 filter_field = get_filter_field vector_field = get_vector_field - collection_name = gen_unique_str("test_collection") + collection_name = ut.gen_unique_str("test_collection") fields = { - "fields": [gen_primary_field(), filter_field, vector_field], + "fields": [ut.gen_primary_field(), filter_field, vector_field], "auto_id": False } connect.create_collection(collection_name, fields) ids = [i for i in range(nb)] - entities = gen_entities_by_fields(fields["fields"], nb, default_dim, ids) + entities = ut.gen_entities_by_fields(fields["fields"], nb, ut.default_dim, ids) logging.getLogger().info(entities) result = connect.insert(collection_name, entities) assert result.primary_keys == ids @@ -264,7 +271,7 @@ class TestInsertBase: """ nb = insert_count with pytest.raises(Exception) as e: - entities = gen_entities(nb) + entities = ut.gen_entities(nb) del entities[0] connect.insert(id_collection, entities) @@ -355,7 +362,7 @@ class TestInsertBase: """ connect.create_partition(id_collection, default_tag) ids = [i for i in range(default_nb)] - entities = gen_entities(default_nb) + entities = ut.gen_entities(default_nb) entities[0]["values"] = ids result = connect.insert(id_collection, entities, partition_name=default_tag) assert result.primary_keys == ids @@ -369,7 +376,7 @@ class TestInsertBase: method: create partition and insert info collection without tag params expected: the collection row count equals to nb """ - result = connect.insert(collection, default_entities, partition_name=default_partition_name) + result = connect.insert(collection, default_entities, partition_name=ut.default_partition_name) assert len(result.primary_keys) == default_nb connect.flush([collection]) stats = connect.get_collection_stats(collection) @@ -383,7 +390,7 @@ class TestInsertBase: method: create collection and insert entities in it, with the not existed partition_name param expected: error raised """ - tag = gen_unique_str() + tag = ut.gen_unique_str() with pytest.raises(Exception) as e: connect.insert(collection, default_entities, partition_name=tag) @@ -409,7 +416,7 @@ class TestInsertBase: method: the entities dimension is half of the collection dimension, check the status expected: error raised """ - vectors = gen_vectors(default_nb, int(default_dim) // 2) + vectors = ut.gen_vectors(default_nb, int(ut.default_dim) // 2) insert_entities = copy.deepcopy(default_entities) insert_entities[-1]["values"] = vectors with pytest.raises(Exception) as e: @@ -422,7 +429,7 @@ class TestInsertBase: method: update entity field name expected: error raised """ - tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", "int64new") + tmp_entity = ut.update_field_name(copy.deepcopy(default_entity), "int64", "int64new") with pytest.raises(Exception): connect.insert(collection, tmp_entity) @@ -433,7 +440,7 @@ class TestInsertBase: method: update entity field type expected: error raised """ - tmp_entity = update_field_type(copy.deepcopy(default_entity), "int64", DataType.FLOAT) + tmp_entity = ut.update_field_type(copy.deepcopy(default_entity), "int64", DataType.FLOAT) with pytest.raises(Exception): connect.insert(collection, tmp_entity) @@ -444,7 +451,7 @@ class TestInsertBase: method: update entity field value expected: error raised """ - tmp_entity = update_field_value(copy.deepcopy(default_entity), DataType.FLOAT, 's') + tmp_entity = ut.update_field_value(copy.deepcopy(default_entity), DataType.FLOAT, 's') with pytest.raises(Exception): connect.insert(collection, tmp_entity) @@ -455,7 +462,7 @@ class TestInsertBase: method: add entity field expected: error raised """ - tmp_entity = add_field(copy.deepcopy(default_entity)) + tmp_entity = ut.add_field(copy.deepcopy(default_entity)) with pytest.raises(Exception): connect.insert(collection, tmp_entity) @@ -466,7 +473,7 @@ class TestInsertBase: method: add entity vector field expected: error raised """ - tmp_entity = add_vector_field(default_nb, default_dim) + tmp_entity = ut.add_vector_field(default_nb, ut.default_dim) with pytest.raises(Exception): connect.insert(collection, tmp_entity) @@ -477,7 +484,7 @@ class TestInsertBase: method: remove entity field expected: error raised """ - tmp_entity = remove_field(copy.deepcopy(default_entity)) + tmp_entity = ut.remove_field(copy.deepcopy(default_entity)) with pytest.raises(Exception): connect.insert(collection, tmp_entity) @@ -488,7 +495,7 @@ class TestInsertBase: method: remove entity vector field expected: error raised """ - tmp_entity = remove_vector_field(copy.deepcopy(default_entity)) + tmp_entity = ut.remove_vector_field(copy.deepcopy(default_entity)) with pytest.raises(Exception): connect.insert(collection, tmp_entity) @@ -542,7 +549,7 @@ class TestInsertBase: pytest.skip("Skip test in http mode") thread_num = 8 threads = [] - milvus = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"], try_connect=False) + milvus = ut.get_milvus(host=args["ip"], port=args["port"], handler=args["handler"], try_connect=False) def insert(thread_i): logging.getLogger().info("In thread-%d" % thread_i) @@ -567,7 +574,7 @@ class TestInsertBase: expected: the count is equal to 0 """ delete_nums = 500 - disable_flush(connect) + ut.disable_flush(connect) result = connect.insert(collection, default_entities) ids = result.primary_keys res = connect.get_entity_by_id(collection, ids[:delete_nums]) @@ -578,7 +585,7 @@ class TestInsertBase: class TestInsertBinary: @pytest.fixture( scope="function", - params=gen_binary_index() + params=ut.gen_binary_index() ) def get_binary_index(self, request): request.param["metric_type"] = "JACCARD" @@ -638,7 +645,7 @@ class TestInsertBinary: assert len(result.primary_keys) == default_nb connect.flush([binary_collection]) index = connect.describe_index(binary_collection, "") - create_target_index(get_binary_index, binary_field_name) + ut.create_target_index(get_binary_index, binary_field_name) assert index == get_binary_index @pytest.mark.timeout(ADD_TIMEOUT) @@ -654,7 +661,7 @@ class TestInsertBinary: connect.flush([binary_collection]) connect.create_index(binary_collection, binary_field_name, get_binary_index) index = connect.describe_index(binary_collection, "") - create_target_index(get_binary_index, binary_field_name) + ut.create_target_index(get_binary_index, binary_field_name) assert index == get_binary_index @pytest.mark.tags(CaseLabel.L0) @@ -666,12 +673,12 @@ class TestInsertBinary: """ result = connect.insert(binary_collection, default_binary_entities) connect.flush([binary_collection]) - query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, + query, vecs = ut.gen_query_vectors(binary_field_name, default_binary_entities, ut.default_top_k, 1, metric_type="JACCARD") connect.load_collection(binary_collection) res = connect.search(binary_collection, query) logging.getLogger().debug(res) - assert len(res[0]) == default_top_k + assert len(res[0]) == ut.default_top_k class TestInsertAsync: @@ -706,7 +713,7 @@ class TestInsertAsync: expected: length of ids is equal to the length of vectors """ nb = insert_count - future = connect.insert(collection, gen_entities(nb), _async=True) + future = connect.insert(collection, ut.gen_entities(nb), _async=True) ids = future.result().primary_keys connect.flush([collection]) assert len(ids) == nb @@ -719,7 +726,7 @@ class TestInsertAsync: expected: length of ids is equal to the length of vectors """ nb = insert_count - result = connect.insert(collection, gen_entities(nb), _async=False) + result = connect.insert(collection, ut.gen_entities(nb), _async=False) # ids = future.result() connect.flush([collection]) assert len(result.primary_keys) == nb @@ -732,7 +739,7 @@ class TestInsertAsync: expected: length of ids is equal to the length of vectors """ nb = insert_count - future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result) + future = connect.insert(collection, ut.gen_entities(nb), _async=True, _callback=self.check_result) future.done() ids = future.result().primary_keys assert len(ids) == nb @@ -745,7 +752,7 @@ class TestInsertAsync: expected: length of ids is equal to the length of vectors """ nb = 50000 - future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result) + future = connect.insert(collection, ut.gen_entities(nb), _async=True, _callback=self.check_result) result = future.result() assert len(result.primary_keys) == nb connect.flush([collection]) @@ -761,7 +768,7 @@ class TestInsertAsync: expected: length of ids is equal to the length of vectors """ nb = 100000 - future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status, timeout=1) + future = connect.insert(collection, ut.gen_entities(nb), _async=True, _callback=self.check_status, timeout=1) with pytest.raises(Exception) as e: result = future.result() @@ -772,7 +779,7 @@ class TestInsertAsync: method: set different vectors as insert method params expected: length of ids is equal to the length of vectors """ - collection_new = gen_unique_str() + collection_new = ut.gen_unique_str() future = connect.insert(collection_new, default_entities, _async=True) future.done() with pytest.raises(Exception) as e: @@ -802,7 +809,7 @@ class TestInsertMultiCollections: @pytest.fixture( scope="function", - params=gen_simple_index() + params=ut.gen_simple_index() ) def get_simple_index(self, request, connect): logging.getLogger().info(request.param) @@ -821,7 +828,7 @@ class TestInsertMultiCollections: collection_num = 10 collection_list = [] for i in range(collection_num): - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) collection_list.append(collection_name) connect.create_collection(collection_name, default_fields) result = connect.insert(collection_name, default_entities) @@ -840,7 +847,7 @@ class TestInsertMultiCollections: method: delete collection_2 and insert vector to collection_1 expected: row count equals the length of entities inserted """ - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) connect.create_collection(collection_name, default_fields) connect.drop_collection(collection) result = connect.insert(collection_name, default_entity) @@ -855,14 +862,14 @@ class TestInsertMultiCollections: method: build index and insert vector expected: status ok """ - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) connect.create_collection(collection_name, default_fields) connect.create_index(collection, field_name, get_simple_index) result = connect.insert(collection_name, default_entity) assert len(result.primary_keys) == 1 if get_simple_index["index_type"] != "FLAT": index = connect.describe_index(collection, "") - create_target_index(get_simple_index, field_name) + ut.create_target_index(get_simple_index, field_name) assert index == get_simple_index connect.drop_collection(collection_name) @@ -874,14 +881,14 @@ class TestInsertMultiCollections: method: build index and insert vector expected: status ok """ - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) connect.create_collection(collection_name, default_fields) result = connect.insert(collection, default_entity) connect.flush([collection]) connect.create_index(collection_name, field_name, get_simple_index) if get_simple_index["index_type"] != "FLAT": index = connect.describe_index(collection_name, "") - create_target_index(get_simple_index, field_name) + ut.create_target_index(get_simple_index, field_name) assert index == get_simple_index stats = connect.get_collection_stats(collection) assert stats[row_count] == 1 @@ -894,7 +901,7 @@ class TestInsertMultiCollections: method: build index and insert vector expected: status ok """ - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) connect.create_collection(collection_name, default_fields) result = connect.insert(collection, default_entity) connect.flush([collection]) @@ -910,7 +917,7 @@ class TestInsertMultiCollections: method: search collection and insert entity expected: status ok """ - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) connect.create_collection(collection_name, default_fields) connect.load_collection(collection) res = connect.search(collection, default_single_query) @@ -928,7 +935,7 @@ class TestInsertMultiCollections: method: search collection and insert entity expected: status ok """ - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) connect.create_collection(collection_name, default_fields) result = connect.insert(collection, default_entity) connect.flush([collection]) @@ -945,7 +952,7 @@ class TestInsertMultiCollections: method: search collection, sleep, and insert entity expected: status ok """ - collection_name = gen_unique_str(uid) + collection_name = ut.gen_unique_str(uid) connect.create_collection(collection_name, default_fields) result = connect.insert(collection, default_entity) connect.flush([collection]) @@ -982,49 +989,49 @@ class TestInsertInvalid(object): @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_collection_name(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_tag_name(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_field_name(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_field_type(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_field_int_value(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_ints() + params=ut.gen_invalid_ints() ) def get_entity_id(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_vectors() + params=ut.gen_invalid_vectors() ) def get_field_vectors_value(self, request): yield request.param @@ -1069,21 +1076,21 @@ class TestInsertInvalid(object): @pytest.mark.tags(CaseLabel.L2) def test_insert_with_invalid_field_name(self, connect, collection, get_field_name): - tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", get_field_name) + tmp_entity = ut.update_field_name(copy.deepcopy(default_entity), "int64", get_field_name) with pytest.raises(Exception): connect.insert(collection, tmp_entity) @pytest.mark.tags(CaseLabel.L2) def test_insert_with_invalid_field_type(self, connect, collection, get_field_type): field_type = get_field_type - tmp_entity = update_field_type(copy.deepcopy(default_entity), 'float', field_type) + tmp_entity = ut.update_field_type(copy.deepcopy(default_entity), 'float', field_type) with pytest.raises(Exception): connect.insert(collection, tmp_entity) @pytest.mark.tags(CaseLabel.L2) def test_insert_with_invalid_field_value(self, connect, collection, get_field_int_value): field_value = get_field_int_value - tmp_entity = update_field_type(copy.deepcopy(default_entity), 'int64', field_value) + tmp_entity = ut.update_field_type(copy.deepcopy(default_entity), 'int64', field_value) with pytest.raises(Exception): connect.insert(collection, tmp_entity) @@ -1108,49 +1115,49 @@ class TestInsertInvalidBinary(object): @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_collection_name(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_tag_name(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_field_name(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_field_type(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_strs() + params=ut.gen_invalid_strs() ) def get_field_int_value(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_ints() + params=ut.gen_invalid_ints() ) def get_entity_id(self, request): yield request.param @pytest.fixture( scope="function", - params=gen_invalid_vectors() + params=ut.gen_invalid_vectors() ) def get_field_vectors_value(self, request): yield request.param @@ -1162,13 +1169,13 @@ class TestInsertInvalidBinary(object): method: insert with invalid field name expected: raise exception """ - tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name) + tmp_entity = ut.update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name) with pytest.raises(Exception): connect.insert(binary_collection, tmp_entity) @pytest.mark.tags(CaseLabel.L2) def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value): - tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value) + tmp_entity = ut.update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value) with pytest.raises(Exception): connect.insert(binary_collection, tmp_entity) @@ -1205,7 +1212,7 @@ class TestInsertInvalidBinary(object): expected: raise exception """ field_type = get_field_type - tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_type) + tmp_entity = ut.update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_type) with pytest.raises(Exception): connect.insert(binary_collection, tmp_entity)