mirror of https://github.com/milvus-io/milvus.git
358 lines
11 KiB
Python
358 lines
11 KiB
Python
import logging
|
|
import pytest
|
|
import mock
|
|
import faker
|
|
import random
|
|
from faker.providers import BaseProvider
|
|
|
|
from client.Client import MegaSearch, Prepare, IndexType, ColumnType
|
|
from client.Status import Status
|
|
from client.Exceptions import (
|
|
RepeatingConnectError,
|
|
DisconnectNotConnectedClientError
|
|
)
|
|
|
|
from thrift.transport.TSocket import TSocket
|
|
from megasearch.thrift import ttypes, MegasearchService
|
|
from thrift.transport.TTransport import TTransportException
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class FakerProvider(BaseProvider):
|
|
|
|
def table_name(self):
|
|
return 'table_name' + str(random.randint(1000, 9999))
|
|
|
|
def name(self):
|
|
return 'name' + str(random.randint(1000, 9999))
|
|
|
|
def dim(self):
|
|
return random.randint(0, 999)
|
|
|
|
|
|
fake = faker.Faker()
|
|
fake.add_provider(FakerProvider)
|
|
|
|
|
|
def vector_column_factory():
|
|
return {
|
|
'name': fake.name(),
|
|
'dimension': fake.dim(),
|
|
'store_raw_vector': True
|
|
}
|
|
|
|
|
|
def column_factory():
|
|
return {
|
|
'name': fake.table_name(),
|
|
'type': ColumnType.INT32
|
|
}
|
|
|
|
|
|
def range_factory():
|
|
return {
|
|
'start': str(random.randint(1, 10)),
|
|
'end': str(random.randint(11, 20)),
|
|
}
|
|
|
|
|
|
def table_schema_factory():
|
|
vec_params = [vector_column_factory() for i in range(10)]
|
|
column_params = [column_factory() for i in range(5)]
|
|
param = {
|
|
'table_name': fake.table_name(),
|
|
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params],
|
|
'attribute_columns': [Prepare.column(**pa) for pa in column_params],
|
|
'partition_column_names': [str(x) for x in range(2)]
|
|
}
|
|
return Prepare.table_schema(**param)
|
|
|
|
|
|
def create_table_partition_param_factory():
|
|
param = {
|
|
'table_name': fake.table_name(),
|
|
'partition_name': fake.table_name(),
|
|
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
|
|
}
|
|
return Prepare.create_table_partition_param(**param)
|
|
|
|
|
|
def delete_table_partition_param_factory():
|
|
param = {
|
|
'table_name': fake.table_name(),
|
|
'partition_names': [fake.name() for i in range(5)]
|
|
}
|
|
return Prepare.delete_table_partition_param(**param)
|
|
|
|
|
|
def row_record_factory():
|
|
param = {
|
|
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
|
|
'column_name_to_attribute': {fake.name(): fake.name()}
|
|
}
|
|
return Prepare.row_record(**param)
|
|
|
|
|
|
class TestConnection:
|
|
param = {'host':'localhost', 'port': '5000'}
|
|
|
|
@mock.patch.object(TSocket, 'open')
|
|
def test_true_connect(self, open):
|
|
open.return_value = None
|
|
cnn = MegaSearch()
|
|
|
|
cnn.connect(**self.param)
|
|
assert cnn.status == Status.OK
|
|
assert cnn.connected
|
|
assert isinstance(cnn.client, MegasearchService.Client)
|
|
|
|
with pytest.raises(RepeatingConnectError):
|
|
cnn.connect(**self.param)
|
|
cnn.connect()
|
|
|
|
def test_false_connect(self):
|
|
cnn = MegaSearch()
|
|
|
|
cnn.connect(self.param)
|
|
assert cnn.status != Status.OK
|
|
|
|
def test_disconnected_error(self):
|
|
cnn = MegaSearch()
|
|
cnn.connect_status = Status(Status.INVALID)
|
|
with pytest.raises(DisconnectNotConnectedClientError):
|
|
cnn.disconnect()
|
|
|
|
|
|
class TestTable:
|
|
|
|
@pytest.fixture
|
|
@mock.patch.object(TSocket, 'open')
|
|
def client(self, open):
|
|
param = {'host': 'localhost', 'port': '5000'}
|
|
open.return_value = None
|
|
|
|
cnn = MegaSearch()
|
|
cnn.connect(**param)
|
|
return cnn
|
|
|
|
@mock.patch.object(MegasearchService.Client, 'CreateTable')
|
|
def test_create_table(self, CreateTable, client):
|
|
CreateTable.return_value = None
|
|
|
|
param = table_schema_factory()
|
|
res = client.create_table(param)
|
|
assert res == Status.OK
|
|
|
|
def test_false_create_table(self, client):
|
|
param = table_schema_factory()
|
|
with pytest.raises(TTransportException):
|
|
res = client.create_table(param)
|
|
LOGGER.error('{}'.format(res))
|
|
assert res != Status.OK
|
|
|
|
@mock.patch.object(MegasearchService.Client, 'DeleteTable')
|
|
def test_delete_table(self, DeleteTable, client):
|
|
DeleteTable.return_value = None
|
|
table_name = 'fake_table_name'
|
|
res = client.delete_table(table_name)
|
|
assert res == Status.OK
|
|
|
|
def test_false_delete_table(self, client):
|
|
table_name = 'fake_table_name'
|
|
res = client.delete_table(table_name)
|
|
assert res != Status.OK
|
|
|
|
|
|
class TestVector:
|
|
|
|
@pytest.fixture
|
|
@mock.patch.object(TSocket, 'open')
|
|
def client(self, open):
|
|
param = {'host': 'localhost', 'port': '5000'}
|
|
open.return_value = None
|
|
|
|
cnn = MegaSearch()
|
|
cnn.connect(**param)
|
|
return cnn
|
|
|
|
@mock.patch.object(MegasearchService.Client, 'CreateTablePartition')
|
|
def test_create_table_partition(self, CreateTablePartition, client):
|
|
CreateTablePartition.return_value = None
|
|
|
|
param = create_table_partition_param_factory()
|
|
res = client.create_table_partition(param)
|
|
assert res == Status.OK
|
|
|
|
def test_false_table_partition(self, client):
|
|
param = create_table_partition_param_factory()
|
|
res = client.create_table_partition(param)
|
|
assert res != Status.OK
|
|
|
|
@mock.patch.object(MegasearchService.Client, 'DeleteTablePartition')
|
|
def test_delete_table_partition(self, DeleteTablePartition, client):
|
|
DeleteTablePartition.return_value = None
|
|
|
|
param = delete_table_partition_param_factory()
|
|
res = client.delete_table_partition(param)
|
|
assert res == Status.OK
|
|
|
|
def test_false_delete_table_partition(self, client):
|
|
param = delete_table_partition_param_factory()
|
|
res = client.delete_table_partition(param)
|
|
assert res != Status.OK
|
|
|
|
@mock.patch.object(MegasearchService.Client, 'AddVector')
|
|
def test_add_vector(self, AddVector, client):
|
|
AddVector.return_value = None
|
|
|
|
param ={
|
|
'table_name': fake.table_name(),
|
|
'records': [row_record_factory() for _ in range(1000)]
|
|
}
|
|
res, ids = client.add_vector(**param)
|
|
assert res == Status.OK
|
|
|
|
def test_false_add_vector(self, client):
|
|
param ={
|
|
'table_name': fake.table_name(),
|
|
'records': [row_record_factory() for _ in range(1000)]
|
|
}
|
|
res, ids = client.add_vector(**param)
|
|
assert res != Status.OK
|
|
|
|
@mock.patch.object(MegasearchService.Client, 'SearchVector')
|
|
def test_search_vector(self, SearchVector, client):
|
|
SearchVector.return_value = None
|
|
param = {
|
|
'table_name': fake.table_name(),
|
|
'query_records': [row_record_factory() for _ in range(1000)],
|
|
'top_k': random.randint(0,10)
|
|
}
|
|
res, results = client.search_vector(**param)
|
|
assert res == Status.OK
|
|
|
|
def test_false_vector(self, client):
|
|
param = {
|
|
'table_name': fake.table_name(),
|
|
'query_records': [row_record_factory() for _ in range(1000)],
|
|
'top_k': random.randint(0,10)
|
|
}
|
|
res, results = client.search_vector(**param)
|
|
assert res != Status.OK
|
|
|
|
@mock.patch.object(MegasearchService.Client, 'DescribeTable')
|
|
def test_describe_table(self, DescribeTable, client):
|
|
DescribeTable.return_value = table_schema_factory()
|
|
|
|
table_name = fake.table_name()
|
|
res, table_schema = client.describe_table(table_name)
|
|
assert res == Status.OK
|
|
assert isinstance(table_schema, ttypes.TableSchema)
|
|
|
|
def test_false_decribe_table(self, client):
|
|
table_name = fake.table_name()
|
|
res, table_schema = client.describe_table(table_name)
|
|
assert res != Status.OK
|
|
assert not table_schema
|
|
|
|
@mock.patch.object(MegasearchService.Client, 'ShowTables')
|
|
def test_show_tables(self, ShowTables, client):
|
|
ShowTables.return_value = [fake.table_name() for _ in range(10)]
|
|
res, tables = client.show_tables()
|
|
assert res == Status.OK
|
|
assert isinstance(tables, list)
|
|
|
|
def test_false_show_tables(self, client):
|
|
res, tables = client.show_tables()
|
|
assert res != Status.OK
|
|
assert not tables
|
|
|
|
def test_client_version(self, client):
|
|
res = client.client_version()
|
|
assert res == '0.0.1'
|
|
|
|
|
|
class TestPrepare:
|
|
|
|
def test_column(self):
|
|
param = {
|
|
'name': 'test01',
|
|
'type': ColumnType.DATE
|
|
}
|
|
res = Prepare.column(**param)
|
|
LOGGER.error('{}'.format(res))
|
|
assert res.name == 'test01'
|
|
assert res.type == ColumnType.DATE
|
|
assert isinstance(res, ttypes.Column)
|
|
|
|
def test_vector_column(self):
|
|
param = vector_column_factory()
|
|
|
|
res = Prepare.vector_column(**param)
|
|
LOGGER.error('{}'.format(res))
|
|
assert isinstance(res, ttypes.VectorColumn)
|
|
|
|
def test_table_schema(self):
|
|
|
|
vec_params = [vector_column_factory() for i in range(10)]
|
|
column_params = [column_factory() for i in range(5)]
|
|
|
|
param = {
|
|
'table_name': 'test03',
|
|
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params],
|
|
'attribute_columns': [Prepare.column(**pa) for pa in column_params],
|
|
'partition_column_names': [str(x) for x in range(2)]
|
|
}
|
|
res = Prepare.table_schema(**param)
|
|
assert isinstance(res, ttypes.TableSchema)
|
|
|
|
def test_range(self):
|
|
param = {
|
|
'start': '200',
|
|
'end': '1000'
|
|
}
|
|
|
|
res = Prepare.range(**param)
|
|
LOGGER.error('{}'.format(res))
|
|
assert isinstance(res, ttypes.Range)
|
|
assert res.start_value == '200'
|
|
assert res.end_value == '1000'
|
|
|
|
def test_create_table_partition_param(self):
|
|
param = {
|
|
'table_name': fake.table_name(),
|
|
'partition_name': fake.table_name(),
|
|
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
|
|
}
|
|
res = Prepare.create_table_partition_param(**param)
|
|
LOGGER.error('{}'.format(res))
|
|
assert isinstance(res, ttypes.CreateTablePartitionParam)
|
|
|
|
def test_delete_table_partition_param(self):
|
|
param = {
|
|
'table_name': fake.table_name(),
|
|
'partition_names': [fake.name() for i in range(5)]
|
|
}
|
|
res = Prepare.delete_table_partition_param(**param)
|
|
assert isinstance(res, ttypes.DeleteTablePartitionParam)
|
|
|
|
def test_row_record(self):
|
|
param={
|
|
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
|
|
'column_name_to_attribute': {fake.name(): fake.name()}
|
|
}
|
|
res = Prepare.row_record(**param)
|
|
assert isinstance(res, ttypes.RowRecord)
|
|
|
|
def test_query_record(self):
|
|
param = {
|
|
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
|
|
'selected_columns': [fake.name() for _ in range(10)],
|
|
'name_to_partition_ranges': {fake.name(): [range_factory() for _ in range(5)]}
|
|
}
|
|
res = Prepare.query_record(**param)
|
|
assert isinstance(res, ttypes.QueryRecord)
|
|
|
|
|