mirror of https://github.com/milvus-io/milvus.git
Merge branch 'branch-0.3.0' into 'branch-0.3.0'
MS-11 Implement Python SDK See merge request megasearch/vecwise_engine!74 Former-commit-id: f4862569304e1acb07582db0ac24b8a6c514b1e4pull/191/head
commit
a0280e5172
|
@ -12,4 +12,6 @@ Please mark all change in change log and use the ticket from JIRA.
|
|||
|
||||
- MS-10 - Add Python SDK APIs
|
||||
|
||||
- MS-11 - Implement Python SDK
|
||||
|
||||
### Task
|
||||
|
|
|
@ -1,433 +0,0 @@
|
|||
from enum import IntEnum
|
||||
from sdk.exceptions import ConnectParamMissingError
|
||||
from sdk.Status import Status
|
||||
|
||||
|
||||
class IndexType(IntEnum):
|
||||
RAW = 1
|
||||
IVFFLAT = 2
|
||||
|
||||
|
||||
class ColumnType(IntEnum):
|
||||
INVALID = 1
|
||||
INT8 = 2
|
||||
INT16 = 3
|
||||
INT32 = 4
|
||||
INT64 = 5
|
||||
FLOAT32 = 6
|
||||
FLOAT64 = 7
|
||||
DATE = 8
|
||||
VECTOR = 9
|
||||
|
||||
|
||||
class ConnectParam(object):
|
||||
"""
|
||||
Connect API parameter
|
||||
|
||||
:type ip_address: str
|
||||
:param ip_address: Server IP address
|
||||
|
||||
:type port: str,
|
||||
:param port: Sever PORT
|
||||
|
||||
"""
|
||||
def __init__(self, ip_address, port):
|
||||
|
||||
self.ip_address = ip_address
|
||||
self.port = port
|
||||
|
||||
|
||||
class Column(object):
|
||||
"""
|
||||
Table column description
|
||||
|
||||
:type type: ColumnType
|
||||
:param type: type of the column
|
||||
|
||||
:type name: str
|
||||
:param name: name of the column
|
||||
|
||||
"""
|
||||
def __init__(self, name=None, type=ColumnType.INVALID):
|
||||
self.type = type
|
||||
self.name = name
|
||||
|
||||
|
||||
class VectorColumn(Column):
|
||||
"""
|
||||
Table vector column description
|
||||
|
||||
:type dimension: int, int64
|
||||
:param dimension: vector dimension
|
||||
|
||||
:type index_type: IndexType
|
||||
:param index_type: IndexType
|
||||
|
||||
:type store_raw_vector: bool
|
||||
:param store_raw_vector: Is vector self stored in the table
|
||||
|
||||
"""
|
||||
def __init__(self, dimension=0,
|
||||
index_type=IndexType.RAW,
|
||||
store_raw_vector=False):
|
||||
self.dimension = dimension
|
||||
self.index_type = index_type
|
||||
self.store_raw_vector = store_raw_vector
|
||||
super(VectorColumn, self).__init__(type=ColumnType.VECTOR)
|
||||
|
||||
|
||||
class TableSchema(object):
|
||||
"""
|
||||
Table Schema
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: Table name
|
||||
|
||||
:type vector_columns: list[VectorColumn]
|
||||
:param vector_columns: vector column description
|
||||
|
||||
:type attribute_columns: list[Column]
|
||||
:param attribute_columns: Columns description
|
||||
|
||||
:type partition_column_names: list[str]
|
||||
:param partition_column_names: Partition column name
|
||||
|
||||
"""
|
||||
def __init__(self, table_name, vector_columns,
|
||||
attribute_columns, partition_column_names):
|
||||
self.table_name = table_name
|
||||
self.vector_columns = vector_columns
|
||||
self.attribute_columns = attribute_columns
|
||||
self.partition_column_names = partition_column_names
|
||||
|
||||
|
||||
class Range(object):
|
||||
"""
|
||||
Range information
|
||||
|
||||
:type start: str
|
||||
:param start: Range start value
|
||||
|
||||
:type end: str
|
||||
:param end: Range end value
|
||||
|
||||
"""
|
||||
def __init__(self, start, end):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
|
||||
class CreateTablePartitionParam(object):
|
||||
"""
|
||||
Create table partition parameters
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: Table name,
|
||||
VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition
|
||||
|
||||
:type partition_name: str
|
||||
:param partition_name: partition name, created partition name
|
||||
|
||||
:type column_name_to_range: dict{str : Range}
|
||||
:param column_name_to_range: Column name to PartitionRange dictionary
|
||||
"""
|
||||
def __init__(self, table_name, partition_name, **column_name_to_range):
|
||||
self.table_name = table_name
|
||||
self.partition_name = partition_name
|
||||
self.column_name_to_range = column_name_to_range
|
||||
|
||||
|
||||
class DeleteTablePartitionParam(object):
|
||||
"""
|
||||
Delete table partition parameters
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: Table name
|
||||
|
||||
:type partition_names: iterable, str
|
||||
:param partition_names: Partition name array
|
||||
|
||||
"""
|
||||
def __init__(self, table_name, *partition_names):
|
||||
self.table_name = table_name
|
||||
self.partition_names = partition_names
|
||||
|
||||
|
||||
class RowRecord(object):
|
||||
"""
|
||||
Record inserted
|
||||
|
||||
:type column_name_to_vector: dict{str : list[float]}
|
||||
:param column_name_to_vector: Column name to vector map
|
||||
|
||||
:type column_name_to_value: dict{str: str}
|
||||
:param column_name_to_value: Other attribute columns
|
||||
"""
|
||||
def __init__(self, column_name_to_vector, column_name_to_value):
|
||||
self.column_name_to_vector = column_name_to_vector
|
||||
self.column_name_to_value = column_name_to_value
|
||||
|
||||
|
||||
class QueryRecord(object):
|
||||
"""
|
||||
Query record
|
||||
|
||||
:type column_name_to_vector: dict{str : list[float]}
|
||||
:param column_name_to_vector: Query vectors, column name to vector map
|
||||
|
||||
:type selected_columns: list[str]
|
||||
:param selected_columns: Output column array
|
||||
|
||||
:type name_to_partition_ranges: dict{str : list[Range]}
|
||||
:param name_to_partition_ranges: Range used to select partitions
|
||||
|
||||
"""
|
||||
def __init__(self, column_name_to_vector, selected_columns, **name_to_partition_ranges):
|
||||
self.column_name_to_vector = column_name_to_vector
|
||||
self.selected_columns = selected_columns
|
||||
self.name_to_partition_ranges = name_to_partition_ranges
|
||||
|
||||
|
||||
class QueryResult(object):
|
||||
"""
|
||||
Query result
|
||||
|
||||
:type id: int
|
||||
:param id: Output result
|
||||
|
||||
:type score: float
|
||||
:param score: Vector similarity 0 <= score <= 100
|
||||
|
||||
:type column_name_to_value: dict{str : str}
|
||||
:param column_name_to_value: Other columns
|
||||
|
||||
"""
|
||||
def __init__(self, id, score, **column_name_to_value):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.column_name_to_value = column_name_to_value
|
||||
|
||||
|
||||
class TopKQueryResult(object):
|
||||
"""
|
||||
TopK query results
|
||||
|
||||
:type query_results: list[QueryResult]
|
||||
:param query_results: TopK query results
|
||||
|
||||
"""
|
||||
def __init__(self, query_results):
|
||||
self.query_results = query_results
|
||||
|
||||
|
||||
def _abstract():
|
||||
raise NotImplementedError('You need to override this function')
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""SDK client class"""
|
||||
|
||||
@staticmethod
|
||||
def create():
|
||||
"""Create a connection instance and return it
|
||||
should be implemented
|
||||
|
||||
:return connection: Connection
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
@staticmethod
|
||||
def destroy(connection):
|
||||
"""Destroy the connection instance
|
||||
should be implemented
|
||||
|
||||
:type connection: Connection
|
||||
:param connection: The connection instance to be destroyed
|
||||
|
||||
:return bool, return True if destroy is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def connect(self, param=None, uri=None):
|
||||
"""
|
||||
Connect method should be called before any operations
|
||||
Server will be connected after connect return OK
|
||||
should be implemented
|
||||
|
||||
:type param: ConnectParam
|
||||
:param param: ConnectParam
|
||||
|
||||
:type uri: str
|
||||
:param uri: uri param
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
"""
|
||||
if (not param and not uri) or (param and uri):
|
||||
raise ConnectParamMissingError('You need to parse exact one param')
|
||||
_abstract()
|
||||
|
||||
def connected(self):
|
||||
"""
|
||||
connected, connection status
|
||||
should be implemented
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Disconnect, server will be disconnected after disconnect return OK
|
||||
should be implemented
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def create_table(self, param):
|
||||
"""
|
||||
Create table
|
||||
should be implemented
|
||||
|
||||
:type param: TableSchema
|
||||
:param param: provide table information to be created
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def delete_table(self, table_name):
|
||||
"""
|
||||
Delete table
|
||||
should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table_name of the deleting table
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def create_table_partition(self, param):
|
||||
"""
|
||||
Create table partition
|
||||
should be implemented
|
||||
|
||||
:type param: CreateTablePartitionParam
|
||||
:param param: provide partition information
|
||||
|
||||
:return: Status, indicate if table partition is created successfully
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def delete_table_partition(self, param):
|
||||
"""
|
||||
Delete table partition
|
||||
should be implemented
|
||||
|
||||
:type param: DeleteTablePartitionParam
|
||||
:param param: provide partition information to be deleted
|
||||
:return: Status, indicate if partition is deleted successfully
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def add_vector(self, table_name, records, ids):
|
||||
"""
|
||||
Add vectors to table
|
||||
should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table name been inserted
|
||||
|
||||
:type records: list[RowRecord]
|
||||
:param records: list of vectors been inserted
|
||||
|
||||
:type ids: list[int]
|
||||
:param ids: list of ids
|
||||
|
||||
:return: Status, indicate if vectors inserted successfully
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def search_vector(self, table_name, query_records, query_results, top_k):
|
||||
"""
|
||||
Query vectors in a table
|
||||
should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table name been queried
|
||||
|
||||
:type query_records: list[QueryRecord]
|
||||
:param query_records: all vectors going to be queried
|
||||
|
||||
:type query_results: list[TopKQueryResult]
|
||||
:param query_results: list of results
|
||||
|
||||
:type top_k: int
|
||||
:param top_k: how many similar vectors will be searched
|
||||
|
||||
:return: Status, indicate if query is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def describe_table(self, table_name, table_schema):
|
||||
"""
|
||||
Show table information
|
||||
should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: which table to be shown
|
||||
|
||||
:type table_schema: TableSchema
|
||||
:param table_schema: table schema is given when operation is successful
|
||||
|
||||
:return: Status, indicate if query is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def show_tables(self, tables):
|
||||
"""
|
||||
Show all tables in database
|
||||
should be implemented
|
||||
|
||||
:type tables: list[str]
|
||||
:param tables: list of tables
|
||||
|
||||
:return: Status, indicate if this operation is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def client_version(self):
|
||||
"""
|
||||
Provide server version
|
||||
should be implemented
|
||||
|
||||
:return: Server version
|
||||
"""
|
||||
_abstract()
|
||||
pass
|
||||
|
||||
def server_status(self):
|
||||
"""
|
||||
Provide server status
|
||||
should be implemented
|
||||
|
||||
:return: Server status
|
||||
"""
|
||||
_abstract()
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
from enum import IntEnum
|
||||
|
||||
|
||||
class Status(IntEnum):
|
||||
|
||||
def __new__(cls, code, message=''):
|
||||
obj = int.__new__(cls, code)
|
||||
obj._code_ = code
|
||||
|
||||
obj.message = message
|
||||
return obj
|
||||
|
||||
def __str__(self):
|
||||
return str(self.code)
|
||||
|
||||
# success
|
||||
OK = 200, 'OK'
|
||||
|
||||
INVALID = 300, 'Invalid'
|
||||
UNKNOWN = 400, 'Unknown error'
|
||||
NOT_SUPPORTED = 500, 'Not supported'
|
|
@ -0,0 +1,298 @@
|
|||
from enum import IntEnum
|
||||
|
||||
|
||||
class IndexType(IntEnum):
|
||||
INVALIDE = 0
|
||||
IDMAP = 1
|
||||
IVFLAT = 2
|
||||
|
||||
|
||||
class TableSchema(object):
|
||||
"""
|
||||
Table Schema
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: (Required) name of table
|
||||
|
||||
:type index_type: IndexType
|
||||
:param index_type: (Optional) index type, default = 0
|
||||
|
||||
`IndexType`: 0-invalid, 1-idmap, 2-ivflat
|
||||
|
||||
:type dimension: int64
|
||||
:param dimension: (Required) dimension of vector
|
||||
|
||||
:type store_raw_vector: bool
|
||||
:param store_raw_vector: (Optional) default = False
|
||||
|
||||
"""
|
||||
def __init__(self, table_name,
|
||||
dimension=0,
|
||||
index_type=IndexType.INVALIDE,
|
||||
store_raw_vector=False):
|
||||
self.table_name = table_name
|
||||
self.index_type = index_type
|
||||
self.dimension = dimension
|
||||
self.store_raw_vector = store_raw_vector
|
||||
|
||||
|
||||
class Range(object):
|
||||
"""
|
||||
Range information
|
||||
|
||||
:type start: str
|
||||
:param start: Range start value
|
||||
|
||||
:type end: str
|
||||
:param end: Range end value
|
||||
|
||||
"""
|
||||
def __init__(self, start, end):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
|
||||
class RowRecord(object):
|
||||
"""
|
||||
Record inserted
|
||||
|
||||
:type vector_data: binary str
|
||||
:param vector_data: (Required) a vector
|
||||
|
||||
"""
|
||||
def __init__(self, vector_data):
|
||||
self.vector_data = vector_data
|
||||
|
||||
|
||||
class QueryResult(object):
|
||||
"""
|
||||
Query result
|
||||
|
||||
:type id: int64
|
||||
:param id: id of the vector
|
||||
|
||||
:type score: float
|
||||
:param score: Vector similarity 0 <= score <= 100
|
||||
|
||||
"""
|
||||
def __init__(self, id, score):
|
||||
self.id = id
|
||||
self.score = score
|
||||
|
||||
def __repr__(self):
|
||||
L = ['%s=%r' % (key, value)
|
||||
for key, value in self.__dict__.items()]
|
||||
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
||||
|
||||
|
||||
class TopKQueryResult(object):
|
||||
"""
|
||||
TopK query results
|
||||
|
||||
:type query_results: list[QueryResult]
|
||||
:param query_results: TopK query results
|
||||
|
||||
"""
|
||||
def __init__(self, query_results):
|
||||
self.query_results = query_results
|
||||
|
||||
def __repr__(self):
|
||||
L = ['%s=%r' % (key, value)
|
||||
for key, value in self.__dict__.items()]
|
||||
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
||||
|
||||
|
||||
def _abstract():
|
||||
raise NotImplementedError('You need to override this function')
|
||||
|
||||
|
||||
class ConnectIntf(object):
|
||||
"""SDK client abstract class
|
||||
|
||||
Connection is a abstract class
|
||||
|
||||
"""
|
||||
|
||||
def connect(self, host=None, port=None, uri=None):
|
||||
"""
|
||||
Connect method should be called before any operations
|
||||
Server will be connected after connect return OK
|
||||
Should be implemented
|
||||
|
||||
:type host: str
|
||||
:param host: host
|
||||
|
||||
:type port: str
|
||||
:param port: port
|
||||
|
||||
:type uri: str
|
||||
:param uri: (Optional) uri
|
||||
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def connected(self):
|
||||
"""
|
||||
connected, connection status
|
||||
Should be implemented
|
||||
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Disconnect, server will be disconnected after disconnect return SUCCESS
|
||||
Should be implemented
|
||||
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def create_table(self, param):
|
||||
"""
|
||||
Create table
|
||||
Should be implemented
|
||||
|
||||
:type param: TableSchema
|
||||
:param param: provide table information to be created
|
||||
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def delete_table(self, table_name):
|
||||
"""
|
||||
Delete table
|
||||
Should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table_name of the deleting table
|
||||
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def add_vectors(self, table_name, records):
|
||||
"""
|
||||
Add vectors to table
|
||||
Should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table name been inserted
|
||||
|
||||
:type records: list[RowRecord]
|
||||
:param records: list of vectors been inserted
|
||||
|
||||
:returns
|
||||
Status : indicate if vectors inserted successfully
|
||||
ids :list of id, after inserted every vector is given a id
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def search_vectors(self, table_name, query_records, query_ranges, top_k):
|
||||
"""
|
||||
Query vectors in a table
|
||||
Should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table name been queried
|
||||
|
||||
:type query_records: list[RowRecord]
|
||||
:param query_records: all vectors going to be queried
|
||||
|
||||
:type query_ranges: list[Range]
|
||||
:param query_ranges: Optional ranges for conditional search.
|
||||
If not specified, search whole table
|
||||
|
||||
:type top_k: int
|
||||
:param top_k: how many similar vectors will be searched
|
||||
|
||||
:returns
|
||||
Status: indicate if query is successful
|
||||
query_results: list[TopKQueryResult]
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def describe_table(self, table_name):
|
||||
"""
|
||||
Show table information
|
||||
Should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: which table to be shown
|
||||
|
||||
:returns
|
||||
Status: indicate if query is successful
|
||||
table_schema: TableSchema, given when operation is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def get_table_row_count(self, table_name):
|
||||
"""
|
||||
Get table row count
|
||||
Should be implemented
|
||||
|
||||
:type table_name, str
|
||||
:param table_name, target table name.
|
||||
|
||||
:returns
|
||||
Status: indicate if operation is successful
|
||||
count: int, table row count
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def show_tables(self):
|
||||
"""
|
||||
Show all tables in database
|
||||
should be implemented
|
||||
|
||||
:return
|
||||
Status: indicate if this operation is successful
|
||||
tables: list[str], list of table names
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def client_version(self):
|
||||
"""
|
||||
Provide client version
|
||||
should be implemented
|
||||
|
||||
:return: str, client version
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def server_version(self):
|
||||
"""
|
||||
Provide server version
|
||||
should be implemented
|
||||
|
||||
:return: str, server version
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def server_status(self, cmd):
|
||||
"""
|
||||
Provide server status
|
||||
should be implemented
|
||||
:type cmd, str
|
||||
|
||||
:return: str, server status
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,306 @@
|
|||
import logging, logging.config
|
||||
|
||||
from thrift.transport import TSocket
|
||||
from thrift.transport import TTransport
|
||||
from thrift.protocol import TBinaryProtocol
|
||||
from thrift.Thrift import TException, TApplicationException
|
||||
|
||||
from megasearch.thrift import MegasearchService
|
||||
from megasearch.thrift import ttypes
|
||||
from client.Abstract import (
|
||||
ConnectIntf,
|
||||
TableSchema,
|
||||
Range,
|
||||
RowRecord,
|
||||
QueryResult,
|
||||
TopKQueryResult,
|
||||
IndexType
|
||||
)
|
||||
|
||||
from client.Status import Status
|
||||
from client.Exceptions import (
|
||||
RepeatingConnectError,
|
||||
DisconnectNotConnectedClientError,
|
||||
NotConnectError
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
__VERSION__ = '0.0.1'
|
||||
__NAME__ = 'Thrift_Client'
|
||||
|
||||
|
||||
class Prepare(object):
|
||||
|
||||
@classmethod
|
||||
def table_schema(cls,
|
||||
table_name,
|
||||
dimension,
|
||||
index_type=IndexType.INVALIDE,
|
||||
store_raw_vector = False):
|
||||
"""
|
||||
|
||||
:param table_name: str, (Required) name of table
|
||||
:param index_type: IndexType, (Required) index type, default = IndexType.INVALID
|
||||
:param dimension: int64, (Optional) dimension of the table
|
||||
:param store_raw_vector: bool, (Optional) default = False
|
||||
|
||||
:return: TableSchema
|
||||
"""
|
||||
temp = TableSchema(table_name,dimension, index_type, store_raw_vector)
|
||||
|
||||
return ttypes.TableSchema(table_name=temp.table_name,
|
||||
dimension=dimension,
|
||||
index_type=index_type,
|
||||
store_raw_vector=store_raw_vector)
|
||||
|
||||
@classmethod
|
||||
def range(cls, start, end):
|
||||
"""
|
||||
:param start: str, (Required) range start
|
||||
:param end: str (Required) range end
|
||||
|
||||
:return Range
|
||||
"""
|
||||
temp = Range(start=start, end=end)
|
||||
return ttypes.Range(start_value=temp.start, end_value=temp.end)
|
||||
|
||||
@classmethod
|
||||
def row_record(cls, vector_data):
|
||||
"""
|
||||
Record inserted
|
||||
|
||||
:param vector_data: float binary str, (Required) a binary str
|
||||
|
||||
"""
|
||||
temp = RowRecord(vector_data)
|
||||
return ttypes.RowRecord(vector_data=temp.vector_data)
|
||||
|
||||
|
||||
class MegaSearch(ConnectIntf):
|
||||
|
||||
def __init__(self):
|
||||
self.status = None
|
||||
self._transport = None
|
||||
self._client = None
|
||||
|
||||
def __repr__(self):
|
||||
return '{}'.format(self.status)
|
||||
|
||||
def connect(self, host='localhost', port='9090', uri=None):
|
||||
# TODO URI
|
||||
if self.status and self.status == Status.SUCCESS:
|
||||
raise RepeatingConnectError("You have already connected!")
|
||||
|
||||
transport = TSocket.TSocket(host=host, port=port)
|
||||
self._transport = TTransport.TBufferedTransport(transport)
|
||||
protocol = TBinaryProtocol.TBinaryProtocol(transport)
|
||||
self._client = MegasearchService.Client(protocol)
|
||||
|
||||
try:
|
||||
transport.open()
|
||||
self.status = Status(Status.SUCCESS, 'Connected')
|
||||
LOGGER.info('Connected!')
|
||||
|
||||
except (TTransport.TTransportException, TException) as e:
|
||||
self.status = Status(Status.CONNECT_FAILED, message=str(e))
|
||||
LOGGER.error('logger.error: {}'.format(self.status))
|
||||
finally:
|
||||
return self.status
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
return self.status == Status.SUCCESS
|
||||
|
||||
def disconnect(self):
|
||||
|
||||
if not self._transport:
|
||||
raise DisconnectNotConnectedClientError('Error')
|
||||
|
||||
try:
|
||||
|
||||
self._transport.close()
|
||||
LOGGER.info('Client Disconnected!')
|
||||
self.status = None
|
||||
|
||||
except TException as e:
|
||||
return Status(Status.PERMISSION_DENIED, str(e))
|
||||
return Status(Status.SUCCESS, 'Disconnected')
|
||||
|
||||
def create_table(self, param):
|
||||
"""Create table
|
||||
|
||||
:param param: Provide table information to be created,
|
||||
|
||||
`Please use Prepare.table_schema generate param`
|
||||
|
||||
:return: Status, indicate if operation is successful
|
||||
"""
|
||||
if not self._client:
|
||||
raise NotConnectError('Please Connect to the server first!')
|
||||
|
||||
try:
|
||||
self._client.CreateTable(param)
|
||||
except (TApplicationException, ) as e:
|
||||
LOGGER.error('Unable to create table')
|
||||
return Status(Status.PERMISSION_DENIED, str(e))
|
||||
return Status(message='Table {} created!'.format(param.table_name))
|
||||
|
||||
def delete_table(self, table_name):
|
||||
"""Delete table
|
||||
|
||||
:param table_name: Name of the table being deleted
|
||||
|
||||
:return: Status, indicate if operation is successful
|
||||
"""
|
||||
try:
|
||||
self._client.DeleteTable(table_name)
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('Unable to delete table {}'.format(table_name))
|
||||
return Status(Status.PERMISSION_DENIED, str(e))
|
||||
return Status(message='Table {} deleted!'.format(table_name))
|
||||
|
||||
def add_vectors(self, table_name, records):
|
||||
"""
|
||||
Add vectors to table
|
||||
|
||||
:param table_name: table name been inserted
|
||||
:param records: List[RowRecord], list of vectors been inserted
|
||||
|
||||
`Please use Prepare.row_record generate records`
|
||||
|
||||
:returns:
|
||||
Status : indicate if vectors inserted successfully
|
||||
ids :list of id, after inserted every vector is given a id
|
||||
"""
|
||||
try:
|
||||
ids = self._client.AddVector(table_name=table_name, record_array=records)
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Vectors added successfully!'), ids
|
||||
|
||||
def search_vectors(self, table_name, top_k, query_records, query_ranges=None):
|
||||
"""
|
||||
Query vectors in a table
|
||||
|
||||
:param table_name: str, table name been queried
|
||||
:param query_records: list[QueryRecord], all vectors going to be queried
|
||||
|
||||
`Please use Prepare.query_record generate QueryRecord`
|
||||
|
||||
:param top_k: int, how many similar vectors will be searched
|
||||
:param query_ranges, (Optional) list[Range], search range
|
||||
|
||||
:returns:
|
||||
Status: indicate if query is successful
|
||||
res: list[TopKQueryResult], return when operation is successful
|
||||
"""
|
||||
res = []
|
||||
try:
|
||||
top_k_query_results = self._client.SearchVector(
|
||||
table_name=table_name,
|
||||
query_record_array=query_records,
|
||||
query_range_array=query_ranges,
|
||||
topk=top_k)
|
||||
|
||||
if top_k_query_results:
|
||||
for top_k in top_k_query_results:
|
||||
if top_k:
|
||||
res.append(TopKQueryResult([QueryResult(qr.id, qr.score)
|
||||
for qr in top_k.query_result_arrays]))
|
||||
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Success!'), res
|
||||
|
||||
def describe_table(self, table_name):
|
||||
"""
|
||||
Show table information
|
||||
|
||||
:param table_name: str, which table to be shown
|
||||
|
||||
:returns:
|
||||
Status: indicate if query is successful
|
||||
table_schema: TableSchema, return when operation is successful
|
||||
"""
|
||||
try:
|
||||
temp = self._client.DescribeTable(table_name)
|
||||
|
||||
# res = TableSchema(table_name=temp.table_name, dimension=temp.dimension,
|
||||
# index_type=temp.index_type, store_raw_vector=temp.store_raw_vector)
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Success!'), temp
|
||||
|
||||
def show_tables(self):
|
||||
"""
|
||||
Show all tables in database
|
||||
|
||||
:return:
|
||||
Status: indicate if this operation is successful
|
||||
tables: list[str], list of table names, return when operation
|
||||
is successful
|
||||
"""
|
||||
try:
|
||||
res = self._client.ShowTables()
|
||||
tables = []
|
||||
if res:
|
||||
tables, _ = res
|
||||
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Success!'), tables
|
||||
|
||||
def get_table_row_count(self, table_name):
|
||||
"""
|
||||
Get table row count
|
||||
|
||||
:type table_name, str
|
||||
:param table_name, target table name.
|
||||
|
||||
:returns:
|
||||
Status: indicate if operation is successful
|
||||
res: int, table row count
|
||||
|
||||
"""
|
||||
try:
|
||||
count, _ = self._client.GetTableRowCount(table_name)
|
||||
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Success'), count
|
||||
|
||||
def client_version(self):
|
||||
"""
|
||||
Provide client version
|
||||
|
||||
:return: Client version
|
||||
"""
|
||||
return __VERSION__
|
||||
|
||||
def server_version(self):
|
||||
"""
|
||||
Provide server version
|
||||
|
||||
:return: Server version
|
||||
"""
|
||||
if not self.connected:
|
||||
raise NotConnectError('You have to connect first')
|
||||
|
||||
return self._client.Ping('version')
|
||||
|
||||
def server_status(self, cmd=None):
|
||||
"""
|
||||
Provide server status
|
||||
|
||||
:return: Server status
|
||||
"""
|
||||
if not self.connected:
|
||||
raise NotConnectError('You have to connect first')
|
||||
|
||||
return self._client.Ping(cmd)
|
|
@ -0,0 +1,18 @@
|
|||
class ParamError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NotConnectError(ConnectError):
|
||||
pass
|
||||
|
||||
|
||||
class RepeatingConnectError(ConnectError):
|
||||
pass
|
||||
|
||||
|
||||
class DisconnectNotConnectedClientError(ValueError):
|
||||
pass
|
|
@ -0,0 +1,32 @@
|
|||
class Status(object):
|
||||
"""
|
||||
:attribute code : int (optional) default as ok
|
||||
:attribute message : str (optional) current status message
|
||||
"""
|
||||
SUCCESS = 0
|
||||
CONNECT_FAILED = 1
|
||||
PERMISSION_DENIED = 2
|
||||
TABLE_NOT_EXISTS = 3
|
||||
ILLEGAL_ARGUMENT = 4
|
||||
ILLEGAL_RANGE = 5
|
||||
ILLEGAL_DIMENSION = 6
|
||||
|
||||
def __init__(self, code=SUCCESS, message=None):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
def __repr__(self):
|
||||
L = ['%s=%r' % (key, value)
|
||||
for key, value in self.__dict__.items()]
|
||||
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Make Status comparable with self by code"""
|
||||
if isinstance(other, int):
|
||||
return self.code == other
|
||||
else:
|
||||
return isinstance(other, self.__class__) and self.code == other.code
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
from client.Client import MegaSearch, Prepare, IndexType
|
||||
import random
|
||||
import struct
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def main():
|
||||
# Get client version
|
||||
mega = MegaSearch()
|
||||
print('# Client version: {}'.format(mega.client_version()))
|
||||
|
||||
# Connect
|
||||
# Please change HOST and PORT to correct one
|
||||
param = {'host': 'HOST', 'port': 'PORT'}
|
||||
cnn_status = mega.connect(**param)
|
||||
print('# Connect Status: {}'.format(cnn_status))
|
||||
|
||||
# Check if connected
|
||||
is_connected = mega.connected
|
||||
print('# Is connected: {}'.format(is_connected))
|
||||
|
||||
# Get server version
|
||||
print('# Server version: {}'.format(mega.server_version()))
|
||||
|
||||
# Show tables and their description
|
||||
status, tables = mega.show_tables()
|
||||
print('# Show tables: {}'.format(tables))
|
||||
|
||||
# Create table
|
||||
# 01.Prepare data
|
||||
param = {
|
||||
'table_name': 'test'+ str(random.randint(0,999)),
|
||||
'dimension': 256,
|
||||
'index_type': IndexType.IDMAP,
|
||||
'store_raw_vector': False
|
||||
}
|
||||
|
||||
# 02.Create table
|
||||
res_status = mega.create_table(Prepare.table_schema(**param))
|
||||
print('# Create table status: {}'.format(res_status))
|
||||
|
||||
# Describe table
|
||||
table_name = 'test01'
|
||||
res_status, table = mega.describe_table(table_name)
|
||||
print('# Describe table status: {}'.format(res_status))
|
||||
print('# Describe table:{}'.format(table))
|
||||
|
||||
# Add vectors to table 'test01'
|
||||
# 01. Prepare data
|
||||
dim = 256
|
||||
# list of binary vectors
|
||||
vectors = [Prepare.row_record(struct.pack(str(dim)+'d',
|
||||
*[random.random()for _ in range(dim)]))
|
||||
for _ in range(20)]
|
||||
# 02. Add vectors
|
||||
status, ids = mega.add_vectors(table_name=table_name, records=vectors)
|
||||
print('# Add vector status: {}'.format(status))
|
||||
pprint(ids)
|
||||
|
||||
# Search vectors
|
||||
q_records = [Prepare.row_record(struct.pack(str(dim) + 'd',
|
||||
*[random.random() for _ in range(dim)]))
|
||||
for _ in range(5)]
|
||||
param = {
|
||||
'table_name': 'test01',
|
||||
'query_records': q_records,
|
||||
'top_k': 10,
|
||||
# 'query_ranges': None # Optional
|
||||
}
|
||||
sta, results = mega.search_vectors(**param)
|
||||
print('# Search vectors status: {}'.format(sta))
|
||||
pprint(results)
|
||||
|
||||
# Get table row count
|
||||
sta, result = mega.get_table_row_count(table_name)
|
||||
print('# Status: {}'.format(sta))
|
||||
print('# Count: {}'.format(result))
|
||||
|
||||
# Delete table 'test01'
|
||||
res_status = mega.delete_table(table_name)
|
||||
print('# Delete table status: {}'.format(res_status))
|
||||
|
||||
# Disconnect
|
||||
discnn_status = mega.disconnect()
|
||||
print('# Disconnect Status: {}'.format(discnn_status))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,2 +0,0 @@
|
|||
class ConnectParamMissingError(ValueError):
|
||||
pass
|
|
@ -0,0 +1,5 @@
|
|||
[pytest]
|
||||
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
|
||||
|
||||
log_cli = true
|
||||
log_level = 20
|
|
@ -0,0 +1,18 @@
|
|||
atomicwrites==1.3.0
|
||||
attrs==19.1.0
|
||||
Faker==1.0.7
|
||||
importlib-metadata==0.17
|
||||
mock==3.0.5
|
||||
more-itertools==7.0.0
|
||||
packaging==19.0
|
||||
pathlib2==2.3.3
|
||||
pluggy==0.12.0
|
||||
py==1.8.0
|
||||
pyparsing==2.4.0
|
||||
pytest==4.6.0
|
||||
python-dateutil==2.8.0
|
||||
six==1.12.0
|
||||
text-unidecode==1.2
|
||||
thrift==0.11.0
|
||||
wcwidth==0.1.7
|
||||
zipp==0.5.1
|
|
@ -0,0 +1,21 @@
|
|||
import setuptools
|
||||
|
||||
long_description = ''
|
||||
|
||||
setuptools.setup(
|
||||
name="MegaSearch",
|
||||
version="0.0.1",
|
||||
author="XuanYang",
|
||||
author_email="xuan.yang@zilliz.com",
|
||||
description="Sdk for using MegaSearch",
|
||||
packages=setuptools.find_packages(),
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3.4",
|
||||
"Programming Language :: Python :: 3.5",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
|
||||
|
||||
python_requires='>=3.4'
|
||||
)
|
|
@ -0,0 +1,281 @@
|
|||
import logging
|
||||
import pytest
|
||||
import mock
|
||||
import faker
|
||||
import random
|
||||
import struct
|
||||
from faker.providers import BaseProvider
|
||||
|
||||
from client.Client import MegaSearch, Prepare
|
||||
from client.Abstract import IndexType, TableSchema
|
||||
from client.Status import Status
|
||||
from client.Exceptions import (
|
||||
RepeatingConnectError,
|
||||
DisconnectNotConnectedClientError
|
||||
)
|
||||
from megasearch.thrift import ttypes, MegasearchService
|
||||
|
||||
from thrift.transport.TSocket import TSocket
|
||||
from thrift.transport import TTransport
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FakerProvider(BaseProvider):
|
||||
|
||||
def table_name(self):
|
||||
return 'table_name' + str(random.randint(1000, 9999))
|
||||
|
||||
def name(self):
|
||||
return 'name' + str(random.randint(1000, 9999))
|
||||
|
||||
def dim(self):
|
||||
return random.randint(0, 999)
|
||||
|
||||
|
||||
fake = faker.Faker()
|
||||
fake.add_provider(FakerProvider)
|
||||
|
||||
|
||||
def range_factory():
|
||||
param = {
|
||||
'start': str(random.randint(1, 10)),
|
||||
'end': str(random.randint(11, 20)),
|
||||
}
|
||||
return Prepare.range(**param)
|
||||
|
||||
|
||||
def ranges_factory():
|
||||
return [range_factory() for _ in range(5)]
|
||||
|
||||
|
||||
def table_schema_factory():
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'dimension': random.randint(0, 999),
|
||||
'index_type': IndexType.IDMAP,
|
||||
'store_raw_vector': False
|
||||
}
|
||||
return Prepare.table_schema(**param)
|
||||
|
||||
|
||||
def row_record_factory(dimension):
|
||||
vec = [random.random() + random.randint(0,9) for _ in range(dimension)]
|
||||
bin_vec = struct.pack(str(dimension) + "d", *vec)
|
||||
|
||||
return Prepare.row_record(vector_data=bin_vec)
|
||||
|
||||
|
||||
def row_records_factory(dimension):
|
||||
return [row_record_factory(dimension) for _ in range(20)]
|
||||
|
||||
|
||||
class TestConnection:
|
||||
param = {'host':'localhost', 'port': '5000'}
|
||||
|
||||
@mock.patch.object(TSocket, 'open')
|
||||
def test_true_connect(self, open):
|
||||
open.return_value = None
|
||||
cnn = MegaSearch()
|
||||
|
||||
cnn.connect(**self.param)
|
||||
assert cnn.status == Status.SUCCESS
|
||||
assert cnn.connected
|
||||
|
||||
with pytest.raises(RepeatingConnectError):
|
||||
cnn.connect(**self.param)
|
||||
cnn.connect()
|
||||
|
||||
def test_false_connect(self):
|
||||
cnn = MegaSearch()
|
||||
|
||||
cnn.connect(**self.param)
|
||||
assert cnn.status != Status.SUCCESS
|
||||
|
||||
@mock.patch.object(TTransport.TBufferedTransport, 'close')
|
||||
@mock.patch.object(TSocket, 'open')
|
||||
def test_disconnected(self, close, open):
|
||||
close.return_value = None
|
||||
open.return_value = None
|
||||
|
||||
cnn = MegaSearch()
|
||||
cnn.connect(**self.param)
|
||||
|
||||
assert cnn.disconnect() == Status.SUCCESS
|
||||
|
||||
def test_disconnected_error(self):
|
||||
cnn = MegaSearch()
|
||||
cnn.connect_status = Status(Status.PERMISSION_DENIED)
|
||||
with pytest.raises(DisconnectNotConnectedClientError):
|
||||
cnn.disconnect()
|
||||
|
||||
|
||||
class TestTable:
|
||||
|
||||
@pytest.fixture
|
||||
@mock.patch.object(TSocket, 'open')
|
||||
def client(self, open):
|
||||
param = {'host': 'localhost', 'port': '5000'}
|
||||
open.return_value = None
|
||||
|
||||
cnn = MegaSearch()
|
||||
cnn.connect(**param)
|
||||
return cnn
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'CreateTable')
|
||||
def test_create_table(self, CreateTable, client):
|
||||
CreateTable.return_value = None
|
||||
|
||||
param = table_schema_factory()
|
||||
res = client.create_table(param)
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_create_table(self, client):
|
||||
param = table_schema_factory()
|
||||
with pytest.raises(TTransportException):
|
||||
res = client.create_table(param)
|
||||
LOGGER.error('{}'.format(res))
|
||||
assert res != Status.SUCCESS
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'DeleteTable')
|
||||
def test_delete_table(self, DeleteTable, client):
|
||||
DeleteTable.return_value = None
|
||||
table_name = 'fake_table_name'
|
||||
res = client.delete_table(table_name)
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_delete_table(self, client):
|
||||
table_name = 'fake_table_name'
|
||||
res = client.delete_table(table_name)
|
||||
assert res != Status.SUCCESS
|
||||
|
||||
|
||||
class TestVector:
|
||||
|
||||
@pytest.fixture
|
||||
@mock.patch.object(TSocket, 'open')
|
||||
def client(self, open):
|
||||
param = {'host': 'localhost', 'port': '5000'}
|
||||
open.return_value = None
|
||||
|
||||
cnn = MegaSearch()
|
||||
cnn.connect(**param)
|
||||
return cnn
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'AddVector')
|
||||
def test_add_vector(self, AddVector, client):
|
||||
AddVector.return_value = None
|
||||
|
||||
param ={
|
||||
'table_name': fake.table_name(),
|
||||
'records': row_records_factory(256)
|
||||
}
|
||||
res, ids = client.add_vectors(**param)
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_add_vector(self, client):
|
||||
param ={
|
||||
'table_name': fake.table_name(),
|
||||
'records': row_records_factory(256)
|
||||
}
|
||||
res, ids = client.add_vectors(**param)
|
||||
assert res != Status.SUCCESS
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'SearchVector')
|
||||
def test_search_vector(self, SearchVector, client):
|
||||
SearchVector.return_value = None, None
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'query_records': row_records_factory(256),
|
||||
'query_ranges': ranges_factory(),
|
||||
'top_k': random.randint(0, 10)
|
||||
}
|
||||
res, results = client.search_vectors(**param)
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_vector(self, client):
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'query_records': row_records_factory(256),
|
||||
'query_ranges': ranges_factory(),
|
||||
'top_k': random.randint(0, 10)
|
||||
}
|
||||
res, results = client.search_vectors(**param)
|
||||
assert res != Status.SUCCESS
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'DescribeTable')
|
||||
def test_describe_table(self, DescribeTable, client):
|
||||
DescribeTable.return_value = table_schema_factory()
|
||||
|
||||
table_name = fake.table_name()
|
||||
res, table_schema = client.describe_table(table_name)
|
||||
assert res == Status.SUCCESS
|
||||
assert isinstance(table_schema, ttypes.TableSchema)
|
||||
|
||||
def test_false_decribe_table(self, client):
|
||||
table_name = fake.table_name()
|
||||
res, table_schema = client.describe_table(table_name)
|
||||
assert res != Status.SUCCESS
|
||||
assert not table_schema
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'ShowTables')
|
||||
def test_show_tables(self, ShowTables, client):
|
||||
ShowTables.return_value = [fake.table_name() for _ in range(10)], None
|
||||
res, tables = client.show_tables()
|
||||
assert res == Status.SUCCESS
|
||||
assert isinstance(tables, list)
|
||||
|
||||
def test_false_show_tables(self, client):
|
||||
res, tables = client.show_tables()
|
||||
assert res != Status.SUCCESS
|
||||
assert not tables
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'GetTableRowCount')
|
||||
def test_get_table_row_count(self, GetTableRowCount, client):
|
||||
GetTableRowCount.return_value = 22, None
|
||||
res, count = client.get_table_row_count('fake_table')
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_get_table_row_count(self, client):
|
||||
res,count = client.get_table_row_count('fake_table')
|
||||
assert res != Status.SUCCESS
|
||||
assert not count
|
||||
|
||||
def test_client_version(self, client):
|
||||
res = client.client_version()
|
||||
assert res == '0.0.1'
|
||||
|
||||
|
||||
class TestPrepare:
|
||||
|
||||
def test_table_schema(self):
|
||||
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'dimension': random.randint(0, 999),
|
||||
'index_type': IndexType.IDMAP,
|
||||
'store_raw_vector': False
|
||||
}
|
||||
res = Prepare.table_schema(**param)
|
||||
assert isinstance(res, ttypes.TableSchema)
|
||||
|
||||
def test_range(self):
|
||||
param = {
|
||||
'start': '200',
|
||||
'end': '1000'
|
||||
}
|
||||
|
||||
res = Prepare.range(**param)
|
||||
LOGGER.error('{}'.format(res))
|
||||
assert isinstance(res, ttypes.Range)
|
||||
assert res.start_value == '200'
|
||||
assert res.end_value == '1000'
|
||||
|
||||
def test_row_record(self):
|
||||
vec = [random.random() + random.randint(0, 9) for _ in range(256)]
|
||||
bin_vec = struct.pack(str(256) + "d", *vec)
|
||||
res = Prepare.row_record(bin_vec)
|
||||
assert isinstance(res, ttypes.RowRecord)
|
||||
assert isinstance(bin_vec, bytes)
|
||||
|
Loading…
Reference in New Issue