mirror of https://github.com/milvus-io/milvus.git
pr: #38699 issue: #38697 Signed-off-by: ThreadDao <yufen.zong@zilliz.com>pull/38872/head
parent
71dea30d44
commit
cf148f9bd2
|
@ -0,0 +1,174 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from typing import Optional, List, Union, Dict
|
||||
|
||||
from pymilvus import (
|
||||
AsyncMilvusClient,
|
||||
AnnSearchRequest,
|
||||
RRFRanker,
|
||||
)
|
||||
from pymilvus.orm.types import CONSISTENCY_STRONG
|
||||
from pymilvus.orm.collection import CollectionSchema
|
||||
|
||||
from check.func_check import ResponseChecker
|
||||
from utils.api_request import api_request, logger_interceptor
|
||||
|
||||
|
||||
class AsyncMilvusClientWrapper:
|
||||
async_milvus_client = None
|
||||
|
||||
def __init__(self, active_trace=False):
|
||||
self.active_trace = active_trace
|
||||
|
||||
def init_async_client(self, uri: str = "http://localhost:19530",
|
||||
user: str = "",
|
||||
password: str = "",
|
||||
db_name: str = "",
|
||||
token: str = "",
|
||||
timeout: Optional[float] = None,
|
||||
active_trace=False,
|
||||
check_task=None, check_items=None,
|
||||
**kwargs):
|
||||
self.active_trace = active_trace
|
||||
|
||||
""" In order to distinguish the same name of collection """
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, is_succ = api_request([AsyncMilvusClient, uri, user, password, db_name, token,
|
||||
timeout], **kwargs)
|
||||
self.async_milvus_client = res if is_succ else None
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@logger_interceptor()
|
||||
async def create_collection(self,
|
||||
collection_name: str,
|
||||
dimension: Optional[int] = None,
|
||||
primary_field_name: str = "id", # default is "id"
|
||||
id_type: str = "int", # or "string",
|
||||
vector_field_name: str = "vector", # default is "vector"
|
||||
metric_type: str = "COSINE",
|
||||
auto_id: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
schema: Optional[CollectionSchema] = None,
|
||||
index_params=None,
|
||||
**kwargs):
|
||||
kwargs["consistency_level"] = kwargs.get("consistency_level", CONSISTENCY_STRONG)
|
||||
|
||||
return await self.async_milvus_client.create_collection(collection_name, dimension,
|
||||
primary_field_name,
|
||||
id_type, vector_field_name, metric_type,
|
||||
auto_id,
|
||||
timeout, schema, index_params, **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def drop_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
|
||||
return await self.async_milvus_client.drop_collection(collection_name, timeout, **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def load_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
|
||||
return await self.async_milvus_client.load_collection(collection_name, timeout, **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def create_index(self, collection_name: str, index_params, timeout: Optional[float] = None,
|
||||
**kwargs):
|
||||
return await self.async_milvus_client.create_index(collection_name, index_params, timeout, **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def insert(self,
|
||||
collection_name: str,
|
||||
data: Union[Dict, List[Dict]],
|
||||
timeout: Optional[float] = None,
|
||||
partition_name: Optional[str] = "",
|
||||
**kwargs):
|
||||
return await self.async_milvus_client.insert(collection_name, data, timeout, partition_name, **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def upsert(self,
|
||||
collection_name: str,
|
||||
data: Union[Dict, List[Dict]],
|
||||
timeout: Optional[float] = None,
|
||||
partition_name: Optional[str] = "",
|
||||
**kwargs):
|
||||
return await self.async_milvus_client.upsert(collection_name, data, timeout, partition_name, **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def search(self,
|
||||
collection_name: str,
|
||||
data: Union[List[list], list],
|
||||
filter: str = "",
|
||||
limit: int = 10,
|
||||
output_fields: Optional[List[str]] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
timeout: Optional[float] = None,
|
||||
partition_names: Optional[List[str]] = None,
|
||||
anns_field: Optional[str] = None,
|
||||
**kwargs):
|
||||
return await self.async_milvus_client.search(collection_name, data,
|
||||
filter,
|
||||
limit, output_fields, search_params,
|
||||
timeout,
|
||||
partition_names, anns_field, **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def hybrid_search(self,
|
||||
collection_name: str,
|
||||
reqs: List[AnnSearchRequest],
|
||||
ranker: RRFRanker,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[List[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
partition_names: Optional[List[str]] = None,
|
||||
**kwargs):
|
||||
return await self.async_milvus_client.hybrid_search(collection_name, reqs,
|
||||
ranker,
|
||||
limit, output_fields,
|
||||
timeout, partition_names, **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def query(self,
|
||||
collection_name: str,
|
||||
filter: str = "",
|
||||
output_fields: Optional[List[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
ids: Optional[Union[List, str, int]] = None,
|
||||
partition_names: Optional[List[str]] = None,
|
||||
**kwargs):
|
||||
return await self.async_milvus_client.query(collection_name, filter,
|
||||
output_fields, timeout,
|
||||
ids, partition_names,
|
||||
**kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def get(self,
|
||||
collection_name: str,
|
||||
ids: Union[list, str, int],
|
||||
output_fields: Optional[List[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
partition_names: Optional[List[str]] = None,
|
||||
**kwargs):
|
||||
return await self.async_milvus_client.get(collection_name, ids,
|
||||
output_fields, timeout,
|
||||
partition_names,
|
||||
**kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def delete(self,
|
||||
collection_name: str,
|
||||
ids: Optional[Union[list, str, int]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
filter: Optional[str] = None,
|
||||
partition_name: Optional[str] = None,
|
||||
**kwargs):
|
||||
return await self.async_milvus_client.delete(collection_name, ids,
|
||||
timeout, filter,
|
||||
partition_name,
|
||||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_schema(cls, **kwargs):
|
||||
kwargs["check_fields"] = False # do not check fields for now
|
||||
return CollectionSchema([], **kwargs)
|
||||
|
||||
@logger_interceptor()
|
||||
async def close(self, **kwargs):
|
||||
return await self.async_milvus_client.close(**kwargs)
|
|
@ -13,6 +13,7 @@ from base.index_wrapper import ApiIndexWrapper
|
|||
from base.utility_wrapper import ApiUtilityWrapper
|
||||
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
|
||||
from base.high_level_api_wrapper import HighLevelApiWrapper
|
||||
from base.async_milvus_client_wrapper import AsyncMilvusClientWrapper
|
||||
from utils.util_log import test_log as log
|
||||
from common import common_func as cf
|
||||
from common import common_type as ct
|
||||
|
@ -35,6 +36,7 @@ class Base:
|
|||
collection_object_list = []
|
||||
resource_group_list = []
|
||||
high_level_api_wrap = None
|
||||
async_milvus_client_wrap = None
|
||||
skip_connection = False
|
||||
|
||||
def setup_class(self):
|
||||
|
@ -59,6 +61,7 @@ class Base:
|
|||
self.field_schema_wrap = ApiFieldSchemaWrapper()
|
||||
self.database_wrap = ApiDatabaseWrapper()
|
||||
self.high_level_api_wrap = HighLevelApiWrapper()
|
||||
self.async_milvus_client_wrap = AsyncMilvusClientWrapper()
|
||||
|
||||
def teardown_method(self, method):
|
||||
log.info(("*" * 35) + " teardown " + ("*" * 35))
|
||||
|
@ -166,6 +169,16 @@ class TestcaseBase(Base):
|
|||
log.info(f"server version: {server_version}")
|
||||
return res
|
||||
|
||||
def init_async_milvus_client(self):
|
||||
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
||||
kwargs = {
|
||||
"uri": uri,
|
||||
"user": cf.param_info.param_user,
|
||||
"password": cf.param_info.param_password,
|
||||
"token": cf.param_info.param_token,
|
||||
}
|
||||
self.async_milvus_client_wrap.init_async_client(**kwargs)
|
||||
|
||||
def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,
|
||||
enable_dynamic_field=False, with_json=True, **kwargs):
|
||||
name = cf.gen_unique_str('coll_') if name is None else name
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import timeout_decorator
|
||||
from numpy import NaN
|
||||
|
||||
|
@ -40,6 +42,13 @@ class HighLevelApiWrapper:
|
|||
timeout=timeout, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def close(self, client, check_task=None, check_items=None):
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, is_succ = api_request([client.close])
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def create_schema(self, client, timeout=None, check_task=None,
|
||||
check_items=None, **kwargs):
|
||||
|
@ -103,6 +112,17 @@ class HighLevelApiWrapper:
|
|||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def get_collection_stats(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.get_collection_stats, collection_name], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def search(self, client, collection_name, data, limit=10, filter=None, output_fields=None, search_params=None,
|
||||
timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
|
@ -315,6 +335,16 @@ class HighLevelApiWrapper:
|
|||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def create_database(self, client, db_name, properties: Optional[dict] = None, check_task=None, check_items=None, **kwargs):
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.create_database, db_name, properties], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task,
|
||||
check_items, check,
|
||||
db_name=db_name, properties=properties,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def create_partition(self, client, collection_name, partition_name, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
|
|
|
@ -7,10 +7,21 @@ from common import common_func as cf
|
|||
from common.common_type import CheckTasks, Connect_Object_Name
|
||||
# from common.code_mapping import ErrorCode, ErrorMessage
|
||||
from pymilvus import Collection, Partition, ResourceGroupInfo
|
||||
from utils.api_request import Error
|
||||
import check.param_check as pc
|
||||
|
||||
|
||||
class Error:
|
||||
def __init__(self, error):
|
||||
self.code = getattr(error, 'code', -1)
|
||||
self.message = getattr(error, 'message', str(error))
|
||||
|
||||
def __str__(self):
|
||||
return f"Error(code={self.code}, message={self.message})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Error(code={self.code}, message={self.message})"
|
||||
|
||||
|
||||
class ResponseChecker:
|
||||
def __init__(self, response, func_name, check_task, check_items, is_succ=True, **kwargs):
|
||||
self.response = response # response of api request
|
||||
|
|
|
@ -25,7 +25,7 @@ def pytest_addoption(parser):
|
|||
parser.addoption("--user", action="store", default="", help="user name for connection")
|
||||
parser.addoption("--password", action="store", default="", help="password for connection")
|
||||
parser.addoption("--db_name", action="store", default="default", help="database name for connection")
|
||||
parser.addoption("--secure", type=bool, action="store", default=False, help="secure for connection")
|
||||
parser.addoption("--secure", action="store", default=False, help="secure for connection")
|
||||
parser.addoption("--milvus_ns", action="store", default="chaos-testing", help="milvus_ns")
|
||||
parser.addoption("--http_port", action="store", default=19121, help="http's port")
|
||||
parser.addoption("--handler", action="store", default="GRPC", help="handler of request")
|
||||
|
@ -45,7 +45,7 @@ def pytest_addoption(parser):
|
|||
parser.addoption('--term_expr', action='store', default="term_expr", help="expr of query quest")
|
||||
parser.addoption('--check_content', action='store', default="check_content", help="content of check")
|
||||
parser.addoption('--field_name', action='store', default="field_name", help="field_name of index")
|
||||
parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number")
|
||||
parser.addoption('--replica_num', action='store', default=ct.default_replica_num, help="memory replica number")
|
||||
parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip")
|
||||
parser.addoption('--uri', action='store', default="", help="uri for high level api")
|
||||
parser.addoption('--token', action='store', default="", help="token for high level api")
|
||||
|
|
|
@ -10,3 +10,6 @@ log_date_format = %Y-%m-%d %H:%M:%S
|
|||
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ requests==2.26.0
|
|||
scikit-learn==1.1.3
|
||||
timeout_decorator==0.5.0
|
||||
ujson==5.5.0
|
||||
pytest==7.2.0
|
||||
pytest==8.3.4
|
||||
pytest-asyncio==0.24.0
|
||||
pytest-assume==2.4.3
|
||||
pytest-timeout==1.3.3
|
||||
pytest-repeat==0.8.0
|
||||
|
|
|
@ -0,0 +1,509 @@
|
|||
import random
|
||||
import time
|
||||
import pytest
|
||||
import asyncio
|
||||
from pymilvus.client.types import LoadState, DataType
|
||||
from pymilvus import AnnSearchRequest, RRFRanker
|
||||
|
||||
from base.client_base import TestcaseBase
|
||||
from common import common_func as cf
|
||||
from common import common_type as ct
|
||||
from common.common_type import CaseLabel, CheckTasks
|
||||
from utils.util_log import test_log as log
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
prefix = "async"
|
||||
async_default_nb = 5000
|
||||
default_pk_name = "id"
|
||||
default_vector_name = "vector"
|
||||
|
||||
|
||||
class TestAsyncMilvusClient(TestcaseBase):
|
||||
|
||||
def teardown_method(self, method):
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.async_milvus_client_wrap.close())
|
||||
super().teardown_method(method)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
async def test_async_client_default(self):
|
||||
# init client
|
||||
milvus_client = self._connect(enable_milvus_client_api=True)
|
||||
self.init_async_milvus_client()
|
||||
|
||||
# create collection
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
|
||||
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
|
||||
assert c_name in collections
|
||||
|
||||
# insert entities
|
||||
rows = [
|
||||
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
|
||||
for i in range(async_default_nb)]
|
||||
start_time = time.time()
|
||||
tasks = []
|
||||
step = 1000
|
||||
for i in range(0, async_default_nb, step):
|
||||
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
|
||||
tasks.append(task)
|
||||
insert_res = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
|
||||
for r in insert_res:
|
||||
assert r[0]['insert_count'] == step
|
||||
|
||||
# dql tasks
|
||||
tasks = []
|
||||
# search default
|
||||
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
|
||||
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(default_search_task)
|
||||
|
||||
# search with filter & search_params
|
||||
sp = {"metric_type": "COSINE", "params": {"ef": "96"}}
|
||||
filter_params_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
filter=f"{default_pk_name} > 10",
|
||||
search_params=sp,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(filter_params_search_task)
|
||||
|
||||
# search output fields
|
||||
output_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
output_fields=["*"],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(output_search_task)
|
||||
|
||||
# query with filter and default output "*"
|
||||
exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)]
|
||||
filter_query_task = self.async_milvus_client_wrap.query(c_name,
|
||||
filter=f"{default_pk_name} < {ct.default_limit}",
|
||||
output_fields=[default_pk_name],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": exp_query_res,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(filter_query_task)
|
||||
# query with ids and output all fields
|
||||
ids_query_task = self.async_milvus_client_wrap.query(c_name,
|
||||
ids=[i for i in range(ct.default_limit)],
|
||||
output_fields=["*"],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": rows[:ct.default_limit],
|
||||
"with_vec": True,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(ids_query_task)
|
||||
# get with ids
|
||||
get_task = self.async_milvus_client_wrap.get(c_name,
|
||||
ids=[0, 1],
|
||||
output_fields=[default_pk_name, default_vector_name],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": rows[:2], "with_vec": True,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(get_task)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
async def test_async_client_partition(self):
|
||||
# init client
|
||||
milvus_client = self._connect(enable_milvus_client_api=True)
|
||||
self.init_async_milvus_client()
|
||||
|
||||
# create collection & partition
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
p_name = cf.gen_unique_str("par")
|
||||
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
|
||||
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
|
||||
assert c_name in collections
|
||||
self.high_level_api_wrap.create_partition(milvus_client, c_name, p_name)
|
||||
partitions, _ = self.high_level_api_wrap.list_partitions(milvus_client, c_name)
|
||||
assert p_name in partitions
|
||||
|
||||
# insert entities
|
||||
rows = [
|
||||
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
|
||||
for i in range(async_default_nb)]
|
||||
start_time = time.time()
|
||||
tasks = []
|
||||
step = 1000
|
||||
for i in range(0, async_default_nb, step):
|
||||
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step], partition_name=p_name)
|
||||
tasks.append(task)
|
||||
insert_res = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
|
||||
for r in insert_res:
|
||||
assert r[0]['insert_count'] == step
|
||||
|
||||
# count from default partition
|
||||
count_res, _ = await self.async_milvus_client_wrap.query(c_name, output_fields=["count(*)"], partition_names=[ct.default_partition_name])
|
||||
assert count_res[0]["count(*)"] == 0
|
||||
|
||||
# dql tasks
|
||||
tasks = []
|
||||
# search default
|
||||
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
|
||||
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
partition_names=[p_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(default_search_task)
|
||||
|
||||
# search with filter & search_params
|
||||
sp = {"metric_type": "COSINE", "params": {"ef": "96"}}
|
||||
filter_params_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
filter=f"{default_pk_name} > 10",
|
||||
search_params=sp,
|
||||
partition_names=[p_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(filter_params_search_task)
|
||||
|
||||
# search output fields
|
||||
output_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
output_fields=["*"],
|
||||
partition_names=[p_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(output_search_task)
|
||||
|
||||
# query with filter and default output "*"
|
||||
exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)]
|
||||
filter_query_task = self.async_milvus_client_wrap.query(c_name,
|
||||
filter=f"{default_pk_name} < {ct.default_limit}",
|
||||
output_fields=[default_pk_name],
|
||||
partition_names=[p_name],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": exp_query_res,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(filter_query_task)
|
||||
# query with ids and output all fields
|
||||
ids_query_task = self.async_milvus_client_wrap.query(c_name,
|
||||
ids=[i for i in range(ct.default_limit)],
|
||||
output_fields=["*"],
|
||||
partition_names=[p_name],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": rows[:ct.default_limit],
|
||||
"with_vec": True,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(ids_query_task)
|
||||
# get with ids
|
||||
get_task = self.async_milvus_client_wrap.get(c_name,
|
||||
ids=[0, 1], partition_names=[p_name],
|
||||
output_fields=[default_pk_name, default_vector_name],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": rows[:2], "with_vec": True,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(get_task)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
async def test_async_client_with_schema(self, schema):
|
||||
# init client
|
||||
milvus_client = self._connect(enable_milvus_client_api=True)
|
||||
self.init_async_milvus_client()
|
||||
|
||||
# create collection
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
schema = self.async_milvus_client_wrap.create_schema(auto_id=False,
|
||||
partition_key_field=ct.default_int64_field_name)
|
||||
schema.add_field(ct.default_string_field_name, DataType.VARCHAR, max_length=100, is_primary=True)
|
||||
schema.add_field(ct.default_int64_field_name, DataType.INT64, is_partition_key=True)
|
||||
schema.add_field(ct.default_float_vec_field_name, DataType.FLOAT_VECTOR, dim=ct.default_dim)
|
||||
schema.add_field(default_vector_name, DataType.FLOAT_VECTOR, dim=ct.default_dim)
|
||||
await self.async_milvus_client_wrap.create_collection(c_name, schema=schema)
|
||||
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
|
||||
assert c_name in collections
|
||||
|
||||
# insert entities
|
||||
rows = [
|
||||
{ct.default_string_field_name: str(i),
|
||||
ct.default_int64_field_name: i,
|
||||
ct.default_float_vec_field_name: [random.random() for _ in range(ct.default_dim)],
|
||||
default_vector_name: [random.random() for _ in range(ct.default_dim)],
|
||||
} for i in range(async_default_nb)]
|
||||
start_time = time.time()
|
||||
tasks = []
|
||||
step = 1000
|
||||
for i in range(0, async_default_nb, step):
|
||||
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
|
||||
tasks.append(task)
|
||||
insert_res = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
|
||||
for r in insert_res:
|
||||
assert r[0]['insert_count'] == step
|
||||
|
||||
# flush
|
||||
self.high_level_api_wrap.flush(milvus_client, c_name)
|
||||
stats, _ = self.high_level_api_wrap.get_collection_stats(milvus_client, c_name)
|
||||
assert stats["row_count"] == async_default_nb
|
||||
|
||||
# create index -> load
|
||||
index_params, _ = self.high_level_api_wrap.prepare_index_params(milvus_client,
|
||||
field_name=ct.default_float_vec_field_name,
|
||||
index_type="HNSW", metric_type="COSINE", M=30,
|
||||
efConstruction=200)
|
||||
index_params.add_index(field_name=default_vector_name, index_type="IVF_SQ8",
|
||||
metric_type="L2", nlist=32)
|
||||
await self.async_milvus_client_wrap.create_index(c_name, index_params)
|
||||
await self.async_milvus_client_wrap.load_collection(c_name)
|
||||
|
||||
_index, _ = self.high_level_api_wrap.describe_index(milvus_client, c_name, default_vector_name)
|
||||
assert _index["indexed_rows"] == async_default_nb
|
||||
assert _index["state"] == "Finished"
|
||||
_load, _ = self.high_level_api_wrap.get_load_state(milvus_client, c_name)
|
||||
assert _load["state"] == LoadState.Loaded
|
||||
|
||||
# dql tasks
|
||||
tasks = []
|
||||
# search default
|
||||
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
|
||||
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
anns_field=ct.default_float_vec_field_name,
|
||||
search_params={"metric_type": "COSINE",
|
||||
"params": {"ef": "96"}},
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(default_search_task)
|
||||
|
||||
# hybrid_search
|
||||
search_param = {
|
||||
"data": cf.gen_vectors(ct.default_nq, ct.default_dim, vector_data_type="FLOAT_VECTOR"),
|
||||
"anns_field": ct.default_float_vec_field_name,
|
||||
"param": {"metric_type": "COSINE", "params": {"ef": "96"}},
|
||||
"limit": ct.default_limit,
|
||||
"expr": f"{ct.default_int64_field_name} > 10"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
|
||||
search_param2 = {
|
||||
"data": cf.gen_vectors(ct.default_nq, ct.default_dim, vector_data_type="FLOAT_VECTOR"),
|
||||
"anns_field": default_vector_name,
|
||||
"param": {"metric_type": "L2", "params": {"nprobe": "32"}},
|
||||
"limit": ct.default_limit
|
||||
}
|
||||
req2 = AnnSearchRequest(**search_param2)
|
||||
_output_fields = [ct.default_int64_field_name, ct.default_string_field_name]
|
||||
filter_params_search_task = self.async_milvus_client_wrap.hybrid_search(c_name, [req, req2], RRFRanker(),
|
||||
limit=5,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={
|
||||
"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": 5})
|
||||
tasks.append(filter_params_search_task)
|
||||
|
||||
# get with ids
|
||||
get_task = self.async_milvus_client_wrap.get(c_name, ids=['0', '1'], output_fields=[ct.default_int64_field_name,
|
||||
ct.default_string_field_name])
|
||||
tasks.append(get_task)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
async def test_async_client_dml(self):
|
||||
# init client
|
||||
milvus_client = self._connect(enable_milvus_client_api=True)
|
||||
self.init_async_milvus_client()
|
||||
|
||||
# create collection
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
|
||||
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
|
||||
assert c_name in collections
|
||||
|
||||
# insert entities
|
||||
rows = [
|
||||
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
|
||||
for i in range(ct.default_nb)]
|
||||
start_time = time.time()
|
||||
tasks = []
|
||||
step = 1000
|
||||
for i in range(0, ct.default_nb, step):
|
||||
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
|
||||
tasks.append(task)
|
||||
insert_res = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
|
||||
for r in insert_res:
|
||||
assert r[0]['insert_count'] == step
|
||||
|
||||
# dml tasks
|
||||
# query id -> upsert id -> query id -> delete id -> query id
|
||||
_id = 10
|
||||
get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id],
|
||||
output_fields=[default_pk_name, default_vector_name])
|
||||
assert len(get_res) == 1
|
||||
|
||||
# upsert
|
||||
upsert_row = [{
|
||||
default_pk_name: _id, default_vector_name: [random.random() for _ in range(ct.default_dim)]
|
||||
}]
|
||||
upsert_res, _ = await self.async_milvus_client_wrap.upsert(c_name, upsert_row)
|
||||
assert upsert_res["upsert_count"] == 1
|
||||
|
||||
# get _id after upsert
|
||||
get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id],
|
||||
output_fields=[default_pk_name, default_vector_name])
|
||||
for j in range(5):
|
||||
assert abs(get_res[0][default_vector_name][j] - upsert_row[0][default_vector_name][j]) < ct.epsilon
|
||||
|
||||
# delete
|
||||
del_res, _ = await self.async_milvus_client_wrap.delete(c_name, ids=[_id])
|
||||
assert del_res["delete_count"] == 1
|
||||
|
||||
# query after delete
|
||||
get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id],
|
||||
output_fields=[default_pk_name, default_vector_name])
|
||||
assert len(get_res) == 0
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
async def test_async_client_with_db(self):
|
||||
# init client
|
||||
milvus_client = self._connect(enable_milvus_client_api=True)
|
||||
db_name = cf.gen_unique_str("db")
|
||||
self.high_level_api_wrap.create_database(milvus_client, db_name)
|
||||
self.high_level_api_wrap.close(milvus_client)
|
||||
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
||||
milvus_client, _ = self.connection_wrap.MilvusClient(uri=uri, db_name=db_name)
|
||||
self.async_milvus_client_wrap.init_async_client(uri, db_name=db_name)
|
||||
|
||||
# create collection
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
|
||||
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
|
||||
assert c_name in collections
|
||||
|
||||
# insert entities
|
||||
rows = [
|
||||
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
|
||||
for i in range(async_default_nb)]
|
||||
start_time = time.time()
|
||||
tasks = []
|
||||
step = 1000
|
||||
for i in range(0, async_default_nb, step):
|
||||
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
|
||||
tasks.append(task)
|
||||
insert_res = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
|
||||
for r in insert_res:
|
||||
assert r[0]['insert_count'] == step
|
||||
|
||||
# dql tasks
|
||||
tasks = []
|
||||
# search default
|
||||
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
|
||||
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(default_search_task)
|
||||
|
||||
# query with filter and default output "*"
|
||||
exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)]
|
||||
filter_query_task = self.async_milvus_client_wrap.query(c_name,
|
||||
filter=f"{default_pk_name} < {ct.default_limit}",
|
||||
output_fields=[default_pk_name],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": exp_query_res,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(filter_query_task)
|
||||
|
||||
# get with ids
|
||||
get_task = self.async_milvus_client_wrap.get(c_name,
|
||||
ids=[0, 1],
|
||||
output_fields=[default_pk_name, default_vector_name],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": rows[:2], "with_vec": True,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(get_task)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
async def test_async_client_close(self):
|
||||
# init async client
|
||||
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
||||
self.async_milvus_client_wrap.init_async_client(uri)
|
||||
|
||||
# create collection
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
|
||||
|
||||
# close -> search raise error
|
||||
await self.async_milvus_client_wrap.close()
|
||||
vector = cf.gen_vectors(1, ct.default_dim)
|
||||
error = {ct.err_code: 1, ct.err_msg: "should create connection first"}
|
||||
await self.async_milvus_client_wrap.search(c_name, vector, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L3)
|
||||
@pytest.mark.skip("connect with zilliz cloud")
|
||||
async def test_async_client_with_token(self):
|
||||
# init client
|
||||
milvus_client = self._connect(enable_milvus_client_api=True)
|
||||
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
||||
token = cf.param_info.param_token
|
||||
milvus_client, _ = self.connection_wrap.MilvusClient(uri=uri, token=token)
|
||||
self.async_milvus_client_wrap.init_async_client(uri, token=token)
|
||||
|
||||
# create collection
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim)
|
||||
collections, _ = self.high_level_api_wrap.list_collections(milvus_client)
|
||||
assert c_name in collections
|
||||
|
||||
# insert entities
|
||||
rows = [
|
||||
{default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]}
|
||||
for i in range(ct.default_nb)]
|
||||
start_time = time.time()
|
||||
tasks = []
|
||||
step = 1000
|
||||
for i in range(0, ct.default_nb, step):
|
||||
task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step])
|
||||
tasks.append(task)
|
||||
insert_res = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
log.info("Total time: {:.2f} seconds".format(end_time - start_time))
|
||||
for r in insert_res:
|
||||
assert r[0]['insert_count'] == step
|
||||
|
||||
# dql tasks
|
||||
tasks = []
|
||||
# search default
|
||||
vector = cf.gen_vectors(ct.default_nq, ct.default_dim)
|
||||
default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": ct.default_nq,
|
||||
"limit": ct.default_limit})
|
||||
tasks.append(default_search_task)
|
||||
|
||||
# query with filter and default output "*"
|
||||
exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)]
|
||||
filter_query_task = self.async_milvus_client_wrap.query(c_name,
|
||||
filter=f"{default_pk_name} < {ct.default_limit}",
|
||||
output_fields=[default_pk_name],
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={"exp_res": exp_query_res,
|
||||
"primary_field": default_pk_name})
|
||||
tasks.append(filter_query_task)
|
||||
await asyncio.gather(*tasks)
|
|
@ -1,24 +1,14 @@
|
|||
import sys
|
||||
import traceback
|
||||
import copy
|
||||
import os
|
||||
|
||||
from check.func_check import ResponseChecker, Error
|
||||
from utils.util_log import test_log as log
|
||||
|
||||
# enable_traceback = os.getenv('ENABLE_TRACEBACK', "True")
|
||||
# log.info(f"enable_traceback:{enable_traceback}")
|
||||
|
||||
|
||||
class Error:
|
||||
def __init__(self, error):
|
||||
self.code = getattr(error, 'code', -1)
|
||||
self.message = getattr(error, 'message', str(error))
|
||||
|
||||
def __str__(self):
|
||||
return f"Error(code={self.code}, message={self.message})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Error(code={self.code}, message={self.message})"
|
||||
|
||||
|
||||
log_row_length = 300
|
||||
|
||||
|
||||
|
@ -62,3 +52,50 @@ def api_request(_list, **kwargs):
|
|||
log.debug("(api_request) : [%s] args: %s, kwargs: %s" % (func.__qualname__, log_arg, str(kwargs)))
|
||||
return func(*arg, **kwargs)
|
||||
return False, False
|
||||
|
||||
|
||||
def logger_interceptor():
|
||||
def wrapper(func):
|
||||
def log_request(*arg, **kwargs):
|
||||
arg = arg[1:]
|
||||
arg_str = str(arg)
|
||||
log_arg = arg_str[0:log_row_length] + '......' if len(arg_str) > log_row_length else arg_str
|
||||
if kwargs.get("enable_traceback", True):
|
||||
log.debug("(api_request) : [%s] args: %s, kwargs: %s" % (func.__name__, log_arg, str(kwargs)))
|
||||
|
||||
def log_response(res, **kwargs):
|
||||
if kwargs.get("enable_traceback", True):
|
||||
res_str = str(res)
|
||||
log_res = res_str[0:log_row_length] + '......' if len(res_str) > log_row_length else res_str
|
||||
log.debug("(api_response) : [%s] %s " % (func.__name__, log_res))
|
||||
return res, True
|
||||
|
||||
async def handler(*args, **kwargs):
|
||||
_kwargs = copy.deepcopy(kwargs)
|
||||
_kwargs.pop("enable_traceback", None)
|
||||
check_task = kwargs.get("check_task", None)
|
||||
check_items = kwargs.get("check_items", None)
|
||||
try:
|
||||
# log request
|
||||
log_request(*args, **_kwargs)
|
||||
# exec func
|
||||
res = await func(*args, **_kwargs)
|
||||
# log response
|
||||
log_response(res, **_kwargs)
|
||||
# check_response
|
||||
check_res = ResponseChecker(res, sys._getframe().f_code.co_name, check_task, check_items, True).run()
|
||||
return res, check_res
|
||||
except Exception as e:
|
||||
log.error(str(e))
|
||||
e_str = str(e)
|
||||
log_e = e_str[0:log_row_length] + '......' if len(e_str) > log_row_length else e_str
|
||||
if kwargs.get("enable_traceback", True):
|
||||
log.error(traceback.format_exc())
|
||||
log.error("(api_response) : %s" % log_e)
|
||||
check_res = ResponseChecker(Error(e), sys._getframe().f_code.co_name, check_task,
|
||||
check_items, False).run()
|
||||
return Error(e), check_res
|
||||
|
||||
return handler
|
||||
|
||||
return wrapper
|
||||
|
|
Loading…
Reference in New Issue