mirror of https://github.com/milvus-io/milvus.git
Add test cases support for random primary keys (#25840)
Signed-off-by: binbin lv <binbin.lv@zilliz.com>pull/25972/head
parent
e24a8b3606
commit
4ba922876e
|
@ -228,7 +228,7 @@ class TestcaseBase(Base):
|
|||
partition_num=0, is_binary=False, is_all_data_type=False,
|
||||
auto_id=False, dim=ct.default_dim, is_index=True,
|
||||
primary_field=ct.default_int64_field_name, is_flush=True, name=None,
|
||||
enable_dynamic_field=False, with_json=True, **kwargs):
|
||||
enable_dynamic_field=False, with_json=True, random_primary_key=False, **kwargs):
|
||||
"""
|
||||
target: create specified collections
|
||||
method: 1. create collections (binary/non-binary, default/all data type, auto_id or not)
|
||||
|
@ -268,7 +268,8 @@ class TestcaseBase(Base):
|
|||
if insert_data:
|
||||
collection_w, vectors, binary_raw_vectors, insert_ids, time_stamp = \
|
||||
cf.insert_data(collection_w, nb, is_binary, is_all_data_type, auto_id=auto_id,
|
||||
dim=dim, enable_dynamic_field=enable_dynamic_field, with_json=with_json)
|
||||
dim=dim, enable_dynamic_field=enable_dynamic_field, with_json=with_json,
|
||||
random_primary_key=random_primary_key)
|
||||
if is_flush:
|
||||
assert collection_w.is_empty is False
|
||||
assert collection_w.num_entities == nb
|
||||
|
|
|
@ -242,5 +242,7 @@ def output_field_value_check(search_res, original):
|
|||
for order in range(0, len(entity[field]), 4):
|
||||
assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon
|
||||
else:
|
||||
assert original[field][_id] == entity[field]
|
||||
num = original[original[ct.default_int64_field_name] == _id].index.to_list()[0]
|
||||
assert original[field][num] == entity[field]
|
||||
|
||||
return True
|
||||
|
|
|
@ -299,8 +299,12 @@ def gen_binary_vectors(num, dim):
|
|||
return raw_vectors, binary_vectors
|
||||
|
||||
|
||||
def gen_default_dataframe_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True):
|
||||
int_values = pd.Series(data=[i for i in range(start, start + nb)])
|
||||
def gen_default_dataframe_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True,
|
||||
random_primary_key=False):
|
||||
if not random_primary_key:
|
||||
int_values = pd.Series(data=[i for i in range(start, start + nb)])
|
||||
else:
|
||||
int_values = pd.Series(data=random.sample(range(start, start + nb), nb))
|
||||
float_values = pd.Series(data=[np.float32(i) for i in range(start, start + nb)], dtype="float32")
|
||||
string_values = pd.Series(data=[str(i) for i in range(start, start + nb)], dtype="string")
|
||||
json_values = [{"number": i, "float": i*1.0} for i in range(start, start + nb)]
|
||||
|
@ -399,8 +403,11 @@ def gen_dataframe_multi_string_fields(string_fields, nb=ct.default_nb):
|
|||
return df
|
||||
|
||||
|
||||
def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True):
|
||||
int64_values = pd.Series(data=[i for i in range(start, start + nb)])
|
||||
def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, with_json=True, random_primary_key=False):
|
||||
if not random_primary_key:
|
||||
int64_values = pd.Series(data=[i for i in range(start, start + nb)])
|
||||
else:
|
||||
int64_values = pd.Series(data=random.sample(range(start, start + nb), nb))
|
||||
int32_values = pd.Series(data=[np.int32(i) for i in range(start, start + nb)], dtype="int32")
|
||||
int16_values = pd.Series(data=[np.int16(i) for i in range(start, start + nb)], dtype="int16")
|
||||
int8_values = pd.Series(data=[np.int8(i) for i in range(start, start + nb)], dtype="int8")
|
||||
|
@ -1001,7 +1008,8 @@ def gen_partitions(collection_w, partition_num=1):
|
|||
|
||||
|
||||
def insert_data(collection_w, nb=ct.default_nb, is_binary=False, is_all_data_type=False,
|
||||
auto_id=False, dim=ct.default_dim, insert_offset=0, enable_dynamic_field=False, with_json=True):
|
||||
auto_id=False, dim=ct.default_dim, insert_offset=0, enable_dynamic_field=False, with_json=True,
|
||||
random_primary_key=False):
|
||||
"""
|
||||
target: insert non-binary/binary data
|
||||
method: insert non-binary/binary data into partitions if any
|
||||
|
@ -1016,14 +1024,16 @@ def insert_data(collection_w, nb=ct.default_nb, is_binary=False, is_all_data_typ
|
|||
log.info(f"inserted {nb} data into collection {collection_w.name}")
|
||||
for i in range(num):
|
||||
log.debug("Dynamic field is enabled: %s" % enable_dynamic_field)
|
||||
default_data = gen_default_dataframe_data(nb // num, dim=dim, start=start, with_json=with_json)
|
||||
default_data = gen_default_dataframe_data(nb // num, dim=dim, start=start, with_json=with_json,
|
||||
random_primary_key=random_primary_key)
|
||||
if enable_dynamic_field:
|
||||
default_data = gen_default_rows_data(nb // num, dim=dim, start=start, with_json=with_json)
|
||||
if is_binary:
|
||||
default_data, binary_raw_data = gen_default_binary_dataframe_data(nb // num, dim=dim, start=start)
|
||||
binary_raw_vectors.extend(binary_raw_data)
|
||||
if is_all_data_type:
|
||||
default_data = gen_dataframe_all_data_type(nb // num, dim=dim, start=start, with_json=with_json)
|
||||
default_data = gen_dataframe_all_data_type(nb // num, dim=dim, start=start, with_json=with_json,
|
||||
random_primary_key=random_primary_key)
|
||||
if enable_dynamic_field:
|
||||
default_data = gen_default_rows_data_all_data_type(nb // num, dim=dim, start=start, with_json=with_json)
|
||||
if auto_id:
|
||||
|
|
|
@ -742,3 +742,4 @@ class TestBulkInsert(TestcaseBaseBulkInsert):
|
|||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": 1},
|
||||
)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
pd.set_option("expand_frame_repr", False)
|
||||
from pymilvus import DefaultConfig
|
||||
import threading
|
||||
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY
|
||||
|
@ -46,6 +47,10 @@ class TestQueryParams(TestcaseBase):
|
|||
def enable_dynamic_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=[True, False])
|
||||
def random_primary_key(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_query_invalid(self):
|
||||
"""
|
||||
|
@ -708,18 +713,17 @@ class TestQueryParams(TestcaseBase):
|
|||
assert set(res[0].keys()) == {ct.default_int64_field_name, ct.default_float_field_name}
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue 24637")
|
||||
def test_query_output_all_fields(self, enable_dynamic_field):
|
||||
def test_query_output_all_fields(self, enable_dynamic_field, random_primary_key):
|
||||
"""
|
||||
target: test query with none output field
|
||||
method: query with output field=None
|
||||
expected: return all fields
|
||||
"""
|
||||
# 1. initialize with data
|
||||
collection_w, df, _, insert_ids = self.init_collection_general(prefix, True, nb=10,
|
||||
is_all_data_type=True,
|
||||
enable_dynamic_field=
|
||||
enable_dynamic_field)[0:4]
|
||||
collection_w, df, _, insert_ids = \
|
||||
self.init_collection_general(prefix, True, nb=10, is_all_data_type=True,
|
||||
enable_dynamic_field=enable_dynamic_field,
|
||||
random_primary_key=random_primary_key)[0:4]
|
||||
all_fields = [ct.default_int64_field_name, ct.default_int32_field_name, ct.default_int16_field_name,
|
||||
ct.default_int8_field_name, ct.default_bool_field_name, ct.default_float_field_name,
|
||||
ct.default_double_field_name, ct.default_string_field_name, ct.default_json_field_name,
|
||||
|
@ -727,7 +731,10 @@ class TestQueryParams(TestcaseBase):
|
|||
if enable_dynamic_field:
|
||||
res = df[0][:2]
|
||||
else:
|
||||
res = df[0].iloc[:2].to_dict('records')
|
||||
res = []
|
||||
for id in range(2):
|
||||
num = df[0][df[0][ct.default_int64_field_name] == id].index.to_list()[0]
|
||||
res.append(df[0].iloc[num].to_dict())
|
||||
log.info(res)
|
||||
collection_w.load()
|
||||
actual_res, _ = collection_w.query(default_term_expr, output_fields=all_fields,
|
||||
|
|
|
@ -5,6 +5,9 @@ import numpy
|
|||
import threading
|
||||
import pytest
|
||||
import pandas as pd
|
||||
pd.set_option("expand_frame_repr", False)
|
||||
import decimal
|
||||
from decimal import Decimal, getcontext
|
||||
from time import sleep
|
||||
import heapq
|
||||
|
||||
|
@ -1242,6 +1245,10 @@ class TestCollectionSearch(TestcaseBase):
|
|||
def metric_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=[True, False])
|
||||
def random_primary_key(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
# The following are valid base cases
|
||||
|
@ -1385,6 +1392,35 @@ class TestCollectionSearch(TestcaseBase):
|
|||
# verify that top 1 hit is itself,so min distance is 0
|
||||
assert 1.0 - hits.distances[0] <= epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_search_random_primary_key(self, random_primary_key):
|
||||
"""
|
||||
target: test search for collection with random primary keys
|
||||
method: create connection, collection, insert and search
|
||||
expected: Search without errors and data consistency
|
||||
"""
|
||||
# 1. initialize collection with random primary key
|
||||
|
||||
collection_w, _vectors, _, insert_ids, time_stamp = \
|
||||
self.init_collection_general(prefix, True, 10, random_primary_key=random_primary_key)[0:5]
|
||||
# 2. search
|
||||
log.info("test_search_random_primary_key: searching collection %s" % collection_w.name)
|
||||
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
|
||||
collection_w.search(vectors[:default_nq], default_search_field,
|
||||
default_search_params, default_limit,
|
||||
default_search_exp,
|
||||
output_fields=[default_int64_field_name,
|
||||
default_float_field_name,
|
||||
default_json_field_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": insert_ids,
|
||||
"limit": 10,
|
||||
"original_entities": _vectors,
|
||||
"output_fields": [default_int64_field_name,
|
||||
default_float_field_name,
|
||||
default_json_field_name]})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("dup_times", [1, 2, 3])
|
||||
def test_search_with_dup_primary_key(self, dim, auto_id, _async, dup_times):
|
||||
|
|
Loading…
Reference in New Issue