mirror of https://github.com/milvus-io/milvus.git
add case nprobe>2048 (#2906)
parent
f31a81ab16
commit
bd02b19a71
|
@ -34,7 +34,7 @@ class TestSearchBase:
|
|||
global vectors
|
||||
if nb == 6000:
|
||||
add_vectors = vectors
|
||||
else:
|
||||
else:
|
||||
add_vectors = gen_vectors(nb, dim)
|
||||
add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2')
|
||||
add_vectors = add_vectors.tolist()
|
||||
|
@ -57,7 +57,7 @@ class TestSearchBase:
|
|||
if nb == 6000:
|
||||
add_vectors = binary_vectors
|
||||
add_raw_vectors = raw_vectors
|
||||
else:
|
||||
else:
|
||||
add_raw_vectors, add_vectors = gen_binary_vectors(nb, dim)
|
||||
if insert is True:
|
||||
if partition_tags is None:
|
||||
|
@ -72,6 +72,7 @@ class TestSearchBase:
|
|||
"""
|
||||
generate valid create_index params
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index()
|
||||
|
@ -128,6 +129,7 @@ class TestSearchBase:
|
|||
"""
|
||||
generate top-k params
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[1, 99, 1024, 2049]
|
||||
|
@ -135,7 +137,6 @@ class TestSearchBase:
|
|||
def get_top_k(self, request):
|
||||
yield request.param
|
||||
|
||||
|
||||
def test_search_top_k_flat_index(self, connect, collection, get_top_k):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
|
@ -301,7 +302,8 @@ class TestSearchBase:
|
|||
query_vec = [vectors[0]]
|
||||
top_k = 10
|
||||
search_param = get_search_param(index_type)
|
||||
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], params=search_param)
|
||||
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"],
|
||||
params=search_param)
|
||||
logging.getLogger().info(result)
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
|
@ -349,7 +351,8 @@ class TestSearchBase:
|
|||
status = connect.create_index(collection, index_type, index_param)
|
||||
query_vec = [vectors[0], new_vectors[0]]
|
||||
search_param = get_search_param(index_type)
|
||||
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag], params=search_param)
|
||||
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag],
|
||||
params=search_param)
|
||||
logging.getLogger().info(result)
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
|
@ -432,7 +435,7 @@ class TestSearchBase:
|
|||
vectors, ids = self.init_data(connect, ip_collection)
|
||||
status = connect.create_index(ip_collection, index_type, index_param)
|
||||
query_vec = []
|
||||
for i in range (1200):
|
||||
for i in range(1200):
|
||||
query_vec.append(vectors[i])
|
||||
top_k = 10
|
||||
search_param = get_search_param(index_type)
|
||||
|
@ -532,7 +535,7 @@ class TestSearchBase:
|
|||
collection_name = None
|
||||
nprobe = 1
|
||||
query_vecs = [vectors[0]]
|
||||
with pytest.raises(Exception) as e:
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search(collection_name, top_k, query_vecs)
|
||||
|
||||
def test_search_top_k_query_records(self, connect, collection):
|
||||
|
@ -543,7 +546,7 @@ class TestSearchBase:
|
|||
'''
|
||||
top_k = 10
|
||||
vectors, ids = self.init_data(connect, collection)
|
||||
query_vecs = [vectors[0],vectors[55],vectors[99]]
|
||||
query_vecs = [vectors[0], vectors[55], vectors[99]]
|
||||
status, result = connect.search(collection, top_k, query_vecs)
|
||||
assert status.OK()
|
||||
assert len(result) == len(query_vecs)
|
||||
|
@ -563,7 +566,8 @@ class TestSearchBase:
|
|||
distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
|
||||
distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
|
||||
status, result = connect.search(collection, top_k, query_vecs)
|
||||
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
|
||||
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(
|
||||
result[0][0].distance)
|
||||
|
||||
def test_search_distance_ip_flat_index(self, connect, ip_collection):
|
||||
'''
|
||||
|
@ -653,7 +657,8 @@ class TestSearchBase:
|
|||
connect.create_index(substructure_collection, index_type, index_param)
|
||||
logging.getLogger().info(connect.get_collection_info(substructure_collection))
|
||||
logging.getLogger().info(connect.get_index_info(substructure_collection))
|
||||
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False)
|
||||
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1,
|
||||
insert=False)
|
||||
distance_0 = substructure(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = substructure(query_int_vectors[0], int_vectors[1])
|
||||
search_param = get_search_param(index_type)
|
||||
|
@ -683,7 +688,7 @@ class TestSearchBase:
|
|||
search_param = get_search_param(index_type)
|
||||
status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(result)
|
||||
logging.getLogger().info(result)
|
||||
assert len(result[0]) == 1
|
||||
assert len(result[1]) == 1
|
||||
assert result[0][0].distance <= epsilon
|
||||
|
@ -707,7 +712,8 @@ class TestSearchBase:
|
|||
connect.create_index(superstructure_collection, index_type, index_param)
|
||||
logging.getLogger().info(connect.get_collection_info(superstructure_collection))
|
||||
logging.getLogger().info(connect.get_index_info(superstructure_collection))
|
||||
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False)
|
||||
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1,
|
||||
insert=False)
|
||||
distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
|
||||
search_param = get_search_param(index_type)
|
||||
|
@ -843,7 +849,8 @@ class TestSearchBase:
|
|||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
milvus.create_collection(param)
|
||||
vectors, ids = self.init_data(milvus, collection, nb=nb)
|
||||
query_vecs = vectors[nb//2:nb]
|
||||
query_vecs = vectors[nb // 2:nb]
|
||||
|
||||
def search(milvus):
|
||||
status, result = milvus.search(collection, top_k, query_vecs)
|
||||
assert len(result) == len(query_vecs)
|
||||
|
@ -853,7 +860,7 @@ class TestSearchBase:
|
|||
|
||||
for i in range(threads_num):
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
t = threading.Thread(target=search, args=(milvus, ))
|
||||
t = threading.Thread(target=search, args=(milvus,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
|
@ -875,14 +882,15 @@ class TestSearchBase:
|
|||
collection = gen_unique_str("test_search_concurrent_multiprocessing")
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
param = {'collection_name': collection,
|
||||
'dimension': dim,
|
||||
'index_type': IndexType.FLAT,
|
||||
'store_raw_vector': False}
|
||||
'dimension': dim,
|
||||
'index_type': IndexType.FLAT,
|
||||
'store_raw_vector': False}
|
||||
# create collection
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
milvus.create_collection(param)
|
||||
vectors, ids = self.init_data(milvus, collection, nb=nb)
|
||||
query_vecs = vectors[nb//2:nb]
|
||||
query_vecs = vectors[nb // 2:nb]
|
||||
|
||||
def search(milvus):
|
||||
status, result = milvus.search(collection, top_k, query_vecs)
|
||||
assert len(result) == len(query_vecs)
|
||||
|
@ -892,7 +900,7 @@ class TestSearchBase:
|
|||
|
||||
for i in range(process_num):
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
p = Process(target=search, args=(milvus, ))
|
||||
p = Process(target=search, args=(milvus,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
time.sleep(0.2)
|
||||
|
@ -978,6 +986,8 @@ class TestSearchBase:
|
|||
assert len(result[j]) == top_k
|
||||
for j in range(len(query_vecs)):
|
||||
assert check_result(result[j], idx[3 * i + j])
|
||||
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
# The following cases are used to test `search_vectors` function
|
||||
|
@ -985,6 +995,7 @@ class TestSearchBase:
|
|||
******************************************************************
|
||||
"""
|
||||
|
||||
|
||||
class TestSearchParamsInvalid(object):
|
||||
nlist = 16384
|
||||
index_type = IndexType.IVF_SQ8
|
||||
|
@ -998,15 +1009,16 @@ class TestSearchParamsInvalid(object):
|
|||
global vectors
|
||||
if nb == 6000:
|
||||
insert = vectors
|
||||
else:
|
||||
else:
|
||||
insert = gen_vectors(nb, dim)
|
||||
status, ids = connect.insert(collection, insert)
|
||||
sleep(add_interval_time)
|
||||
connect.flush([collection])
|
||||
return insert, ids
|
||||
|
||||
"""
|
||||
Test search collection with invalid collection names
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_collection_names()
|
||||
|
@ -1018,14 +1030,14 @@ class TestSearchParamsInvalid(object):
|
|||
def test_search_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
logging.getLogger().info(collection_name)
|
||||
nprobe = 1
|
||||
nprobe = 1
|
||||
query_vecs = gen_vectors(1, dim)
|
||||
status, result = connect.search(collection_name, top_k, query_vecs)
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_search_with_invalid_tag_format(self, connect, collection):
|
||||
nprobe = 1
|
||||
nprobe = 1
|
||||
query_vecs = gen_vectors(1, dim)
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search(collection, top_k, query_vecs, partition_tags="tag")
|
||||
|
@ -1042,6 +1054,7 @@ class TestSearchParamsInvalid(object):
|
|||
"""
|
||||
Test search collection with invalid top-k
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_top_ks()
|
||||
|
@ -1084,9 +1097,11 @@ class TestSearchParamsInvalid(object):
|
|||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search(ip_collection, top_k, query_vecs)
|
||||
|
||||
"""
|
||||
Test search collection with invalid nprobe
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_nprobes()
|
||||
|
@ -1137,6 +1152,26 @@ class TestSearchParamsInvalid(object):
|
|||
# with pytest.raises(Exception) as e:
|
||||
# status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
|
||||
|
||||
def test_search_with_2049_nprobe(self, connect, collection):
|
||||
'''
|
||||
target: test search function, with 2049 nprobe in GPU mode
|
||||
method: search with nprobe
|
||||
expected: status not ok
|
||||
'''
|
||||
if str(connect._cmd("mode")[1]) == "CPU":
|
||||
pytest.skip("Only support GPU mode")
|
||||
for index in gen_simple_index():
|
||||
if index["index_type"] in [IndexType.IVF_PQ, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]:
|
||||
index_type = index["index_type"]
|
||||
index_param = index["index_param"]
|
||||
self.init_data(connect, collection)
|
||||
connect.create_index(collection, index_type, index_param)
|
||||
nprobe = 2049
|
||||
search_param = {"nprobe": nprobe}
|
||||
query_vecs = gen_vectors(nprobe, dim)
|
||||
status, result = connect.search(collection, top_k, query_vecs, params=search_param)
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
|
@ -1197,6 +1232,7 @@ class TestSearchParamsInvalid(object):
|
|||
status, result = connect.search(collection, top_k, query_vecs, params=search_param)
|
||||
assert not status.OK()
|
||||
|
||||
|
||||
def check_result(result, id):
|
||||
if len(result) >= 5:
|
||||
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
|
||||
|
|
Loading…
Reference in New Issue