mirror of https://github.com/milvus-io/milvus.git
[test]Add test cases for restful and sdk mix use (#26146)
Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/26133/head
parent
d1d0169fa3
commit
3eb870c8d6
|
@ -92,7 +92,7 @@ class Requests:
|
|||
|
||||
|
||||
class VectorClient(Requests):
|
||||
def __init__(self, url, api_key, protocol="http"):
|
||||
def __init__(self, url, api_key, protocol):
|
||||
super().__init__(url, api_key)
|
||||
self.protocol = protocol
|
||||
self.url = url
|
||||
|
@ -189,7 +189,7 @@ class VectorClient(Requests):
|
|||
|
||||
class CollectionClient(Requests):
|
||||
|
||||
def __init__(self, url, api_key, protocol="http"):
|
||||
def __init__(self, url, api_key, protocol):
|
||||
super().__init__(url, api_key)
|
||||
self.protocol = protocol
|
||||
self.url = url
|
||||
|
|
|
@ -15,6 +15,7 @@ def get_config():
|
|||
|
||||
class Base:
|
||||
name = None
|
||||
protocol = None
|
||||
host = None
|
||||
port = None
|
||||
url = None
|
||||
|
@ -42,7 +43,8 @@ class TestBase(Base):
|
|||
logger.error(e)
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def init_client(self, host, port, username, password):
|
||||
def init_client(self, protocol, host, port, username, password):
|
||||
self.protocol = protocol
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f"{host}:{port}/v1"
|
||||
|
@ -52,6 +54,7 @@ class TestBase(Base):
|
|||
self.invalid_api_key = "invalid_token"
|
||||
self.vector_client = VectorClient(self.url, self.api_key)
|
||||
self.collection_client = CollectionClient(self.url, self.api_key)
|
||||
connections.connect(host=self.host, port=self.port)
|
||||
|
||||
def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=128, nb=100, batch_size=1000):
|
||||
# create collection
|
||||
|
|
|
@ -3,12 +3,18 @@ import yaml
|
|||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--protocol", action="store", default="http", help="host")
|
||||
parser.addoption("--host", action="store", default="127.0.0.1", help="host")
|
||||
parser.addoption("--port", action="store", default="19530", help="port")
|
||||
parser.addoption("--username", action="store", default="root", help="email")
|
||||
parser.addoption("--password", action="store", default="Milvus", help="password")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def protocol(request):
|
||||
return request.config.getoption("--protocol")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def host(request):
|
||||
return request.config.getoption("--host")
|
||||
|
|
|
@ -274,7 +274,7 @@ class TestDescribeCollection(TestBase):
|
|||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# describe collection
|
||||
illegal_client = CollectionClient(self.url, "illegal_api_key")
|
||||
illegal_client = CollectionClient(self.url, "illegal_api_key", self.protocol)
|
||||
rsp = illegal_client.collection_describe(name)
|
||||
assert rsp['code'] == 1800
|
||||
|
||||
|
@ -361,7 +361,7 @@ class TestDropCollection(TestBase):
|
|||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
illegal_client = CollectionClient(self.url, "invalid_api_key")
|
||||
illegal_client = CollectionClient(self.url, "invalid_api_key", self.protocol)
|
||||
rsp = illegal_client.collection_drop(payload)
|
||||
assert rsp['code'] == 1800
|
||||
rsp = client.collection_list()
|
||||
|
|
|
@ -0,0 +1,321 @@
|
|||
import random
|
||||
import time
|
||||
from utils.utils import gen_collection_name
|
||||
from utils.util_log import test_log as logger
|
||||
import pytest
|
||||
from base.testbase import TestBase
|
||||
from pymilvus import (
|
||||
FieldSchema, CollectionSchema, DataType,
|
||||
Collection
|
||||
)
|
||||
|
||||
|
||||
class TestRestfulSdkCompatibility(TestBase):
|
||||
|
||||
@pytest.mark.parametrize("dim", [128, 256])
|
||||
@pytest.mark.parametrize("enable_dynamic", [True, False])
|
||||
@pytest.mark.parametrize("shard_num", [1, 2])
|
||||
def test_collection_created_by_sdk_describe_by_restful(self, dim, enable_dynamic, shard_num):
|
||||
"""
|
||||
"""
|
||||
# 1. create collection by sdk
|
||||
name = gen_collection_name()
|
||||
default_fields = [
|
||||
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="float", dtype=DataType.FLOAT),
|
||||
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=dim)
|
||||
]
|
||||
default_schema = CollectionSchema(fields=default_fields, description="test collection",
|
||||
enable_dynamic_field=enable_dynamic)
|
||||
collection = Collection(name=name, schema=default_schema, shards_num=shard_num)
|
||||
logger.info(collection.schema)
|
||||
# 2. use restful to get collection info
|
||||
client = self.collection_client
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
rsp = client.collection_describe(name)
|
||||
assert rsp['code'] == 200
|
||||
assert rsp['data']['collectionName'] == name
|
||||
assert rsp['data']['enableDynamic'] == enable_dynamic
|
||||
assert rsp['data']['load'] == "LoadStateNotLoad"
|
||||
assert rsp['data']['shardsNum'] == shard_num
|
||||
|
||||
@pytest.mark.parametrize("vector_field", ["vector", "emb"])
|
||||
@pytest.mark.parametrize("primary_field", ["id", "doc_id"])
|
||||
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
|
||||
@pytest.mark.parametrize("dim", [128])
|
||||
def test_collection_created_by_restful_describe_by_sdk(self, dim, metric_type, primary_field, vector_field):
|
||||
"""
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
"metricType": metric_type,
|
||||
"primaryField": primary_field,
|
||||
"vectorField": vector_field,
|
||||
}
|
||||
if primary_field is None:
|
||||
del payload["primaryField"]
|
||||
if vector_field is None:
|
||||
del payload["vectorField"]
|
||||
rsp = client.collection_create(payload)
|
||||
collection = Collection(name=name)
|
||||
logger.info(collection.schema)
|
||||
field_names = [field.name for field in collection.schema.fields]
|
||||
assert len(field_names) == 2
|
||||
assert collection.schema.enable_dynamic_field is True
|
||||
assert len(collection.indexes) > 0
|
||||
|
||||
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
|
||||
def test_collection_created_index_by_sdk_describe_by_restful(self, metric_type):
|
||||
"""
|
||||
"""
|
||||
# 1. create collection by sdk
|
||||
name = gen_collection_name()
|
||||
default_fields = [
|
||||
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="float", dtype=DataType.FLOAT),
|
||||
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
|
||||
]
|
||||
default_schema = CollectionSchema(fields=default_fields, description="test collection",
|
||||
enable_dynamic_field=True)
|
||||
collection = Collection(name=name, schema=default_schema)
|
||||
# create index by sdk
|
||||
index_param = {"metric_type": metric_type, "index_type": "IVF_FLAT", "params": {"nlist": 128}}
|
||||
collection.create_index(field_name="float_vector", index_params=index_param)
|
||||
# 2. use restful to get collection info
|
||||
client = self.collection_client
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
rsp = client.collection_describe(name)
|
||||
assert rsp['code'] == 200
|
||||
assert rsp['data']['collectionName'] == name
|
||||
assert len(rsp['data']['indexes']) == 1 and rsp['data']['indexes'][0]['metricType'] == metric_type
|
||||
|
||||
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
|
||||
def test_collection_load_by_sdk_describe_by_restful(self, metric_type):
|
||||
"""
|
||||
"""
|
||||
# 1. create collection by sdk
|
||||
name = gen_collection_name()
|
||||
default_fields = [
|
||||
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="float", dtype=DataType.FLOAT),
|
||||
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
|
||||
]
|
||||
default_schema = CollectionSchema(fields=default_fields, description="test collection",
|
||||
enable_dynamic_field=True)
|
||||
collection = Collection(name=name, schema=default_schema)
|
||||
# create index by sdk
|
||||
index_param = {"metric_type": metric_type, "index_type": "IVF_FLAT", "params": {"nlist": 128}}
|
||||
collection.create_index(field_name="float_vector", index_params=index_param)
|
||||
collection.load()
|
||||
# 2. use restful to get collection info
|
||||
client = self.collection_client
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
rsp = client.collection_describe(name)
|
||||
assert rsp['data']['load'] == "LoadStateLoaded"
|
||||
|
||||
def test_collection_create_by_sdk_insert_vector_by_restful(self):
|
||||
"""
|
||||
"""
|
||||
# 1. create collection by sdk
|
||||
dim = 128
|
||||
nb = 100
|
||||
name = gen_collection_name()
|
||||
default_fields = [
|
||||
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="float", dtype=DataType.FLOAT),
|
||||
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
|
||||
]
|
||||
default_schema = CollectionSchema(fields=default_fields, description="test collection",
|
||||
enable_dynamic_field=True)
|
||||
collection = Collection(name=name, schema=default_schema)
|
||||
# create index by sdk
|
||||
index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
|
||||
collection.create_index(field_name="float_vector", index_params=index_param)
|
||||
collection.load()
|
||||
# insert data by restful
|
||||
data = [
|
||||
{"int64": i, "float": i, "varchar": str(i), "float_vector": [random.random() for _ in range(dim)], "age": i}
|
||||
for i in range(nb)
|
||||
]
|
||||
client = self.vector_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"data": data,
|
||||
}
|
||||
rsp = client.vector_insert(payload)
|
||||
assert rsp['code'] == 200
|
||||
assert rsp['data']['insertCount'] == nb
|
||||
|
||||
def test_collection_create_by_sdk_search_vector_by_restful(self):
|
||||
"""
|
||||
"""
|
||||
dim = 128
|
||||
nb = 100
|
||||
name = gen_collection_name()
|
||||
default_fields = [
|
||||
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="float", dtype=DataType.FLOAT),
|
||||
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
|
||||
]
|
||||
default_schema = CollectionSchema(fields=default_fields, description="test collection",
|
||||
enable_dynamic_field=True)
|
||||
# init collection by sdk
|
||||
collection = Collection(name=name, schema=default_schema)
|
||||
index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
|
||||
collection.create_index(field_name="float_vector", index_params=index_param)
|
||||
collection.load()
|
||||
data = [
|
||||
{"int64": i, "float": i, "varchar": str(i), "float_vector": [random.random() for _ in range(dim)], "age": i}
|
||||
for i in range(nb)
|
||||
]
|
||||
collection.insert(data)
|
||||
client = self.vector_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"vector": [random.random() for _ in range(dim)],
|
||||
"limit": 10
|
||||
}
|
||||
# search data by restful
|
||||
rsp = client.vector_search(payload)
|
||||
assert rsp['code'] == 200
|
||||
assert len(rsp['data']) == 10
|
||||
|
||||
def test_collection_create_by_sdk_query_vector_by_restful(self):
|
||||
"""
|
||||
"""
|
||||
dim = 128
|
||||
nb = 100
|
||||
name = gen_collection_name()
|
||||
default_fields = [
|
||||
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="float", dtype=DataType.FLOAT),
|
||||
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
|
||||
]
|
||||
default_schema = CollectionSchema(fields=default_fields, description="test collection",
|
||||
enable_dynamic_field=True)
|
||||
# init collection by sdk
|
||||
collection = Collection(name=name, schema=default_schema)
|
||||
index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
|
||||
collection.create_index(field_name="float_vector", index_params=index_param)
|
||||
collection.load()
|
||||
data = [
|
||||
{"int64": i, "float": i, "varchar": str(i), "float_vector": [random.random() for _ in range(dim)], "age": i}
|
||||
for i in range(nb)
|
||||
]
|
||||
collection.insert(data)
|
||||
client = self.vector_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"filter": "int64 in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]",
|
||||
}
|
||||
# query data by restful
|
||||
rsp = client.vector_query(payload)
|
||||
assert rsp['code'] == 200
|
||||
assert len(rsp['data']) == 10
|
||||
|
||||
def test_collection_create_by_restful_search_vector_by_sdk(self):
|
||||
"""
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
dim = 128
|
||||
# insert data by restful
|
||||
self.init_collection(name, metric_type="L2", dim=dim)
|
||||
time.sleep(5)
|
||||
# search data by sdk
|
||||
collection = Collection(name=name)
|
||||
nq = 5
|
||||
vectors_to_search = [[random.random() for i in range(dim)] for j in range(nq)]
|
||||
res = collection.search(data=vectors_to_search, anns_field="vector", param={}, limit=10)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == 10
|
||||
|
||||
def test_collection_create_by_restful_query_vector_by_sdk(self):
|
||||
"""
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
dim = 128
|
||||
# insert data by restful
|
||||
self.init_collection(name, metric_type="L2", dim=dim)
|
||||
time.sleep(5)
|
||||
# query data by sdk
|
||||
collection = Collection(name=name)
|
||||
res = collection.query(expr=f"uid in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
|
||||
for item in res:
|
||||
uid = item["uid"]
|
||||
assert uid in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
|
||||
def test_collection_create_by_restful_delete_vector_by_sdk(self):
|
||||
"""
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
dim = 128
|
||||
# insert data by restful
|
||||
self.init_collection(name, metric_type="L2", dim=dim)
|
||||
time.sleep(5)
|
||||
# query data by sdk
|
||||
collection = Collection(name=name)
|
||||
res = collection.query(expr=f"uid in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
|
||||
pk_id_list = []
|
||||
for item in res:
|
||||
uid = item["uid"]
|
||||
pk_id_list.append(item["id"])
|
||||
expr = f"id in {pk_id_list}"
|
||||
collection.delete(expr)
|
||||
time.sleep(5)
|
||||
res = collection.query(expr=f"uid in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
|
||||
assert len(res) == 0
|
||||
|
||||
def test_collection_create_by_sdk_delete_vector_by_restful(self):
|
||||
"""
|
||||
"""
|
||||
dim = 128
|
||||
nb = 100
|
||||
name = gen_collection_name()
|
||||
default_fields = [
|
||||
FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(name="float", dtype=DataType.FLOAT),
|
||||
FieldSchema(name="varchar", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=128)
|
||||
]
|
||||
default_schema = CollectionSchema(fields=default_fields, description="test collection",
|
||||
enable_dynamic_field=True)
|
||||
# init collection by sdk
|
||||
collection = Collection(name=name, schema=default_schema)
|
||||
index_param = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
|
||||
collection.create_index(field_name="float_vector", index_params=index_param)
|
||||
collection.load()
|
||||
data = [
|
||||
{"int64": i, "float": i, "varchar": str(i), "float_vector": [random.random() for _ in range(dim)], "age": i}
|
||||
for i in range(nb)
|
||||
]
|
||||
collection.insert(data)
|
||||
time.sleep(5)
|
||||
res = collection.query(expr=f"int64 in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
|
||||
pk_id_list = []
|
||||
for item in res:
|
||||
pk_id_list.append(item["int64"])
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"id": pk_id_list
|
||||
}
|
||||
# delete data by restful
|
||||
rsp = self.vector_client.vector_delete(payload)
|
||||
time.sleep(5)
|
||||
res = collection.query(expr=f"int64 in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]", output_fields=["*"])
|
||||
assert len(res) == 0
|
Loading…
Reference in New Issue