Merge branch 'branch-0.3.0' into 'branch-0.3.0'

MS-11 Implement Python SDK

See merge request megasearch/vecwise_engine!74

Former-commit-id: f4862569304e1acb07582db0ac24b8a6c514b1e4
pull/191/head
jinhai 2019-06-12 11:06:52 +08:00
commit a0280e5172
16 changed files with 1070 additions and 456 deletions

View File

@ -12,4 +12,6 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-10 - Add Python SDK APIs
- MS-11 - Implement Python SDK
### Task

View File

@ -1,433 +0,0 @@
from enum import IntEnum
from sdk.exceptions import ConnectParamMissingError
from sdk.Status import Status
class IndexType(IntEnum):
RAW = 1
IVFFLAT = 2
class ColumnType(IntEnum):
INVALID = 1
INT8 = 2
INT16 = 3
INT32 = 4
INT64 = 5
FLOAT32 = 6
FLOAT64 = 7
DATE = 8
VECTOR = 9
class ConnectParam(object):
"""
Connect API parameter
:type ip_address: str
:param ip_address: Server IP address
:type port: str,
:param port: Sever PORT
"""
def __init__(self, ip_address, port):
self.ip_address = ip_address
self.port = port
class Column(object):
"""
Table column description
:type type: ColumnType
:param type: type of the column
:type name: str
:param name: name of the column
"""
def __init__(self, name=None, type=ColumnType.INVALID):
self.type = type
self.name = name
class VectorColumn(Column):
"""
Table vector column description
:type dimension: int, int64
:param dimension: vector dimension
:type index_type: IndexType
:param index_type: IndexType
:type store_raw_vector: bool
:param store_raw_vector: Is vector self stored in the table
"""
def __init__(self, dimension=0,
index_type=IndexType.RAW,
store_raw_vector=False):
self.dimension = dimension
self.index_type = index_type
self.store_raw_vector = store_raw_vector
super(VectorColumn, self).__init__(type=ColumnType.VECTOR)
class TableSchema(object):
"""
Table Schema
:type table_name: str
:param table_name: Table name
:type vector_columns: list[VectorColumn]
:param vector_columns: vector column description
:type attribute_columns: list[Column]
:param attribute_columns: Columns description
:type partition_column_names: list[str]
:param partition_column_names: Partition column name
"""
def __init__(self, table_name, vector_columns,
attribute_columns, partition_column_names):
self.table_name = table_name
self.vector_columns = vector_columns
self.attribute_columns = attribute_columns
self.partition_column_names = partition_column_names
class Range(object):
"""
Range information
:type start: str
:param start: Range start value
:type end: str
:param end: Range end value
"""
def __init__(self, start, end):
self.start = start
self.end = end
class CreateTablePartitionParam(object):
"""
Create table partition parameters
:type table_name: str
:param table_name: Table name,
VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition
:type partition_name: str
:param partition_name: partition name, created partition name
:type column_name_to_range: dict{str : Range}
:param column_name_to_range: Column name to PartitionRange dictionary
"""
def __init__(self, table_name, partition_name, **column_name_to_range):
self.table_name = table_name
self.partition_name = partition_name
self.column_name_to_range = column_name_to_range
class DeleteTablePartitionParam(object):
"""
Delete table partition parameters
:type table_name: str
:param table_name: Table name
:type partition_names: iterable, str
:param partition_names: Partition name array
"""
def __init__(self, table_name, *partition_names):
self.table_name = table_name
self.partition_names = partition_names
class RowRecord(object):
"""
Record inserted
:type column_name_to_vector: dict{str : list[float]}
:param column_name_to_vector: Column name to vector map
:type column_name_to_value: dict{str: str}
:param column_name_to_value: Other attribute columns
"""
def __init__(self, column_name_to_vector, column_name_to_value):
self.column_name_to_vector = column_name_to_vector
self.column_name_to_value = column_name_to_value
class QueryRecord(object):
"""
Query record
:type column_name_to_vector: dict{str : list[float]}
:param column_name_to_vector: Query vectors, column name to vector map
:type selected_columns: list[str]
:param selected_columns: Output column array
:type name_to_partition_ranges: dict{str : list[Range]}
:param name_to_partition_ranges: Range used to select partitions
"""
def __init__(self, column_name_to_vector, selected_columns, **name_to_partition_ranges):
self.column_name_to_vector = column_name_to_vector
self.selected_columns = selected_columns
self.name_to_partition_ranges = name_to_partition_ranges
class QueryResult(object):
"""
Query result
:type id: int
:param id: Output result
:type score: float
:param score: Vector similarity 0 <= score <= 100
:type column_name_to_value: dict{str : str}
:param column_name_to_value: Other columns
"""
def __init__(self, id, score, **column_name_to_value):
self.id = id
self.score = score
self.column_name_to_value = column_name_to_value
class TopKQueryResult(object):
"""
TopK query results
:type query_results: list[QueryResult]
:param query_results: TopK query results
"""
def __init__(self, query_results):
self.query_results = query_results
def _abstract():
raise NotImplementedError('You need to override this function')
class Connection(object):
"""SDK client class"""
@staticmethod
def create():
"""Create a connection instance and return it
should be implemented
:return connection: Connection
"""
_abstract()
@staticmethod
def destroy(connection):
"""Destroy the connection instance
should be implemented
:type connection: Connection
:param connection: The connection instance to be destroyed
:return bool, return True if destroy is successful
"""
_abstract()
def connect(self, param=None, uri=None):
"""
Connect method should be called before any operations
Server will be connected after connect return OK
should be implemented
:type param: ConnectParam
:param param: ConnectParam
:type uri: str
:param uri: uri param
:return: Status, indicate if connect is successful
"""
if (not param and not uri) or (param and uri):
raise ConnectParamMissingError('You need to parse exact one param')
_abstract()
def connected(self):
"""
connected, connection status
should be implemented
:return: Status, indicate if connect is successful
"""
_abstract()
def disconnect(self):
"""
Disconnect, server will be disconnected after disconnect return OK
should be implemented
:return: Status, indicate if connect is successful
"""
_abstract()
def create_table(self, param):
"""
Create table
should be implemented
:type param: TableSchema
:param param: provide table information to be created
:return: Status, indicate if connect is successful
"""
_abstract()
def delete_table(self, table_name):
"""
Delete table
should be implemented
:type table_name: str
:param table_name: table_name of the deleting table
:return: Status, indicate if connect is successful
"""
_abstract()
def create_table_partition(self, param):
"""
Create table partition
should be implemented
:type param: CreateTablePartitionParam
:param param: provide partition information
:return: Status, indicate if table partition is created successfully
"""
_abstract()
def delete_table_partition(self, param):
"""
Delete table partition
should be implemented
:type param: DeleteTablePartitionParam
:param param: provide partition information to be deleted
:return: Status, indicate if partition is deleted successfully
"""
_abstract()
def add_vector(self, table_name, records, ids):
"""
Add vectors to table
should be implemented
:type table_name: str
:param table_name: table name been inserted
:type records: list[RowRecord]
:param records: list of vectors been inserted
:type ids: list[int]
:param ids: list of ids
:return: Status, indicate if vectors inserted successfully
"""
_abstract()
def search_vector(self, table_name, query_records, query_results, top_k):
"""
Query vectors in a table
should be implemented
:type table_name: str
:param table_name: table name been queried
:type query_records: list[QueryRecord]
:param query_records: all vectors going to be queried
:type query_results: list[TopKQueryResult]
:param query_results: list of results
:type top_k: int
:param top_k: how many similar vectors will be searched
:return: Status, indicate if query is successful
"""
_abstract()
def describe_table(self, table_name, table_schema):
"""
Show table information
should be implemented
:type table_name: str
:param table_name: which table to be shown
:type table_schema: TableSchema
:param table_schema: table schema is given when operation is successful
:return: Status, indicate if query is successful
"""
_abstract()
def show_tables(self, tables):
"""
Show all tables in database
should be implemented
:type tables: list[str]
:param tables: list of tables
:return: Status, indicate if this operation is successful
"""
_abstract()
def client_version(self):
"""
Provide server version
should be implemented
:return: Server version
"""
_abstract()
pass
def server_status(self):
"""
Provide server status
should be implemented
:return: Server status
"""
_abstract()
pass

View File

@ -1,21 +0,0 @@
from enum import IntEnum
class Status(IntEnum):
def __new__(cls, code, message=''):
obj = int.__new__(cls, code)
obj._code_ = code
obj.message = message
return obj
def __str__(self):
return str(self.code)
# success
OK = 200, 'OK'
INVALID = 300, 'Invalid'
UNKNOWN = 400, 'Unknown error'
NOT_SUPPORTED = 500, 'Not supported'

View File

@ -0,0 +1,298 @@
from enum import IntEnum
class IndexType(IntEnum):
INVALIDE = 0
IDMAP = 1
IVFLAT = 2
class TableSchema(object):
"""
Table Schema
:type table_name: str
:param table_name: (Required) name of table
:type index_type: IndexType
:param index_type: (Optional) index type, default = 0
`IndexType`: 0-invalid, 1-idmap, 2-ivflat
:type dimension: int64
:param dimension: (Required) dimension of vector
:type store_raw_vector: bool
:param store_raw_vector: (Optional) default = False
"""
def __init__(self, table_name,
dimension=0,
index_type=IndexType.INVALIDE,
store_raw_vector=False):
self.table_name = table_name
self.index_type = index_type
self.dimension = dimension
self.store_raw_vector = store_raw_vector
class Range(object):
"""
Range information
:type start: str
:param start: Range start value
:type end: str
:param end: Range end value
"""
def __init__(self, start, end):
self.start = start
self.end = end
class RowRecord(object):
"""
Record inserted
:type vector_data: binary str
:param vector_data: (Required) a vector
"""
def __init__(self, vector_data):
self.vector_data = vector_data
class QueryResult(object):
"""
Query result
:type id: int64
:param id: id of the vector
:type score: float
:param score: Vector similarity 0 <= score <= 100
"""
def __init__(self, id, score):
self.id = id
self.score = score
def __repr__(self):
L = ['%s=%r' % (key, value)
for key, value in self.__dict__.items()]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
class TopKQueryResult(object):
"""
TopK query results
:type query_results: list[QueryResult]
:param query_results: TopK query results
"""
def __init__(self, query_results):
self.query_results = query_results
def __repr__(self):
L = ['%s=%r' % (key, value)
for key, value in self.__dict__.items()]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def _abstract():
raise NotImplementedError('You need to override this function')
class ConnectIntf(object):
"""SDK client abstract class
Connection is a abstract class
"""
def connect(self, host=None, port=None, uri=None):
"""
Connect method should be called before any operations
Server will be connected after connect return OK
Should be implemented
:type host: str
:param host: host
:type port: str
:param port: port
:type uri: str
:param uri: (Optional) uri
:return Status, indicate if connect is successful
"""
_abstract()
def connected(self):
"""
connected, connection status
Should be implemented
:return Status, indicate if connect is successful
"""
_abstract()
def disconnect(self):
"""
Disconnect, server will be disconnected after disconnect return SUCCESS
Should be implemented
:return Status, indicate if connect is successful
"""
_abstract()
def create_table(self, param):
"""
Create table
Should be implemented
:type param: TableSchema
:param param: provide table information to be created
:return Status, indicate if connect is successful
"""
_abstract()
def delete_table(self, table_name):
"""
Delete table
Should be implemented
:type table_name: str
:param table_name: table_name of the deleting table
:return Status, indicate if connect is successful
"""
_abstract()
def add_vectors(self, table_name, records):
"""
Add vectors to table
Should be implemented
:type table_name: str
:param table_name: table name been inserted
:type records: list[RowRecord]
:param records: list of vectors been inserted
:returns
Status : indicate if vectors inserted successfully
ids :list of id, after inserted every vector is given a id
"""
_abstract()
def search_vectors(self, table_name, query_records, query_ranges, top_k):
"""
Query vectors in a table
Should be implemented
:type table_name: str
:param table_name: table name been queried
:type query_records: list[RowRecord]
:param query_records: all vectors going to be queried
:type query_ranges: list[Range]
:param query_ranges: Optional ranges for conditional search.
If not specified, search whole table
:type top_k: int
:param top_k: how many similar vectors will be searched
:returns
Status: indicate if query is successful
query_results: list[TopKQueryResult]
"""
_abstract()
def describe_table(self, table_name):
"""
Show table information
Should be implemented
:type table_name: str
:param table_name: which table to be shown
:returns
Status: indicate if query is successful
table_schema: TableSchema, given when operation is successful
"""
_abstract()
def get_table_row_count(self, table_name):
"""
Get table row count
Should be implemented
:type table_name, str
:param table_name, target table name.
:returns
Status: indicate if operation is successful
count: int, table row count
"""
_abstract()
def show_tables(self):
"""
Show all tables in database
should be implemented
:return
Status: indicate if this operation is successful
tables: list[str], list of table names
"""
_abstract()
def client_version(self):
"""
Provide client version
should be implemented
:return: str, client version
"""
_abstract()
def server_version(self):
"""
Provide server version
should be implemented
:return: str, server version
"""
_abstract()
def server_status(self, cmd):
"""
Provide server status
should be implemented
:type cmd, str
:return: str, server status
"""
_abstract()

306
python/sdk/client/Client.py Normal file
View File

@ -0,0 +1,306 @@
import logging, logging.config
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol
from thrift.Thrift import TException, TApplicationException
from megasearch.thrift import MegasearchService
from megasearch.thrift import ttypes
from client.Abstract import (
ConnectIntf,
TableSchema,
Range,
RowRecord,
QueryResult,
TopKQueryResult,
IndexType
)
from client.Status import Status
from client.Exceptions import (
RepeatingConnectError,
DisconnectNotConnectedClientError,
NotConnectError
)
LOGGER = logging.getLogger(__name__)
__VERSION__ = '0.0.1'
__NAME__ = 'Thrift_Client'
class Prepare(object):
@classmethod
def table_schema(cls,
table_name,
dimension,
index_type=IndexType.INVALIDE,
store_raw_vector = False):
"""
:param table_name: str, (Required) name of table
:param index_type: IndexType, (Required) index type, default = IndexType.INVALID
:param dimension: int64, (Optional) dimension of the table
:param store_raw_vector: bool, (Optional) default = False
:return: TableSchema
"""
temp = TableSchema(table_name,dimension, index_type, store_raw_vector)
return ttypes.TableSchema(table_name=temp.table_name,
dimension=dimension,
index_type=index_type,
store_raw_vector=store_raw_vector)
@classmethod
def range(cls, start, end):
"""
:param start: str, (Required) range start
:param end: str (Required) range end
:return Range
"""
temp = Range(start=start, end=end)
return ttypes.Range(start_value=temp.start, end_value=temp.end)
@classmethod
def row_record(cls, vector_data):
"""
Record inserted
:param vector_data: float binary str, (Required) a binary str
"""
temp = RowRecord(vector_data)
return ttypes.RowRecord(vector_data=temp.vector_data)
class MegaSearch(ConnectIntf):
def __init__(self):
self.status = None
self._transport = None
self._client = None
def __repr__(self):
return '{}'.format(self.status)
def connect(self, host='localhost', port='9090', uri=None):
# TODO URI
if self.status and self.status == Status.SUCCESS:
raise RepeatingConnectError("You have already connected!")
transport = TSocket.TSocket(host=host, port=port)
self._transport = TTransport.TBufferedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
self._client = MegasearchService.Client(protocol)
try:
transport.open()
self.status = Status(Status.SUCCESS, 'Connected')
LOGGER.info('Connected!')
except (TTransport.TTransportException, TException) as e:
self.status = Status(Status.CONNECT_FAILED, message=str(e))
LOGGER.error('logger.error: {}'.format(self.status))
finally:
return self.status
@property
def connected(self):
return self.status == Status.SUCCESS
def disconnect(self):
if not self._transport:
raise DisconnectNotConnectedClientError('Error')
try:
self._transport.close()
LOGGER.info('Client Disconnected!')
self.status = None
except TException as e:
return Status(Status.PERMISSION_DENIED, str(e))
return Status(Status.SUCCESS, 'Disconnected')
def create_table(self, param):
"""Create table
:param param: Provide table information to be created,
`Please use Prepare.table_schema generate param`
:return: Status, indicate if operation is successful
"""
if not self._client:
raise NotConnectError('Please Connect to the server first!')
try:
self._client.CreateTable(param)
except (TApplicationException, ) as e:
LOGGER.error('Unable to create table')
return Status(Status.PERMISSION_DENIED, str(e))
return Status(message='Table {} created!'.format(param.table_name))
def delete_table(self, table_name):
"""Delete table
:param table_name: Name of the table being deleted
:return: Status, indicate if operation is successful
"""
try:
self._client.DeleteTable(table_name)
except (TApplicationException, TException) as e:
LOGGER.error('Unable to delete table {}'.format(table_name))
return Status(Status.PERMISSION_DENIED, str(e))
return Status(message='Table {} deleted!'.format(table_name))
def add_vectors(self, table_name, records):
"""
Add vectors to table
:param table_name: table name been inserted
:param records: List[RowRecord], list of vectors been inserted
`Please use Prepare.row_record generate records`
:returns:
Status : indicate if vectors inserted successfully
ids :list of id, after inserted every vector is given a id
"""
try:
ids = self._client.AddVector(table_name=table_name, record_array=records)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Vectors added successfully!'), ids
def search_vectors(self, table_name, top_k, query_records, query_ranges=None):
"""
Query vectors in a table
:param table_name: str, table name been queried
:param query_records: list[QueryRecord], all vectors going to be queried
`Please use Prepare.query_record generate QueryRecord`
:param top_k: int, how many similar vectors will be searched
:param query_ranges, (Optional) list[Range], search range
:returns:
Status: indicate if query is successful
res: list[TopKQueryResult], return when operation is successful
"""
res = []
try:
top_k_query_results = self._client.SearchVector(
table_name=table_name,
query_record_array=query_records,
query_range_array=query_ranges,
topk=top_k)
if top_k_query_results:
for top_k in top_k_query_results:
if top_k:
res.append(TopKQueryResult([QueryResult(qr.id, qr.score)
for qr in top_k.query_result_arrays]))
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Success!'), res
def describe_table(self, table_name):
"""
Show table information
:param table_name: str, which table to be shown
:returns:
Status: indicate if query is successful
table_schema: TableSchema, return when operation is successful
"""
try:
temp = self._client.DescribeTable(table_name)
# res = TableSchema(table_name=temp.table_name, dimension=temp.dimension,
# index_type=temp.index_type, store_raw_vector=temp.store_raw_vector)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Success!'), temp
def show_tables(self):
"""
Show all tables in database
:return:
Status: indicate if this operation is successful
tables: list[str], list of table names, return when operation
is successful
"""
try:
res = self._client.ShowTables()
tables = []
if res:
tables, _ = res
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Success!'), tables
def get_table_row_count(self, table_name):
"""
Get table row count
:type table_name, str
:param table_name, target table name.
:returns:
Status: indicate if operation is successful
res: int, table row count
"""
try:
count, _ = self._client.GetTableRowCount(table_name)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Success'), count
def client_version(self):
"""
Provide client version
:return: Client version
"""
return __VERSION__
def server_version(self):
"""
Provide server version
:return: Server version
"""
if not self.connected:
raise NotConnectError('You have to connect first')
return self._client.Ping('version')
def server_status(self, cmd=None):
"""
Provide server status
:return: Server status
"""
if not self.connected:
raise NotConnectError('You have to connect first')
return self._client.Ping(cmd)

View File

@ -0,0 +1,18 @@
class ParamError(ValueError):
pass
class ConnectError(ValueError):
pass
class NotConnectError(ConnectError):
pass
class RepeatingConnectError(ConnectError):
pass
class DisconnectNotConnectedClientError(ValueError):
pass

View File

@ -0,0 +1,32 @@
class Status(object):
"""
:attribute code : int (optional) default as ok
:attribute message : str (optional) current status message
"""
SUCCESS = 0
CONNECT_FAILED = 1
PERMISSION_DENIED = 2
TABLE_NOT_EXISTS = 3
ILLEGAL_ARGUMENT = 4
ILLEGAL_RANGE = 5
ILLEGAL_DIMENSION = 6
def __init__(self, code=SUCCESS, message=None):
self.code = code
self.message = message
def __repr__(self):
L = ['%s=%r' % (key, value)
for key, value in self.__dict__.items()]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def __eq__(self, other):
"""Make Status comparable with self by code"""
if isinstance(other, int):
return self.code == other
else:
return isinstance(other, self.__class__) and self.code == other.code
def __ne__(self, other):
return not (self == other)

View File

View File

@ -0,0 +1,89 @@
from client.Client import MegaSearch, Prepare, IndexType
import random
import struct
from pprint import pprint
def main():
# Get client version
mega = MegaSearch()
print('# Client version: {}'.format(mega.client_version()))
# Connect
# Please change HOST and PORT to correct one
param = {'host': 'HOST', 'port': 'PORT'}
cnn_status = mega.connect(**param)
print('# Connect Status: {}'.format(cnn_status))
# Check if connected
is_connected = mega.connected
print('# Is connected: {}'.format(is_connected))
# Get server version
print('# Server version: {}'.format(mega.server_version()))
# Show tables and their description
status, tables = mega.show_tables()
print('# Show tables: {}'.format(tables))
# Create table
# 01.Prepare data
param = {
'table_name': 'test'+ str(random.randint(0,999)),
'dimension': 256,
'index_type': IndexType.IDMAP,
'store_raw_vector': False
}
# 02.Create table
res_status = mega.create_table(Prepare.table_schema(**param))
print('# Create table status: {}'.format(res_status))
# Describe table
table_name = 'test01'
res_status, table = mega.describe_table(table_name)
print('# Describe table status: {}'.format(res_status))
print('# Describe table:{}'.format(table))
# Add vectors to table 'test01'
# 01. Prepare data
dim = 256
# list of binary vectors
vectors = [Prepare.row_record(struct.pack(str(dim)+'d',
*[random.random()for _ in range(dim)]))
for _ in range(20)]
# 02. Add vectors
status, ids = mega.add_vectors(table_name=table_name, records=vectors)
print('# Add vector status: {}'.format(status))
pprint(ids)
# Search vectors
q_records = [Prepare.row_record(struct.pack(str(dim) + 'd',
*[random.random() for _ in range(dim)]))
for _ in range(5)]
param = {
'table_name': 'test01',
'query_records': q_records,
'top_k': 10,
# 'query_ranges': None # Optional
}
sta, results = mega.search_vectors(**param)
print('# Search vectors status: {}'.format(sta))
pprint(results)
# Get table row count
sta, result = mega.get_table_row_count(table_name)
print('# Status: {}'.format(sta))
print('# Count: {}'.format(result))
# Delete table 'test01'
res_status = mega.delete_table(table_name)
print('# Delete table status: {}'.format(res_status))
# Disconnect
discnn_status = mega.disconnect()
print('# Disconnect Status: {}'.format(discnn_status))
if __name__ == '__main__':
main()

View File

@ -1,2 +0,0 @@
class ConnectParamMissingError(ValueError):
pass

5
python/sdk/pytest.ini Normal file
View File

@ -0,0 +1,5 @@
[pytest]
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
log_cli = true
log_level = 20

View File

@ -0,0 +1,18 @@
atomicwrites==1.3.0
attrs==19.1.0
Faker==1.0.7
importlib-metadata==0.17
mock==3.0.5
more-itertools==7.0.0
packaging==19.0
pathlib2==2.3.3
pluggy==0.12.0
py==1.8.0
pyparsing==2.4.0
pytest==4.6.0
python-dateutil==2.8.0
six==1.12.0
text-unidecode==1.2
thrift==0.11.0
wcwidth==0.1.7
zipp==0.5.1

21
python/sdk/setup.py Normal file
View File

@ -0,0 +1,21 @@
import setuptools
long_description = ''
setuptools.setup(
name="MegaSearch",
version="0.0.1",
author="XuanYang",
author_email="xuan.yang@zilliz.com",
description="Sdk for using MegaSearch",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Operating System :: OS Independent",
],
python_requires='>=3.4'
)

View File

@ -0,0 +1,281 @@
import logging
import pytest
import mock
import faker
import random
import struct
from faker.providers import BaseProvider
from client.Client import MegaSearch, Prepare
from client.Abstract import IndexType, TableSchema
from client.Status import Status
from client.Exceptions import (
RepeatingConnectError,
DisconnectNotConnectedClientError
)
from megasearch.thrift import ttypes, MegasearchService
from thrift.transport.TSocket import TSocket
from thrift.transport import TTransport
from thrift.transport.TTransport import TTransportException
LOGGER = logging.getLogger(__name__)
class FakerProvider(BaseProvider):
def table_name(self):
return 'table_name' + str(random.randint(1000, 9999))
def name(self):
return 'name' + str(random.randint(1000, 9999))
def dim(self):
return random.randint(0, 999)
fake = faker.Faker()
fake.add_provider(FakerProvider)
def range_factory():
param = {
'start': str(random.randint(1, 10)),
'end': str(random.randint(11, 20)),
}
return Prepare.range(**param)
def ranges_factory():
return [range_factory() for _ in range(5)]
def table_schema_factory():
param = {
'table_name': fake.table_name(),
'dimension': random.randint(0, 999),
'index_type': IndexType.IDMAP,
'store_raw_vector': False
}
return Prepare.table_schema(**param)
def row_record_factory(dimension):
vec = [random.random() + random.randint(0,9) for _ in range(dimension)]
bin_vec = struct.pack(str(dimension) + "d", *vec)
return Prepare.row_record(vector_data=bin_vec)
def row_records_factory(dimension):
return [row_record_factory(dimension) for _ in range(20)]
class TestConnection:
param = {'host':'localhost', 'port': '5000'}
@mock.patch.object(TSocket, 'open')
def test_true_connect(self, open):
open.return_value = None
cnn = MegaSearch()
cnn.connect(**self.param)
assert cnn.status == Status.SUCCESS
assert cnn.connected
with pytest.raises(RepeatingConnectError):
cnn.connect(**self.param)
cnn.connect()
def test_false_connect(self):
cnn = MegaSearch()
cnn.connect(**self.param)
assert cnn.status != Status.SUCCESS
@mock.patch.object(TTransport.TBufferedTransport, 'close')
@mock.patch.object(TSocket, 'open')
def test_disconnected(self, close, open):
close.return_value = None
open.return_value = None
cnn = MegaSearch()
cnn.connect(**self.param)
assert cnn.disconnect() == Status.SUCCESS
def test_disconnected_error(self):
cnn = MegaSearch()
cnn.connect_status = Status(Status.PERMISSION_DENIED)
with pytest.raises(DisconnectNotConnectedClientError):
cnn.disconnect()
class TestTable:
@pytest.fixture
@mock.patch.object(TSocket, 'open')
def client(self, open):
param = {'host': 'localhost', 'port': '5000'}
open.return_value = None
cnn = MegaSearch()
cnn.connect(**param)
return cnn
@mock.patch.object(MegasearchService.Client, 'CreateTable')
def test_create_table(self, CreateTable, client):
CreateTable.return_value = None
param = table_schema_factory()
res = client.create_table(param)
assert res == Status.SUCCESS
def test_false_create_table(self, client):
param = table_schema_factory()
with pytest.raises(TTransportException):
res = client.create_table(param)
LOGGER.error('{}'.format(res))
assert res != Status.SUCCESS
@mock.patch.object(MegasearchService.Client, 'DeleteTable')
def test_delete_table(self, DeleteTable, client):
DeleteTable.return_value = None
table_name = 'fake_table_name'
res = client.delete_table(table_name)
assert res == Status.SUCCESS
def test_false_delete_table(self, client):
table_name = 'fake_table_name'
res = client.delete_table(table_name)
assert res != Status.SUCCESS
class TestVector:
@pytest.fixture
@mock.patch.object(TSocket, 'open')
def client(self, open):
param = {'host': 'localhost', 'port': '5000'}
open.return_value = None
cnn = MegaSearch()
cnn.connect(**param)
return cnn
@mock.patch.object(MegasearchService.Client, 'AddVector')
def test_add_vector(self, AddVector, client):
AddVector.return_value = None
param ={
'table_name': fake.table_name(),
'records': row_records_factory(256)
}
res, ids = client.add_vectors(**param)
assert res == Status.SUCCESS
def test_false_add_vector(self, client):
param ={
'table_name': fake.table_name(),
'records': row_records_factory(256)
}
res, ids = client.add_vectors(**param)
assert res != Status.SUCCESS
@mock.patch.object(MegasearchService.Client, 'SearchVector')
def test_search_vector(self, SearchVector, client):
SearchVector.return_value = None, None
param = {
'table_name': fake.table_name(),
'query_records': row_records_factory(256),
'query_ranges': ranges_factory(),
'top_k': random.randint(0, 10)
}
res, results = client.search_vectors(**param)
assert res == Status.SUCCESS
def test_false_vector(self, client):
param = {
'table_name': fake.table_name(),
'query_records': row_records_factory(256),
'query_ranges': ranges_factory(),
'top_k': random.randint(0, 10)
}
res, results = client.search_vectors(**param)
assert res != Status.SUCCESS
@mock.patch.object(MegasearchService.Client, 'DescribeTable')
def test_describe_table(self, DescribeTable, client):
DescribeTable.return_value = table_schema_factory()
table_name = fake.table_name()
res, table_schema = client.describe_table(table_name)
assert res == Status.SUCCESS
assert isinstance(table_schema, ttypes.TableSchema)
def test_false_decribe_table(self, client):
table_name = fake.table_name()
res, table_schema = client.describe_table(table_name)
assert res != Status.SUCCESS
assert not table_schema
@mock.patch.object(MegasearchService.Client, 'ShowTables')
def test_show_tables(self, ShowTables, client):
ShowTables.return_value = [fake.table_name() for _ in range(10)], None
res, tables = client.show_tables()
assert res == Status.SUCCESS
assert isinstance(tables, list)
def test_false_show_tables(self, client):
res, tables = client.show_tables()
assert res != Status.SUCCESS
assert not tables
@mock.patch.object(MegasearchService.Client, 'GetTableRowCount')
def test_get_table_row_count(self, GetTableRowCount, client):
GetTableRowCount.return_value = 22, None
res, count = client.get_table_row_count('fake_table')
assert res == Status.SUCCESS
def test_false_get_table_row_count(self, client):
res,count = client.get_table_row_count('fake_table')
assert res != Status.SUCCESS
assert not count
def test_client_version(self, client):
res = client.client_version()
assert res == '0.0.1'
class TestPrepare:
def test_table_schema(self):
param = {
'table_name': fake.table_name(),
'dimension': random.randint(0, 999),
'index_type': IndexType.IDMAP,
'store_raw_vector': False
}
res = Prepare.table_schema(**param)
assert isinstance(res, ttypes.TableSchema)
def test_range(self):
param = {
'start': '200',
'end': '1000'
}
res = Prepare.range(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.Range)
assert res.start_value == '200'
assert res.end_value == '1000'
def test_row_record(self):
vec = [random.random() + random.randint(0, 9) for _ in range(256)]
bin_vec = struct.pack(str(256) + "d", *vec)
res = Prepare.row_record(bin_vec)
assert isinstance(res, ttypes.RowRecord)
assert isinstance(bin_vec, bytes)

View File