milvus/python/sdk/tests/TestClient.py

358 lines
11 KiB
Python

import logging
import pytest
import mock
import faker
import random
from faker.providers import BaseProvider
from client.Client import MegaSearch, Prepare, IndexType, ColumnType
from client.Status import Status
from client.Exceptions import (
RepeatingConnectError,
DisconnectNotConnectedClientError
)
from thrift.transport.TSocket import TSocket
from megasearch.thrift import ttypes, MegasearchService
from thrift.transport.TTransport import TTransportException
LOGGER = logging.getLogger(__name__)
class FakerProvider(BaseProvider):
def table_name(self):
return 'table_name' + str(random.randint(1000, 9999))
def name(self):
return 'name' + str(random.randint(1000, 9999))
def dim(self):
return random.randint(0, 999)
fake = faker.Faker()
fake.add_provider(FakerProvider)
def vector_column_factory():
return {
'name': fake.name(),
'dimension': fake.dim(),
'store_raw_vector': True
}
def column_factory():
return {
'name': fake.table_name(),
'type': ColumnType.INT32
}
def range_factory():
return {
'start': str(random.randint(1, 10)),
'end': str(random.randint(11, 20)),
}
def table_schema_factory():
vec_params = [vector_column_factory() for i in range(10)]
column_params = [column_factory() for i in range(5)]
param = {
'table_name': fake.table_name(),
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params],
'attribute_columns': [Prepare.column(**pa) for pa in column_params],
'partition_column_names': [str(x) for x in range(2)]
}
return Prepare.table_schema(**param)
def create_table_partition_param_factory():
param = {
'table_name': fake.table_name(),
'partition_name': fake.table_name(),
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
}
return Prepare.create_table_partition_param(**param)
def delete_table_partition_param_factory():
param = {
'table_name': fake.table_name(),
'partition_names': [fake.name() for i in range(5)]
}
return Prepare.delete_table_partition_param(**param)
def row_record_factory():
param = {
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
'column_name_to_attribute': {fake.name(): fake.name()}
}
return Prepare.row_record(**param)
class TestConnection:
param = {'host':'localhost', 'port': '5000'}
@mock.patch.object(TSocket, 'open')
def test_true_connect(self, open):
open.return_value = None
cnn = MegaSearch()
cnn.connect(**self.param)
assert cnn.status == Status.OK
assert cnn.connected
assert isinstance(cnn.client, MegasearchService.Client)
with pytest.raises(RepeatingConnectError):
cnn.connect(**self.param)
cnn.connect()
def test_false_connect(self):
cnn = MegaSearch()
cnn.connect(self.param)
assert cnn.status != Status.OK
def test_disconnected_error(self):
cnn = MegaSearch()
cnn.connect_status = Status(Status.INVALID)
with pytest.raises(DisconnectNotConnectedClientError):
cnn.disconnect()
class TestTable:
@pytest.fixture
@mock.patch.object(TSocket, 'open')
def client(self, open):
param = {'host': 'localhost', 'port': '5000'}
open.return_value = None
cnn = MegaSearch()
cnn.connect(**param)
return cnn
@mock.patch.object(MegasearchService.Client, 'CreateTable')
def test_create_table(self, CreateTable, client):
CreateTable.return_value = None
param = table_schema_factory()
res = client.create_table(param)
assert res == Status.OK
def test_false_create_table(self, client):
param = table_schema_factory()
with pytest.raises(TTransportException):
res = client.create_table(param)
LOGGER.error('{}'.format(res))
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'DeleteTable')
def test_delete_table(self, DeleteTable, client):
DeleteTable.return_value = None
table_name = 'fake_table_name'
res = client.delete_table(table_name)
assert res == Status.OK
def test_false_delete_table(self, client):
table_name = 'fake_table_name'
res = client.delete_table(table_name)
assert res != Status.OK
class TestVector:
@pytest.fixture
@mock.patch.object(TSocket, 'open')
def client(self, open):
param = {'host': 'localhost', 'port': '5000'}
open.return_value = None
cnn = MegaSearch()
cnn.connect(**param)
return cnn
@mock.patch.object(MegasearchService.Client, 'CreateTablePartition')
def test_create_table_partition(self, CreateTablePartition, client):
CreateTablePartition.return_value = None
param = create_table_partition_param_factory()
res = client.create_table_partition(param)
assert res == Status.OK
def test_false_table_partition(self, client):
param = create_table_partition_param_factory()
res = client.create_table_partition(param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'DeleteTablePartition')
def test_delete_table_partition(self, DeleteTablePartition, client):
DeleteTablePartition.return_value = None
param = delete_table_partition_param_factory()
res = client.delete_table_partition(param)
assert res == Status.OK
def test_false_delete_table_partition(self, client):
param = delete_table_partition_param_factory()
res = client.delete_table_partition(param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'AddVector')
def test_add_vector(self, AddVector, client):
AddVector.return_value = None
param ={
'table_name': fake.table_name(),
'records': [row_record_factory() for _ in range(1000)]
}
res, ids = client.add_vector(**param)
assert res == Status.OK
def test_false_add_vector(self, client):
param ={
'table_name': fake.table_name(),
'records': [row_record_factory() for _ in range(1000)]
}
res, ids = client.add_vector(**param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'SearchVector')
def test_search_vector(self, SearchVector, client):
SearchVector.return_value = None
param = {
'table_name': fake.table_name(),
'query_records': [row_record_factory() for _ in range(1000)],
'top_k': random.randint(0,10)
}
res, results = client.search_vector(**param)
assert res == Status.OK
def test_false_vector(self, client):
param = {
'table_name': fake.table_name(),
'query_records': [row_record_factory() for _ in range(1000)],
'top_k': random.randint(0,10)
}
res, results = client.search_vector(**param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'DescribeTable')
def test_describe_table(self, DescribeTable, client):
DescribeTable.return_value = table_schema_factory()
table_name = fake.table_name()
res, table_schema = client.describe_table(table_name)
assert res == Status.OK
assert isinstance(table_schema, ttypes.TableSchema)
def test_false_decribe_table(self, client):
table_name = fake.table_name()
res, table_schema = client.describe_table(table_name)
assert res != Status.OK
assert not table_schema
@mock.patch.object(MegasearchService.Client, 'ShowTables')
def test_show_tables(self, ShowTables, client):
ShowTables.return_value = [fake.table_name() for _ in range(10)]
res, tables = client.show_tables()
assert res == Status.OK
assert isinstance(tables, list)
def test_false_show_tables(self, client):
res, tables = client.show_tables()
assert res != Status.OK
assert not tables
def test_client_version(self, client):
res = client.client_version()
assert res == '0.0.1'
class TestPrepare:
def test_column(self):
param = {
'name': 'test01',
'type': ColumnType.DATE
}
res = Prepare.column(**param)
LOGGER.error('{}'.format(res))
assert res.name == 'test01'
assert res.type == ColumnType.DATE
assert isinstance(res, ttypes.Column)
def test_vector_column(self):
param = vector_column_factory()
res = Prepare.vector_column(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.VectorColumn)
def test_table_schema(self):
vec_params = [vector_column_factory() for i in range(10)]
column_params = [column_factory() for i in range(5)]
param = {
'table_name': 'test03',
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params],
'attribute_columns': [Prepare.column(**pa) for pa in column_params],
'partition_column_names': [str(x) for x in range(2)]
}
res = Prepare.table_schema(**param)
assert isinstance(res, ttypes.TableSchema)
def test_range(self):
param = {
'start': '200',
'end': '1000'
}
res = Prepare.range(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.Range)
assert res.start_value == '200'
assert res.end_value == '1000'
def test_create_table_partition_param(self):
param = {
'table_name': fake.table_name(),
'partition_name': fake.table_name(),
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
}
res = Prepare.create_table_partition_param(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.CreateTablePartitionParam)
def test_delete_table_partition_param(self):
param = {
'table_name': fake.table_name(),
'partition_names': [fake.name() for i in range(5)]
}
res = Prepare.delete_table_partition_param(**param)
assert isinstance(res, ttypes.DeleteTablePartitionParam)
def test_row_record(self):
param={
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
'column_name_to_attribute': {fake.name(): fake.name()}
}
res = Prepare.row_record(**param)
assert isinstance(res, ttypes.RowRecord)
def test_query_record(self):
param = {
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
'selected_columns': [fake.name() for _ in range(10)],
'name_to_partition_ranges': {fake.name(): [range_factory() for _ in range(5)]}
}
res = Prepare.query_record(**param)
assert isinstance(res, ttypes.QueryRecord)