mirror of https://github.com/milvus-io/milvus.git
Update query cases for output vector field (#6633)
* fix check query result with vec fields Signed-off-by: ThreadDao <yufen.zong@zilliz.com> * update query cases for output vec field Signed-off-by: ThreadDao <yufen.zong@zilliz.com>pull/7458/head
parent
234954931f
commit
860ca4b40e
|
@ -202,8 +202,9 @@ class ResponseChecker:
|
|||
if len(check_items) == 0:
|
||||
raise Exception("No expect values found in the check task")
|
||||
exp_res = check_items.get("exp_res", None)
|
||||
with_vec = check_items.get("with_vec", False)
|
||||
if exp_res and isinstance(query_res, list):
|
||||
assert pc.equal_entities_list(exp=exp_res, actual=query_res)
|
||||
assert pc.equal_entities_list(exp=exp_res, actual=query_res, with_vec=with_vec)
|
||||
# assert len(exp_res) == len(query_res)
|
||||
# for i in range(len(exp_res)):
|
||||
# assert_entity_equal(exp=exp_res[i], actual=query_res[i])
|
||||
|
|
|
@ -151,10 +151,11 @@ def equal_entity(exp, actual):
|
|||
for field, value in exp.items():
|
||||
if isinstance(value, list):
|
||||
assert len(actual[field]) == len(exp[field])
|
||||
for i in range(len(exp[field])):
|
||||
for i in range(0, len(exp[field]), 2):
|
||||
assert abs(actual[field][i] - exp[field][i]) < ct.epsilon
|
||||
else:
|
||||
assert actual[field] == exp[field]
|
||||
return True
|
||||
|
||||
|
||||
def entity_in(entity, entities, primary_field=ct.default_int64_field_name):
|
||||
|
@ -173,7 +174,7 @@ def entity_in(entity, entities, primary_field=ct.default_int64_field_name):
|
|||
primary_keys.append(e[primary_field])
|
||||
if primary_key not in primary_keys:
|
||||
return False
|
||||
index = primary_key.index(primary_key)
|
||||
index = primary_keys.index(primary_key)
|
||||
return equal_entity(entities[index], entity)
|
||||
|
||||
|
||||
|
@ -196,9 +197,10 @@ def remove_entity(entity, entities, primary_field=ct.default_int64_field_name):
|
|||
return entities
|
||||
|
||||
|
||||
def equal_entities_list(exp, actual):
|
||||
def equal_entities_list(exp, actual, with_vec=False):
|
||||
"""
|
||||
compare two entities lists in inconsistent order
|
||||
:param with_vec: whether entities with vec field
|
||||
:param exp: exp entities list, list of dict
|
||||
:param actual: actual entities list, list of dict
|
||||
:return: True or False
|
||||
|
@ -209,14 +211,21 @@ def equal_entities_list(exp, actual):
|
|||
"""
|
||||
if len(exp) != len(actual):
|
||||
return False
|
||||
for a in actual:
|
||||
# if vec field returned in query res
|
||||
# if entity_in_entities(a, exp):
|
||||
if a in exp:
|
||||
try:
|
||||
exp.remove(a)
|
||||
# if vec field returned in query res
|
||||
# remove_entity(a, exp)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
||||
if with_vec:
|
||||
for a in actual:
|
||||
# if vec field returned in query res
|
||||
if entity_in(a, exp):
|
||||
try:
|
||||
# if vec field returned in query res
|
||||
remove_entity(a, exp)
|
||||
except Exception as ex:
|
||||
log.error(ex)
|
||||
else:
|
||||
for a in actual:
|
||||
if a in exp:
|
||||
try:
|
||||
exp.remove(a)
|
||||
except Exception as ex:
|
||||
log.error(ex)
|
||||
return True if len(exp) == 0 else False
|
||||
|
|
|
@ -13,6 +13,8 @@ from utils.util_log import test_log as log
|
|||
prefix = "query"
|
||||
exp_res = "exp_res"
|
||||
default_term_expr = f'{ct.default_int64_field_name} in [0, 1]'
|
||||
default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}}
|
||||
binary_index_params = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"nlist": 64}}
|
||||
|
||||
|
||||
class TestQueryBase(TestcaseBase):
|
||||
|
@ -52,20 +54,30 @@ class TestQueryBase(TestcaseBase):
|
|||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_auto_id_collection(self):
|
||||
"""
|
||||
target: test query on collection that primary field auto_id=True
|
||||
method: 1.create collection with auto_id=True 2.query on primary field
|
||||
expected: verify primary field values of query result
|
||||
target: test query with auto_id=True collection
|
||||
method: test query with auto id
|
||||
expected: query result is correct
|
||||
"""
|
||||
schema = cf.gen_default_collection_schema(auto_id=True)
|
||||
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema)
|
||||
self._connect()
|
||||
df = cf.gen_default_dataframe_data(ct.default_nb)
|
||||
df.drop(ct.default_int64_field_name, axis=1, inplace=True)
|
||||
mutation_res, _ = collection_w.insert(data=df)
|
||||
assert collection_w.num_entities == ct.default_nb
|
||||
collection_w.load()
|
||||
term_expr = f'{ct.default_int64_field_name} in [{mutation_res.primary_keys[0]}]'
|
||||
res, _ = collection_w.query(term_expr)
|
||||
assert res[0][ct.default_int64_field_name] == mutation_res.primary_keys[0]
|
||||
df[ct.default_int64_field_name] = None
|
||||
res, _, = self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df,
|
||||
primary_field=ct.default_int64_field_name, auto_id=True)
|
||||
assert self.collection_wrap.num_entities == ct.default_nb
|
||||
ids = res[1].primary_keys
|
||||
res = df.iloc[:2, :2].to_dict('records')
|
||||
self.collection_wrap.load()
|
||||
|
||||
# query with all primary keys
|
||||
term_expr_1 = f'{ct.default_int64_field_name} in {ids[:2]}'
|
||||
for i in range(2):
|
||||
res[i][ct.default_int64_field_name] = ids[i]
|
||||
self.collection_wrap.query(term_expr_1, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
||||
|
||||
# query with part primary keys
|
||||
term_expr_2 = f'{ct.default_int64_field_name} in {[ids[0], 0]}'
|
||||
self.collection_wrap.query(term_expr_2, check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: res[:1]})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_auto_id_not_existed_primary_key(self):
|
||||
|
@ -97,20 +109,20 @@ class TestQueryBase(TestcaseBase):
|
|||
collection_w.query(None, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("expr", [1, 2., [], {}, ()])
|
||||
def test_query_expr_non_string(self, expr):
|
||||
def test_query_expr_non_string(self):
|
||||
"""
|
||||
target: test query with non-string expr
|
||||
method: query with non-string expr, eg 1, [] ..
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
exprs = [1, 2., [], {}, ()]
|
||||
error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"}
|
||||
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
for expr in exprs:
|
||||
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("expr", ["12-s", "中文", "a", " "])
|
||||
def test_query_expr_invalid_string(self, expr):
|
||||
def test_query_expr_invalid_string(self):
|
||||
"""
|
||||
target: test query with invalid expr
|
||||
method: query with invalid string expr
|
||||
|
@ -118,10 +130,12 @@ class TestQueryBase(TestcaseBase):
|
|||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
error = {ct.err_code: 1, ct.err_msg: "Invalid expression!"}
|
||||
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
exprs = ["12-s", "中文", "a", " "]
|
||||
for expr in exprs:
|
||||
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_expr_term(self):
|
||||
def _test_query_expr_term(self):
|
||||
"""
|
||||
target: test query with TermExpr
|
||||
method: query with TermExpr
|
||||
|
@ -132,33 +146,30 @@ class TestQueryBase(TestcaseBase):
|
|||
collection_w.query(default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6259")
|
||||
def test_query_expr_not_existed_field(self):
|
||||
"""
|
||||
target: test query with not existed field
|
||||
method: query by term expr with fake field
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix))
|
||||
term_expr = 'field in [1, 2]'
|
||||
error = {ct.err_code: 1, ct.err_msg: "fieldName(field) not found"}
|
||||
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6259")
|
||||
def test_query_expr_unsupported_field(self):
|
||||
"""
|
||||
target: test query on unsupported field
|
||||
method: query on float field
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix))
|
||||
term_expr = f'{ct.default_float_field_name} in [1., 2.]'
|
||||
error = {ct.err_code: 1, ct.err_msg: "column is not int64"}
|
||||
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6259")
|
||||
def test_query_expr_non_primary_field(self):
|
||||
"""
|
||||
target: test query on non-primary field
|
||||
|
@ -177,7 +188,6 @@ class TestQueryBase(TestcaseBase):
|
|||
collection_w.query(default_term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6259")
|
||||
def test_query_expr_wrong_term_keyword(self):
|
||||
"""
|
||||
target: test query with wrong term expr keyword
|
||||
|
@ -198,19 +208,19 @@ class TestQueryBase(TestcaseBase):
|
|||
collection_w.query(expr_3, check_task=CheckTasks.err_res, check_items=error_3)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6259")
|
||||
@pytest.mark.parametrize("expr", [f'{ct.default_int64_field_name} in 1',
|
||||
f'{ct.default_int64_field_name} in "in"',
|
||||
f'{ct.default_int64_field_name} in (mn)'])
|
||||
def test_query_expr_non_array_term(self, expr):
|
||||
def test_query_expr_non_array_term(self):
|
||||
"""
|
||||
target: test query with non-array term expr
|
||||
method: query with non-array term expr
|
||||
expected: raise exception
|
||||
"""
|
||||
exprs = [f'{ct.default_int64_field_name} in 1',
|
||||
f'{ct.default_int64_field_name} in "in"',
|
||||
f'{ct.default_int64_field_name} in (mn)']
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
error = {ct.err_code: 1, ct.err_msg: "right operand of the InExpr must be array"}
|
||||
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
for expr in exprs:
|
||||
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_expr_empty_term_array(self):
|
||||
|
@ -225,47 +235,33 @@ class TestQueryBase(TestcaseBase):
|
|||
assert len(res) == 0
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6259")
|
||||
def test_query_expr_inconstant_term_array(self):
|
||||
def test_query_expr_inconsistent_mix_term_array(self):
|
||||
"""
|
||||
target: test query with term expr that field and array are inconsistent
|
||||
method: query with int field and float values
|
||||
target: test query with term expr that field and array are inconsistent or mix type
|
||||
method: 1.query with int field and float values
|
||||
2.query with term expr that has int and float type value
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
int_values = [1., 2.]
|
||||
term_expr = f'{ct.default_int64_field_name} in {int_values}'
|
||||
collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix))
|
||||
int_values = [[1., 2.], [1, 2.]]
|
||||
error = {ct.err_code: 1, ct.err_msg: "type mismatch"}
|
||||
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
for values in int_values:
|
||||
term_expr = f'{ct.default_int64_field_name} in {values}'
|
||||
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6259")
|
||||
def test_query_expr_mix_term_array(self):
|
||||
"""
|
||||
target: test query with mix type value expr
|
||||
method: query with term expr that has int and float type value
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
int_values = [1, 2.]
|
||||
term_expr = f'{ct.default_int64_field_name} in {int_values}'
|
||||
error = {ct.err_code: 1, ct.err_msg: "type mismatch"}
|
||||
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6259")
|
||||
@pytest.mark.parametrize("constant", [[1], (), {}])
|
||||
def test_query_expr_non_constant_array_term(self, constant):
|
||||
def test_query_expr_non_constant_array_term(self):
|
||||
"""
|
||||
target: test query with non-constant array term expr
|
||||
method: query with non-constant array expr
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
term_expr = f'{ct.default_int64_field_name} in [{constant}]'
|
||||
log.debug(term_expr)
|
||||
constants = [[1], (), {}]
|
||||
error = {ct.err_code: 1, ct.err_msg: "unsupported leaf node"}
|
||||
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
for constant in constants:
|
||||
term_expr = f'{ct.default_int64_field_name} in [{constant}]'
|
||||
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_output_field_none(self):
|
||||
|
@ -287,8 +283,8 @@ class TestQueryBase(TestcaseBase):
|
|||
expected: return one field
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name])
|
||||
assert set(res[0].keys()) == set([ct.default_int64_field_name])
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_float_field_name])
|
||||
assert set(res[0].keys()) == set([ct.default_int64_field_name, ct.default_float_field_name])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_output_all_fields(self):
|
||||
|
@ -297,27 +293,52 @@ class TestQueryBase(TestcaseBase):
|
|||
method: query with output field=None
|
||||
expected: return all fields
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
fields = [ct.default_int64_field_name, ct.default_float_field_name]
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=fields)
|
||||
assert set(res[0].keys()) == set(fields)
|
||||
res_1, _ = collection_w.query(default_term_expr, output_fields=[ct.default_float_field_name])
|
||||
assert set(res_1[0].keys()) == set(fields)
|
||||
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
|
||||
df = cf.gen_default_dataframe_data()
|
||||
collection_w.insert(df)
|
||||
assert collection_w.num_entities == ct.default_nb
|
||||
all_fields = [ct.default_int64_field_name, ct.default_float_field_name, ct.default_float_vec_field_name]
|
||||
res = df.iloc[:2].to_dict('records')
|
||||
collection_w.load()
|
||||
actual_res, _ = collection_w.query(default_term_expr, output_fields=all_fields,
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: res, "with_vec": True})
|
||||
assert set(actual_res[0].keys()) == set(all_fields)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6299")
|
||||
def test_query_output_vec_field(self):
|
||||
"""
|
||||
target: test query with vec output field
|
||||
method: specify vec field as output field
|
||||
expected: raise exception
|
||||
expected: return primary field and vec field
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
|
||||
df = cf.gen_default_dataframe_data()
|
||||
collection_w.insert(df)
|
||||
assert collection_w.num_entities == ct.default_nb
|
||||
fields = [[ct.default_float_vec_field_name], [ct.default_int64_field_name, ct.default_float_vec_field_name]]
|
||||
error = {ct.err_code: 1, ct.err_msg: "Query does not support vector field currently"}
|
||||
res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records')
|
||||
collection_w.load()
|
||||
for output_fields in fields:
|
||||
collection_w.query(default_term_expr, output_fields=output_fields,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: res, "with_vec": True})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6594")
|
||||
# todo
|
||||
def test_query_output_binary_vec_field(self):
|
||||
"""
|
||||
target: test query with binary vec output field
|
||||
method: specify binary vec field as output field
|
||||
expected: return primary field and binary vec field
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True)[0:2]
|
||||
log.debug(collection_w.schema)
|
||||
fields = [[ct.default_binary_vec_field_name], [ct.default_int64_field_name, ct.default_binary_vec_field_name]]
|
||||
for output_fields in fields:
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=output_fields)
|
||||
assert list(res[0].keys()) == fields[-1]
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_output_primary_field(self):
|
||||
|
@ -331,9 +352,7 @@ class TestQueryBase(TestcaseBase):
|
|||
assert list(res[0].keys()) == [ct.default_int64_field_name]
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("output_fields", [["int"],
|
||||
[ct.default_int64_field_name, "int"]])
|
||||
def test_query_output_not_existed_field(self, output_fields):
|
||||
def test_query_output_not_existed_field(self):
|
||||
"""
|
||||
target: test query output not existed field
|
||||
method: query with not existed output field
|
||||
|
@ -341,7 +360,10 @@ class TestQueryBase(TestcaseBase):
|
|||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
error = {ct.err_code: 1, ct.err_msg: 'Field int not exist'}
|
||||
collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.err_res, check_items=error)
|
||||
output_fields = [["int"], [ct.default_int64_field_name, "int"]]
|
||||
for fields in output_fields:
|
||||
collection_w.query(default_term_expr, output_fields=fields, check_task=CheckTasks.err_res,
|
||||
check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_empty_output_fields(self):
|
||||
|
@ -355,18 +377,20 @@ class TestQueryBase(TestcaseBase):
|
|||
fields = [ct.default_int64_field_name, ct.default_float_field_name]
|
||||
assert list(query_res[0].keys()) == fields
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.xfail(reason="exception not MilvusException")
|
||||
@pytest.mark.parametrize("output_fields", ["12-s", 1, [1, "2", 3], (1,), {1: 1}])
|
||||
def test_query_invalid_output_fields(self, output_fields):
|
||||
def test_query_invalid_output_fields(self):
|
||||
"""
|
||||
target: test query with invalid output fields
|
||||
method: query with invalid field fields
|
||||
expected: raise exception
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
output_fields = ["12-s", 1, [1, "2", 3], (1,), {1: 1}]
|
||||
error = {ct.err_code: 0, ct.err_msg: f'Invalid query format. \'output_fields\' must be a list'}
|
||||
collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.err_res, check_items=error)
|
||||
for fields in output_fields:
|
||||
collection_w.query(default_term_expr, output_fields=fields, check_task=CheckTasks.err_res,
|
||||
check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_query_partition(self):
|
||||
|
@ -568,7 +592,7 @@ class TestQueryOperation(TestcaseBase):
|
|||
"""
|
||||
target: test query with repeated term array on primary field with unique value
|
||||
method: query with repeated array value
|
||||
expected: verify query result
|
||||
expected: todo
|
||||
"""
|
||||
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3]
|
||||
int_values = [0, 0, 0, 0]
|
||||
|
@ -577,6 +601,25 @@ class TestQueryOperation(TestcaseBase):
|
|||
assert len(res) == 1
|
||||
assert res[0][ct.default_int64_field_name] == int_values[0]
|
||||
|
||||
@pytest.mark.tags(ct.CaseLabel.L1)
|
||||
@pytest.mark.xfail(reason="issue #6624")
|
||||
def test_query_dup_ids_dup_term_array(self):
|
||||
"""
|
||||
target: test query on duplicate primary keys with dup term array
|
||||
method: 1.create collection and insert dup primary keys
|
||||
2.query with dup term array
|
||||
expected: todo
|
||||
"""
|
||||
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
|
||||
df = cf.gen_default_dataframe_data(nb=ct.default_nb)
|
||||
df[ct.default_int64_field_name] = 0
|
||||
mutation_res, _ = collection_w.insert(df)
|
||||
assert mutation_res.primary_keys == df[ct.default_int64_field_name].tolist()
|
||||
collection_w.load()
|
||||
term_expr = f'{ct.default_int64_field_name} in {[0, 0, 0]}'
|
||||
res, _ = collection_w.query(term_expr)
|
||||
log.debug(res)
|
||||
|
||||
@pytest.mark.tags(ct.CaseLabel.L0)
|
||||
def test_query_after_index(self):
|
||||
"""
|
||||
|
@ -587,9 +630,7 @@ class TestQueryOperation(TestcaseBase):
|
|||
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3]
|
||||
|
||||
default_field_name = ct.default_float_vec_field_name
|
||||
default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}}
|
||||
index_name = ct.default_index_name
|
||||
collection_w.create_index(default_field_name, default_index_params, index_name=index_name)
|
||||
collection_w.create_index(default_field_name, default_index_params)
|
||||
|
||||
collection_w.load()
|
||||
|
||||
|
@ -625,6 +666,42 @@ class TestQueryOperation(TestcaseBase):
|
|||
check_vec = vectors[0].iloc[:, [0, 1]][0:2].to_dict('records')
|
||||
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
|
||||
|
||||
@pytest.mark.tags(ct.CaseLabel.L1)
|
||||
def test_query_output_vec_field_after_index(self):
|
||||
"""
|
||||
target: test query output vec field after index
|
||||
method: create index and specify vec field as output field
|
||||
expected: return primary field and vec field
|
||||
"""
|
||||
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
|
||||
df = cf.gen_default_dataframe_data(nb=5000)
|
||||
collection_w.insert(df)
|
||||
assert collection_w.num_entities == 5000
|
||||
fields = [ct.default_int64_field_name, ct.default_float_vec_field_name]
|
||||
collection_w.create_index(ct.default_float_vec_field_name, default_index_params)
|
||||
assert collection_w.has_index()[0]
|
||||
res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records')
|
||||
collection_w.load()
|
||||
collection_w.query(default_term_expr, output_fields=fields,
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: res, "with_vec": True})
|
||||
|
||||
@pytest.mark.tags(ct.CaseLabel.L2)
|
||||
@pytest.mark.xfail(reason="issue #6594")
|
||||
# todo
|
||||
def test_query_output_binary_vec_field_after_index(self):
|
||||
"""
|
||||
target: test query output vec field after index
|
||||
method: create index and specify vec field as output field
|
||||
expected: return primary field and vec field
|
||||
"""
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True)[0:2]
|
||||
fields = [ct.default_int64_field_name, ct.default_binary_vec_field_name]
|
||||
collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params)
|
||||
assert collection_w.has_index()[0]
|
||||
res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name])
|
||||
assert list(res[0].keys()) == fields
|
||||
|
||||
@pytest.mark.tags(ct.CaseLabel.L2)
|
||||
def test_query_partition_repeatedly(self):
|
||||
"""
|
||||
|
@ -705,22 +782,3 @@ class TestQueryOperation(TestcaseBase):
|
|||
res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name])
|
||||
assert len(res) == 1
|
||||
assert res[0][ct.default_int64_field_name] == half
|
||||
|
||||
# def insert_entities_into_two_partitions_in_half(self, half):
|
||||
# """
|
||||
# insert default entities into two partitions(partition_w and _default) in half(int64 and float fields values)
|
||||
# :param half: half of nb
|
||||
# :return: collection wrap and partition wrap
|
||||
# """
|
||||
# conn = self._connect()
|
||||
# collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
|
||||
# partition_w = self.init_partition_wrap(collection_wrap=collection_w)
|
||||
# # insert [0, half) into partition_w
|
||||
# df_partition = cf.gen_default_dataframe_data(nb=half, start=0)
|
||||
# partition_w.insert(df_partition)
|
||||
# # insert [half, nb) into _default
|
||||
# df_default = cf.gen_default_dataframe_data(nb=half, start=half)
|
||||
# collection_w.insert(df_default)
|
||||
# conn.flush([collection_w.name])
|
||||
# collection_w.load()
|
||||
# return collection_w, partition_w, df_partition, df_default
|
||||
|
|
Loading…
Reference in New Issue