mirror of https://github.com/milvus-io/milvus.git
[skip ci] Filter id:-1 (#2274)
* add async Signed-off-by: zw <zw@zilliz.com> * Update case Signed-off-by: zw <zw@zilliz.com> * Update some cases Signed-off-by: zw <zw@zilliz.com> * add has_partition case Signed-off-by: zw <zw@zilliz.com> * add search_by_id case: ids duplicate Signed-off-by: zw <zw@zilliz.com> * add async add Signed-off-by: zw <zw@zilliz.com> * add flush case Signed-off-by: zw <zw@zilliz.com> * update case Signed-off-by: zw <zw@zilliz.com> * fix search by id case Signed-off-by: zw <zw@zilliz.com> * update case Signed-off-by: zw <zw@zilliz.com> * filter id:-1 Signed-off-by: zw <zw@zilliz.com> * pq distance Signed-off-by: zw <zw@zilliz.com> * [skip ci] skip ci Signed-off-by: zw <zw@zilliz.com>pull/2234/head
parent
d49b609a14
commit
7f744c036b
|
@ -1,4 +1,5 @@
|
|||
import time
|
||||
import pdb
|
||||
import threading
|
||||
import logging
|
||||
import threading
|
||||
|
@ -676,6 +677,121 @@ class TestAddBase:
|
|||
status, ids = connect.add_vectors(collection_name=collection_list[i], records=vectors)
|
||||
assert status.OK()
|
||||
|
||||
class TestAddAsync:
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
1000
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
yield request.param
|
||||
|
||||
def check_status(self, status, result):
|
||||
logging.getLogger().info("In callback check status")
|
||||
assert status.OK()
|
||||
|
||||
def check_status_not_ok(self, status, result):
|
||||
logging.getLogger().info("In callback check status")
|
||||
assert not status.OK()
|
||||
|
||||
|
||||
def test_insert_async(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test add vectors with different length of vectors
|
||||
method: set different vectors as add method params
|
||||
expected: length of ids is equal to the length of vectors
|
||||
'''
|
||||
nb = insert_count
|
||||
insert_vec_list = gen_vectors(nb, dim)
|
||||
future = connect.add_vectors(collection, insert_vec_list, _async=True)
|
||||
status, ids = future.result()
|
||||
connect.flush([collection])
|
||||
assert len(ids) == nb
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_async_false(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test add vectors with different length of vectors
|
||||
method: set different vectors as add method params
|
||||
expected: length of ids is equal to the length of vectors
|
||||
'''
|
||||
nb = insert_count
|
||||
insert_vec_list = gen_vectors(nb, dim)
|
||||
status, ids = connect.add_vectors(collection, insert_vec_list, _async=False)
|
||||
connect.flush([collection])
|
||||
assert len(ids) == nb
|
||||
assert status.OK()
|
||||
|
||||
def test_insert_async_callback(self, connect, collection, insert_count):
|
||||
'''
|
||||
target: test add vectors with different length of vectors
|
||||
method: set different vectors as add method params
|
||||
expected: length of ids is equal to the length of vectors
|
||||
'''
|
||||
nb = insert_count
|
||||
insert_vec_list = gen_vectors(nb, dim)
|
||||
future = connect.add_vectors(collection, insert_vec_list, _async=True, _callback=self.check_status)
|
||||
future.done()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_async_long(self, connect, collection):
|
||||
'''
|
||||
target: test add vectors with different length of vectors
|
||||
method: set different vectors as add method params
|
||||
expected: length of ids is equal to the length of vectors
|
||||
'''
|
||||
nb = 50000
|
||||
insert_vec_list = gen_vectors(nb, dim)
|
||||
future = connect.add_vectors(collection, insert_vec_list, _async=True, _callback=self.check_status)
|
||||
status, result = future.result()
|
||||
assert status.OK()
|
||||
assert len(result) == nb
|
||||
connect.flush([collection])
|
||||
status, count = connect.count_collection(collection)
|
||||
assert status.OK()
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(count)
|
||||
assert count == nb
|
||||
|
||||
def test_insert_async_callback_timeout(self, connect, collection):
|
||||
'''
|
||||
target: test add vectors with different length of vectors
|
||||
method: set different vectors as add method params
|
||||
expected: length of ids is equal to the length of vectors
|
||||
'''
|
||||
nb = 100000
|
||||
insert_vec_list = gen_vectors(nb, dim)
|
||||
future = connect.add_vectors(collection, insert_vec_list, _async=True, _callback=self.check_status, timeout=1)
|
||||
future.done()
|
||||
|
||||
def test_insert_async_invalid_params(self, connect, collection):
|
||||
'''
|
||||
target: test add vectors with different length of vectors
|
||||
method: set different vectors as add method params
|
||||
expected: length of ids is equal to the length of vectors
|
||||
'''
|
||||
insert_vec_list = gen_vectors(nb, dim)
|
||||
collection_new = gen_unique_str()
|
||||
future = connect.add_vectors(collection_new, insert_vec_list, _async=True)
|
||||
status, result = future.result()
|
||||
assert not status.OK()
|
||||
|
||||
# TODO: add assertion
|
||||
def test_insert_async_invalid_params_raise_exception(self, connect, collection):
|
||||
'''
|
||||
target: test add vectors with different length of vectors
|
||||
method: set different vectors as add method params
|
||||
expected: length of ids is equal to the length of vectors
|
||||
'''
|
||||
insert_vec_list = []
|
||||
collection_new = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
future = connect.add_vectors(collection_new, insert_vec_list, _async=True)
|
||||
|
||||
|
||||
class TestAddIP:
|
||||
"""
|
||||
******************************************************************
|
||||
|
|
|
@ -233,6 +233,44 @@ class TestFlushBase:
|
|||
assert res == 0
|
||||
|
||||
|
||||
class TestFlushAsync:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `flush` function
|
||||
******************************************************************
|
||||
"""
|
||||
def check_status(self, status, result):
|
||||
logging.getLogger().info("In callback check status")
|
||||
assert status.OK()
|
||||
|
||||
def test_flush_empty_collection(self, connect, collection):
|
||||
'''
|
||||
method: flush collection with no vectors
|
||||
expected: status ok
|
||||
'''
|
||||
future = connect.flush([collection], _async=True)
|
||||
status = future.result()
|
||||
assert status.OK()
|
||||
|
||||
def test_flush_async(self, connect, collection):
|
||||
vectors = gen_vectors(nb, dim)
|
||||
status, ids = connect.add_vectors(collection, vectors)
|
||||
future = connect.flush([collection], _async=True)
|
||||
status = future.result()
|
||||
assert status.OK()
|
||||
|
||||
def test_flush_async(self, connect, collection):
|
||||
nb = 100000
|
||||
vectors = gen_vectors(nb, dim)
|
||||
connect.add_vectors(collection, vectors)
|
||||
logging.getLogger().info("before")
|
||||
future = connect.flush([collection], _async=True, _callback=self.check_status)
|
||||
logging.getLogger().info("after")
|
||||
future.done()
|
||||
status = future.result()
|
||||
assert status.OK()
|
||||
|
||||
|
||||
class TestCollectionNameInvalid(object):
|
||||
"""
|
||||
Test adding vectors with invalid collection names
|
||||
|
|
|
@ -1806,3 +1806,75 @@ class TestCreateIndexParamsInvalid(object):
|
|||
logging.getLogger().info(result)
|
||||
assert result._collection_name == collection
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
class TestIndexAsync:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index()
|
||||
)
|
||||
def get_index(self, request, connect):
|
||||
if str(connect._cmd("mode")[1]) == "CPU":
|
||||
if request.param["index_type"] == IndexType.IVF_SQ8H:
|
||||
pytest.skip("sq8h not support in CPU mode")
|
||||
if str(connect._cmd("mode")[1]) == "GPU":
|
||||
if request.param["index_type"] == IndexType.IVF_PQ:
|
||||
pytest.skip("ivfpq not support in GPU mode")
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
if str(connect._cmd("mode")[1]) == "CPU":
|
||||
if request.param["index_type"] == IndexType.IVF_SQ8H:
|
||||
pytest.skip("sq8h not support in CPU mode")
|
||||
if str(connect._cmd("mode")[1]) == "GPU":
|
||||
# if request.param["index_type"] == IndexType.IVF_PQ:
|
||||
if request.param["index_type"] not in [IndexType.IVF_FLAT]:
|
||||
# pytest.skip("ivfpq not support in GPU mode")
|
||||
pytest.skip("debug ivf_flat in GPU mode")
|
||||
return request.param
|
||||
|
||||
def check_status(self, status):
|
||||
logging.getLogger().info("In callback check status")
|
||||
assert status.OK()
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
index_param = get_simple_index["index_param"]
|
||||
index_type = get_simple_index["index_type"]
|
||||
logging.getLogger().info(get_simple_index)
|
||||
vectors = gen_vectors(nb, dim)
|
||||
status, ids = connect.add_vectors(collection, vectors)
|
||||
logging.getLogger().info("start index")
|
||||
# future = connect.create_index(collection, index_type, index_param, _async=True, _callback=self.check_status)
|
||||
future = connect.create_index(collection, index_type, index_param, _async=True)
|
||||
logging.getLogger().info("before result")
|
||||
status = future.result()
|
||||
assert status.OK()
|
||||
|
||||
def test_create_index_with_invalid_collectionname(self, connect):
|
||||
collection_name = " "
|
||||
nlist = NLIST
|
||||
index_param = {"nlist": nlist}
|
||||
future = connect.create_index(collection_name, IndexType.IVF_SQ8, index_param, _async=True)
|
||||
status = future.result()
|
||||
assert not status.OK()
|
||||
|
||||
|
|
|
@ -228,6 +228,78 @@ class TestShowBase:
|
|||
assert status.OK()
|
||||
|
||||
|
||||
class TestHasBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `has_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_collection_names()
|
||||
)
|
||||
def get_tag_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_has_partition(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: create partition first, then call function: has_partition
|
||||
expected: status ok, result true
|
||||
'''
|
||||
status = connect.create_partition(collection, tag)
|
||||
status, res = connect.has_partition(collection, tag)
|
||||
assert status.OK()
|
||||
logging.getLogger().info(res)
|
||||
assert res
|
||||
|
||||
def test_has_partition_multi_partitions(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: create partition first, then call function: has_partition
|
||||
expected: status ok, result true
|
||||
'''
|
||||
for tag_name in [tag, "tag_new", "tag_new_new"]:
|
||||
status = connect.create_partition(collection, tag_name)
|
||||
for tag_name in [tag, "tag_new", "tag_new_new"]:
|
||||
status, res = connect.has_partition(collection, tag_name)
|
||||
assert status.OK()
|
||||
assert res
|
||||
|
||||
def test_has_partition_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: then call function: has_partition, with tag not existed
|
||||
expected: status ok, result empty
|
||||
'''
|
||||
status, res = connect.has_partition(collection, tag)
|
||||
assert status.OK()
|
||||
logging.getLogger().info(res)
|
||||
assert not res
|
||||
|
||||
def test_has_partition_collection_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: then call function: has_partition, with collection not existed
|
||||
expected: status not ok
|
||||
'''
|
||||
status, res = connect.has_partition("not_existed_collection", tag)
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
'''
|
||||
target: test has partition, with invalid tag name, check status returned
|
||||
method: call function: has_partition
|
||||
expected: status ok
|
||||
'''
|
||||
tag_name = get_tag_name
|
||||
status = connect.create_partition(collection, tag)
|
||||
status, res = connect.has_partition(collection, tag_name)
|
||||
assert status.OK()
|
||||
|
||||
|
||||
class TestDropBase:
|
||||
|
||||
"""
|
||||
|
|
|
@ -29,12 +29,14 @@ raw_vectors, binary_vectors = gen_binary_vectors(6000, dim)
|
|||
|
||||
|
||||
class TestSearchBase:
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def skip_check(self, connect):
|
||||
if str(connect._cmd("mode")[1]) == "CPU" or str(connect._cmd("mode")[1]) == "GPU":
|
||||
reason = "GPU mode not support"
|
||||
logging.getLogger().info(reason)
|
||||
pytest.skip(reason)
|
||||
# @pytest.fixture(scope="function", autouse=True)
|
||||
# def skip_check(self, connect):
|
||||
# if str(connect._cmd("mode")[1]) == "CPU":
|
||||
# if request.param["index_type"] == IndexType.IVF_SQ8H:
|
||||
# pytest.skip("sq8h not support in CPU mode")
|
||||
# if str(connect._cmd("mode")[1]) == "GPU":
|
||||
# if request.param["index_type"] == IndexType.IVF_PQ:
|
||||
# pytest.skip("ivfpq not support in GPU mode")
|
||||
|
||||
def init_data(self, connect, collection, nb=6000):
|
||||
'''
|
||||
|
@ -82,16 +84,6 @@ class TestSearchBase:
|
|||
connect.flush([collection])
|
||||
return add_vectors, ids
|
||||
|
||||
def check_no_result(self, results):
|
||||
if len(results) == 0:
|
||||
return True
|
||||
flag = True
|
||||
for r in results:
|
||||
flag = flag and (r.id == -1)
|
||||
if not flag:
|
||||
return False
|
||||
return flag
|
||||
|
||||
def init_data_partition(self, connect, collection, partition_tag, nb=6000):
|
||||
'''
|
||||
Generate vectors and add it in collection, before search vectors
|
||||
|
@ -104,6 +96,7 @@ class TestSearchBase:
|
|||
add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2')
|
||||
add_vectors = add_vectors.tolist()
|
||||
status, ids = connect.add_vectors(collection, add_vectors, partition_tag=partition_tag)
|
||||
assert status.OK()
|
||||
connect.flush([collection])
|
||||
return add_vectors, ids
|
||||
|
||||
|
@ -178,6 +171,22 @@ class TestSearchBase:
|
|||
assert result[0][0].distance <= epsilon
|
||||
assert check_result(result[0], ids[0])
|
||||
|
||||
def test_search_flat_same_ids(self, connect, collection):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
method: search with the given vector id, check the result
|
||||
expected: search status ok, and the length of the result is top_k
|
||||
'''
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
query_ids = [ids[0], ids[0]]
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert result[0][0].distance <= epsilon
|
||||
assert result[1][0].distance <= epsilon
|
||||
assert check_result(result[0], ids[0])
|
||||
assert check_result(result[1], ids[0])
|
||||
|
||||
def test_search_flat_max_topk(self, connect, collection):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
|
@ -186,7 +195,7 @@ class TestSearchBase:
|
|||
'''
|
||||
top_k = 2049
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
query_ids = ids[0]
|
||||
query_ids = [ids[0]]
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
|
||||
assert not status.OK()
|
||||
|
||||
|
@ -200,7 +209,7 @@ class TestSearchBase:
|
|||
query_ids = non_exist_id
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
def test_search_collection_empty(self, connect, collection):
|
||||
'''
|
||||
|
@ -209,9 +218,11 @@ class TestSearchBase:
|
|||
expected: search status ok, and the length of the result is top_k
|
||||
'''
|
||||
query_ids = non_exist_id
|
||||
logging.getLogger().info(query_ids)
|
||||
logging.getLogger().info(collection)
|
||||
logging.getLogger().info(connect.describe_collection(collection))
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
|
||||
assert status.OK()
|
||||
assert len(result) == 0
|
||||
assert not status.OK()
|
||||
|
||||
def test_search_index_l2(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
|
@ -221,6 +232,8 @@ class TestSearchBase:
|
|||
'''
|
||||
index_param = get_simple_index["index_param"]
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == IndexType.IVF_PQ:
|
||||
pytest.skip("skip pq")
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
status = connect.create_index(collection, index_type, index_param)
|
||||
query_ids = [ids[0]]
|
||||
|
@ -239,6 +252,8 @@ class TestSearchBase:
|
|||
'''
|
||||
index_param = get_simple_index["index_param"]
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == IndexType.IVF_PQ:
|
||||
pytest.skip("skip pq")
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
status = connect.create_index(collection, index_type, index_param)
|
||||
query_ids = ids[0:nq]
|
||||
|
@ -246,7 +261,7 @@ class TestSearchBase:
|
|||
status, result = connect.search_by_ids(collection, query_ids, top_k, params=search_param)
|
||||
assert status.OK()
|
||||
assert len(result) == nq
|
||||
for i in nq:
|
||||
for i in range(nq):
|
||||
assert len(result[i]) == min(len(vectors), top_k)
|
||||
assert result[i][0].distance <= epsilon
|
||||
assert check_result(result[i], ids[i])
|
||||
|
@ -259,17 +274,19 @@ class TestSearchBase:
|
|||
'''
|
||||
index_param = get_simple_index["index_param"]
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == IndexType.IVF_PQ:
|
||||
pytest.skip("skip pq")
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
status = connect.create_index(collection, index_type, index_param)
|
||||
query_ids = ids[0:nq]
|
||||
query_ids[0] = non_exist_id
|
||||
query_ids[0] = 1
|
||||
search_param = get_search_param(index_type)
|
||||
status, result = connect.search_by_ids(collection, [query_ids], top_k, params=search_param)
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, params=search_param)
|
||||
assert status.OK()
|
||||
assert len(result) == nq
|
||||
for i in nq:
|
||||
for i in range(nq):
|
||||
if i == 0:
|
||||
assert result[i].id == -1
|
||||
assert len(result[i]) == 0
|
||||
else:
|
||||
assert len(result[i]) == min(len(vectors), top_k)
|
||||
assert result[i][0].distance <= epsilon
|
||||
|
@ -277,15 +294,16 @@ class TestSearchBase:
|
|||
|
||||
def test_search_index_delete(self, connect, collection):
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
query_ids = ids[0]
|
||||
status = connect.delete_by_id(collection, [query_ids])
|
||||
query_ids = ids[0:nq]
|
||||
status = connect.delete_by_id(collection, [query_ids[0]])
|
||||
assert status.OK()
|
||||
status = connect.flush(collection)
|
||||
status, result = connect.search_by_ids(collection, [query_ids], top_k, params={})
|
||||
status = connect.flush([collection])
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, params={})
|
||||
assert status.OK()
|
||||
assert len(result) == 1
|
||||
assert result[0][0].distance <= epsilon
|
||||
assert result[0][0].id != ids[0]
|
||||
assert len(result) == nq
|
||||
assert len(result[0]) == 0
|
||||
assert len(result[1]) == top_k
|
||||
assert result[1][0].distance <= epsilon
|
||||
|
||||
def test_search_l2_partition_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
|
@ -295,28 +313,31 @@ class TestSearchBase:
|
|||
'''
|
||||
status = connect.create_partition(collection, tag)
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
query_ids = ids[0]
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag], params=search_param)
|
||||
assert status.OK()
|
||||
query_ids = [ids[0]]
|
||||
new_tag = gen_unique_str()
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[new_tag], params={})
|
||||
assert not status.OK()
|
||||
logging.getLogger().info(status)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_search_l2_partition_other(self, connect, collection):
|
||||
tag = gen_unique_str()
|
||||
def test_search_l2_partition_empty(self, connect, collection):
|
||||
status = connect.create_partition(collection, tag)
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
query_ids = ids[0]
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag], params=search_param)
|
||||
assert status.OK()
|
||||
query_ids = [ids[0]]
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag], params={})
|
||||
assert not status.OK()
|
||||
logging.getLogger().info(status)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_search_l2_partition(self, connect, collection):
|
||||
status = connect.create_partition(collection, tag)
|
||||
vectors, ids = self.init_data_partition(connect, collection, tag)
|
||||
query_ids = ids[-1]
|
||||
query_ids = ids[-1:]
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag])
|
||||
assert status.OK()
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert check_result(result[0], query_ids)
|
||||
assert check_result(result[0], query_ids[-1])
|
||||
|
||||
def test_search_l2_partition_B(self, connect, collection):
|
||||
status = connect.create_partition(collection, tag)
|
||||
|
@ -325,7 +346,7 @@ class TestSearchBase:
|
|||
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag])
|
||||
assert status.OK()
|
||||
assert len(result) == nq
|
||||
for i in nq:
|
||||
for i in range(nq):
|
||||
assert len(result[i]) == min(len(vectors), top_k)
|
||||
assert result[i][0].distance <= epsilon
|
||||
assert check_result(result[i], ids[i])
|
||||
|
@ -338,14 +359,17 @@ class TestSearchBase:
|
|||
vectors, new_ids = self.init_data_partition(connect, collection, new_tag, nb=nb+1)
|
||||
tmp = 2
|
||||
query_ids = ids[0:tmp]
|
||||
query_ids.extend(new_ids[0:nq-tmp])
|
||||
query_ids.extend(new_ids[tmp:nq])
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[tag, new_tag], params={})
|
||||
assert status.OK()
|
||||
assert len(result) == nq
|
||||
for i in nq:
|
||||
for i in range(nq):
|
||||
assert len(result[i]) == min(len(vectors), top_k)
|
||||
assert result[i][0].distance <= epsilon
|
||||
assert check_result(result[i], ids[i])
|
||||
if i < tmp:
|
||||
assert result[i][0].id == ids[i]
|
||||
else:
|
||||
assert result[i][0].id == new_ids[i]
|
||||
|
||||
def test_search_l2_index_partitions_match_one_tag(self, connect, collection):
|
||||
new_tag = "new_tag"
|
||||
|
@ -355,18 +379,19 @@ class TestSearchBase:
|
|||
vectors, new_ids = self.init_data_partition(connect, collection, new_tag, nb=nb+1)
|
||||
tmp = 2
|
||||
query_ids = ids[0:tmp]
|
||||
query_ids.extend(new_ids[0:nq-tmp])
|
||||
query_ids.extend(new_ids[tmp:nq])
|
||||
status, result = connect.search_by_ids(collection, query_ids, top_k, partition_tags=[new_tag], params={})
|
||||
assert status.OK()
|
||||
assert len(result) == nq
|
||||
for i in nq:
|
||||
for i in range(nq):
|
||||
if i < tmp:
|
||||
assert result[i][0].distance > epsilon
|
||||
assert result[i][0].id != ids[i]
|
||||
else:
|
||||
assert len(result[i]) == min(len(vectors), top_k)
|
||||
assert result[i][0].distance <= epsilon
|
||||
assert check_result(result[i], ids[i])
|
||||
assert result[i][0].id == new_ids[i]
|
||||
assert result[i][1].distance > epsilon
|
||||
|
||||
# def test_search_by_ids_without_connect(self, dis_connect, collection):
|
||||
# '''
|
||||
|
@ -411,7 +436,7 @@ class TestSearchBase:
|
|||
status, result = connect.search_by_ids(jac_collection, query_ids, top_k, params=search_param)
|
||||
assert status.OK()
|
||||
assert len(result) == nq
|
||||
for i in nq:
|
||||
for i in range(nq):
|
||||
assert len(result[i]) == min(len(vectors), top_k)
|
||||
assert result[i][0].distance <= epsilon
|
||||
assert check_result(result[i], ids[i])
|
||||
|
@ -499,7 +524,7 @@ class TestSearchParamsInvalid(object):
|
|||
|
||||
|
||||
def check_result(result, id):
|
||||
if len(result) >= 5:
|
||||
return id in [x.id for x in result[:5]]
|
||||
if len(result) >= top_k:
|
||||
return id in [x.id for x in result[:top_k]]
|
||||
else:
|
||||
return id in (i.id for i in result)
|
||||
|
|
|
@ -666,7 +666,7 @@ class TestSearchBase:
|
|||
status, result = connect.search_vectors(substructure_collection, top_k, query_vecs, params=search_param)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(result)
|
||||
assert result[0][0].id == -1
|
||||
assert len(result[0]) == 0
|
||||
|
||||
def test_search_distance_substructure_flat_index_B(self, connect, substructure_collection):
|
||||
'''
|
||||
|
@ -690,12 +690,12 @@ class TestSearchBase:
|
|||
status, result = connect.search_vectors(substructure_collection, top_k, query_vecs, params=search_param)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(result)
|
||||
assert len(result[0]) == 1
|
||||
assert len(result[1]) == 1
|
||||
assert result[0][0].distance <= epsilon
|
||||
assert result[0][0].id == ids[0]
|
||||
assert result[1][0].distance <= epsilon
|
||||
assert result[1][0].id == ids[1]
|
||||
assert result[0][1].id == -1
|
||||
assert result[1][1].id == -1
|
||||
|
||||
def test_search_distance_superstructure_flat_index(self, connect, superstructure_collection):
|
||||
'''
|
||||
|
@ -720,7 +720,7 @@ class TestSearchBase:
|
|||
status, result = connect.search_vectors(superstructure_collection, top_k, query_vecs, params=search_param)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(result)
|
||||
assert result[0][0].id == -1
|
||||
assert len(result[0]) == 0
|
||||
|
||||
def test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection):
|
||||
'''
|
||||
|
@ -744,12 +744,12 @@ class TestSearchBase:
|
|||
status, result = connect.search_vectors(superstructure_collection, top_k, query_vecs, params=search_param)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(result)
|
||||
assert len(result[0]) == 2
|
||||
assert len(result[1]) == 2
|
||||
assert result[0][0].id in ids
|
||||
assert result[0][0].distance <= epsilon
|
||||
assert result[1][0].id in ids
|
||||
assert result[1][0].distance <= epsilon
|
||||
assert result[0][2].id == -1
|
||||
assert result[1][2].id == -1
|
||||
|
||||
def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection):
|
||||
'''
|
||||
|
|
Loading…
Reference in New Issue