diff --git a/python/sdk/client/Abstract.py b/python/sdk/client/Abstract.py new file mode 100644 index 0000000000..94b1c1710d --- /dev/null +++ b/python/sdk/client/Abstract.py @@ -0,0 +1,444 @@ +from enum import IntEnum +from .Exceptions import ConnectParamMissingError + + +class AbstactIndexType(object): + RAW = 1 + IVFFLAT = 2 + + +class AbstractColumnType(object): + INVALID = 1 + INT8 = 2 + INT16 = 3 + INT32 = 4 + INT64 = 5 + FLOAT32 = 6 + FLOAT64 = 7 + DATE = 8 + VECTOR = 9 + + +class Column(object): + """ + Table column description + + :type type: ColumnType + :param type: type of the column + + :type name: str + :param name: name of the column + + """ + def __init__(self, name=None, type=AbstractColumnType.INVALID): + self.type = type + self.name = name + + +class VectorColumn(Column): + """ + Table vector column description + + :type dimension: int, int64 + :param dimension: vector dimension + + :type index_type: IndexType + :param index_type: IndexType + + :type store_raw_vector: bool + :param store_raw_vector: Is vector self stored in the table + + `Column`: + :type name: str + :param name: Name of the column + + :type type: ColumnType + :param type: Default type is ColumnType.VECTOR, can't change + + """ + def __init__(self, name, + dimension=0, + index_type=AbstactIndexType.RAW, + store_raw_vector=False): + self.dimension = dimension + self.index_type = index_type + self.store_raw_vector = store_raw_vector + super(VectorColumn, self).__init__(name, type=AbstractColumnType.VECTOR) + + +class TableSchema(object): + """ + Table Schema + + :type table_name: str + :param table_name: name of table + + :type vector_columns: list[VectorColumn] + :param vector_columns: a list of VectorColumns, + + Stores different types of vectors + + :type attribute_columns: list[Column] + :param attribute_columns: Columns description + + List of `Columns` whose type isn't VECTOR + + :type partition_column_names: list[str] + :param partition_column_names: Partition column name + + `Partition columns` are `attribute columns`, the number of + partition columns may be less than or equal to attribute columns, + this param only stores `column name` + + """ + def __init__(self, table_name, vector_columns, + attribute_columns, partition_column_names, **kwargs): + self.table_name = table_name + self.vector_columns = vector_columns + self.attribute_columns = attribute_columns + self.partition_column_names = partition_column_names + + +class Range(object): + """ + Range information + + :type start: str + :param start: Range start value + + :type end: str + :param end: Range end value + + """ + def __init__(self, start, end): + self.start = start + self.end = end + + +class CreateTablePartitionParam(object): + """ + Create table partition parameters + + :type table_name: str + :param table_name: Table name, + VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition + + :type partition_name: str + :param partition_name: partition name, created partition name + + :type column_name_to_range: dict{str : Range} + :param column_name_to_range: Column name to PartitionRange dictionary + """ + # TODO Iterable + def __init__(self, table_name, partition_name, column_name_to_range): + self.table_name = table_name + self.partition_name = partition_name + self.column_name_to_range = column_name_to_range + + +class DeleteTablePartitionParam(object): + """ + Delete table partition parameters + + :type table_name: str + :param table_name: Table name + + :type partition_names: iterable, str + :param partition_names: Partition name array + + """ + # TODO Iterable + def __init__(self, table_name, partition_names): + self.table_name = table_name + self.partition_names = partition_names + + +class RowRecord(object): + """ + Record inserted + + :type column_name_to_vector: dict{str : list[float]} + :param column_name_to_vector: Column name to vector map + + :type column_name_to_attribute: dict{str: str} + :param column_name_to_attribute: Other attribute columns + """ + def __init__(self, column_name_to_vector, column_name_to_attribute): + self.column_name_to_vector = column_name_to_vector + self.column_name_to_attribute = column_name_to_attribute + + +class QueryRecord(object): + """ + Query record + + :type column_name_to_vector: dict{str : list[float]} + :param column_name_to_vector: Query vectors, column name to vector map + + :type selected_columns: list[str] + :param selected_columns: Output column array + + :type name_to_partition_ranges: dict{str : list[Range]} + :param name_to_partition_ranges: Range used to select partitions + + """ + def __init__(self, column_name_to_vector, selected_columns, name_to_partition_ranges): + self.column_name_to_vector = column_name_to_vector + self.selected_columns = selected_columns + self.name_to_partition_ranges = name_to_partition_ranges + + +class QueryResult(object): + """ + Query result + + :type id: int + :param id: Output result + + :type score: float + :param score: Vector similarity 0 <= score <= 100 + + :type column_name_to_attribute: dict{str : str} + :param column_name_to_attribute: Other columns + + """ + def __init__(self, id, score, column_name_to_attribute): + self.id = id + self.score = score + self.column_name_to_value = column_name_to_attribute + + +class TopKQueryResult(object): + """ + TopK query results + + :type query_results: list[QueryResult] + :param query_results: TopK query results + + """ + def __init__(self, query_results): + self.query_results = query_results + + +def _abstract(): + raise NotImplementedError('You need to override this function') + + +class ConnectIntf(object): + """SDK client abstract class + + Connection is a abstract class + + """ + + @staticmethod + def create(): + """Create a connection instance and return it + should be implemented + + :return connection: Connection + """ + _abstract() + + @staticmethod + def destroy(connection): + """Destroy the connection instance + should be implemented + + :type connection: Connection + :param connection: The connection instance to be destroyed + + :return bool, return True if destroy is successful + """ + _abstract() + + def connect(self, param=None, uri=None): + """ + Connect method should be called before any operations + Server will be connected after connect return OK + should be implemented + + :type param: ConnectParam + :param param: ConnectParam + + :type uri: str + :param uri: uri param + + :return: Status, indicate if connect is successful + """ + if (not param and not uri) or (param and uri): + raise ConnectParamMissingError('You need to parse exact one param') + _abstract() + + def connected(self): + """ + connected, connection status + should be implemented + + :return: Status, indicate if connect is successful + """ + _abstract() + + def disconnect(self): + """ + Disconnect, server will be disconnected after disconnect return OK + should be implemented + + :return: Status, indicate if connect is successful + """ + _abstract() + + def create_table(self, param): + """ + Create table + should be implemented + + :type param: TableSchema + :param param: provide table information to be created + + :return: Status, indicate if connect is successful + """ + _abstract() + + def delete_table(self, table_name): + """ + Delete table + should be implemented + + :type table_name: str + :param table_name: table_name of the deleting table + + :return: Status, indicate if connect is successful + """ + _abstract() + + def create_table_partition(self, param): + """ + Create table partition + should be implemented + + :type param: CreateTablePartitionParam + :param param: provide partition information + + :return: Status, indicate if table partition is created successfully + """ + _abstract() + + def delete_table_partition(self, param): + """ + Delete table partition + should be implemented + + :type param: DeleteTablePartitionParam + :param param: provide partition information to be deleted + :return: Status, indicate if partition is deleted successfully + """ + _abstract() + + def add_vector(self, table_name, records): + """ + Add vectors to table + should be implemented + + :type table_name: str + :param table_name: table name been inserted + + :type records: list[RowRecord] + :param records: list of vectors been inserted + + :returns: + Status : indicate if vectors inserted successfully + ids :list of id, after inserted every vector is given a id + """ + _abstract() + + def search_vector(self, table_name, query_records, top_k): + """ + Query vectors in a table + should be implemented + + :type table_name: str + :param table_name: table name been queried + + :type query_records: list[QueryRecord] + :param query_records: all vectors going to be queried + + :type top_k: int + :param top_k: how many similar vectors will be searched + + :returns: + Status: indicate if query is successful + query_results: list[TopKQueryResult] + """ + _abstract() + + def describe_table(self, table_name): + """ + Show table information + should be implemented + + :type table_name: str + :param table_name: which table to be shown + + :returns: + Status: indicate if query is successful + table_schema: TableSchema, given when operation is successful + """ + _abstract() + + def show_tables(self): + """ + Show all tables in database + should be implemented + + :return: + Status: indicate if this operation is successful + tables: list[str], list of table names + """ + _abstract() + + def client_version(self): + """ + Provide client version + should be implemented + + :return: Client version + """ + _abstract() + pass + + def server_version(self): + """ + Provide server version + should be implemented + + :return: Server version + """ + + def server_status(self, cmd): + """ + Provide server status + should be implemented + # TODO What is cmd + :type cmd + :param cmd + + :return: Server status + """ + _abstract() + pass + + + + + + + + + + + + + + + diff --git a/python/sdk/client/Client.py b/python/sdk/client/Client.py new file mode 100644 index 0000000000..5426176444 --- /dev/null +++ b/python/sdk/client/Client.py @@ -0,0 +1,473 @@ +import logging, logging.config + +from thrift.transport import TSocket +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol +from thrift.Thrift import TException, TApplicationException, TType + +from megasearch.thrift import MegasearchService +from megasearch.thrift import ttypes +from client.Abstract import ( + ConnectIntf, TableSchema, + AbstactIndexType, AbstractColumnType, + Column, + VectorColumn, Range, + CreateTablePartitionParam, + DeleteTablePartitionParam, + RowRecord, QueryRecord, + QueryResult, TopKQueryResult +) + +from client.Status import Status +from client.Exceptions import ( + RepeatingConnectError, ConnectParamMissingError, + DisconnectNotConnectedClientError, + ParamError, NotConnectError +) + +LOGGER = logging.getLogger(__name__) + +__VERSION__ = '0.0.1' +__NAME__ = 'Thrift_Client' + + +class IndexType(AbstactIndexType): + # TODO thrift in IndexType + RAW = 1 + IVFFLAT = 2 + + +class ColumnType(AbstractColumnType): + # INVALID = 1 + # INT8 = 2 + # INT16 = 3 + # INT32 = 4 + # INT64 = 5 + FLOAT32 = 6 + FLOAT64 = 7 + DATE = 8 + # VECTOR = 9 + + INVALID = TType.STOP + INT8 = TType.I08 + INT16 = TType.I16 + INT32 = TType.I32 + INT64 = TType.I64 + VECTOR = TType.LIST + + +class Prepare(object): + + @classmethod + def column(cls, name, type): + """ + Table column param + + :param type: ColumnType, type of the column + :param name: str, name of the column + + :return Column + """ + # TODO type in Thrift, may have error + temp_column = Column(name=name, type=type) + return ttypes.Column(name=temp_column.name, type=temp_column.type) + + @classmethod + def vector_column(cls, name, dimension, + # index_type=IndexType.RAW, + store_raw_vector=False): + """ + Table vector column description + + :param dimension: int64, vector dimension + :param index_type: IndexType + :param store_raw_vector: Bool, Is vector self stored in the table + + `Column`: + :param name: Name of the column + :param type: Default type is ColumnType.VECTOR, can't change + + :return VectorColumn + """ + # temp = VectorColumn(name=name, dimension=dimension, + # index_type=index_type, store_raw_vector=store_raw_vector) + + # return ttypes.VectorColumn(base=base, dimension=temp.dimension, + # store_raw_vector=temp.store_raw_vector, + # index_type=temp.index_type) + + # Without IndexType + temp = VectorColumn(name=name, dimension=dimension, + store_raw_vector=store_raw_vector) + base = ttypes.Column(name=temp.name, type=ColumnType.VECTOR) + return ttypes.VectorColumn(base=base, dimension=temp.dimension, + store_raw_vector=temp.store_raw_vector) + + @classmethod + def table_schema(cls, table_name, + vector_columns, + attribute_columns, + partition_column_names): + """ + + :param table_name: Name of the table + :param vector_columns: List of VectorColumns + + `VectorColumn`: + - dimension: int, default = 0 + Dimension of the vector, different vector_columns' + dimension may vary + - index_type: (optional) IndexType, default=IndexType.RAW + Vector's index type + - store_raw_vector : (optional) bool, default=False + - name: str + Name of the column + - type: ColumnType, default=ColumnType.VECTOR, can't change + + :param attribute_columns: List of Columns. Attribute + columns are Columns whose type aren't ColumnType.VECTOR + + `Column`: + - name: str + - type: ColumnType, default=ColumnType.INVALID + + :param partition_column_names: List of str. + + Partition columns name + indicates which attribute columns is used for partition, can + have lots of partition columns as long as: + -> No. partition_column_names <= No. attribute_columns + -> partition_column_names IN attribute_column_names + + :return: TableSchema + """ + temp = TableSchema(table_name,vector_columns, + attribute_columns, + partition_column_names) + + return ttypes.TableSchema(table_name=temp.table_name, + vector_column_array=temp.vector_columns, + attribute_column_array=temp.attribute_columns, + partition_column_name_array=temp.partition_column_names) + + @classmethod + def range(cls, start, end): + """ + :param start: Partition range start value + :param end: Partition range end value + + :return Range + """ + temp = Range(start=start, end=end) + return ttypes.Range(start_value=temp.start, end_value=temp.end) + + @classmethod + def create_table_partition_param(cls, + table_name, + partition_name, + column_name_to_range): + """ + Create table partition parameters + :param table_name: str, Table name, + VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition + :param partition_name: str partition name, created partition name + :param column_name_to_range: dict, column name to partition range dictionary + + :return CreateTablePartitionParam + """ + temp = CreateTablePartitionParam(table_name=table_name, + partition_name=partition_name, + column_name_to_range=column_name_to_range) + return ttypes.CreateTablePartitionParam(table_name=temp.table_name, + partition_name=temp.partition_name, + range_map=temp.column_name_to_range) + + @classmethod + def delete_table_partition_param(cls, table_name, partition_names): + """ + Delete table partition parameters + :param table_name: Table name + :param partition_names: List of partition names + + :return DeleteTablePartitionParam + """ + temp = DeleteTablePartitionParam(table_name=table_name, + partition_names=partition_names) + return ttypes.DeleteTablePartitionParam(table_name=table_name, + partition_name_array=partition_names) + + @classmethod + def row_record(cls, column_name_to_vector, column_name_to_attribute): + """ + :param column_name_to_vector: dict{str : list[float]} + Column name to vector map + + :param column_name_to_attribute: dict{str: str} + Other attribute columns + """ + temp = RowRecord(column_name_to_vector=column_name_to_vector, + column_name_to_attribute=column_name_to_attribute) + return ttypes.RowRecord(vector_map=temp.column_name_to_vector, + attribute_map=temp.column_name_to_attribute) + + @classmethod + def query_record(cls, column_name_to_vector, + selected_columns, name_to_partition_ranges): + """ + :param column_name_to_vector: dict{str : list[float]} + Query vectors, column name to vector map + + :param selected_columns: list[str_column_name] + List of Output columns + + :param name_to_partition_ranges: dict{str : list[Range]} + Partition Range used to search + + `Range`: + :param start: Partition range start value + :param end: Partition range end value + + :return QueryRecord + """ + temp = QueryRecord(column_name_to_vector=column_name_to_vector, + selected_columns=selected_columns, + name_to_partition_ranges=name_to_partition_ranges) + return ttypes.QueryRecord(vector_map=temp.column_name_to_vector, + selected_column_array=temp.selected_columns, + partition_filter_column_map=name_to_partition_ranges) + + +class MegaSearch(ConnectIntf): + + def __init__(self): + self.transport = None + self.client = None + self.status = None + + def __repr__(self): + return '{}'.format(self.status) + + @staticmethod + def create(): + # TODO in python, maybe this method is useless + return MegaSearch() + + @staticmethod + def destroy(connection): + """Destroy the connection instance""" + # TODO in python, maybe this method is useless + + pass + + def connect(self, host='localhost', port='9090', uri=None): + # TODO URI + if self.status and self.status == Status(message='Connected'): + raise RepeatingConnectError("You have already connected!") + + transport = TSocket.TSocket(host=host, port=port) + self.transport = TTransport.TBufferedTransport(transport) + protocol = TJSONProtocol.TJSONProtocol(transport) + self.client = MegasearchService.Client(protocol) + + try: + transport.open() + self.status = Status(Status.OK, 'Connected') + LOGGER.info('Connected!') + + except (TTransport.TTransportException, TException) as e: + self.status = Status(Status.INVALID, message=str(e)) + LOGGER.error('logger.error: {}'.format(self.status)) + finally: + return self.status + + @property + def connected(self): + return self.status == Status() + + def disconnect(self): + + if not self.transport: + raise DisconnectNotConnectedClientError('Error') + + try: + + self.transport.close() + LOGGER.info('Client Disconnected!') + self.status = None + + except TException as e: + return Status(Status.INVALID, str(e)) + return Status(Status.OK, 'Disconnected') + + def create_table(self, param): + """Create table + + :param param: Provide table information to be created, + + `Please use Prepare.table_schema generate param` + + :return: Status, indicate if operation is successful + """ + if not self.client: + raise NotConnectError('Please Connect to the server first!') + + try: + self.client.CreateTable(param) + except (TApplicationException, TException) as e: + LOGGER.error('Unable to create table') + return Status(Status.INVALID, str(e)) + return Status(message='Table {} created!'.format(param.table_name)) + + def delete_table(self, table_name): + """Delete table + + :param table_name: Name of the table being deleted + + :return: Status, indicate if operation is successful + """ + try: + self.client.DeleteTable(table_name) + except (TApplicationException, TException) as e: + LOGGER.error('Unable to delete table {}'.format(table_name)) + return Status(Status.INVALID, str(e)) + return Status(message='Table {} deleted!'.format(table_name)) + + def create_table_partition(self, param): + """ + Create table partition + + :type param: CreateTablePartitionParam, provide partition information + + `Please use Prepare.create_table_partition_param generate param` + + :return: Status, indicate if table partition is created successfully + """ + try: + self.client.CreateTablePartition(param) + except (TApplicationException, TException) as e: + LOGGER.error('{}'.format(e)) + return Status(Status.INVALID, str(e)) + return Status(message='Table partition created successfully!') + + def delete_table_partition(self, param): + """ + Delete table partition + + :type param: DeleteTablePartitionParam + :param param: provide partition information to be deleted + + `Please use Prepare.delete_table_partition_param generate param` + + :return: Status, indicate if partition is deleted successfully + """ + try: + self.client.DeleteTablePartition(param) + except (TApplicationException, TException) as e: + LOGGER.error('{}'.format(e)) + return Status(Status.INVALID, str(e)) + return Status(message='Table partition deleted successfully!') + + def add_vector(self, table_name, records): + """ + Add vectors to table + + :param table_name: table name been inserted + :param records: List[RowRecord], list of vectors been inserted + + `Please use Prepare.row_record generate records` + + :returns: + Status : indicate if vectors inserted successfully + ids :list of id, after inserted every vector is given a id + """ + try: + ids = self.client.AddVector(table_name=table_name, record_array=records) + except (TApplicationException, TException) as e: + LOGGER.error('{}'.format(e)) + return Status(Status.INVALID, str(e)), None + return Status(message='Vector added successfully!'), ids + + def search_vector(self, table_name, query_records, top_k): + """ + Query vectors in a table + + :param table_name: str, table name been queried + :param query_records: list[QueryRecord], all vectors going to be queried + + `Please use Prepare.query_record generate QueryRecord` + + :param top_k: int, how many similar vectors will be searched + + :returns: + Status: indicate if query is successful + query_results: list[TopKQueryResult], return when operation is successful + """ + # TODO topk_query_results + try: + topk_query_results = self.client.SearchVector( + table_name=table_name, query_record_array=query_records, topk=top_k) + + except (TApplicationException, TException) as e: + LOGGER.error('{}'.format(e)) + return Status(Status.INVALID, str(e)), None + return Status(message='Success!'), topk_query_results + + def describe_table(self, table_name): + """ + Show table information + + :param table_name: str, which table to be shown + + :returns: + Status: indicate if query is successful + table_schema: TableSchema, return when operation is successful + """ + try: + thrift_table_schema = self.client.DescribeTable(table_name) + except (TApplicationException, TException) as e: + LOGGER.error('{}'.format(e)) + return Status(Status.INVALID, str(e)), None + # TODO Table Schema + return Status(message='Success!'), thrift_table_schema + + def show_tables(self): + """ + Show all tables in database + + :return: + Status: indicate if this operation is successful + tables: list[str], list of table names, return when operation + is successful + """ + try: + tables = self.client.ShowTables() + except (TApplicationException, TException) as e: + LOGGER.error('{}'.format(e)) + return Status(Status.INVALID, str(e)), None + return Status(message='Success!'), tables + + def client_version(self): + """ + Provide client version + + :return: Client version + """ + return __VERSION__ + + def server_version(self): + """ + Provide server version + + :return: Server version + """ + # TODO How to get server version + pass + + def server_status(self, cmd): + """ + Provide server status + + :return: Server status + """ + self.client.Ping(cmd) + pass diff --git a/python/sdk/client/Exceptions.py b/python/sdk/client/Exceptions.py new file mode 100644 index 0000000000..88ced39b25 --- /dev/null +++ b/python/sdk/client/Exceptions.py @@ -0,0 +1,22 @@ +class ParamError(ValueError): + pass + + +class ConnectParamMissingError(ParamError): + pass + + +class ConnectError(ValueError): + pass + + +class NotConnectError(ConnectError): + pass + + +class RepeatingConnectError(ConnectError): + pass + + +class DisconnectNotConnectedClientError(ValueError): + pass diff --git a/python/sdk/client/Status.py b/python/sdk/client/Status.py new file mode 100644 index 0000000000..7d5b606205 --- /dev/null +++ b/python/sdk/client/Status.py @@ -0,0 +1,30 @@ +class Status(object): + """ + :attribute code : int (optional) default as ok + :attribute message : str (optional) current status message + """ + OK = 0 + INVALID = 1 + UNKNOWN_ERROR = 2 + NOT_SUPPORTED = 3 + NOT_CONNECTED = 4 + + def __init__(self, code=OK, message=None): + self.code = code + self.message = message + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + """Make Status comparable with self by code""" + if isinstance(other, int): + return self.code == other + else: + return isinstance(other, self.__class__) and self.code == other.code + + def __ne__(self, other): + return not (self == other) + diff --git a/python/sdk/client/__init__.py b/python/sdk/client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/sdk/examples/__init__.py b/python/sdk/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/sdk/examples/connection_exp.py b/python/sdk/examples/connection_exp.py index a6fdce0189..dbe5949274 100644 --- a/python/sdk/examples/connection_exp.py +++ b/python/sdk/examples/connection_exp.py @@ -1,6 +1,48 @@ -from megasearch.client.Client import Connection +from client.Client import MegaSearch, Prepare, IndexType, ColumnType +from client.Status import Status -client = Connection() -# param = -# client.connect(param) \ No newline at end of file +def main(): + mega = MegaSearch() + + # Connect + param = {'host': '192.168.1.129', 'port': '33001'} + cnn_status = mega.connect(**param) + print('Connect Status: {}'.format(cnn_status)) + + is_connected = mega.connected + print('Connect status: {}'.format(is_connected)) + + # # Create table with 1 vector column, 1 attribute column and 1 partition column + # # 1. prepare table_schema + # vector_column = { + # 'name': 'fake_vec_name01', + # 'store_raw_vector': True, + # 'dimension': 10 + # } + # attribute_column = { + # 'name': 'fake_attri_name01', + # 'type': ColumnType.DATE, + # } + # + # table = { + # 'table_name': 'fake_table_name01', + # 'vector_columns': [Prepare.vector_column(**vector_column)], + # 'attribute_columns': [Prepare.column(**attribute_column)], + # 'partition_column_names': ['fake_attri_name01'] + # } + # table_schema = Prepare.table_schema(**table) + # + # # 2. Create Table + # create_status = mega.create_table(table_schema) + # print('Create table status: {}'.format(create_status)) + + mega.server_status('ok!') + + # Disconnect + discnn_status = mega.disconnect() + print('Disconnect Status{}'.format(discnn_status)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/python/sdk/pytest.ini b/python/sdk/pytest.ini new file mode 100644 index 0000000000..ffc723ad78 --- /dev/null +++ b/python/sdk/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s) + +log_cli = true +log_level = 30 \ No newline at end of file diff --git a/python/sdk/tests/TestClient.py b/python/sdk/tests/TestClient.py new file mode 100644 index 0000000000..51ce1d62a9 --- /dev/null +++ b/python/sdk/tests/TestClient.py @@ -0,0 +1,356 @@ +import logging +import pytest +import mock +import faker +import random +from faker.providers import BaseProvider + +from client.Client import MegaSearch, Prepare, IndexType, ColumnType +from client.Status import Status +from client.Exceptions import ( + RepeatingConnectError, + DisconnectNotConnectedClientError +) + +from thrift.transport.TSocket import TSocket +from megasearch.thrift import ttypes, MegasearchService + +LOGGER = logging.getLogger(__name__) + + +class FakerProvider(BaseProvider): + + def table_name(self): + return 'table_name' + str(random.randint(1000, 9999)) + + def name(self): + return 'name' + str(random.randint(1000, 9999)) + + def dim(self): + return random.randint(0, 999) + + +fake = faker.Faker() +fake.add_provider(FakerProvider) + + +def vector_column_factory(): + return { + 'name': fake.name(), + 'dimension': fake.dim(), + 'index_type': IndexType.IVFFLAT, + 'store_raw_vector': True + } + + +def column_factory(): + return { + 'name': fake.table_name(), + 'type': IndexType.RAW + } + + +def range_factory(): + return { + 'start': str(random.randint(1, 10)), + 'end': str(random.randint(11, 20)), + } + + +def table_schema_factory(): + vec_params = [vector_column_factory() for i in range(10)] + column_params = [column_factory() for i in range(5)] + param = { + 'table_name': fake.table_name(), + 'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params], + 'attribute_columns': [Prepare.column(**pa) for pa in column_params], + 'partition_column_names': [str(x) for x in range(2)] + } + return Prepare.table_schema(**param) + + +def create_table_partition_param_factory(): + param = { + 'table_name': fake.table_name(), + 'partition_name': fake.table_name(), + 'column_name_to_range': {fake.name(): range_factory() for _ in range(3)} + } + return Prepare.create_table_partition_param(**param) + + +def delete_table_partition_param_factory(): + param = { + 'table_name': fake.table_name(), + 'partition_names': [fake.name() for i in range(5)] + } + return Prepare.delete_table_partition_param(**param) + + +def row_record_factory(): + param = { + 'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]}, + 'column_name_to_attribute': {fake.name(): fake.name()} + } + return Prepare.row_record(**param) + + +class TestConnection: + param = {'host':'localhost', 'port': '5000'} + + @mock.patch.object(TSocket, 'open') + def test_true_connect(self, open): + open.return_value = None + cnn = MegaSearch() + + cnn.connect(**self.param) + assert cnn.status == Status.OK + assert cnn.connected + assert isinstance(cnn.client, MegasearchService.Client) + + with pytest.raises(RepeatingConnectError): + cnn.connect(**self.param) + cnn.connect() + + def test_false_connect(self): + cnn = MegaSearch() + + cnn.connect(self.param) + assert cnn.status != Status.OK + + def test_disconnected_error(self): + cnn = MegaSearch() + cnn.connect_status = Status(Status.INVALID) + with pytest.raises(DisconnectNotConnectedClientError): + cnn.disconnect() + + +class TestTable: + + @pytest.fixture + @mock.patch.object(TSocket, 'open') + def client(self, open): + param = {'host': 'localhost', 'port': '5000'} + open.return_value = None + + cnn = MegaSearch() + cnn.connect(**param) + return cnn + + @mock.patch.object(MegasearchService.Client, 'CreateTable') + def test_create_table(self, CreateTable, client): + CreateTable.return_value = None + + param = table_schema_factory() + res = client.create_table(param) + assert res == Status.OK + + def test_false_create_table(self, client): + param = table_schema_factory() + res = client.create_table(param) + LOGGER.error('{}'.format(res)) + assert res != Status.OK + + @mock.patch.object(MegasearchService.Client, 'DeleteTable') + def test_delete_table(self, DeleteTable, client): + DeleteTable.return_value = None + table_name = 'fake_table_name' + res = client.delete_table(table_name) + assert res == Status.OK + + def test_false_delete_table(self, client): + table_name = 'fake_table_name' + res = client.delete_table(table_name) + assert res != Status.OK + + +class TestVector: + + @pytest.fixture + @mock.patch.object(TSocket, 'open') + def client(self, open): + param = {'host': 'localhost', 'port': '5000'} + open.return_value = None + + cnn = MegaSearch() + cnn.connect(**param) + return cnn + + @mock.patch.object(MegasearchService.Client, 'CreateTablePartition') + def test_create_table_partition(self, CreateTablePartition, client): + CreateTablePartition.return_value = None + + param = create_table_partition_param_factory() + res = client.create_table_partition(param) + assert res == Status.OK + + def test_false_table_partition(self, client): + param = create_table_partition_param_factory() + res = client.create_table_partition(param) + assert res != Status.OK + + @mock.patch.object(MegasearchService.Client, 'DeleteTablePartition') + def test_delete_table_partition(self, DeleteTablePartition, client): + DeleteTablePartition.return_value = None + + param = delete_table_partition_param_factory() + res = client.delete_table_partition(param) + assert res == Status.OK + + def test_false_delete_table_partition(self, client): + param = delete_table_partition_param_factory() + res = client.delete_table_partition(param) + assert res != Status.OK + + @mock.patch.object(MegasearchService.Client, 'AddVector') + def test_add_vector(self, AddVector, client): + AddVector.return_value = None + + param ={ + 'table_name': fake.table_name(), + 'records': [row_record_factory() for _ in range(1000)] + } + res, ids = client.add_vector(**param) + assert res == Status.OK + + def test_false_add_vector(self, client): + param ={ + 'table_name': fake.table_name(), + 'records': [row_record_factory() for _ in range(1000)] + } + res, ids = client.add_vector(**param) + assert res != Status.OK + + @mock.patch.object(MegasearchService.Client, 'SearchVector') + def test_search_vector(self, SearchVector, client): + SearchVector.return_value = None + param = { + 'table_name': fake.table_name(), + 'query_records': [row_record_factory() for _ in range(1000)], + 'top_k': random.randint(0,10) + } + res, results = client.search_vector(**param) + assert res == Status.OK + + def test_false_vector(self, client): + param = { + 'table_name': fake.table_name(), + 'query_records': [row_record_factory() for _ in range(1000)], + 'top_k': random.randint(0,10) + } + res, results = client.search_vector(**param) + assert res != Status.OK + + @mock.patch.object(MegasearchService.Client, 'DescribeTable') + def test_describe_table(self, DescribeTable, client): + DescribeTable.return_value = table_schema_factory() + + table_name = fake.table_name() + res, table_schema = client.describe_table(table_name) + assert res == Status.OK + assert isinstance(table_schema, ttypes.TableSchema) + + def test_false_decribe_table(self, client): + table_name = fake.table_name() + res, table_schema = client.describe_table(table_name) + assert res != Status.OK + assert not table_schema + + @mock.patch.object(MegasearchService.Client, 'ShowTables') + def test_show_tables(self, ShowTables, client): + ShowTables.return_value = [fake.table_name() for _ in range(10)] + res, tables = client.show_tables() + assert res == Status.OK + assert isinstance(tables, list) + + def test_false_show_tables(self, client): + res, tables = client.show_tables() + assert res != Status.OK + assert not tables + + def test_client_version(self, client): + res = client.client_version() + assert res == '0.0.1' + + +class TestPrepare: + + def test_column(self): + param = { + 'name': 'test01', + 'type': ColumnType.DATE + } + res = Prepare.column(**param) + LOGGER.error('{}'.format(res)) + assert res.name == 'test01' + assert res.type == ColumnType.DATE + assert isinstance(res, ttypes.Column) + + def test_vector_column(self): + param = vector_column_factory() + + res = Prepare.vector_column(**param) + LOGGER.error('{}'.format(res)) + assert isinstance(res, ttypes.VectorColumn) + + def test_table_schema(self): + + vec_params = [vector_column_factory() for i in range(10)] + column_params = [column_factory() for i in range(5)] + + param = { + 'table_name': 'test03', + 'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params], + 'attribute_columns': [Prepare.column(**pa) for pa in column_params], + 'partition_column_names': [str(x) for x in range(2)] + } + res = Prepare.table_schema(**param) + assert isinstance(res, ttypes.TableSchema) + + def test_range(self): + param = { + 'start': '200', + 'end': '1000' + } + + res = Prepare.range(**param) + LOGGER.error('{}'.format(res)) + assert isinstance(res, ttypes.Range) + assert res.start_value == '200' + assert res.end_value == '1000' + + def test_create_table_partition_param(self): + param = { + 'table_name': fake.table_name(), + 'partition_name': fake.table_name(), + 'column_name_to_range': {fake.name(): range_factory() for _ in range(3)} + } + res = Prepare.create_table_partition_param(**param) + LOGGER.error('{}'.format(res)) + assert isinstance(res, ttypes.CreateTablePartitionParam) + + def test_delete_table_partition_param(self): + param = { + 'table_name': fake.table_name(), + 'partition_names': [fake.name() for i in range(5)] + } + res = Prepare.delete_table_partition_param(**param) + assert isinstance(res, ttypes.DeleteTablePartitionParam) + + def test_row_record(self): + param={ + 'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]}, + 'column_name_to_attribute': {fake.name(): fake.name()} + } + res = Prepare.row_record(**param) + assert isinstance(res, ttypes.RowRecord) + + def test_query_record(self): + param = { + 'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]}, + 'selected_columns': [fake.name() for _ in range(10)], + 'name_to_partition_ranges': {fake.name(): [range_factory() for _ in range(5)]} + } + res = Prepare.query_record(**param) + assert isinstance(res, ttypes.QueryRecord) + + diff --git a/python/sdk/tests/__init__.py b/python/sdk/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2