test: [2.5] add cases for async milvus client (#38699) (#38853)

pr: #38699
issue: #38697

Signed-off-by: ThreadDao <yufen.zong@zilliz.com>
pull/38872/head
ThreadDao 2024-12-31 11:30:52 +08:00 committed by GitHub
parent 71dea30d44
commit cf148f9bd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 795 additions and 17 deletions

View File

@ -0,0 +1,174 @@
import asyncio
import sys
from typing import Optional, List, Union, Dict
from pymilvus import (
AsyncMilvusClient,
AnnSearchRequest,
RRFRanker,
)
from pymilvus.orm.types import CONSISTENCY_STRONG
from pymilvus.orm.collection import CollectionSchema
from check.func_check import ResponseChecker
from utils.api_request import api_request, logger_interceptor
class AsyncMilvusClientWrapper:
async_milvus_client = None
def __init__(self, active_trace=False):
self.active_trace = active_trace
def init_async_client(self, uri: str = "http://localhost:19530",
user: str = "",
password: str = "",
db_name: str = "",
token: str = "",
timeout: Optional[float] = None,
active_trace=False,
check_task=None, check_items=None,
**kwargs):
self.active_trace = active_trace
""" In order to distinguish the same name of collection """
func_name = sys._getframe().f_code.co_name
res, is_succ = api_request([AsyncMilvusClient, uri, user, password, db_name, token,
timeout], **kwargs)
self.async_milvus_client = res if is_succ else None
check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ, **kwargs).run()
return res, check_result
@logger_interceptor()
async def create_collection(self,
collection_name: str,
dimension: Optional[int] = None,
primary_field_name: str = "id", # default is "id"
id_type: str = "int", # or "string",
vector_field_name: str = "vector", # default is "vector"
metric_type: str = "COSINE",
auto_id: bool = False,
timeout: Optional[float] = None,
schema: Optional[CollectionSchema] = None,
index_params=None,
**kwargs):
kwargs["consistency_level"] = kwargs.get("consistency_level", CONSISTENCY_STRONG)
return await self.async_milvus_client.create_collection(collection_name, dimension,
primary_field_name,
id_type, vector_field_name, metric_type,
auto_id,
timeout, schema, index_params, **kwargs)
@logger_interceptor()
async def drop_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
return await self.async_milvus_client.drop_collection(collection_name, timeout, **kwargs)
@logger_interceptor()
async def load_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
return await self.async_milvus_client.load_collection(collection_name, timeout, **kwargs)
@logger_interceptor()
async def create_index(self, collection_name: str, index_params, timeout: Optional[float] = None,
**kwargs):
return await self.async_milvus_client.create_index(collection_name, index_params, timeout, **kwargs)
@logger_interceptor()
async def insert(self,
collection_name: str,
data: Union[Dict, List[Dict]],
timeout: Optional[float] = None,
partition_name: Optional[str] = "",
**kwargs):
return await self.async_milvus_client.insert(collection_name, data, timeout, partition_name, **kwargs)
@logger_interceptor()
async def upsert(self,
collection_name: str,
data: Union[Dict, List[Dict]],
timeout: Optional[float] = None,
partition_name: Optional[str] = "",
**kwargs):
return await self.async_milvus_client.upsert(collection_name, data, timeout, partition_name, **kwargs)
@logger_interceptor()
async def search(self,
collection_name: str,
data: Union[List[list], list],
filter: str = "",
limit: int = 10,
output_fields: Optional[List[str]] = None,
search_params: Optional[dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
**kwargs):
return await self.async_milvus_client.search(collection_name, data,
filter,
limit, output_fields, search_params,
timeout,
partition_names, anns_field, **kwargs)
@logger_interceptor()
async def hybrid_search(self,
collection_name: str,
reqs: List[AnnSearchRequest],
ranker: RRFRanker,
limit: int = 10,
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.hybrid_search(collection_name, reqs,
ranker,
limit, output_fields,
timeout, partition_names, **kwargs)
@logger_interceptor()
async def query(self,
collection_name: str,
filter: str = "",
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
ids: Optional[Union[List, str, int]] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.query(collection_name, filter,
output_fields, timeout,
ids, partition_names,
**kwargs)
@logger_interceptor()
async def get(self,
collection_name: str,
ids: Union[list, str, int],
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.get(collection_name, ids,
output_fields, timeout,
partition_names,
**kwargs)
@logger_interceptor()
async def delete(self,
collection_name: str,
ids: Optional[Union[list, str, int]] = None,
timeout: Optional[float] = None,
filter: Optional[str] = None,
partition_name: Optional[str] = None,
**kwargs):
return await self.async_milvus_client.delete(collection_name, ids,
timeout, filter,
partition_name,
**kwargs)
@classmethod
def create_schema(cls, **kwargs):
kwargs["check_fields"] = False # do not check fields for now
return CollectionSchema([], **kwargs)
@logger_interceptor()
async def close(self, **kwargs):
return await self.async_milvus_client.close(**kwargs)

View File

@ -13,6 +13,7 @@ from base.index_wrapper import ApiIndexWrapper
from base.utility_wrapper import ApiUtilityWrapper
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
from base.high_level_api_wrapper import HighLevelApiWrapper
from base.async_milvus_client_wrapper import AsyncMilvusClientWrapper
from utils.util_log import test_log as log
from common import common_func as cf
from common import common_type as ct
@ -35,6 +36,7 @@ class Base:
collection_object_list = []
resource_group_list = []
high_level_api_wrap = None
async_milvus_client_wrap = None
skip_connection = False
def setup_class(self):
@ -59,6 +61,7 @@ class Base:
self.field_schema_wrap = ApiFieldSchemaWrapper()
self.database_wrap = ApiDatabaseWrapper()
self.high_level_api_wrap = HighLevelApiWrapper()
self.async_milvus_client_wrap = AsyncMilvusClientWrapper()
def teardown_method(self, method):
log.info(("*" * 35) + " teardown " + ("*" * 35))
@ -166,6 +169,16 @@ class TestcaseBase(Base):
log.info(f"server version: {server_version}")
return res
def init_async_milvus_client(self):
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
kwargs = {
"uri": uri,
"user": cf.param_info.param_user,
"password": cf.param_info.param_password,
"token": cf.param_info.param_token,
}
self.async_milvus_client_wrap.init_async_client(**kwargs)
def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,
enable_dynamic_field=False, with_json=True, **kwargs):
name = cf.gen_unique_str('coll_') if name is None else name

View File

@ -1,5 +1,7 @@
import sys
import time
from typing import Optional
import timeout_decorator
from numpy import NaN
@ -40,6 +42,13 @@ class HighLevelApiWrapper:
timeout=timeout, **kwargs).run()
return res, check_result
@trace()
def close(self, client, check_task=None, check_items=None):
func_name = sys._getframe().f_code.co_name
res, is_succ = api_request([client.close])
check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ).run()
return res, check_result
@trace()
def create_schema(self, client, timeout=None, check_task=None,
check_items=None, **kwargs):
@ -103,6 +112,17 @@ class HighLevelApiWrapper:
**kwargs).run()
return res, check_result
@trace()
def get_collection_stats(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.get_collection_stats, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, **kwargs).run()
return res, check_result
@trace()
def search(self, client, collection_name, data, limit=10, filter=None, output_fields=None, search_params=None,
timeout=None, check_task=None, check_items=None, **kwargs):
@ -315,6 +335,16 @@ class HighLevelApiWrapper:
**kwargs).run()
return res, check_result
@trace()
def create_database(self, client, db_name, properties: Optional[dict] = None, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.create_database, db_name, properties], **kwargs)
check_result = ResponseChecker(res, func_name, check_task,
check_items, check,
db_name=db_name, properties=properties,
**kwargs).run()
return res, check_result
@trace()
def create_partition(self, client, collection_name, partition_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout

View File

@ -7,10 +7,21 @@ from common import common_func as cf
from common.common_type import CheckTasks, Connect_Object_Name
# from common.code_mapping import ErrorCode, ErrorMessage
from pymilvus import Collection, Partition, ResourceGroupInfo
from utils.api_request import Error
import check.param_check as pc
class Error:
def __init__(self, error):
self.code = getattr(error, 'code', -1)
self.message = getattr(error, 'message', str(error))
def __str__(self):
return f"Error(code={self.code}, message={self.message})"
def __repr__(self):
return f"Error(code={self.code}, message={self.message})"
class ResponseChecker:
def __init__(self, response, func_name, check_task, check_items, is_succ=True, **kwargs):
self.response = response # response of api request

View File

@ -25,7 +25,7 @@ def pytest_addoption(parser):
parser.addoption("--user", action="store", default="", help="user name for connection")
parser.addoption("--password", action="store", default="", help="password for connection")
parser.addoption("--db_name", action="store", default="default", help="database name for connection")
parser.addoption("--secure", type=bool, action="store", default=False, help="secure for connection")
parser.addoption("--secure", action="store", default=False, help="secure for connection")
parser.addoption("--milvus_ns", action="store", default="chaos-testing", help="milvus_ns")
parser.addoption("--http_port", action="store", default=19121, help="http's port")
parser.addoption("--handler", action="store", default="GRPC", help="handler of request")
@ -45,7 +45,7 @@ def pytest_addoption(parser):
parser.addoption('--term_expr', action='store', default="term_expr", help="expr of query quest")
parser.addoption('--check_content', action='store', default="check_content", help="content of check")
parser.addoption('--field_name', action='store', default="field_name", help="field_name of index")
parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number")
parser.addoption('--replica_num', action='store', default=ct.default_replica_num, help="memory replica number")
parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip")
parser.addoption('--uri', action='store', default="", help="uri for high level api")
parser.addoption('--token', action='store', default="", help="token for high level api")

View File

@ -10,3 +10,6 @@ log_date_format = %Y-%m-%d %H:%M:%S
filterwarnings =
ignore::DeprecationWarning
asyncio_default_fixture_loop_scope = function

View File

@ -4,7 +4,8 @@ requests==2.26.0
scikit-learn==1.1.3
timeout_decorator==0.5.0
ujson==5.5.0
pytest==7.2.0
pytest==8.3.4
pytest-asyncio==0.24.0
pytest-assume==2.4.3
pytest-timeout==1.3.3
pytest-repeat==0.8.0

View File

@ -0,0 +1,509 @@
import random
import time
import pytest
import asyncio
from pymilvus.client.types import LoadState, DataType
from pymilvus import AnnSearchRequest, RRFRanker
from base.client_base import TestcaseBase
from common import common_func as cf
from common import common_type as ct
from common.common_type import CaseLabel, CheckTasks
from utils.util_log import test_log as log
pytestmark = pytest.mark.asyncio
prefix = "async"
async_default_nb = 5000
default_pk_name = "id"
default_vector_name = "vector"
class TestAsyncMilvusClient(TestcaseBase):
def teardown_method(self, method):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.async_milvus_client_wrap.close())
super().teardown_method(method)
@pytest.mark.tags(CaseLabel.L0)
async def test_async_client_default(self):
# init client
milvus_client = self._connect(enable_milvus_client_api=True)
self.init_async_milvus_client()
# create collection
c_name = cf.gen_unique_str(prefix)
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
assert c_name in collections
# insert entities
rows = [
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
for i in range(async_default_nb)]
start_time = time.time()
tasks = []
step = 1000
for i in range(0, async_default_nb, step):
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
tasks.append(task)
insert_res = await asyncio.gather(*tasks)
end_time = time.time()
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
for r in insert_res:
assert r[0]['insert_count'] == step
# dql tasks
tasks = []
# search default
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(default_search_task)
# search with filter & search_params
sp = {"metric_type": "COSINE", "params": {"ef": "96"}}
filter_params_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
filter=f"{default_pk_name} > 10",
search_params=sp,
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(filter_params_search_task)
# search output fields
output_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
output_fields=["*"],
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(output_search_task)
# query with filter and default output "*"
exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)]
filter_query_task = self.async_milvus_client_wrap.query(c_name,
filter=f"{default_pk_name} < {ct.default_limit}",
output_fields=[default_pk_name],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": exp_query_res,
"primary_field": default_pk_name})
tasks.append(filter_query_task)
# query with ids and output all fields
ids_query_task = self.async_milvus_client_wrap.query(c_name,
ids=[i for i in range(ct.default_limit)],
output_fields=["*"],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": rows[:ct.default_limit],
"with_vec": True,
"primary_field": default_pk_name})
tasks.append(ids_query_task)
# get with ids
get_task = self.async_milvus_client_wrap.get(c_name,
ids=[0, 1],
output_fields=[default_pk_name, default_vector_name],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": rows[:2], "with_vec": True,
"primary_field": default_pk_name})
tasks.append(get_task)
await asyncio.gather(*tasks)
@pytest.mark.tags(CaseLabel.L0)
async def test_async_client_partition(self):
# init client
milvus_client = self._connect(enable_milvus_client_api=True)
self.init_async_milvus_client()
# create collection & partition
c_name = cf.gen_unique_str(prefix)
p_name = cf.gen_unique_str("par")
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
assert c_name in collections
self.high_level_api_wrap.create_partition(milvus_client, c_name, p_name)
partitions, _ = self.high_level_api_wrap.list_partitions(milvus_client, c_name)
assert p_name in partitions
# insert entities
rows = [
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
for i in range(async_default_nb)]
start_time = time.time()
tasks = []
step = 1000
for i in range(0, async_default_nb, step):
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step], partition_name=p_name)
tasks.append(task)
insert_res = await asyncio.gather(*tasks)
end_time = time.time()
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
for r in insert_res:
assert r[0]['insert_count'] == step
# count from default partition
count_res, _ = await self.async_milvus_client_wrap.query(c_name, output_fields=["count(*)"], partition_names=[ct.default_partition_name])
assert count_res[0]["count(*)"] == 0
# dql tasks
tasks = []
# search default
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
partition_names=[p_name],
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(default_search_task)
# search with filter & search_params
sp = {"metric_type": "COSINE", "params": {"ef": "96"}}
filter_params_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
filter=f"{default_pk_name} > 10",
search_params=sp,
partition_names=[p_name],
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(filter_params_search_task)
# search output fields
output_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
output_fields=["*"],
partition_names=[p_name],
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(output_search_task)
# query with filter and default output "*"
exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)]
filter_query_task = self.async_milvus_client_wrap.query(c_name,
filter=f"{default_pk_name} < {ct.default_limit}",
output_fields=[default_pk_name],
partition_names=[p_name],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": exp_query_res,
"primary_field": default_pk_name})
tasks.append(filter_query_task)
# query with ids and output all fields
ids_query_task = self.async_milvus_client_wrap.query(c_name,
ids=[i for i in range(ct.default_limit)],
output_fields=["*"],
partition_names=[p_name],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": rows[:ct.default_limit],
"with_vec": True,
"primary_field": default_pk_name})
tasks.append(ids_query_task)
# get with ids
get_task = self.async_milvus_client_wrap.get(c_name,
ids=[0, 1], partition_names=[p_name],
output_fields=[default_pk_name, default_vector_name],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": rows[:2], "with_vec": True,
"primary_field": default_pk_name})
tasks.append(get_task)
await asyncio.gather(*tasks)
@pytest.mark.tags(CaseLabel.L0)
async def test_async_client_with_schema(self, schema):
# init client
milvus_client = self._connect(enable_milvus_client_api=True)
self.init_async_milvus_client()
# create collection
c_name = cf.gen_unique_str(prefix)
schema = self.async_milvus_client_wrap.create_schema(auto_id=False,
partition_key_field=ct.default_int64_field_name)
schema.add_field(ct.default_string_field_name, DataType.VARCHAR, max_length=100, is_primary=True)
schema.add_field(ct.default_int64_field_name, DataType.INT64, is_partition_key=True)
schema.add_field(ct.default_float_vec_field_name, DataType.FLOAT_VECTOR, dim=ct.default_dim)
schema.add_field(default_vector_name, DataType.FLOAT_VECTOR, dim=ct.default_dim)
await self.async_milvus_client_wrap.create_collection(c_name, schema=schema)
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
assert c_name in collections
# insert entities
rows = [
{ct.default_string_field_name: str(i),
ct.default_int64_field_name: i,
ct.default_float_vec_field_name: [random.random() for _ in range(ct.default_dim)],
default_vector_name: [random.random() for _ in range(ct.default_dim)],
} for i in range(async_default_nb)]
start_time = time.time()
tasks = []
step = 1000
for i in range(0, async_default_nb, step):
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
tasks.append(task)
insert_res = await asyncio.gather(*tasks)
end_time = time.time()
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
for r in insert_res:
assert r[0]['insert_count'] == step
# flush
self.high_level_api_wrap.flush(milvus_client, c_name)
stats, _ = self.high_level_api_wrap.get_collection_stats(milvus_client, c_name)
assert stats["row_count"] == async_default_nb
# create index -> load
index_params, _ = self.high_level_api_wrap.prepare_index_params(milvus_client,
field_name=ct.default_float_vec_field_name,
index_type="HNSW", metric_type="COSINE", M=30,
efConstruction=200)
index_params.add_index(field_name=default_vector_name, index_type="IVF_SQ8",
metric_type="L2", nlist=32)
await self.async_milvus_client_wrap.create_index(c_name, index_params)
await self.async_milvus_client_wrap.load_collection(c_name)
_index, _ = self.high_level_api_wrap.describe_index(milvus_client, c_name, default_vector_name)
assert _index["indexed_rows"] == async_default_nb
assert _index["state"] == "Finished"
_load, _ = self.high_level_api_wrap.get_load_state(milvus_client, c_name)
assert _load["state"] == LoadState.Loaded
# dql tasks
tasks = []
# search default
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
anns_field=ct.default_float_vec_field_name,
search_params={"metric_type": "COSINE",
"params": {"ef": "96"}},
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(default_search_task)
# hybrid_search
search_param = {
"data": cf.gen_vectors(ct.default_nq, ct.default_dim, vector_data_type="FLOAT_VECTOR"),
"anns_field": ct.default_float_vec_field_name,
"param": {"metric_type": "COSINE", "params": {"ef": "96"}},
"limit": ct.default_limit,
"expr": f"{ct.default_int64_field_name} > 10"}
req = AnnSearchRequest(**search_param)
search_param2 = {
"data": cf.gen_vectors(ct.default_nq, ct.default_dim, vector_data_type="FLOAT_VECTOR"),
"anns_field": default_vector_name,
"param": {"metric_type": "L2", "params": {"nprobe": "32"}},
"limit": ct.default_limit
}
req2 = AnnSearchRequest(**search_param2)
_output_fields = [ct.default_int64_field_name, ct.default_string_field_name]
filter_params_search_task = self.async_milvus_client_wrap.hybrid_search(c_name, [req, req2], RRFRanker(),
limit=5,
check_task=CheckTasks.check_search_results,
check_items={
"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": 5})
tasks.append(filter_params_search_task)
# get with ids
get_task = self.async_milvus_client_wrap.get(c_name, ids=['0', '1'], output_fields=[ct.default_int64_field_name,
ct.default_string_field_name])
tasks.append(get_task)
await asyncio.gather(*tasks)
@pytest.mark.tags(CaseLabel.L0)
async def test_async_client_dml(self):
# init client
milvus_client = self._connect(enable_milvus_client_api=True)
self.init_async_milvus_client()
# create collection
c_name = cf.gen_unique_str(prefix)
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
assert c_name in collections
# insert entities
rows = [
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
for i in range(ct.default_nb)]
start_time = time.time()
tasks = []
step = 1000
for i in range(0, ct.default_nb, step):
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
tasks.append(task)
insert_res = await asyncio.gather(*tasks)
end_time = time.time()
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
for r in insert_res:
assert r[0]['insert_count'] == step
# dml tasks
# query id -> upsert id -> query id -> delete id -> query id
_id = 10
get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id],
output_fields=[default_pk_name, default_vector_name])
assert len(get_res) == 1
# upsert
upsert_row = [{
default_pk_name: _id, default_vector_name: [random.random() for _ in range(ct.default_dim)]
}]
upsert_res, _ = await self.async_milvus_client_wrap.upsert(c_name, upsert_row)
assert upsert_res["upsert_count"] == 1
# get _id after upsert
get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id],
output_fields=[default_pk_name, default_vector_name])
for j in range(5):
assert abs(get_res[0][default_vector_name][j] - upsert_row[0][default_vector_name][j]) < ct.epsilon
# delete
del_res, _ = await self.async_milvus_client_wrap.delete(c_name, ids=[_id])
assert del_res["delete_count"] == 1
# query after delete
get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id],
output_fields=[default_pk_name, default_vector_name])
assert len(get_res) == 0
@pytest.mark.tags(CaseLabel.L2)
async def test_async_client_with_db(self):
# init client
milvus_client = self._connect(enable_milvus_client_api=True)
db_name = cf.gen_unique_str("db")
self.high_level_api_wrap.create_database(milvus_client, db_name)
self.high_level_api_wrap.close(milvus_client)
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
milvus_client, _ = self.connection_wrap.MilvusClient(uri=uri, db_name=db_name)
self.async_milvus_client_wrap.init_async_client(uri, db_name=db_name)
# create collection
c_name = cf.gen_unique_str(prefix)
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
assert c_name in collections
# insert entities
rows = [
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
for i in range(async_default_nb)]
start_time = time.time()
tasks = []
step = 1000
for i in range(0, async_default_nb, step):
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
tasks.append(task)
insert_res = await asyncio.gather(*tasks)
end_time = time.time()
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
for r in insert_res:
assert r[0]['insert_count'] == step
# dql tasks
tasks = []
# search default
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(default_search_task)
# query with filter and default output "*"
exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)]
filter_query_task = self.async_milvus_client_wrap.query(c_name,
filter=f"{default_pk_name} < {ct.default_limit}",
output_fields=[default_pk_name],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": exp_query_res,
"primary_field": default_pk_name})
tasks.append(filter_query_task)
# get with ids
get_task = self.async_milvus_client_wrap.get(c_name,
ids=[0, 1],
output_fields=[default_pk_name, default_vector_name],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": rows[:2], "with_vec": True,
"primary_field": default_pk_name})
tasks.append(get_task)
await asyncio.gather(*tasks)
@pytest.mark.tags(CaseLabel.L0)
async def test_async_client_close(self):
# init async client
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
self.async_milvus_client_wrap.init_async_client(uri)
# create collection
c_name = cf.gen_unique_str(prefix)
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
# close -> search raise error
await self.async_milvus_client_wrap.close()
vector = cf.gen_vectors(1, ct.default_dim)
error = {ct.err_code: 1, ct.err_msg: "should create connection first"}
await self.async_milvus_client_wrap.search(c_name, vector, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L3)
@pytest.mark.skip("connect with zilliz cloud")
async def test_async_client_with_token(self):
# init client
milvus_client = self._connect(enable_milvus_client_api=True)
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
token = cf.param_info.param_token
milvus_client, _ = self.connection_wrap.MilvusClient(uri=uri, token=token)
self.async_milvus_client_wrap.init_async_client(uri, token=token)
# create collection
c_name = cf.gen_unique_str(prefix)
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
assert c_name in collections
# insert entities
rows = [
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
for i in range(ct.default_nb)]
start_time = time.time()
tasks = []
step = 1000
for i in range(0, ct.default_nb, step):
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
tasks.append(task)
insert_res = await asyncio.gather(*tasks)
end_time = time.time()
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
for r in insert_res:
assert r[0]['insert_count'] == step
# dql tasks
tasks = []
# search default
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
check_task=CheckTasks.check_search_results,
check_items={"enable_milvus_client_api": True,
"nq": ct.default_nq,
"limit": ct.default_limit})
tasks.append(default_search_task)
# query with filter and default output "*"
exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)]
filter_query_task = self.async_milvus_client_wrap.query(c_name,
filter=f"{default_pk_name} < {ct.default_limit}",
output_fields=[default_pk_name],
check_task=CheckTasks.check_query_results,
check_items={"exp_res": exp_query_res,
"primary_field": default_pk_name})
tasks.append(filter_query_task)
await asyncio.gather(*tasks)

View File

@ -1,24 +1,14 @@
import sys
import traceback
import copy
import os
from check.func_check import ResponseChecker, Error
from utils.util_log import test_log as log
# enable_traceback = os.getenv('ENABLE_TRACEBACK', "True")
# log.info(f"enable_traceback:{enable_traceback}")
class Error:
def __init__(self, error):
self.code = getattr(error, 'code', -1)
self.message = getattr(error, 'message', str(error))
def __str__(self):
return f"Error(code={self.code}, message={self.message})"
def __repr__(self):
return f"Error(code={self.code}, message={self.message})"
log_row_length = 300
@ -62,3 +52,50 @@ def api_request(_list, **kwargs):
log.debug("(api_request) : [%s] args: %s, kwargs: %s" % (func.__qualname__, log_arg, str(kwargs)))
return func(*arg, **kwargs)
return False, False
def logger_interceptor():
def wrapper(func):
def log_request(*arg, **kwargs):
arg = arg[1:]
arg_str = str(arg)
log_arg = arg_str[0:log_row_length] + '......' if len(arg_str) > log_row_length else arg_str
if kwargs.get("enable_traceback", True):
log.debug("(api_request) : [%s] args: %s, kwargs: %s" % (func.__name__, log_arg, str(kwargs)))
def log_response(res, **kwargs):
if kwargs.get("enable_traceback", True):
res_str = str(res)
log_res = res_str[0:log_row_length] + '......' if len(res_str) > log_row_length else res_str
log.debug("(api_response) : [%s] %s " % (func.__name__, log_res))
return res, True
async def handler(*args, **kwargs):
_kwargs = copy.deepcopy(kwargs)
_kwargs.pop("enable_traceback", None)
check_task = kwargs.get("check_task", None)
check_items = kwargs.get("check_items", None)
try:
# log request
log_request(*args, **_kwargs)
# exec func
res = await func(*args, **_kwargs)
# log response
log_response(res, **_kwargs)
# check_response
check_res = ResponseChecker(res, sys._getframe().f_code.co_name, check_task, check_items, True).run()
return res, check_res
except Exception as e:
log.error(str(e))
e_str = str(e)
log_e = e_str[0:log_row_length] + '......' if len(e_str) > log_row_length else e_str
if kwargs.get("enable_traceback", True):
log.error(traceback.format_exc())
log.error("(api_response) : %s" % log_e)
check_res = ResponseChecker(Error(e), sys._getframe().f_code.co_name, check_task,
check_items, False).run()
return Error(e), check_res
return handler
return wrapper