From 6c520a725df2188619c6fd77be2021ca7f89cdd4 Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Wed, 2 Dec 2020 14:39:29 +0800 Subject: [PATCH] Update regression tests Signed-off-by: cai.zhang --- tests/python/requirements.txt | 2 +- tests/python/test_bulk_insert.py | 4 +- tests/python/test_create_collection.py | 1 - tests/python/test_has_collection.py | 2 +- tests/python/test_partition.py | 1 + tests/python/test_search.py | 1831 ++++++++++++++++++++++++ tests/python/utils.py | 14 +- 7 files changed, 1843 insertions(+), 12 deletions(-) create mode 100644 tests/python/test_search.py diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt index 9c72eafacf..15a3d0d6c1 100644 --- a/tests/python/requirements.txt +++ b/tests/python/requirements.txt @@ -4,5 +4,5 @@ numpy==1.18.1 pytest==5.3.4 pytest-cov==2.8.1 pytest-timeout==1.3.4 -pymilvus-distributed==0.0.1 +pymilvus-distributed==0.0.2 sklearn==0.0 diff --git a/tests/python/test_bulk_insert.py b/tests/python/test_bulk_insert.py index d0a7ce8ad1..1f0e9d8675 100644 --- a/tests/python/test_bulk_insert.py +++ b/tests/python/test_bulk_insert.py @@ -563,7 +563,7 @@ class TestInsertBase: milvus.flush([collection]) for i in range(thread_num): - t = TestThread(target=insert, args=(i,)) + t = MilvusTestThread(target=insert, args=(i,)) threads.append(t) t.start() for t in threads: @@ -817,7 +817,6 @@ class TestInsertAsync: future.result() -@pytest.mark.skip class TestInsertMultiCollections: """ ****************************************************************** @@ -959,7 +958,6 @@ class TestInsertMultiCollections: result = connect.search(collection_name, default_single_query) -@pytest.mark.skip class TestInsertInvalid(object): """ Test inserting vectors with invalid collection names diff --git a/tests/python/test_create_collection.py b/tests/python/test_create_collection.py index 00dd6665b1..10da49dd62 100644 --- a/tests/python/test_create_collection.py +++ b/tests/python/test_create_collection.py @@ -4,7 +4,6 @@ from .constants import * uid = "create_collection" - class TestCreateCollection: """ ****************************************************************** diff --git a/tests/python/test_has_collection.py b/tests/python/test_has_collection.py index d846203b8a..0b5c740637 100644 --- a/tests/python/test_has_collection.py +++ b/tests/python/test_has_collection.py @@ -55,7 +55,7 @@ class TestHasCollection: assert connect.has_collection(collection_name) # assert not assert_collection(connect, collection_name) for i in range(threads_num): - t = TestThread(target=has, args=()) + t = MilvusTestThread(target=has, args=()) threads.append(t) t.start() time.sleep(0.2) diff --git a/tests/python/test_partition.py b/tests/python/test_partition.py index ee1586ac2e..5b21be5c47 100644 --- a/tests/python/test_partition.py +++ b/tests/python/test_partition.py @@ -20,6 +20,7 @@ class TestCreateBase: @pytest.mark.level(2) @pytest.mark.timeout(600) + @pytest.mark.skip def test_create_partition_limit(self, connect, collection, args): ''' target: test create partitions, check status returned diff --git a/tests/python/test_search.py b/tests/python/test_search.py new file mode 100644 index 0000000000..e883d930ae --- /dev/null +++ b/tests/python/test_search.py @@ -0,0 +1,1831 @@ +import time +import pdb +import copy +import logging +from multiprocessing import Pool, Process +import pytest +import numpy as np + +from milvus import DataType +from .utils import * +from .constants import * + +uid = "test_search" +nq = 1 +epsilon = 0.001 +field_name = default_float_vec_field_name +binary_field_name = default_binary_vec_field_name +search_param = {"nprobe": 1} + +entity = gen_entities(1, is_normal=True) +entities = gen_entities(default_nb, is_normal=True) +raw_vectors, binary_entities = gen_binary_entities(default_nb) +default_query, default_query_vecs = gen_query_vectors(field_name, entities, default_top_k, nq) +default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, default_top_k, + nq) + + +def init_data(connect, collection, nb=1200, partition_tags=None, auto_id=True): + ''' + Generate entities and add it in collection + ''' + global entities + if nb == 1200: + insert_entities = entities + else: + insert_entities = gen_entities(nb, is_normal=True) + if partition_tags is None: + if auto_id: + ids = connect.bulk_insert(collection, insert_entities) + else: + ids = connect.bulk_insert(collection, insert_entities, ids=[i for i in range(nb)]) + else: + if auto_id: + ids = connect.bulk_insert(collection, insert_entities, partition_tag=partition_tags) + else: + ids = connect.bulk_insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags) + # connect.flush([collection]) + return insert_entities, ids + + +def init_binary_data(connect, collection, nb=1200, insert=True, partition_tags=None): + ''' + Generate entities and add it in collection + ''' + ids = [] + global binary_entities + global raw_vectors + if nb == 1200: + insert_entities = binary_entities + insert_raw_vectors = raw_vectors + else: + insert_raw_vectors, insert_entities = gen_binary_entities(nb) + if insert is True: + if partition_tags is None: + ids = connect.bulk_insert(collection, insert_entities) + else: + ids = connect.bulk_insert(collection, insert_entities, partition_tag=partition_tags) + connect.flush([collection]) + return insert_raw_vectors, insert_entities, ids + + +class TestSearchBase: + """ + generate valid create_index params + """ + + @pytest.fixture( + scope="function", + params=gen_index() + ) + def get_index(self, request, connect): + if str(connect._cmd("mode")) == "CPU": + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("sq8h not support in CPU mode") + return request.param + + @pytest.fixture( + scope="function", + params=gen_simple_index() + ) + def get_simple_index(self, request, connect): + if str(connect._cmd("mode")) == "CPU": + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("sq8h not support in CPU mode") + return request.param + + @pytest.fixture( + scope="function", + params=gen_binary_index() + ) + def get_jaccard_index(self, request, connect): + logging.getLogger().info(request.param) + if request.param["index_type"] in binary_support(): + return request.param + else: + pytest.skip("Skip index Temporary") + + @pytest.fixture( + scope="function", + params=gen_binary_index() + ) + def get_hamming_index(self, request, connect): + logging.getLogger().info(request.param) + if request.param["index_type"] in binary_support(): + return request.param + else: + pytest.skip("Skip index Temporary") + + @pytest.fixture( + scope="function", + params=gen_binary_index() + ) + def get_structure_index(self, request, connect): + logging.getLogger().info(request.param) + if request.param["index_type"] == "FLAT": + return request.param + else: + pytest.skip("Skip index Temporary") + + """ + generate top-k params + """ + + @pytest.fixture( + scope="function", + params=[1, 10] + ) + def get_top_k(self, request): + yield request.param + + @pytest.fixture( + scope="function", + params=[1, 10, 1100] + ) + def get_nq(self, request): + yield request.param + + # PASS + @pytest.mark.skip("should pass") + def test_search_flat(self, connect, collection, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, change top-k value + method: search with the given vectors, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = get_nq + entities, ids = init_data(connect, collection) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq) + if top_k <= max_top_k: + res = connect.search(collection, query) + assert len(res[0]) == top_k + assert res[0]._distances[0] <= epsilon + assert check_id_result(res[0], ids[0]) + else: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # milvus-distributed dose not have the limitation of top_k + def test_search_flat_top_k(self, connect, collection, get_nq): + ''' + target: test basic search function, all the search params is corrent, change top-k value + method: search with the given vectors, check the result + expected: the length of the result is top_k + ''' + top_k = 16385 + nq = get_nq + entities, ids = init_data(connect, collection) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq) + if top_k <= max_top_k: + res = connect.search(collection, query) + assert len(res[0]) == top_k + assert res[0]._distances[0] <= epsilon + assert check_id_result(res[0], ids[0]) + else: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # TODO: reopen after we supporting targetEntry + @pytest.mark.skip("search_field") + def test_search_field(self, connect, collection, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, change top-k value + method: search with the given vectors, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = get_nq + entities, ids = init_data(connect, collection) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq) + if top_k <= max_top_k: + res = connect.search(collection, query, fields=["float_vector"]) + assert len(res[0]) == top_k + assert res[0]._distances[0] <= epsilon + assert check_id_result(res[0], ids[0]) + res = connect.search(collection, query, fields=["float"]) + for i in range(nq): + assert entities[1]["values"][:nq][i] in [r.entity.get('float') for r in res[i]] + else: + with pytest.raises(Exception): + connect.search(collection, query) + + @pytest.mark.skip("search_after_delete") + def test_search_after_delete(self, connect, collection, get_top_k, get_nq): + ''' + target: test basic search function before and after deletion, all the search params is + corrent, change top-k value. + check issue #4200 + method: search with the given vectors, check the result + expected: the deleted entities do not exist in the result. + ''' + top_k = get_top_k + nq = get_nq + + entities, ids = init_data(connect, collection, nb=10000) + first_int64_value = entities[0]["values"][0] + first_vector = entities[2]["values"][0] + + search_param = get_search_param("FLAT") + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) + vecs[:] = [] + vecs.append(first_vector) + + res = None + if top_k > max_top_k: + with pytest.raises(Exception): + connect.search(collection, query, fields=['int64']) + pytest.skip("top_k value is larger than max_topp_k") + else: + res = connect.search(collection, query, fields=['int64']) + assert len(res) == 1 + assert len(res[0]) >= top_k + assert res[0][0].id == ids[0] + assert res[0][0].entity.get("int64") == first_int64_value + assert res[0]._distances[0] < epsilon + assert check_id_result(res[0], ids[0]) + + connect.delete_entity_by_id(collection, ids[:1]) + connect.flush([collection]) + + res2 = connect.search(collection, query, fields=['int64']) + assert len(res2) == 1 + assert len(res2[0]) >= top_k + assert res2[0][0].id != ids[0] + if top_k > 1: + assert res2[0][0].id == res[0][1].id + assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64") + + @pytest.mark.skip("search_after_index") + @pytest.mark.level(2) + def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: search with the given vectors, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = get_nq + + index_type = get_simple_index["index_type"] + if index_type in skip_pq(): + pytest.skip("Skip PQ") + entities, ids = init_data(connect, collection) + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + else: + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) >= top_k + assert res[0]._distances[0] < epsilon + assert check_id_result(res[0], ids[0]) + + @pytest.mark.skip("search_after_index_different_metric_type") + def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index): + ''' + target: test search with different metric_type + method: build index with L2, and search using IP + expected: search ok + ''' + search_metric_type = "IP" + index_type = get_simple_index["index_type"] + entities, ids = init_data(connect, collection) + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, metric_type=search_metric_type, + search_params=search_param) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == default_top_k + + @pytest.mark.skip("search_index_partition") + @pytest.mark.level(2) + def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: add vectors into collection, search with the given vectors, check the result + expected: the length of the result is top_k, search collection with partition tag return empty + ''' + top_k = get_top_k + nq = get_nq + + index_type = get_simple_index["index_type"] + if index_type in skip_pq(): + pytest.skip("Skip PQ") + connect.create_partition(collection, default_tag) + entities, ids = init_data(connect, collection) + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + else: + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) >= top_k + assert res[0]._distances[0] < epsilon + assert check_id_result(res[0], ids[0]) + res = connect.search(collection, query, partition_tags=[default_tag]) + assert len(res) == nq + + @pytest.mark.skip("search_index_partition_B") + @pytest.mark.level(2) + def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: search with the given vectors, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = get_nq + + index_type = get_simple_index["index_type"] + if index_type in skip_pq(): + pytest.skip("Skip PQ") + connect.create_partition(collection, default_tag) + entities, ids = init_data(connect, collection, partition_tags=default_tag) + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) + for tags in [[default_tag], [default_tag, "new_tag"]]: + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query, partition_tags=tags) + else: + res = connect.search(collection, query, partition_tags=tags) + assert len(res) == nq + assert len(res[0]) >= top_k + assert res[0]._distances[0] < epsilon + assert check_id_result(res[0], ids[0]) + + @pytest.mark.skip("search_index_partition_C") + @pytest.mark.level(2) + def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: search with the given vectors and tag (tag name not existed in collection), check the result + expected: error raised + ''' + top_k = get_top_k + nq = get_nq + entities, ids = init_data(connect, collection) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq) + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query, partition_tags=["new_tag"]) + else: + res = connect.search(collection, query, partition_tags=["new_tag"]) + assert len(res) == nq + assert len(res[0]) == 0 + + @pytest.mark.skip("search_index_partitions") + @pytest.mark.level(2) + def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: search collection with the given vectors and tags, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = 2 + new_tag = "new_tag" + index_type = get_simple_index["index_type"] + if index_type in skip_pq(): + pytest.skip("Skip PQ") + connect.create_partition(collection, default_tag) + connect.create_partition(collection, new_tag) + entities, ids = init_data(connect, collection, partition_tags=default_tag) + new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + else: + res = connect.search(collection, query) + assert check_id_result(res[0], ids[0]) + assert not check_id_result(res[1], new_ids[0]) + assert res[0]._distances[0] < epsilon + assert res[1]._distances[0] < epsilon + res = connect.search(collection, query, partition_tags=["new_tag"]) + assert res[0]._distances[0] > epsilon + assert res[1]._distances[0] > epsilon + + @pytest.mark.skip("search_index_partitions_B") + @pytest.mark.level(2) + def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: search collection with the given vectors and tags, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = 2 + tag = "tag" + new_tag = "new_tag" + index_type = get_simple_index["index_type"] + if index_type in skip_pq(): + pytest.skip("Skip PQ") + connect.create_partition(collection, tag) + connect.create_partition(collection, new_tag) + entities, ids = init_data(connect, collection, partition_tags=tag) + new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param) + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + else: + res = connect.search(collection, query, partition_tags=["(.*)tag"]) + assert not check_id_result(res[0], ids[0]) + assert res[0]._distances[0] < epsilon + assert res[1]._distances[0] < epsilon + res = connect.search(collection, query, partition_tags=["new(.*)"]) + assert res[0]._distances[0] < epsilon + assert res[1]._distances[0] < epsilon + + # + # test for ip metric + # + # TODO: reopen after we supporting ip flat + @pytest.mark.skip("search_ip_flat") + @pytest.mark.level(2) + def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, change top-k value + method: search with the given vectors, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = get_nq + entities, ids = init_data(connect, collection) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP") + if top_k <= max_top_k: + res = connect.search(collection, query) + assert len(res[0]) == top_k + assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) + assert check_id_result(res[0], ids[0]) + else: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.mark.skip("search_ip_after_index") + @pytest.mark.level(2) + def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: search with the given vectors, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = get_nq + + index_type = get_simple_index["index_type"] + if index_type in skip_pq(): + pytest.skip("Skip PQ") + entities, ids = init_data(connect, collection) + get_simple_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + else: + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) >= top_k + assert check_id_result(res[0], ids[0]) + assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) + + @pytest.mark.skip("search_ip_index_partition") + @pytest.mark.level(2) + def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: add vectors into collection, search with the given vectors, check the result + expected: the length of the result is top_k, search collection with partition tag return empty + ''' + top_k = get_top_k + nq = get_nq + metric_type = "IP" + index_type = get_simple_index["index_type"] + if index_type in skip_pq(): + pytest.skip("Skip PQ") + connect.create_partition(collection, default_tag) + entities, ids = init_data(connect, collection) + get_simple_index["metric_type"] = metric_type + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, + search_params=search_param) + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + else: + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) >= top_k + assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) + assert check_id_result(res[0], ids[0]) + res = connect.search(collection, query, partition_tags=[default_tag]) + assert len(res) == nq + + @pytest.mark.skip("search_ip_index_partitions") + @pytest.mark.level(2) + def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): + ''' + target: test basic search function, all the search params is corrent, test all index params, and build + method: search collection with the given vectors and tags, check the result + expected: the length of the result is top_k + ''' + top_k = get_top_k + nq = 2 + metric_type = "IP" + new_tag = "new_tag" + index_type = get_simple_index["index_type"] + if index_type in skip_pq(): + pytest.skip("Skip PQ") + connect.create_partition(collection, default_tag) + connect.create_partition(collection, new_tag) + entities, ids = init_data(connect, collection, partition_tags=default_tag) + new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) + get_simple_index["metric_type"] = metric_type + connect.create_index(collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) + if top_k > max_top_k: + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + else: + res = connect.search(collection, query) + assert check_id_result(res[0], ids[0]) + assert not check_id_result(res[1], new_ids[0]) + assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) + assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) + res = connect.search(collection, query, partition_tags=["new_tag"]) + assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0]) + # TODO: + # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) + + # PASS + @pytest.mark.level(2) + def test_search_without_connect(self, dis_connect, collection): + ''' + target: test search vectors without connection + method: use dis connected instance, call search method and check if search successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + res = dis_connect.search(collection, default_query) + + # PASS + # TODO: proxy or SDK checks if collection exists + def test_search_collection_name_not_existed(self, connect): + ''' + target: search collection not existed + method: search with the random collection_name, which is not in db + expected: status not ok + ''' + collection_name = gen_unique_str(uid) + with pytest.raises(Exception) as e: + res = connect.search(collection_name, default_query) + + # PASS + @pytest.mark.skip("should pass") + def test_search_distance_l2(self, connect, collection): + ''' + target: search collection, and check the result: distance + method: compare the return distance value with value computed with Euclidean + expected: the return distance equals to the computed value + ''' + nq = 2 + search_param = {"nprobe": 1} + entities, ids = init_data(connect, collection, nb=nq) + query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, + search_params=search_param) + inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, + search_params=search_param) + distance_0 = l2(vecs[0], inside_vecs[0]) + distance_1 = l2(vecs[0], inside_vecs[1]) + res = connect.search(collection, query) + assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) + + @pytest.mark.skip("search_distance_l2_after_index") + def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index): + ''' + target: search collection, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + index_type = get_simple_index["index_type"] + nq = 2 + entities, ids = init_data(connect, id_collection, auto_id=False) + connect.create_index(id_collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, + search_params=search_param) + inside_vecs = entities[-1]["values"] + min_distance = 1.0 + min_id = None + for i in range(default_nb): + tmp_dis = l2(vecs[0], inside_vecs[i]) + if min_distance > tmp_dis: + min_distance = tmp_dis + min_id = ids[i] + res = connect.search(id_collection, query) + tmp_epsilon = epsilon + check_id_result(res[0], min_id) + # if index_type in ["ANNOY", "IVF_PQ"]: + # tmp_epsilon = 0.1 + # TODO: + # assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon + + # TODO: reopen after we supporting ip flat + @pytest.mark.skip("search_distance_ip") + @pytest.mark.level(2) + def test_search_distance_ip(self, connect, collection): + ''' + target: search collection, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + nq = 2 + metirc_type = "IP" + search_param = {"nprobe": 1} + entities, ids = init_data(connect, collection, nb=nq) + query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, + metric_type=metirc_type, + search_params=search_param) + inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, + search_params=search_param) + distance_0 = ip(vecs[0], inside_vecs[0]) + distance_1 = ip(vecs[0], inside_vecs[1]) + res = connect.search(collection, query) + assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon + + @pytest.mark.skip("search_distance_ip_after_index") + def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index): + ''' + target: search collection, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + index_type = get_simple_index["index_type"] + nq = 2 + metirc_type = "IP" + entities, ids = init_data(connect, id_collection, auto_id=False) + get_simple_index["metric_type"] = metirc_type + connect.create_index(id_collection, field_name, get_simple_index) + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, + metric_type=metirc_type, + search_params=search_param) + inside_vecs = entities[-1]["values"] + max_distance = 0 + max_id = None + for i in range(default_nb): + tmp_dis = ip(vecs[0], inside_vecs[i]) + if max_distance < tmp_dis: + max_distance = tmp_dis + max_id = ids[i] + res = connect.search(id_collection, query) + tmp_epsilon = epsilon + check_id_result(res[0], max_id) + # if index_type in ["ANNOY", "IVF_PQ"]: + # tmp_epsilon = 0.1 + # TODO: + # assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon + + @pytest.mark.skip("search_distance_jaccard_flat_index") + def test_search_distance_jaccard_flat_index(self, connect, binary_collection): + ''' + target: search binary_collection, and check the result: distance + method: compare the return distance value with value computed with L2 + expected: the return distance equals to the computed value + ''' + nq = 1 + int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) + query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) + distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) + distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) + query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="JACCARD") + res = connect.search(binary_collection, query) + assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon + + @pytest.mark.skip("search_distance_jaccard_flat_index_L2") + @pytest.mark.level(2) + def test_search_distance_jaccard_flat_index_L2(self, connect, binary_collection): + ''' + target: search binary_collection, and check the result: distance + method: compare the return distance value with value computed with L2 + expected: the return distance equals to the computed value + ''' + nq = 1 + int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) + query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) + distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) + distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) + query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="L2") + with pytest.raises(Exception) as e: + res = connect.search(binary_collection, query) + + @pytest.mark.skip("search_distance_hamming_flat_index") + @pytest.mark.level(2) + def test_search_distance_hamming_flat_index(self, connect, binary_collection): + ''' + target: search binary_collection, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + nq = 1 + int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) + query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) + distance_0 = hamming(query_int_vectors[0], int_vectors[0]) + distance_1 = hamming(query_int_vectors[0], int_vectors[1]) + query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="HAMMING") + res = connect.search(binary_collection, query) + assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon + + @pytest.mark.skip("search_distance_substructure_flat_index") + @pytest.mark.level(2) + def test_search_distance_substructure_flat_index(self, connect, binary_collection): + ''' + target: search binary_collection, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + nq = 1 + int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) + query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) + distance_0 = substructure(query_int_vectors[0], int_vectors[0]) + distance_1 = substructure(query_int_vectors[0], int_vectors[1]) + query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, + metric_type="SUBSTRUCTURE") + res = connect.search(binary_collection, query) + assert len(res[0]) == 0 + + @pytest.mark.skip("search_distance_substructure_flat_index_B") + @pytest.mark.level(2) + def test_search_distance_substructure_flat_index_B(self, connect, binary_collection): + ''' + target: search binary_collection, and check the result: distance + method: compare the return distance value with value computed with SUB + expected: the return distance equals to the computed value + ''' + top_k = 3 + int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) + query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2) + query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUBSTRUCTURE", + replace_vecs=query_vecs) + res = connect.search(binary_collection, query) + assert res[0][0].distance <= epsilon + assert res[0][0].id == ids[0] + assert res[1][0].distance <= epsilon + assert res[1][0].id == ids[1] + + @pytest.mark.skip("search_distance_superstructure_flat_index") + @pytest.mark.level(2) + def test_search_distance_superstructure_flat_index(self, connect, binary_collection): + ''' + target: search binary_collection, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + nq = 1 + int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) + query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) + distance_0 = superstructure(query_int_vectors[0], int_vectors[0]) + distance_1 = superstructure(query_int_vectors[0], int_vectors[1]) + query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, + metric_type="SUPERSTRUCTURE") + res = connect.search(binary_collection, query) + assert len(res[0]) == 0 + + @pytest.mark.skip("search_distance_superstructure_flat_index_B") + @pytest.mark.level(2) + def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection): + ''' + target: search binary_collection, and check the result: distance + method: compare the return distance value with value computed with SUPER + expected: the return distance equals to the computed value + ''' + top_k = 3 + int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) + query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2) + query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUPERSTRUCTURE", + replace_vecs=query_vecs) + res = connect.search(binary_collection, query) + assert len(res[0]) == 2 + assert len(res[1]) == 2 + assert res[0][0].id in ids + assert res[0][0].distance <= epsilon + assert res[1][0].id in ids + assert res[1][0].distance <= epsilon + + @pytest.mark.skip("search_distance_tanimoto_flat_index") + @pytest.mark.level(2) + def test_search_distance_tanimoto_flat_index(self, connect, binary_collection): + ''' + target: search binary_collection, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + nq = 1 + int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) + query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) + distance_0 = tanimoto(query_int_vectors[0], int_vectors[0]) + distance_1 = tanimoto(query_int_vectors[0], int_vectors[1]) + query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="TANIMOTO") + res = connect.search(binary_collection, query) + assert abs(res[0][0].distance - min(distance_0, distance_1)) <= epsilon + + # PASS + @pytest.mark.level(2) + @pytest.mark.timeout(30) + @pytest.mark.skip("should pass") + def test_search_concurrent_multithreads(self, connect, args): + ''' + target: test concurrent search with multiprocessess + method: search with 10 processes, each process uses dependent connection + expected: status ok and the returned vectors should be query_records + ''' + nb = 100 + top_k = 10 + threads_num = 4 + threads = [] + collection = gen_unique_str(uid) + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + # create collection + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) + milvus.create_collection(collection, default_fields) + entities, ids = init_data(milvus, collection) + + def search(milvus): + res = milvus.search(collection, default_query) + assert len(res) == 1 + assert res[0]._entities[0].id in ids + assert res[0]._distances[0] < epsilon + + for i in range(threads_num): + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) + t = MilvusTestThread(target=search, args=(milvus,)) + threads.append(t) + t.start() + time.sleep(0.2) + for t in threads: + t.join() + + # PASS + @pytest.mark.level(2) + @pytest.mark.timeout(30) + @pytest.mark.skip("should pass") + def test_search_concurrent_multithreads_single_connection(self, connect, args): + ''' + target: test concurrent search with multiprocessess + method: search with 10 processes, each process uses dependent connection + expected: status ok and the returned vectors should be query_records + ''' + nb = 100 + top_k = 10 + threads_num = 4 + threads = [] + collection = gen_unique_str(uid) + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + # create collection + milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) + milvus.create_collection(collection, default_fields) + entities, ids = init_data(milvus, collection) + + def search(milvus): + res = milvus.search(collection, default_query) + assert len(res) == 1 + assert res[0]._entities[0].id in ids + assert res[0]._distances[0] < epsilon + + for i in range(threads_num): + t = MilvusTestThread(target=search, args=(milvus,)) + threads.append(t) + t.start() + time.sleep(0.2) + for t in threads: + t.join() + + # PASS + @pytest.mark.level(2) + @pytest.mark.skip("should pass") + def test_search_multi_collections(self, connect, args): + ''' + target: test search multi collections of L2 + method: add vectors into 10 collections, and search + expected: search status ok, the length of result + ''' + num = 10 + top_k = 10 + nq = 20 + for i in range(num): + collection = gen_unique_str(uid + str(i)) + connect.create_collection(collection, default_fields) + entities, ids = init_data(connect, collection) + assert len(ids) == default_nb + query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) + res = connect.search(collection, query) + assert len(res) == nq + for i in range(nq): + assert check_id_result(res[i], ids[i]) + assert res[i]._distances[0] < epsilon + assert res[i]._distances[1] > epsilon + + @pytest.mark.skip("query_entities_with_field_less_than_top_k") + def test_query_entities_with_field_less_than_top_k(self, connect, id_collection): + """ + target: test search with field, and let return entities less than topk + method: insert entities and build ivf_ index, and search with field, n_probe=1 + expected: + """ + entities, ids = init_data(connect, id_collection, auto_id=False) + simple_index = {"index_type": "IVF_FLAT", "params": {"nlist": 200}, "metric_type": "L2"} + connect.create_index(id_collection, field_name, simple_index) + # logging.getLogger().info(connect.get_collection_info(id_collection)) + top_k = 300 + default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 1}) + expr = {"must": [gen_default_vector_expr(default_query)]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(id_collection, query, fields=["int64"]) + assert len(res) == nq + for r in res[0]: + assert getattr(r.entity, "int64") == getattr(r.entity, "id") + + +class TestSearchDSL(object): + """ + ****************************************************************** + # The following cases are used to build invalid query expr + ****************************************************************** + """ + + @pytest.mark.skip("query_no_must") + def test_query_no_must(self, connect, collection): + ''' + method: build query without must expr + expected: error raised + ''' + # entities, ids = init_data(connect, collection) + query = update_query_expr(default_query, keep_old=False) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.mark.skip("query_no_vector_term_only") + def test_query_no_vector_term_only(self, connect, collection): + ''' + method: build query without vector only term + expected: error raised + ''' + # entities, ids = init_data(connect, collection) + expr = { + "must": [gen_default_term_expr] + } + query = update_query_expr(default_query, keep_old=False, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # PASS + def test_query_no_vector_range_only(self, connect, collection): + ''' + method: build query without vector only range + expected: error raised + ''' + # entities, ids = init_data(connect, collection) + expr = { + "must": [gen_default_range_expr] + } + query = update_query_expr(default_query, keep_old=False, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # PASS + @pytest.mark.skip("should pass") + def test_query_vector_only(self, connect, collection): + entities, ids = init_data(connect, collection) + res = connect.search(collection, default_query) + assert len(res) == nq + assert len(res[0]) == default_top_k + + @pytest.mark.skip("query_wrong_format") + def test_query_wrong_format(self, connect, collection): + ''' + method: build query without must expr, with wrong expr name + expected: error raised + ''' + # entities, ids = init_data(connect, collection) + expr = { + "must1": [gen_default_term_expr] + } + query = update_query_expr(default_query, keep_old=False, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # PASS + def test_query_empty(self, connect, collection): + ''' + method: search with empty query + expected: error raised + ''' + query = {} + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + """ + ****************************************************************** + # The following cases are used to build valid query expr + ****************************************************************** + """ + + @pytest.mark.skip("query_term_value_not_in") + @pytest.mark.level(2) + def test_query_term_value_not_in(self, connect, collection): + ''' + method: build query with vector and term expr, with no term can be filtered + expected: filter pass + ''' + entities, ids = init_data(connect, collection) + expr = { + "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[100000])]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 0 + # TODO: + + # TODO: + @pytest.mark.skip("query_term_value_all_in") + @pytest.mark.level(2) + def test_query_term_value_all_in(self, connect, collection): + ''' + method: build query with vector and term expr, with all term can be filtered + expected: filter pass + ''' + entities, ids = init_data(connect, collection) + expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1])]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 1 + # TODO: + + # TODO: + @pytest.mark.skip("query_term_values_not_in") + @pytest.mark.level(2) + def test_query_term_values_not_in(self, connect, collection): + ''' + method: build query with vector and term expr, with no term can be filtered + expected: filter pass + ''' + entities, ids = init_data(connect, collection) + expr = {"must": [gen_default_vector_expr(default_query), + gen_default_term_expr(values=[i for i in range(100000, 100010)])]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 0 + # TODO: + + @pytest.mark.skip("query_term_values_all_in") + def test_query_term_values_all_in(self, connect, collection): + ''' + method: build query with vector and term expr, with all term can be filtered + expected: filter pass + ''' + entities, ids = init_data(connect, collection) + expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr()]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == default_top_k + limit = default_nb // 2 + for i in range(nq): + for result in res[i]: + logging.getLogger().info(result.id) + assert result.id in ids[:limit] + # TODO: + + @pytest.mark.skip("query_term_values_parts_in") + def test_query_term_values_parts_in(self, connect, collection): + ''' + method: build query with vector and term expr, with parts of term can be filtered + expected: filter pass + ''' + entities, ids = init_data(connect, collection) + expr = {"must": [gen_default_vector_expr(default_query), + gen_default_term_expr( + values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == default_top_k + # TODO: + + # TODO: + @pytest.mark.skip("query_term_values_repeat") + @pytest.mark.level(2) + def test_query_term_values_repeat(self, connect, collection): + ''' + method: build query with vector and term expr, with the same values + expected: filter pass + ''' + entities, ids = init_data(connect, collection) + expr = { + "must": [gen_default_vector_expr(default_query), + gen_default_term_expr(values=[1 for i in range(1, default_nb)])]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 1 + # TODO: + + @pytest.mark.skip("query_term_value_empty") + def test_query_term_value_empty(self, connect, collection): + ''' + method: build query with term value empty + expected: return null + ''' + expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[])]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 0 + + @pytest.mark.skip("query_complex_dsl") + def test_query_complex_dsl(self, connect, collection): + ''' + method: query with complicated dsl + expected: no error raised + ''' + expr = {"must": [ + {"must": [{"should": [gen_default_term_expr(values=[1]), gen_default_range_expr()]}]}, + {"must": [gen_default_vector_expr(default_query)]} + ]} + logging.getLogger().info(expr) + query = update_query_expr(default_query, expr=expr) + logging.getLogger().info(query) + res = connect.search(collection, query) + logging.getLogger().info(res) + + """ + ****************************************************************** + # The following cases are used to build invalid term query expr + ****************************************************************** + """ + + # TODO + @pytest.mark.skip("query_term_key_error") + @pytest.mark.level(2) + def test_query_term_key_error(self, connect, collection): + ''' + method: build query with term key error + expected: Exception raised + ''' + expr = {"must": [gen_default_vector_expr(default_query), + gen_default_term_expr(keyword="terrm", values=[i for i in range(default_nb // 2)])]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.fixture( + scope="function", + params=gen_invalid_term() + ) + def get_invalid_term(self, request): + return request.param + + @pytest.mark.skip("query_term_wrong_format") + @pytest.mark.level(2) + def test_query_term_wrong_format(self, connect, collection, get_invalid_term): + ''' + method: build query with wrong format term + expected: Exception raised + ''' + entities, ids = init_data(connect, collection) + term = get_invalid_term + expr = {"must": [gen_default_vector_expr(default_query), term]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # TODO + @pytest.mark.skip("query_term_field_named_term") + @pytest.mark.level(2) + def test_query_term_field_named_term(self, connect, collection): + ''' + method: build query with field named "term" + expected: error raised + ''' + term_fields = add_field_default(default_fields, field_name="term") + collection_term = gen_unique_str("term") + connect.create_collection(collection_term, term_fields) + term_entities = add_field(entities, field_name="term") + ids = connect.bulk_insert(collection_term, term_entities) + assert len(ids) == default_nb + connect.flush([collection_term]) + count = connect.count_entities(collection_term) + assert count == default_nb + term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}} + expr = {"must": [gen_default_vector_expr(default_query), + term_param]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection_term, query) + assert len(res) == nq + assert len(res[0]) == default_top_k + connect.drop_collection(collection_term) + + @pytest.mark.skip("query_term_one_field_not_existed") + @pytest.mark.level(2) + def test_query_term_one_field_not_existed(self, connect, collection): + ''' + method: build query with two fields term, one of it not existed + expected: exception raised + ''' + entities, ids = init_data(connect, collection) + term = gen_default_term_expr() + term["term"].update({"a": [0]}) + expr = {"must": [gen_default_vector_expr(default_query), term]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + """ + ****************************************************************** + # The following cases are used to build valid range query expr + ****************************************************************** + """ + + # PASS + # TODO + def test_query_range_key_error(self, connect, collection): + ''' + method: build query with range key error + expected: Exception raised + ''' + range = gen_default_range_expr(keyword="ranges") + expr = {"must": [gen_default_vector_expr(default_query), range]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.fixture( + scope="function", + params=gen_invalid_range() + ) + def get_invalid_range(self, request): + return request.param + + # PASS + # TODO + @pytest.mark.level(2) + def test_query_range_wrong_format(self, connect, collection, get_invalid_range): + ''' + method: build query with wrong format range + expected: Exception raised + ''' + entities, ids = init_data(connect, collection) + range = get_invalid_range + expr = {"must": [gen_default_vector_expr(default_query), range]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # PASS + @pytest.mark.level(2) + def test_query_range_string_ranges(self, connect, collection): + ''' + method: build query with invalid ranges + expected: raise Exception + ''' + entities, ids = init_data(connect, collection) + ranges = {"GT": "0", "LT": "1000"} + range = gen_default_range_expr(ranges=ranges) + expr = {"must": [gen_default_vector_expr(default_query), range]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # PASS + @pytest.mark.level(2) + @pytest.mark.skip("should pass") + def test_query_range_invalid_ranges(self, connect, collection): + ''' + method: build query with invalid ranges + expected: 0 + ''' + entities, ids = init_data(connect, collection) + ranges = {"GT": default_nb, "LT": 0} + range = gen_default_range_expr(ranges=ranges) + expr = {"must": [gen_default_vector_expr(default_query), range]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception): + res = connect.search(collection, query) + assert len(res[0]) == 0 + + @pytest.fixture( + scope="function", + params=gen_valid_ranges() + ) + def get_valid_ranges(self, request): + return request.param + + # PASS + @pytest.mark.level(2) + def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges): + ''' + method: build query with valid ranges + expected: pass + ''' + entities, ids = init_data(connect, collection) + ranges = get_valid_ranges + range = gen_default_range_expr(ranges=ranges) + expr = {"must": [gen_default_vector_expr(default_query), range]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == default_top_k + + @pytest.mark.skip("query_range_one_field_not_existed") + def test_query_range_one_field_not_existed(self, connect, collection): + ''' + method: build query with two fields ranges, one of fields not existed + expected: exception raised + ''' + entities, ids = init_data(connect, collection) + range = gen_default_range_expr() + range["range"].update({"a": {"GT": 1, "LT": default_nb // 2}}) + expr = {"must": [gen_default_vector_expr(default_query), range]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + """ + ************************************************************************ + # The following cases are used to build query expr multi range and term + ************************************************************************ + """ + + # TODO + @pytest.mark.skip("query_multi_term_has_common") + @pytest.mark.level(2) + def test_query_multi_term_has_common(self, connect, collection): + ''' + method: build query with multi term with same field, and values has common + expected: pass + ''' + entities, ids = init_data(connect, collection) + term_first = gen_default_term_expr() + term_second = gen_default_term_expr(values=[i for i in range(default_nb // 3)]) + expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == default_top_k + + # TODO + @pytest.mark.skip("query_multi_term_no_common") + @pytest.mark.level(2) + def test_query_multi_term_no_common(self, connect, collection): + ''' + method: build query with multi range with same field, and ranges no common + expected: pass + ''' + entities, ids = init_data(connect, collection) + term_first = gen_default_term_expr() + term_second = gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)]) + expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 0 + + # TODO + @pytest.mark.skip("query_multi_term_different_fields") + def test_query_multi_term_different_fields(self, connect, collection): + ''' + method: build query with multi range with same field, and ranges no common + expected: pass + ''' + entities, ids = init_data(connect, collection) + term_first = gen_default_term_expr() + term_second = gen_default_term_expr(field="float", + values=[float(i) for i in range(default_nb // 2, default_nb)]) + expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 0 + + # TODO + @pytest.mark.skip("query_single_term_multi_fields") + @pytest.mark.level(2) + def test_query_single_term_multi_fields(self, connect, collection): + ''' + method: build query with multi term, different field each term + expected: pass + ''' + entities, ids = init_data(connect, collection) + term_first = {"int64": {"values": [i for i in range(default_nb // 2)]}} + term_second = {"float": {"values": [float(i) for i in range(default_nb // 2, default_nb)]}} + term = update_term_expr({"term": {}}, [term_first, term_second]) + expr = {"must": [gen_default_vector_expr(default_query), term]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # TODO + @pytest.mark.skip("query_multi_range_has_common") + @pytest.mark.level(2) + def test_query_multi_range_has_common(self, connect, collection): + ''' + method: build query with multi range with same field, and ranges has common + expected: pass + ''' + entities, ids = init_data(connect, collection) + range_one = gen_default_range_expr() + range_two = gen_default_range_expr(ranges={"GT": 1, "LT": default_nb // 3}) + expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == default_top_k + + # TODO + @pytest.mark.skip("query_multi_range_no_common") + @pytest.mark.level(2) + def test_query_multi_range_no_common(self, connect, collection): + ''' + method: build query with multi range with same field, and ranges no common + expected: pass + ''' + entities, ids = init_data(connect, collection) + range_one = gen_default_range_expr() + range_two = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb}) + expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 0 + + # TODO + @pytest.mark.skip("query_multi_range_different_fields") + @pytest.mark.level(2) + def test_query_multi_range_different_fields(self, connect, collection): + ''' + method: build query with multi range, different field each range + expected: pass + ''' + entities, ids = init_data(connect, collection) + range_first = gen_default_range_expr() + range_second = gen_default_range_expr(field="float", ranges={"GT": default_nb // 2, "LT": default_nb}) + expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 0 + + # TODO + @pytest.mark.skip("query_single_range_multi_fields") + @pytest.mark.level(2) + def test_query_single_range_multi_fields(self, connect, collection): + ''' + method: build query with multi range, different field each range + expected: pass + ''' + entities, ids = init_data(connect, collection) + range_first = {"int64": {"GT": 0, "LT": default_nb // 2}} + range_second = {"float": {"GT": default_nb / 2, "LT": float(default_nb)}} + range = update_range_expr({"range": {}}, [range_first, range_second]) + expr = {"must": [gen_default_vector_expr(default_query), range]} + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + """ + ****************************************************************** + # The following cases are used to build query expr both term and range + ****************************************************************** + """ + + # TODO + @pytest.mark.skip("query_single_term_range_has_common") + @pytest.mark.level(2) + def test_query_single_term_range_has_common(self, connect, collection): + ''' + method: build query with single term single range + expected: pass + ''' + entities, ids = init_data(connect, collection) + term = gen_default_term_expr() + range = gen_default_range_expr(ranges={"GT": -1, "LT": default_nb // 2}) + expr = {"must": [gen_default_vector_expr(default_query), term, range]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == default_top_k + + # TODO + @pytest.mark.skip("query_single_term_range_no_common") + def test_query_single_term_range_no_common(self, connect, collection): + ''' + method: build query with single term single range + expected: pass + ''' + entities, ids = init_data(connect, collection) + term = gen_default_term_expr() + range = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb}) + expr = {"must": [gen_default_vector_expr(default_query), term, range]} + query = update_query_expr(default_query, expr=expr) + res = connect.search(collection, query) + assert len(res) == nq + assert len(res[0]) == 0 + + """ + ****************************************************************** + # The following cases are used to build multi vectors query expr + ****************************************************************** + """ + + # PASS + # TODO + def test_query_multi_vectors_same_field(self, connect, collection): + ''' + method: build query with two vectors same field + expected: error raised + ''' + entities, ids = init_data(connect, collection) + vector1 = default_query + vector2 = gen_query_vectors(field_name, entities, default_top_k, nq=2) + expr = { + "must": [vector1, vector2] + } + query = update_query_expr(default_query, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + +class TestSearchDSLBools(object): + """ + ****************************************************************** + # The following cases are used to build invalid query expr + ****************************************************************** + """ + + # PASS + @pytest.mark.level(2) + def test_query_no_bool(self, connect, collection): + ''' + method: build query without bool expr + expected: error raised + ''' + entities, ids = init_data(connect, collection) + expr = {"bool1": {}} + query = expr + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.mark.skip("query_should_only_term") + def test_query_should_only_term(self, connect, collection): + ''' + method: build query without must, with should.term instead + expected: error raised + ''' + expr = {"should": gen_default_term_expr} + query = update_query_expr(default_query, keep_old=False, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.mark.skip("query_should_only_vector") + def test_query_should_only_vector(self, connect, collection): + ''' + method: build query without must, with should.vector instead + expected: error raised + ''' + expr = {"should": default_query["bool"]["must"]} + query = update_query_expr(default_query, keep_old=False, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.mark.skip("query_must_not_only_term") + def test_query_must_not_only_term(self, connect, collection): + ''' + method: build query without must, with must_not.term instead + expected: error raised + ''' + expr = {"must_not": gen_default_term_expr} + query = update_query_expr(default_query, keep_old=False, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.mark.skip("query_must_not_vector") + def test_query_must_not_vector(self, connect, collection): + ''' + method: build query without must, with must_not.vector instead + expected: error raised + ''' + expr = {"must_not": default_query["bool"]["must"]} + query = update_query_expr(default_query, keep_old=False, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + @pytest.mark.skip("query_must_should") + def test_query_must_should(self, connect, collection): + ''' + method: build query must, and with should.term + expected: error raised + ''' + expr = {"should": gen_default_term_expr} + query = update_query_expr(default_query, keep_old=True, expr=expr) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + +""" +****************************************************************** +# The following cases are used to test `search` function +# with invalid collection_name, or invalid query expr +****************************************************************** +""" + + +class TestSearchInvalid(object): + """ + Test search collection with invalid collection names + """ + + @pytest.fixture( + scope="function", + params=gen_invalid_strs() + ) + def get_collection_name(self, request): + yield request.param + + @pytest.fixture( + scope="function", + params=gen_invalid_strs() + ) + def get_invalid_tag(self, request): + yield request.param + + @pytest.fixture( + scope="function", + params=gen_invalid_strs() + ) + def get_invalid_field(self, request): + yield request.param + + @pytest.fixture( + scope="function", + params=gen_simple_index() + ) + def get_simple_index(self, request, connect): + if str(connect._cmd("mode")) == "CPU": + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("sq8h not support in CPU mode") + return request.param + + # PASS + @pytest.mark.level(2) + def test_search_with_invalid_collection(self, connect, get_collection_name): + collection_name = get_collection_name + with pytest.raises(Exception) as e: + res = connect.search(collection_name, default_query) + + # PASS + # TODO(yukun) + @pytest.mark.level(2) + def test_search_with_invalid_tag(self, connect, collection): + tag = " " + with pytest.raises(Exception) as e: + res = connect.search(collection, default_query, partition_tags=tag) + + # TODO: reopen after we supporting targetEntry + @pytest.mark.skip("search_with_invalid_field_name") + @pytest.mark.level(2) + def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field): + fields = [get_invalid_field] + with pytest.raises(Exception) as e: + res = connect.search(collection, default_query, fields=fields) + + # TODO: reopen after we supporting targetEntry + @pytest.mark.skip("search_with_not_existed_field_name") + @pytest.mark.level(1) + def test_search_with_not_existed_field_name(self, connect, collection): + fields = [gen_unique_str("field_name")] + with pytest.raises(Exception) as e: + res = connect.search(collection, default_query, fields=fields) + + """ + Test search collection with invalid query + """ + + @pytest.fixture( + scope="function", + params=gen_invalid_ints() + ) + def get_top_k(self, request): + yield request.param + + @pytest.mark.level(1) + def test_search_with_invalid_top_k(self, connect, collection, get_top_k): + ''' + target: test search function, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = get_top_k + default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k + with pytest.raises(Exception) as e: + res = connect.search(collection, default_query) + + """ + Test search collection with invalid search params + """ + + @pytest.fixture( + scope="function", + params=gen_invaild_search_params() + ) + def get_search_params(self, request): + yield request.param + + # TODO: reopen after we supporting create index + @pytest.mark.skip("search_with_invalid_params") + @pytest.mark.level(2) + def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): + ''' + target: test search function, with the wrong nprobe + method: search with nprobe + expected: raise an error, and the connection is normal + ''' + search_params = get_search_params + index_type = get_simple_index["index_type"] + if index_type in ["FLAT"]: + pytest.skip("skip in FLAT index") + if index_type != search_params["index_type"]: + pytest.skip("skip if index_type not matched") + entities, ids = init_data(connect, collection) + connect.create_index(collection, field_name, get_simple_index) + query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, + search_params=search_params["search_params"]) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + # TODO: reopen after we supporting binary type + @pytest.mark.skip("search_with_invalid_params_binary") + @pytest.mark.level(2) + def test_search_with_invalid_params_binary(self, connect, binary_collection): + ''' + target: test search function, with the wrong nprobe + method: search with nprobe + expected: raise an error, and the connection is normal + ''' + nq = 1 + index_type = "BIN_IVF_FLAT" + int_vectors, entities, ids = init_binary_data(connect, binary_collection) + query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) + connect.create_index(binary_collection, binary_field_name, + {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 128}}) + query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, + search_params={"nprobe": 0}, metric_type="JACCARD") + with pytest.raises(Exception) as e: + res = connect.search(binary_collection, query) + + @pytest.mark.skip("search_with_empty_params") + @pytest.mark.level(2) + def test_search_with_empty_params(self, connect, collection, args, get_simple_index): + ''' + target: test search function, with empty search params + method: search with params + expected: raise an error, and the connection is normal + ''' + index_type = get_simple_index["index_type"] + if args["handler"] == "HTTP": + pytest.skip("skip in http mode") + if index_type == "FLAT": + pytest.skip("skip in FLAT index") + entities, ids = init_data(connect, collection) + connect.create_index(collection, field_name, get_simple_index) + query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params={}) + with pytest.raises(Exception) as e: + res = connect.search(collection, query) + + +def check_id_result(result, id): + limit_in = 5 + ids = [entity.id for entity in result] + if len(result) >= limit_in: + return id in ids[:limit_in] + else: + return id in ids diff --git a/tests/python/utils.py b/tests/python/utils.py index 722d51ec74..2f42075ae4 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -235,7 +235,7 @@ def gen_single_filter_fields(): def gen_single_vector_fields(): fields = [] for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: - field = {"name": data_type.name, "type": data_type, "params": {"dim": default_dim}} + field = {"name": data_type.name, "type": data_type, "params": {"dim": default_dim}, "indexes": [{"metric_type": "L2"}]} fields.append(field) return fields @@ -243,9 +243,11 @@ def gen_single_vector_fields(): def gen_default_fields(auto_id=True): default_fields = { "fields": [ - {"name": "int64", "type": DataType.INT64}, + {"name": "int64", "type": DataType.INT64, "is_primary_key": not auto_id}, {"name": "float", "type": DataType.FLOAT}, - {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": default_dim}}, + {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, + "params": {"dim": default_dim}, + "indexes": [{"metric_type": "L2"}]}, ], "segment_row_limit": default_segment_row_limit, "auto_id": auto_id @@ -974,19 +976,19 @@ def restart_server(helm_release_name): return res -class TestThread(threading.Thread): +class MilvusTestThread(threading.Thread): def __init__(self, target, args=()): threading.Thread.__init__(self, target=target, args=args) def run(self): self.exc = None try: - super(TestThread, self).run() + super(MilvusTestThread, self).run() except BaseException as e: self.exc = e def join(self): - super(TestThread, self).join() + super(MilvusTestThread, self).join() if self.exc: raise self.exc