add case nprobe>2048 (#2906)

pull/2933/head
ThreadDao 2020-07-18 19:31:28 +08:00 committed by GitHub
parent f31a81ab16
commit bd02b19a71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 59 additions and 23 deletions

View File

@ -34,7 +34,7 @@ class TestSearchBase:
global vectors
if nb == 6000:
add_vectors = vectors
else:
else:
add_vectors = gen_vectors(nb, dim)
add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2')
add_vectors = add_vectors.tolist()
@ -57,7 +57,7 @@ class TestSearchBase:
if nb == 6000:
add_vectors = binary_vectors
add_raw_vectors = raw_vectors
else:
else:
add_raw_vectors, add_vectors = gen_binary_vectors(nb, dim)
if insert is True:
if partition_tags is None:
@ -72,6 +72,7 @@ class TestSearchBase:
"""
generate valid create_index params
"""
@pytest.fixture(
scope="function",
params=gen_index()
@ -128,6 +129,7 @@ class TestSearchBase:
"""
generate top-k params
"""
@pytest.fixture(
scope="function",
params=[1, 99, 1024, 2049]
@ -135,7 +137,6 @@ class TestSearchBase:
def get_top_k(self, request):
yield request.param
def test_search_top_k_flat_index(self, connect, collection, get_top_k):
'''
target: test basic search fuction, all the search params is corrent, change top-k value
@ -301,7 +302,8 @@ class TestSearchBase:
query_vec = [vectors[0]]
top_k = 10
search_param = get_search_param(index_type)
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], params=search_param)
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"],
params=search_param)
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
@ -349,7 +351,8 @@ class TestSearchBase:
status = connect.create_index(collection, index_type, index_param)
query_vec = [vectors[0], new_vectors[0]]
search_param = get_search_param(index_type)
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag], params=search_param)
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag],
params=search_param)
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
@ -432,7 +435,7 @@ class TestSearchBase:
vectors, ids = self.init_data(connect, ip_collection)
status = connect.create_index(ip_collection, index_type, index_param)
query_vec = []
for i in range (1200):
for i in range(1200):
query_vec.append(vectors[i])
top_k = 10
search_param = get_search_param(index_type)
@ -532,7 +535,7 @@ class TestSearchBase:
collection_name = None
nprobe = 1
query_vecs = [vectors[0]]
with pytest.raises(Exception) as e:
with pytest.raises(Exception) as e:
status, result = connect.search(collection_name, top_k, query_vecs)
def test_search_top_k_query_records(self, connect, collection):
@ -543,7 +546,7 @@ class TestSearchBase:
'''
top_k = 10
vectors, ids = self.init_data(connect, collection)
query_vecs = [vectors[0],vectors[55],vectors[99]]
query_vecs = [vectors[0], vectors[55], vectors[99]]
status, result = connect.search(collection, top_k, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
@ -563,7 +566,8 @@ class TestSearchBase:
distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
status, result = connect.search(collection, top_k, query_vecs)
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(
result[0][0].distance)
def test_search_distance_ip_flat_index(self, connect, ip_collection):
'''
@ -653,7 +657,8 @@ class TestSearchBase:
connect.create_index(substructure_collection, index_type, index_param)
logging.getLogger().info(connect.get_collection_info(substructure_collection))
logging.getLogger().info(connect.get_index_info(substructure_collection))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False)
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1,
insert=False)
distance_0 = substructure(query_int_vectors[0], int_vectors[0])
distance_1 = substructure(query_int_vectors[0], int_vectors[1])
search_param = get_search_param(index_type)
@ -683,7 +688,7 @@ class TestSearchBase:
search_param = get_search_param(index_type)
status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
logging.getLogger().info(result)
assert len(result[0]) == 1
assert len(result[1]) == 1
assert result[0][0].distance <= epsilon
@ -707,7 +712,8 @@ class TestSearchBase:
connect.create_index(superstructure_collection, index_type, index_param)
logging.getLogger().info(connect.get_collection_info(superstructure_collection))
logging.getLogger().info(connect.get_index_info(superstructure_collection))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False)
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1,
insert=False)
distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
search_param = get_search_param(index_type)
@ -843,7 +849,8 @@ class TestSearchBase:
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
milvus.create_collection(param)
vectors, ids = self.init_data(milvus, collection, nb=nb)
query_vecs = vectors[nb//2:nb]
query_vecs = vectors[nb // 2:nb]
def search(milvus):
status, result = milvus.search(collection, top_k, query_vecs)
assert len(result) == len(query_vecs)
@ -853,7 +860,7 @@ class TestSearchBase:
for i in range(threads_num):
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
t = threading.Thread(target=search, args=(milvus, ))
t = threading.Thread(target=search, args=(milvus,))
threads.append(t)
t.start()
time.sleep(0.2)
@ -875,14 +882,15 @@ class TestSearchBase:
collection = gen_unique_str("test_search_concurrent_multiprocessing")
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'collection_name': collection,
'dimension': dim,
'index_type': IndexType.FLAT,
'store_raw_vector': False}
'dimension': dim,
'index_type': IndexType.FLAT,
'store_raw_vector': False}
# create collection
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
milvus.create_collection(param)
vectors, ids = self.init_data(milvus, collection, nb=nb)
query_vecs = vectors[nb//2:nb]
query_vecs = vectors[nb // 2:nb]
def search(milvus):
status, result = milvus.search(collection, top_k, query_vecs)
assert len(result) == len(query_vecs)
@ -892,7 +900,7 @@ class TestSearchBase:
for i in range(process_num):
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
p = Process(target=search, args=(milvus, ))
p = Process(target=search, args=(milvus,))
processes.append(p)
p.start()
time.sleep(0.2)
@ -978,6 +986,8 @@ class TestSearchBase:
assert len(result[j]) == top_k
for j in range(len(query_vecs)):
assert check_result(result[j], idx[3 * i + j])
"""
******************************************************************
# The following cases are used to test `search_vectors` function
@ -985,6 +995,7 @@ class TestSearchBase:
******************************************************************
"""
class TestSearchParamsInvalid(object):
nlist = 16384
index_type = IndexType.IVF_SQ8
@ -998,15 +1009,16 @@ class TestSearchParamsInvalid(object):
global vectors
if nb == 6000:
insert = vectors
else:
else:
insert = gen_vectors(nb, dim)
status, ids = connect.insert(collection, insert)
sleep(add_interval_time)
connect.flush([collection])
return insert, ids
"""
Test search collection with invalid collection names
"""
@pytest.fixture(
scope="function",
params=gen_invalid_collection_names()
@ -1018,14 +1030,14 @@ class TestSearchParamsInvalid(object):
def test_search_with_invalid_collectionname(self, connect, get_collection_name):
collection_name = get_collection_name
logging.getLogger().info(collection_name)
nprobe = 1
nprobe = 1
query_vecs = gen_vectors(1, dim)
status, result = connect.search(collection_name, top_k, query_vecs)
assert not status.OK()
@pytest.mark.level(1)
def test_search_with_invalid_tag_format(self, connect, collection):
nprobe = 1
nprobe = 1
query_vecs = gen_vectors(1, dim)
with pytest.raises(Exception) as e:
status, result = connect.search(collection, top_k, query_vecs, partition_tags="tag")
@ -1042,6 +1054,7 @@ class TestSearchParamsInvalid(object):
"""
Test search collection with invalid top-k
"""
@pytest.fixture(
scope="function",
params=gen_invalid_top_ks()
@ -1084,9 +1097,11 @@ class TestSearchParamsInvalid(object):
else:
with pytest.raises(Exception) as e:
status, result = connect.search(ip_collection, top_k, query_vecs)
"""
Test search collection with invalid nprobe
"""
@pytest.fixture(
scope="function",
params=gen_invalid_nprobes()
@ -1137,6 +1152,26 @@ class TestSearchParamsInvalid(object):
# with pytest.raises(Exception) as e:
# status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
def test_search_with_2049_nprobe(self, connect, collection):
'''
target: test search function, with 2049 nprobe in GPU mode
method: search with nprobe
expected: status not ok
'''
if str(connect._cmd("mode")[1]) == "CPU":
pytest.skip("Only support GPU mode")
for index in gen_simple_index():
if index["index_type"] in [IndexType.IVF_PQ, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]:
index_type = index["index_type"]
index_param = index["index_param"]
self.init_data(connect, collection)
connect.create_index(collection, index_type, index_param)
nprobe = 2049
search_param = {"nprobe": nprobe}
query_vecs = gen_vectors(nprobe, dim)
status, result = connect.search(collection, top_k, query_vecs, params=search_param)
assert not status.OK()
@pytest.fixture(
scope="function",
params=gen_simple_index()
@ -1197,6 +1232,7 @@ class TestSearchParamsInvalid(object):
status, result = connect.search(collection, top_k, query_vecs, params=search_param)
assert not status.OK()
def check_result(result, id):
if len(result) >= 5:
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]