mirror of https://github.com/milvus-io/milvus.git
enable some cases (#3325)
Signed-off-by: zw <zw@milvus.io> Co-authored-by: zw <zw@milvus.io>pull/3332/head
parent
55f47defdf
commit
d9e9d52f3b
|
@ -1029,9 +1029,8 @@ class TestSearchDSL(object):
|
|||
def get_invalid_term(self, request):
|
||||
return request.param
|
||||
|
||||
# TODO
|
||||
@pytest.mark.level(2)
|
||||
def _test_query_term_wrong_format(self, connect, collection, get_invalid_term):
|
||||
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
|
||||
'''
|
||||
method: build query with wrong format term
|
||||
expected: Exception raised
|
||||
|
@ -1548,9 +1547,8 @@ class TestSearchInvalid(object):
|
|||
def get_search_params(self, request):
|
||||
yield request.param
|
||||
|
||||
# TODO: This case can all pass, but it's too slow
|
||||
@pytest.mark.level(2)
|
||||
def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
|
||||
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
|
||||
'''
|
||||
target: test search fuction, with the wrong nprobe
|
||||
method: search with nprobe
|
||||
|
@ -1560,8 +1558,6 @@ class TestSearchInvalid(object):
|
|||
index_type = get_simple_index["index_type"]
|
||||
entities, ids = init_data(connect, collection)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
if search_params["index_type"] != index_type:
|
||||
pytest.skip("Skip case")
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"])
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
|
|
@ -249,36 +249,32 @@ class TestFlushBase:
|
|||
assert res
|
||||
|
||||
# TODO: CI fail, LOCAL pass
|
||||
def _test_collection_count_during_flush(self, connect, args):
|
||||
@pytest.mark.level(2)
|
||||
def test_collection_count_during_flush(self, connect, collection, args):
|
||||
'''
|
||||
method: flush collection at background, call `count_entities`
|
||||
expected: status ok
|
||||
expected: no timeout
|
||||
'''
|
||||
collection = gen_unique_str("test_flush")
|
||||
# param = {'collection_name': collection,
|
||||
# 'dimension': dim,
|
||||
# 'index_file_size': index_file_size,
|
||||
# 'metric_type': MetricType.L2}
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
milvus.create_collection(collection, default_fields)
|
||||
# vectors = gen_vector(nb, dim)
|
||||
ids = milvus.insert(collection, entities, ids=[i for i in range(nb)])
|
||||
|
||||
def flush(collection_name):
|
||||
ids = []
|
||||
for i in range(5):
|
||||
tmp_ids = connect.insert(collection, entities)
|
||||
connect.flush([collection])
|
||||
ids.extend(tmp_ids)
|
||||
disable_flush(connect)
|
||||
status = connect.delete_entity_by_id(collection, ids)
|
||||
def flush():
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
status = milvus.delete_entity_by_id(collection_name, [i for i in range(nb)])
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.flush([collection_name])
|
||||
|
||||
|
||||
p = Process(target=flush, args=(collection,))
|
||||
logging.error("start flush")
|
||||
milvus.flush([collection])
|
||||
logging.error("end flush")
|
||||
|
||||
p = threading.Thread(target=flush, args=())
|
||||
p.start()
|
||||
res = milvus.count_entities(collection)
|
||||
assert res == nb
|
||||
time.sleep(0.2)
|
||||
logging.error("start count")
|
||||
res = connect.count_entities(collection, timeout = 10)
|
||||
p.join()
|
||||
res = milvus.count_entities(collection)
|
||||
assert res == nb
|
||||
logging.getLogger().info(res)
|
||||
res = connect.count_entities(collection)
|
||||
assert res == 0
|
||||
|
||||
|
||||
|
|
|
@ -551,22 +551,19 @@ class TestIndexBinary:
|
|||
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
|
||||
# TODO:
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def _test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq):
|
||||
def test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq):
|
||||
'''
|
||||
target: test create index interface, search with more query vectors
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
nq = get_nq
|
||||
pdb.set_trace()
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
|
||||
search_param = get_search_param(binary_collection["index_type"])
|
||||
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
|
||||
res = connect.search(binary_collection, query, search_params=search_param)
|
||||
logging.getLogger().info(res)
|
||||
assert len(res) == nq
|
||||
|
||||
"""
|
||||
|
@ -581,15 +578,18 @@ class TestIndexBinary:
|
|||
method: create collection and add entities in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
if get_jaccard_index["index_type"] == "BIN_FLAT":
|
||||
pytest.skip("GetCollectionStats skip BIN_FLAT")
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
logging.getLogger().info(stats)
|
||||
# TODO
|
||||
# assert stats['partitions'][0]['segments'][0]['index_name'] == get_jaccard_index['index_type']
|
||||
assert stats["row_count"] == nb
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
if segments:
|
||||
for segment in segments:
|
||||
for file in segment["files"]:
|
||||
if "index_type" in file:
|
||||
assert file["index_type"] == get_jaccard_index["index_type"]
|
||||
|
||||
def test_get_index_info_partition(self, connect, binary_collection, get_jaccard_index):
|
||||
'''
|
||||
|
@ -597,16 +597,21 @@ class TestIndexBinary:
|
|||
method: create collection, create partition and add entities in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
if get_jaccard_index["index_type"] == "BIN_FLAT":
|
||||
pytest.skip("GetCollectionStats skip BIN_FLAT")
|
||||
connect.create_partition(binary_collection, tag)
|
||||
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
logging.getLogger().info(stats)
|
||||
# TODO
|
||||
# assert stats['partitions'][1]['segments'][0]['index_name'] == get_jaccard_index['index_type']
|
||||
assert stats["row_count"] == nb
|
||||
assert len(stats["partitions"]) == 2
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
if segments:
|
||||
for segment in segments:
|
||||
for file in segment["files"]:
|
||||
if "index_type" in file:
|
||||
assert file["index_type"] == get_jaccard_index["index_type"]
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
|
@ -639,65 +644,18 @@ class TestIndexBinary:
|
|||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
logging.getLogger().info(stats)
|
||||
connect.drop_index(binary_collection, binary_field_name)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
logging.getLogger().info(stats)
|
||||
# TODO
|
||||
# assert stats["partitions"][1]["segments"][0]["index_name"] == default_index_type
|
||||
|
||||
|
||||
class TestIndexMultiCollections(object):
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def _test_create_index_multithread_multicollection(self, connect, args):
|
||||
'''
|
||||
target: test create index interface with multiprocess
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
threads_num = 8
|
||||
loop_num = 8
|
||||
threads = []
|
||||
collection = []
|
||||
j = 0
|
||||
while j < (threads_num * loop_num):
|
||||
collection_name = gen_unique_str("test_create_index_multiprocessing")
|
||||
collection.append(collection_name)
|
||||
param = {'collection_name': collection_name,
|
||||
'dimension': dim,
|
||||
'index_type': IndexType.FLAT,
|
||||
'store_raw_vector': False}
|
||||
connect.create_collection(param)
|
||||
j = j + 1
|
||||
|
||||
def create_index():
|
||||
i = 0
|
||||
while i < loop_num:
|
||||
# assert connect.has_collection(collection[ids*process_num+i])
|
||||
ids = connect.insert(collection[ids * threads_num + i], vectors)
|
||||
connect.create_index(collection[ids * threads_num + i], IndexType.IVFLAT, {"nlist": NLIST, "metric_type": "L2"})
|
||||
assert status.OK()
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 1
|
||||
search_param = {"nprobe": nprobe}
|
||||
status, result = connect.search(collection[ids * threads_num + i], top_k, query_vec,
|
||||
params=search_param)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == top_k
|
||||
assert result[0][0].distance == 0.0
|
||||
i = i + 1
|
||||
|
||||
for i in range(threads_num):
|
||||
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
|
||||
ids = i
|
||||
t = threading.Thread(target=create_index, args=(m, ids))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
assert stats["row_count"] == nb
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
if segments:
|
||||
for segment in segments:
|
||||
for file in segment["files"]:
|
||||
if "index_type" not in file:
|
||||
continue
|
||||
if file["index_type"] == get_jaccard_index["index_type"]:
|
||||
assert False
|
||||
|
||||
|
||||
class TestIndexInvalid(object):
|
||||
|
|
|
@ -21,6 +21,7 @@ entity = gen_entities(1)
|
|||
entities = gen_entities(nb)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
|
||||
class TestCreateBase:
|
||||
|
@ -38,21 +39,33 @@ class TestCreateBase:
|
|||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
|
||||
@pytest.mark.level(3)
|
||||
def _test_create_partition_limit(self, connect, collection, args):
|
||||
@pytest.mark.level(2)
|
||||
def test_create_partition_limit(self, connect, collection, args):
|
||||
'''
|
||||
target: test create partitions, check status returned
|
||||
method: call function: create_partition for 4097 times
|
||||
expected: status not ok
|
||||
expected: exception raised
|
||||
'''
|
||||
threads_num = 16
|
||||
threads = []
|
||||
if args["handler"] == "HTTP":
|
||||
pytest.skip("skip in http mode")
|
||||
|
||||
for i in range(4096):
|
||||
tag_tmp = gen_unique_str()
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
def create(connect, threads_num):
|
||||
for i in range(4096 // threads_num):
|
||||
tag_tmp = gen_unique_str()
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
|
||||
for i in range(threads_num):
|
||||
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
|
||||
t = threading.Thread(target=create, args=(m, threads_num, ))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
tag_tmp = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
|
||||
def test_create_partition_repeat(self, connect, collection):
|
||||
'''
|
||||
|
@ -147,7 +160,8 @@ class TestCreateBase:
|
|||
res = connect.count_entities(id_collection)
|
||||
assert res == nb * 2
|
||||
|
||||
def _test_create_partition_insert_same_tags_two_collections(self, connect, collection):
|
||||
@pytest.mark.level(2)
|
||||
def test_create_partition_insert_same_tags_two_collections(self, connect, collection):
|
||||
'''
|
||||
target: test create two partitions, and insert vectors with the same tag to each collection, check status returned
|
||||
method: call function: create_partition
|
||||
|
@ -156,16 +170,13 @@ class TestCreateBase:
|
|||
connect.create_partition(collection, tag)
|
||||
collection_new = gen_unique_str()
|
||||
connect.create_collection(collection_new, default_fields)
|
||||
connect.create_collection(param)
|
||||
connect.create_partition(collection_new, tag)
|
||||
ids = [i for i in range(nb)]
|
||||
status, ids = connect.insert(collection, entities, ids, partition_tag=tag)
|
||||
ids = [(i+nb) for i in range(nq)]
|
||||
status, ids = connect.insert(collection_new, entities, ids, partition_tag=tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids = connect.insert(collection_new, entities, partition_tag=tag)
|
||||
connect.flush([collection, collection_new])
|
||||
status, res = connect.count_entities(collection)
|
||||
res = connect.count_entities(collection)
|
||||
assert res == nb
|
||||
status, res = connect.count_entities(collection_new)
|
||||
res = connect.count_entities(collection_new)
|
||||
assert res == nb
|
||||
|
||||
|
||||
|
|
|
@ -768,8 +768,8 @@ def gen_binary_index():
|
|||
return index_params
|
||||
|
||||
|
||||
def get_search_param(index_type):
|
||||
search_params = {"metric_type": "L2"}
|
||||
def get_search_param(index_type, metric_type="L2"):
|
||||
search_params = {"metric_type": metric_type}
|
||||
if index_type in ivf() or index_type in binary_support():
|
||||
search_params.update({"nprobe": 64})
|
||||
elif index_type == "HNSW":
|
||||
|
|
Loading…
Reference in New Issue