Update testceses of query (#6154)

* [skip ci] Update query test cases

Signed-off-by: wangting0128 <ting.wang@zilliz.com>

* [skip ci] Conflict resolution

Signed-off-by: wangting0128 <ting.wang@zilliz.com>

* [skip ci] Update teardown

Signed-off-by: wangting0128 <ting.wang@zilliz.com>
pull/6121/head^2
紫晴 2021-06-28 15:54:11 +08:00 committed by GitHub
parent b87baa108a
commit 84319ed311
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 218 additions and 74 deletions

View File

@ -41,6 +41,7 @@ class Base:
utility_wrap = None
collection_schema_wrap = None
field_schema_wrap = None
collection_object_list = []
def setup_class(self):
log.info("[setup_class] Start setup class...")
@ -52,6 +53,7 @@ class Base:
log.info(("*" * 35) + " setup " + ("*" * 35))
self.connection_wrap = ApiConnectionsWrapper()
self.collection_wrap = ApiCollectionWrapper()
self.collection_object_list.append(self.collection_wrap)
self.partition_wrap = ApiPartitionWrapper()
self.index_wrap = ApiIndexWrapper()
self.utility_wrap = ApiUtilityWrapper()
@ -63,14 +65,20 @@ class Base:
try:
""" Drop collection before disconnect """
# if self.collection_wrap is not None and self.collection_wrap.collection is not None:
# self.collection_wrap.drop()
if self.collection_wrap is not None:
collection_list = self.utility_wrap.list_collections()[0]
for i in collection_list:
collection_wrap = ApiCollectionWrapper()
collection_wrap.init_collection(name=i)
collection_wrap.drop()
if self.connection_wrap.get_connection(alias=DefaultConfig.DEFAULT_USING)[0] is None:
self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, host=param_info.param_host,
port=param_info.param_port)
for collection_object in self.collection_object_list:
if collection_object is not None and collection_object.collection is not None:
collection_object.drop()
# if self.collection_wrap is not None:
# collection_list = self.utility_wrap.list_collections()[0]
# for i in collection_list:
# collection_wrap = ApiCollectionWrapper()
# collection_wrap.init_collection(name=i)
# collection_wrap.drop()
except Exception as e:
pass
@ -111,13 +119,14 @@ class TestcaseBase(Base):
Public methods that can be used to add cases.
"""
@pytest.fixture(scope="module", params=ct.get_invalid_strs)
def get_invalid_string(self, request):
yield request.param
@pytest.fixture(scope="module", params=cf.gen_simple_index())
def get_index_param(self, request):
yield request.param
# move to conftest.py
# @pytest.fixture(scope="module", params=ct.get_invalid_strs)
# def get_invalid_string(self, request):
# yield request.param
#
# @pytest.fixture(scope="module", params=cf.gen_simple_index())
# def get_index_param(self, request):
# yield request.param
def _connect(self):
""" Add an connection and create the connect """
@ -131,6 +140,7 @@ class TestcaseBase(Base):
if self.connection_wrap.get_connection(alias=DefaultConfig.DEFAULT_USING)[0] is None:
self._connect()
collection_w = ApiCollectionWrapper()
self.collection_object_list.append(collection_w)
collection_w.init_collection(name=name, schema=schema, check_task=check_task, check_items=check_items, **kwargs)
return collection_w
@ -176,7 +186,7 @@ class TestcaseBase(Base):
if insert_data:
collection_w, vectors, binary_raw_vectors = \
cf.insert_data(collection_w, nb, is_binary, is_all_data_type)
assert collection_w.is_empty == False
assert collection_w.is_empty is False
assert collection_w.num_entities == nb
collection_w.load()

View File

@ -61,7 +61,7 @@ class ResponseChecker:
assert len(error_dict) > 0
if isinstance(res, Error):
error_code = error_dict[ct.err_code]
assert res.code == error_code or error_dict[ct.err_msg] in res.message
assert res.code == error_code and error_dict[ct.err_msg] in res.message
else:
log.error("[CheckFunc] Response of API is not an error: %s" % str(res))
assert False
@ -190,3 +190,4 @@ class ResponseChecker:
# assert len(exp_res) == len(query_res)
# for i in range(len(exp_res)):
# assert_entity_equal(exp=exp_res[i], actual=query_res[i])

View File

@ -11,6 +11,16 @@ ErrorMessage = {ErrorCode.ErrorOk: "",
ErrorCode.Error: "is illegal"}
class ErrorMap:
def __init__(self, err_code, err_msg):
self.err_code = err_code
self.err_msg = err_msg
class ConnectionErrorMessage(ExceptionsMessage):
FailConnect = "Fail connecting to server on %s:%s. Timeout"
ConnectExist = "The connection named %s already creating, but passed parameters don't match the configured parameters"
class CollectionErrorMessage(ExceptionsMessage):
CollNotLoaded = "collection %s was not loaded into memory"

View File

@ -1,5 +1,8 @@
import pytest
import common.common_type as ct
import common.common_func as cf
def pytest_addoption(parser):
parser.addoption("--ip", action="store", default="localhost", help="service's ip")
@ -20,6 +23,8 @@ def pytest_addoption(parser):
parser.addoption('--clean_log', action='store_true', default=False, help="clean log before testing")
parser.addoption('--schema', action='store', default="schema", help="schema of test interface")
parser.addoption('--err_msg', action='store', default="err_msg", help="error message of test")
parser.addoption('--term_expr', action='store', default="term_expr", help="expr of query quest")
parser.addoption('--check_content', action='store', default="check_content", help="content of check")
@pytest.fixture
@ -110,3 +115,26 @@ def schema(request):
@pytest.fixture
def err_msg(request):
return request.config.getoption("--err_msg")
@pytest.fixture
def term_expr(request):
return request.config.getoption("--term_expr")
@pytest.fixture
def check_content(request):
return request.config.getoption("--check_content")
""" fixture func """
@pytest.fixture(params=ct.get_invalid_strs)
def get_invalid_string(request):
yield request.param
@pytest.fixture(params=cf.gen_simple_index())
def get_index_param(request):
yield request.param

View File

@ -1,5 +1,5 @@
[pytest]
addopts = --host 192.168.1.239 --html=/tmp/ci_logs/report.html --self-contained-html
addopts = --host 10.98.0.11 --html=/tmp/ci_logs/report.html --self-contained-html
# -;addopts = --host 172.28.255.155 --html=/tmp/report.html
# python3 -W ignore -m pytest

View File

@ -1,8 +1,10 @@
import pytest
import random
from pymilvus_orm.default_config import DefaultConfig
from base.client_base import TestcaseBase
from common.code_mapping import ConnectionErrorMessage as cem
from common.code_mapping import CollectionErrorMessage as clem
from common import common_func as cf
from common import common_type as ct
from common.common_type import CaseLabel, CheckTasks
@ -433,7 +435,6 @@ class TestQueryBase(TestcaseBase):
check_items=CheckTasks.err_res, check_task=error)
# @pytest.mark.skip(reason="waiting for debug")
class TestQueryOperation(TestcaseBase):
"""
******************************************************************
@ -463,106 +464,146 @@ class TestQueryOperation(TestcaseBase):
collection_w.query(default_term_expr, check_task=CheckTasks.err_res,
check_items={ct.err_code: 0, ct.err_msg: cem.ConnectFirst})
def test_query_without_loading(self):
@pytest.mark.tags(ct.CaseLabel.L3)
@pytest.mark.parametrize("collection_name, data",
[(cf.gen_unique_str(prefix), cf.gen_default_list_data(ct.default_nb))])
def test_query_without_loading(self, collection_name, data):
"""
target: test query without loading
method: no loading before query
expected: raise exception
"""
c_name = cf.gen_unique_str(prefix)
collection_w = self.init_collection_wrap(name=c_name)
data = cf.gen_default_list_data(ct.default_nb)
# init a collection with default connection
collection_w = self.init_collection_wrap(name=collection_name)
# insert data to collection
collection_w.insert(data=data)
conn, _ = self.connection_wrap.get_connection()
conn.flush([c_name])
# check number of entities and that method calls the flush interface
assert collection_w.num_entities == ct.default_nb
error = {ct.err_code: 1, ct.err_msg: "can not find collection"}
collection_w.query(default_term_expr, check_task=CheckTasks.err_res, check_items=error)
def test_query_expr_single_term_array(self):
# query without load
collection_w.query(default_term_expr, check_task=CheckTasks.err_res,
check_items={ct.err_code: 1, ct.err_msg: clem.CollNotLoaded % collection_name})
@pytest.mark.tags(ct.CaseLabel.L3)
@pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]'])
def test_query_expr_single_term_array(self, term_expr):
"""
target: test query with single array term expr
method: query with single array value
expected: query result is one entity
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
term_expr = f'{ct.default_int64_field_name} in [0]'
res, _ = collection_w.query(term_expr)
assert len(res) == 1
df = vectors[0]
assert res[0][ct.default_int64_field_name] == df[ct.default_int64_field_name].values.tolist()[0]
assert res[1][ct.default_float_field_name] == df[ct.default_float_field_name].values.tolist()[0]
assert res[2][ct.default_float_vec_field_name] == df[ct.default_float_vec_field_name].values.tolist()[0]
def test_query_binary_expr_single_term_array(self):
# init a collection and insert data
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)
# query the first row of data
check_vec = vectors[0].iloc[:, [0, 1]][0:1].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L3)
@pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]'])
def test_query_binary_expr_single_term_array(self, term_expr, check_content):
"""
target: test query with single array term expr
method: query with single array value
expected: query result is one entity
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True, is_binary=True)
term_expr = f'{ct.default_int64_field_name} in [0]'
res, _ = collection_w.query(term_expr)
assert len(res) == 1
int_values = vectors[0][ct.default_int64_field_name].values.tolist()
float_values = vectors[0][ct.default_float_field_name].values.tolist()
vec_values = vectors[0][ct.default_float_vec_field_name].values.tolist()
assert res[0][ct.default_int64_field_name] == int_values[0]
assert res[1][ct.default_float_field_name] == float_values[0]
assert res[2][ct.default_float_vec_field_name] == vec_values[0]
# init a collection and insert data
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True,
is_binary=True)
# query the first row of data
check_vec = vectors[0].iloc[:, [0, 1]][0:1].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_expr_all_term_array(self):
"""
target: test query with all array term expr
method: query with all array value
expected: verify query result
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
# init a collection and insert data
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)
# data preparation
int_values = vectors[0][ct.default_int64_field_name].values.tolist()
term_expr = f'{ct.default_int64_field_name} in {int_values}'
res, _ = collection_w.query(term_expr)
assert len(res) == ct.default_nb
for i in ct.default_nb:
assert res[i][ct.default_int64_field_name] == int_values[i]
check_vec = vectors[0].iloc[:, [0, 1]][0:len(int_values)].to_dict('records')
# query all array value
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_expr_half_term_array(self):
"""
target: test query with half array term expr
method: query with half array value
expected: verify query result
"""
half = ct.default_nb // 2
collection_w, partition_w, _, df_default = self.insert_entities_into_two_partitions_in_half(half)
collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half(half)
int_values = df_default[ct.default_int64_field_name].values.tolist()
float_values = df_default[ct.default_float_field_name].values.tolist()
vec_values = df_default[ct.default_float_vec_field_name].values.tolist()
term_expr = f'{ct.default_int64_field_name} in {int_values}'
res, _ = collection_w.query(term_expr)
assert len(res) == half
for i in half:
assert res[i][ct.default_int64_field_name] == int_values[i]
assert res[i][ct.default_float_field_name] == float_values[i]
assert res[i][ct.default_float_vec_field_name] == vec_values[i]
assert len(res) == len(int_values)
# half = ct.default_nb // 2
# collection_w, partition_w, _, df_default = self.insert_entities_into_two_partitions_in_half(half)
# int_values = df_default[ct.default_int64_field_name].values.tolist()
# float_values = df_default[ct.default_float_field_name].values.tolist()
# vec_values = df_default[ct.default_float_vec_field_name].values.tolist()
# term_expr = f'{ct.default_int64_field_name} in {int_values}'
# res, _ = collection_w.query(term_expr)
# assert len(res) == half
# for i in half:
# assert res[i][ct.default_int64_field_name] == int_values[i]
# assert res[i][ct.default_float_field_name] == float_values[i]
# assert res[i][ct.default_float_vec_field_name] == vec_values[i]
@pytest.mark.xfail(reason="fail")
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_expr_repeated_term_array(self):
"""
target: test query with repeated term array on primary field with unique value
method: query with repeated array value
expected: verify query result
"""
collection_w, vectors, _, = self.init_collection_general(prefix, insert_data=True)
int_values = [0, 0]
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)
int_values = [0, 0, 0, 0]
term_expr = f'{ct.default_int64_field_name} in {int_values}'
res, _ = collection_w.query(term_expr)
assert len(res) == 1
assert res[0][ct.default_int64_field_name] == int_values[0]
def test_query_after_index(self, get_simple_index):
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_after_index(self):
"""
target: test query after creating index
method: query after index
expected: query result is correct
"""
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)
default_field_name = ct.default_float_vec_field_name
default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}}
index_name = ct.default_index_name
collection_w.create_index(default_field_name, default_index_params, index_name=index_name)
collection_w.load()
int_values = [0]
term_expr = f'{ct.default_int64_field_name} in {int_values}'
check_vec = vectors[0].iloc[:, [0, 1]][0:len(int_values)].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
# entities, ids = init_data(connect, collection)
# assert len(ids) == ut.default_nb
# connect.create_index(collection, ut.default_float_vec_field_name, get_simple_index)
@ -570,12 +611,33 @@ class TestQueryOperation(TestcaseBase):
# res = connect.query(collection, default_term_expr)
# logging.getLogger().info(res)
@pytest.mark.xfail(reason='')
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_after_search(self):
"""
target: test query after search
method: query after search
expected: query result is correct
"""
limit = 1000
nb_old = 500
collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, True, nb_old)
# 2. search for original data after load
vectors_s = [[random.random() for _ in range(ct.default_dim)] for _ in range(ct.default_nq)]
collection_w.search(vectors_s[:ct.default_nq], ct.default_float_vec_field_name,
ct.default_search_params, limit, "int64 >= 0",
check_task=CheckTasks.check_search_results,
check_items={"nq": ct.default_nq, "limit": nb_old})
# check number of entities and that method calls the flush interface
assert collection_w.num_entities == nb_old
term_expr = f'{ct.default_int64_field_name} in {default_term_expr}'
check_vec = vectors[0].iloc[:, [0, 1]][0:len(default_term_expr)].to_dict('records')
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
# entities, ids = init_data(connect, collection)
# assert len(ids) == ut.default_nb
# top_k = 10
@ -588,23 +650,50 @@ class TestQueryOperation(TestcaseBase):
# query_res = connect.query(collection, default_term_expr)
# logging.getLogger().info(query_res)
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_partition_repeatedly(self):
"""
target: test query repeatedly on partition
method: query on partition twice
expected: verify query result
"""
conn = self._connect()
# create connection
self._connect()
# init collection
collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
# init partition
partition_w = self.init_partition_wrap(collection_wrap=collection_w)
# insert data to partition
df = cf.gen_default_dataframe_data(ct.default_nb)
partition_w.insert(df)
conn.flush([collection_w.name])
# check number of entities and that method calls the flush interface
assert collection_w.num_entities == ct.default_nb
# load partition
partition_w.load()
# query twice
res_one, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name])
res_two, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name])
assert res_one == res_two
# conn = self._connect()
# collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix))
# partition_w = self.init_partition_wrap(collection_wrap=collection_w)
# df = cf.gen_default_dataframe_data(ct.default_nb)
# partition_w.insert(df)
# conn.flush([collection_w.name])
# partition_w.load()
# res_one, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name])
# res_two, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name])
# assert res_one == res_two
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_another_partition(self):
"""
target: test query another partition
@ -614,11 +703,13 @@ class TestQueryOperation(TestcaseBase):
"""
half = ct.default_nb // 2
collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half)
term_expr = f'{ct.default_int64_field_name} in [{half}]'
# half entity in _default partition rather than partition_w
res, _ = collection_w.query(term_expr, partition_names=[partition_w.name])
assert len(res) == 0
collection_w.query(term_expr, partition_names=[partition_w.name], check_task=CheckTasks.check_query_results,
check_items={exp_res: []})
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_multi_partitions_multi_results(self):
"""
target: test query on multi partitions and get multi results
@ -628,11 +719,13 @@ class TestQueryOperation(TestcaseBase):
"""
half = ct.default_nb // 2
collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half)
term_expr = f'{ct.default_int64_field_name} in [{half - 1}, {half}]'
# half entity in _default, half-1 entity in partition_w
res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name])
assert len(res) == 2
@pytest.mark.tags(ct.CaseLabel.L3)
def test_query_multi_partitions_single_result(self):
"""
target: test query on multi partitions and get single result
@ -641,7 +734,8 @@ class TestQueryOperation(TestcaseBase):
expected: query from two partitions and get single result
"""
half = ct.default_nb // 2
collection_w, partition_w = self.insert_entities_into_two_partitions_in_half(half)
collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half(half)
term_expr = f'{ct.default_int64_field_name} in [{half}]'
# half entity in _default
res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name])

View File

@ -8,7 +8,7 @@ class Error:
self.message = getattr(error, 'message', str(error))
log_row_length = 300
log_row_length = 3000
def api_request_catch():
@ -16,12 +16,13 @@ def api_request_catch():
def inner_wrapper(*args, **kwargs):
try:
res = func(*args, **kwargs)
# log.debug("(api_response) Response : %s " % str(res)[0:log_row_length])
log_res = str(res)[0:log_row_length] + '......' if len(str(res)) > log_row_length else str(res)
log.debug("(api_response) Response : %s " % log_res)
return res, True
except Exception as e:
log_e = str(e)[0:log_row_length] + '......' if len(str(e)) > log_row_length else str(e)
log.error(traceback.format_exc())
log.error("(api_response) [Milvus API Exception]%s: %s"
% (str(func), str(e)[0:log_row_length]))
log.error("(api_response) [Milvus API Exception]%s: %s" % (str(func), log_e))
return Error(e), False
return inner_wrapper
return wrapper
@ -36,8 +37,8 @@ def api_request(_list, **kwargs):
if len(_list) > 1:
for a in _list[1:]:
arg.append(a)
# log.debug("(api_request) Request: [%s] args: %s, kwargs: %s"
# % (str(func), str(arg)[0:log_row_length], str(kwargs)))
log_arg = str(arg)[0:log_row_length] + '......' if len(str(arg)) > log_row_length else str(arg)
log.debug("(api_request) Request: [%s] args: %s, kwargs: %s" % (str(func), log_arg, str(kwargs)))
return func(*arg, **kwargs)
return False, False