mirror of https://github.com/milvus-io/milvus.git
Add test cases of upsert and diskann (#22780)
Signed-off-by: nico <cheng.yuan@zilliz.com>pull/22834/head
parent
6fa217568e
commit
fb17b915d9
|
@ -1362,6 +1362,25 @@ class TestUpsertValid(TestcaseBase):
|
|||
collection_w.upsert([[" a", "b "], vectors])
|
||||
assert collection_w.num_entities == 4
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_upsert_binary_data(self):
|
||||
"""
|
||||
target: test upsert binary data
|
||||
method: 1. create a collection and insert data
|
||||
2. upsert data
|
||||
3. check the results
|
||||
expected: raise no exception
|
||||
"""
|
||||
nb = 500
|
||||
c_name = cf.gen_unique_str(pre_upsert)
|
||||
collection_w = self.init_collection_general(c_name, True, is_binary=True)[0]
|
||||
binary_vectors = cf.gen_binary_vectors(nb, ct.default_dim)[1]
|
||||
data = [[i for i in range(nb)], [np.float32(i) for i in range(nb)],
|
||||
[str(i) for i in range(nb)], binary_vectors]
|
||||
collection_w.upsert(data)
|
||||
res = collection_w.query("int64 >= 0", [ct.default_binary_vec_field_name])[0]
|
||||
assert binary_vectors[0] == res[0][ct. default_binary_vec_field_name]
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_upsert_same_with_inserted_data(self):
|
||||
"""
|
||||
|
@ -1494,12 +1513,12 @@ class TestUpsertValid(TestcaseBase):
|
|||
|
||||
# check the result
|
||||
exp = f"int64 >= 0 && int64 <= {upsert_nb}"
|
||||
res = collection_w.query(exp, output_fields=[default_float_name])[0]
|
||||
res = collection_w.query(exp, [default_float_name], consistency_level="Strong")[0]
|
||||
res = [res[i][default_float_name] for i in range(upsert_nb)]
|
||||
if not res == float_values1.to_list() or res == float_values2.to_list():
|
||||
if not (res == float_values1.to_list() or res == float_values2.to_list()):
|
||||
assert False
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_upsert_multiple_times(self):
|
||||
"""
|
||||
target: test upsert multiple times
|
||||
|
@ -1517,31 +1536,25 @@ class TestUpsertValid(TestcaseBase):
|
|||
assert collection_w.num_entities == upsert_nb*10 + ct.default_nb
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip(reason="Something wrong with get query segment info")
|
||||
def test_upsert_during_index(self):
|
||||
def test_upsert_pk_string_multiple_times(self):
|
||||
"""
|
||||
target: test upsert during index
|
||||
target: test upsert multiple times
|
||||
method: 1. create a collection and insert data
|
||||
2. upsert during creating index
|
||||
expected: get query segment info
|
||||
2. upsert repeatedly
|
||||
expected: not raise exception
|
||||
"""
|
||||
# initialize a collection
|
||||
upsert_nb = 1000
|
||||
c_name = cf.gen_unique_str(pre_upsert)
|
||||
collection_w = self.init_collection_wrap(name=c_name)
|
||||
cf.insert_data(collection_w)
|
||||
schema = cf.gen_string_pk_default_collection_schema()
|
||||
name = cf.gen_unique_str(pre_upsert)
|
||||
collection_w = self.init_collection_wrap(name, schema)
|
||||
collection_w.insert(cf.gen_default_list_data())
|
||||
|
||||
def create_index():
|
||||
collection_w.create_index(ct.default_float_vec_field_name, default_index_params)
|
||||
|
||||
t = threading.Thread(target=create_index, args=())
|
||||
t.start()
|
||||
data = cf.gen_default_data_for_upsert(upsert_nb)[0]
|
||||
collection_w.upsert(data=data)
|
||||
t.join()
|
||||
|
||||
res = self.utility_wrap.get_query_segment_info(collection_w.name)
|
||||
assert len(res) >= 1
|
||||
# upsert
|
||||
for i in range(10):
|
||||
data = cf.gen_default_list_data(upsert_nb, start=i * 500)
|
||||
collection_w.upsert(data)
|
||||
assert collection_w.num_entities == upsert_nb * 10 + ct.default_nb
|
||||
|
||||
|
||||
class TestUpsertInvalid(TestcaseBase):
|
||||
|
@ -1610,6 +1623,35 @@ class TestUpsertInvalid(TestcaseBase):
|
|||
"expected: ['int64', 'float', 'varchar', 'float_vector']"}
|
||||
collection_w.upsert(data=[data], check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("dim", [127, 129, 200])
|
||||
@pytest.mark.xfail(reason="issue #22777")
|
||||
def test_upsert_binary_dim_unmatch(self, dim):
|
||||
"""
|
||||
target: test upsert with unmatched vector dim
|
||||
method: 1. create a collection with default dim 128
|
||||
2. upsert with mismatched dim
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w = self.init_collection_general(pre_upsert, True, is_binary=True)[0]
|
||||
data = cf.gen_default_binary_dataframe_data(dim=dim)[0]
|
||||
error = {ct.err_code: 1, ct.err_msg: f"Collection field dim is 128, but entities field dim is {dim}"}
|
||||
collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("dim", [127, 129, 200])
|
||||
def test_upsert_dim_unmatch(self, dim):
|
||||
"""
|
||||
target: test upsert with unmatched vector dim
|
||||
method: 1. create a collection with default dim 128
|
||||
2. upsert with mismatched dim
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w = self.init_collection_general(pre_upsert, True)[0]
|
||||
data = cf.gen_default_data_for_upsert(dim=dim)[0]
|
||||
error = {ct.err_code: 1, ct.err_msg: f"Collection field dim is 128, but entities field dim is {dim}"}
|
||||
collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("partition_name", ct.get_invalid_strs)
|
||||
def test_upsert_partition_name_invalid(self, partition_name):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import threading
|
||||
import pytest
|
||||
import time
|
||||
|
||||
from base.partition_wrapper import ApiPartitionWrapper
|
||||
from base.client_base import TestcaseBase
|
||||
|
@ -539,6 +540,7 @@ class TestPartitionParams(TestcaseBase):
|
|||
assert partition_w.num_entities == (nums + nums)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="not stable")
|
||||
def test_partition_upsert(self):
|
||||
"""
|
||||
target: verify upsert entities multiple times
|
||||
|
@ -562,14 +564,16 @@ class TestPartitionParams(TestcaseBase):
|
|||
data, values = cf.gen_default_data_for_upsert(nb=upsert_nb, start=2000)
|
||||
partition_w.upsert(data)
|
||||
res = partition_w.query("int64 >= 2000 && int64 < 3000", [ct.default_float_field_name])[0]
|
||||
assert partition_w.num_entities == upsert_nb + ct.default_nb // 2
|
||||
time.sleep(5)
|
||||
assert partition_w.num_entities == ct.default_nb // 2
|
||||
assert [res[i][ct.default_float_field_name] for i in range(upsert_nb)] == values.to_list()
|
||||
|
||||
# upsert data
|
||||
data, values = cf.gen_default_data_for_upsert(nb=upsert_nb, start=ct.default_nb)
|
||||
partition_w.upsert(data)
|
||||
res = partition_w.query("int64 >= 3000 && int64 < 4000", [ct.default_float_field_name])[0]
|
||||
assert partition_w.num_entities == upsert_nb * 2 + ct.default_nb // 2
|
||||
time.sleep(5)
|
||||
assert partition_w.num_entities == upsert_nb + ct.default_nb // 2
|
||||
assert [res[i][ct.default_float_field_name] for i in range(upsert_nb)] == values.to_list()
|
||||
|
||||
|
||||
|
|
|
@ -4850,6 +4850,41 @@ class TestSearchDiskann(TestcaseBase):
|
|||
"_async": _async}
|
||||
)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("limit", [10, 100, 1000])
|
||||
def test_search_diskann_search_list_equal_to_limit(self, dim, auto_id, limit, _async):
|
||||
"""
|
||||
target: test search diskann index when search_list equal to limit
|
||||
method: 1.create collection , insert data, primary_field is int field
|
||||
2.create diskann index , then load
|
||||
3.search
|
||||
expected: search successfully
|
||||
"""
|
||||
# 1. initialize with data
|
||||
collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id,
|
||||
dim=dim, is_index=False)[0:4]
|
||||
|
||||
# 2. create index
|
||||
default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, default_index)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": "L2", "params": {"search_list": limit}}
|
||||
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
|
||||
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]
|
||||
collection_w.search(vectors[:default_nq], default_search_field,
|
||||
search_params, limit,
|
||||
default_search_exp,
|
||||
output_fields=output_fields,
|
||||
_async=_async,
|
||||
travel_timestamp=0,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": insert_ids,
|
||||
"limit": limit,
|
||||
"_async": _async}
|
||||
)
|
||||
|
||||
|
||||
class TestCollectionRangeSearch(TestcaseBase):
|
||||
""" Test case of range search interface """
|
||||
|
@ -4941,7 +4976,7 @@ class TestCollectionRangeSearch(TestcaseBase):
|
|||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit})[0]
|
||||
"limit": default_limit})
|
||||
# 4. range search with IP
|
||||
range_search_params = {"metric_type": "IP", "params": {"nprobe": 10, "range_filter": 1}}
|
||||
collection_w.search(vectors[:default_nq], default_search_field,
|
||||
|
@ -4950,7 +4985,7 @@ class TestCollectionRangeSearch(TestcaseBase):
|
|||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit})[0]
|
||||
"limit": default_limit})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_range_search_only_radius(self):
|
||||
|
@ -4972,7 +5007,7 @@ class TestCollectionRangeSearch(TestcaseBase):
|
|||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": [],
|
||||
"limit": 0})[0]
|
||||
"limit": 0})
|
||||
# 4. range search with IP
|
||||
range_search_params = {"metric_type": "IP", "params": {"nprobe": 10, "radius": 1}}
|
||||
collection_w.search(vectors[:default_nq], default_search_field,
|
||||
|
@ -4981,7 +5016,7 @@ class TestCollectionRangeSearch(TestcaseBase):
|
|||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": [],
|
||||
"limit": 0})[0]
|
||||
"limit": 0})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_range_search_radius_range_filter_not_in_params(self):
|
||||
|
@ -5003,7 +5038,7 @@ class TestCollectionRangeSearch(TestcaseBase):
|
|||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit})[0]
|
||||
"limit": default_limit})
|
||||
# 4. range search with IP
|
||||
range_search_params = {"metric_type": "IP", "params": {"nprobe": 10}, "radius": 1, "range_filter": 0}
|
||||
collection_w.search(vectors[:default_nq], default_search_field,
|
||||
|
@ -5012,7 +5047,7 @@ class TestCollectionRangeSearch(TestcaseBase):
|
|||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit})[0]
|
||||
"limit": default_limit})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("dup_times", [1, 2])
|
||||
|
|
Loading…
Reference in New Issue