[test]Add test cases for restful and sdk mix use (#26146)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
pull/26133/head
zhuwenxing 2023-08-08 10:25:10 +08:00 committed by GitHub
parent d1d0169fa3
commit 3eb870c8d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 335 additions and 5 deletions

View File

@ -92,7 +92,7 @@ class Requests:
class VectorClient(Requests):
def __init__(self, url, api_key, protocol="http"):
def __init__(self, url, api_key, protocol):
super().__init__(url, api_key)
self.protocol = protocol
self.url = url
@ -189,7 +189,7 @@ class VectorClient(Requests):
class CollectionClient(Requests):
def __init__(self, url, api_key, protocol="http"):
def __init__(self, url, api_key, protocol):
super().__init__(url, api_key)
self.protocol = protocol
self.url = url

View File

@ -15,6 +15,7 @@ def get_config():
class Base:
name = None
protocol = None
host = None
port = None
url = None
@ -42,7 +43,8 @@ class TestBase(Base):
logger.error(e)
@pytest.fixture(scope="function", autouse=True)
def init_client(self, host, port, username, password):
def init_client(self, protocol, host, port, username, password):
self.protocol = protocol
self.host = host
self.port = port
self.url = f"{host}:{port}/v1"
@ -52,6 +54,7 @@ class TestBase(Base):
self.invalid_api_key = "invalid_token"
self.vector_client = VectorClient(self.url, self.api_key)
self.collection_client = CollectionClient(self.url, self.api_key)
connections.connect(host=self.host, port=self.port)
def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=128, nb=100, batch_size=1000):
# create collection

View File

@ -3,12 +3,18 @@ import yaml
def pytest_addoption(parser):
parser.addoption("--protocol", action="store", default="http", help="host")
parser.addoption("--host", action="store", default="127.0.0.1", help="host")
parser.addoption("--port", action="store", default="19530", help="port")
parser.addoption("--username", action="store", default="root", help="email")
parser.addoption("--password", action="store", default="Milvus", help="password")
@pytest.fixture
def protocol(request):
return request.config.getoption("--protocol")
@pytest.fixture
def host(request):
return request.config.getoption("--host")

View File

@ -274,7 +274,7 @@ class TestDescribeCollection(TestBase):
all_collections = rsp['data']
assert name in all_collections
# describe collection
illegal_client = CollectionClient(self.url, "illegal_api_key")
illegal_client = CollectionClient(self.url, "illegal_api_key", self.protocol)
rsp = illegal_client.collection_describe(name)
assert rsp['code'] == 1800
@ -361,7 +361,7 @@ class TestDropCollection(TestBase):
payload = {
"collectionName": name,
}
illegal_client = CollectionClient(self.url, "invalid_api_key")
illegal_client = CollectionClient(self.url, "invalid_api_key", self.protocol)
rsp = illegal_client.collection_drop(payload)
assert rsp['code'] == 1800
rsp = client.collection_list()

View File

@ -0,0 +1,321 @@
import random
import time
from utils.utils import gen_collection_name
from utils.util_log import test_log as logger
import pytest
from base.testbase import TestBase
from pymilvus import (
FieldSchema, CollectionSchema, DataType,
Collection
)
class TestRestfulSdkCompatibility(TestBase):
@pytest.mark.parametrize("dim", [128, 256])
@pytest.mark.parametrize("enable_dynamic", [True, False])
@pytest.mark.parametrize("shard_num", [1, 2])
def test_collection_created_by_sdk_describe_by_restful(self, dim, enable_dynamic, shard_num):
"""
"""
# 1. create collection by sdk
name = gen_collection_name()
default_fields = [
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="float", dtype=DataType.FLOAT),
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=dim)
]
default_schema = CollectionSchema(fields=default_fields, description="test collection",
enable_dynamic_field=enable_dynamic)
collection = Collection(name=name, schema=default_schema, shards_num=shard_num)
logger.info(collection.schema)
# 2. use restful to get collection info
client = self.collection_client
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
rsp = client.collection_describe(name)
assert rsp['code'] == 200
assert rsp['data']['collectionName'] == name
assert rsp['data']['enableDynamic'] == enable_dynamic
assert rsp['data']['load'] == "LoadStateNotLoad"
assert rsp['data']['shardsNum'] == shard_num
@pytest.mark.parametrize("vector_field", ["vector", "emb"])
@pytest.mark.parametrize("primary_field", ["id", "doc_id"])
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
@pytest.mark.parametrize("dim", [128])
def test_collection_created_by_restful_describe_by_sdk(self, dim, metric_type, primary_field, vector_field):
"""
"""
name = gen_collection_name()
dim = 128
client = self.collection_client
payload = {
"collectionName": name,
"dimension": dim,
"metricType": metric_type,
"primaryField": primary_field,
"vectorField": vector_field,
}
if primary_field is None:
del payload["primaryField"]
if vector_field is None:
del payload["vectorField"]
rsp = client.collection_create(payload)
collection = Collection(name=name)
logger.info(collection.schema)
field_names = [field.name for field in collection.schema.fields]
assert len(field_names) == 2
assert collection.schema.enable_dynamic_field is True
assert len(collection.indexes) > 0
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
def test_collection_created_index_by_sdk_describe_by_restful(self, metric_type):
"""
"""
# 1. create collection by sdk
name = gen_collection_name()
default_fields = [
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="float", dtype=DataType.FLOAT),
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
]
default_schema = CollectionSchema(fields=default_fields, description="test collection",
enable_dynamic_field=True)
collection = Collection(name=name, schema=default_schema)
# create index by sdk
index_param = {"metric_type": metric_type, "index_type": "IVF_FLAT", "params": {"nlist": 128}}
collection.create_index(field_name="float_vector", index_params=index_param)
# 2. use restful to get collection info
client = self.collection_client
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
rsp = client.collection_describe(name)
assert rsp['code'] == 200
assert rsp['data']['collectionName'] == name
assert len(rsp['data']['indexes']) == 1 and rsp['data']['indexes'][0]['metricType'] == metric_type
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
def test_collection_load_by_sdk_describe_by_restful(self, metric_type):
"""
"""
# 1. create collection by sdk
name = gen_collection_name()
default_fields = [
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="float", dtype=DataType.FLOAT),
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
]
default_schema = CollectionSchema(fields=default_fields, description="test collection",
enable_dynamic_field=True)
collection = Collection(name=name, schema=default_schema)
# create index by sdk
index_param = {"metric_type": metric_type, "index_type": "IVF_FLAT", "params": {"nlist": 128}}
collection.create_index(field_name="float_vector", index_params=index_param)
collection.load()
# 2. use restful to get collection info
client = self.collection_client
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
rsp = client.collection_describe(name)
assert rsp['data']['load'] == "LoadStateLoaded"
def test_collection_create_by_sdk_insert_vector_by_restful(self):
"""
"""
# 1. create collection by sdk
dim = 128
nb = 100
name = gen_collection_name()
default_fields = [
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="float", dtype=DataType.FLOAT),
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
]
default_schema = CollectionSchema(fields=default_fields, description="test collection",
enable_dynamic_field=True)
collection = Collection(name=name, schema=default_schema)
# create index by sdk
index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
collection.create_index(field_name="float_vector", index_params=index_param)
collection.load()
# insert data by restful
data = [
{"int64": i, "float": i, "varchar": str(i), "float_vector": [random.random() for _ in range(dim)], "age": i}
for i in range(nb)
]
client = self.vector_client
payload = {
"collectionName": name,
"data": data,
}
rsp = client.vector_insert(payload)
assert rsp['code'] == 200
assert rsp['data']['insertCount'] == nb
def test_collection_create_by_sdk_search_vector_by_restful(self):
"""
"""
dim = 128
nb = 100
name = gen_collection_name()
default_fields = [
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="float", dtype=DataType.FLOAT),
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
]
default_schema = CollectionSchema(fields=default_fields, description="test collection",
enable_dynamic_field=True)
# init collection by sdk
collection = Collection(name=name, schema=default_schema)
index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
collection.create_index(field_name="float_vector", index_params=index_param)
collection.load()
data = [
{"int64": i, "float": i, "varchar": str(i), "float_vector": [random.random() for _ in range(dim)], "age": i}
for i in range(nb)
]
collection.insert(data)
client = self.vector_client
payload = {
"collectionName": name,
"vector": [random.random() for _ in range(dim)],
"limit": 10
}
# search data by restful
rsp = client.vector_search(payload)
assert rsp['code'] == 200
assert len(rsp['data']) == 10
def test_collection_create_by_sdk_query_vector_by_restful(self):
"""
"""
dim = 128
nb = 100
name = gen_collection_name()
default_fields = [
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="float", dtype=DataType.FLOAT),
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
]
default_schema = CollectionSchema(fields=default_fields, description="test collection",
enable_dynamic_field=True)
# init collection by sdk
collection = Collection(name=name, schema=default_schema)
index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
collection.create_index(field_name="float_vector", index_params=index_param)
collection.load()
data = [
{"int64": i, "float": i, "varchar": str(i), "float_vector": [random.random() for _ in range(dim)], "age": i}
for i in range(nb)
]
collection.insert(data)
client = self.vector_client
payload = {
"collectionName": name,
"filter": "int64 in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]",
}
# query data by restful
rsp = client.vector_query(payload)
assert rsp['code'] == 200
assert len(rsp['data']) == 10
def test_collection_create_by_restful_search_vector_by_sdk(self):
"""
"""
name = gen_collection_name()
dim = 128
# insert data by restful
self.init_collection(name, metric_type="L2", dim=dim)
time.sleep(5)
# search data by sdk
collection = Collection(name=name)
nq = 5
vectors_to_search = [[random.random() for i in range(dim)] for j in range(nq)]
res = collection.search(data=vectors_to_search, anns_field="vector", param={}, limit=10)
assert len(res) == nq
assert len(res[0]) == 10
def test_collection_create_by_restful_query_vector_by_sdk(self):
"""
"""
name = gen_collection_name()
dim = 128
# insert data by restful
self.init_collection(name, metric_type="L2", dim=dim)
time.sleep(5)
# query data by sdk
collection = Collection(name=name)
res = collection.query(expr=f"uid in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
for item in res:
uid = item["uid"]
assert uid in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def test_collection_create_by_restful_delete_vector_by_sdk(self):
"""
"""
name = gen_collection_name()
dim = 128
# insert data by restful
self.init_collection(name, metric_type="L2", dim=dim)
time.sleep(5)
# query data by sdk
collection = Collection(name=name)
res = collection.query(expr=f"uid in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
pk_id_list = []
for item in res:
uid = item["uid"]
pk_id_list.append(item["id"])
expr = f"id in {pk_id_list}"
collection.delete(expr)
time.sleep(5)
res = collection.query(expr=f"uid in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
assert len(res) == 0
def test_collection_create_by_sdk_delete_vector_by_restful(self):
"""
"""
dim = 128
nb = 100
name = gen_collection_name()
default_fields = [
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="float", dtype=DataType.FLOAT),
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
]
default_schema = CollectionSchema(fields=default_fields, description="test collection",
enable_dynamic_field=True)
# init collection by sdk
collection = Collection(name=name, schema=default_schema)
index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
collection.create_index(field_name="float_vector", index_params=index_param)
collection.load()
data = [
{"int64": i, "float": i, "varchar": str(i), "float_vector": [random.random() for _ in range(dim)], "age": i}
for i in range(nb)
]
collection.insert(data)
time.sleep(5)
res = collection.query(expr=f"int64 in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
pk_id_list = []
for item in res:
pk_id_list.append(item["int64"])
payload = {
"collectionName": name,
"id": pk_id_list
}
# delete data by restful
rsp = self.vector_client.vector_delete(payload)
time.sleep(5)
res = collection.query(expr=f"int64 in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
assert len(res) == 0