milvus/python/sdk/client/Client.py

307 lines
9.5 KiB
Python

import logging, logging.config
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.transport.TTransport import TTransportException
from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
from thrift.Thrift import TException, TApplicationException, TType
from megasearch.thrift import MegasearchService
from megasearch.thrift import ttypes
from client.Abstract import (
ConnectIntf,
TableSchema,
IndexType,
Range,
RowRecord,
QueryResult,
TopKQueryResult
)
from client.Status import Status
from client.Exceptions import (
RepeatingConnectError,
DisconnectNotConnectedClientError,
NotConnectError
)
LOGGER = logging.getLogger(__name__)
__VERSION__ = '0.0.1'
__NAME__ = 'Thrift_Client'
class Prepare(object):
@classmethod
def table_schema(cls,
table_name, *,
dimension,
index_type,
store_raw_vector):
"""
:param table_name: str, (Required) name of table
:param index_type: IndexType, (Required) index type, default = IndexType.INVALID
:param dimension: int64, (Optional) dimension of the table
:param store_raw_vector: bool, (Optional) default = False
:return: TableSchema
"""
temp = TableSchema(table_name,dimension, index_type, store_raw_vector)
return ttypes.TableSchema(table_name=temp.table_name,
dimension=dimension,
index_type=index_type,
store_raw_vector=store_raw_vector)
@classmethod
def range(cls, start, end):
"""
:param start: str, (Required) range start
:param end: str (Required) range end
:return Range
"""
temp = Range(start=start, end=end)
return ttypes.Range(start_value=temp.start, end_value=temp.end)
@classmethod
def row_record(cls, vector_data):
"""
Record inserted
:param vector_data: float binary str, (Required) a binary str
"""
temp = RowRecord(vector_data)
return ttypes.RowRecord(vector_data=temp.vector_data)
class MegaSearch(ConnectIntf):
def __init__(self):
self.status = None
self._transport = None
self._client = None
def __repr__(self):
return '{}'.format(self.status)
def connect(self, host='localhost', port='9090', uri=None):
# TODO URI
if self.status and self.status == Status.SUCCESS:
raise RepeatingConnectError("You have already connected!")
transport = TSocket.TSocket(host=host, port=port)
self._transport = TTransport.TBufferedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
self._client = MegasearchService.Client(protocol)
try:
transport.open()
self.status = Status(Status.SUCCESS, 'Connected')
LOGGER.info('Connected!')
except (TTransport.TTransportException, TException) as e:
self.status = Status(Status.CONNECT_FAILED, message=str(e))
LOGGER.error('logger.error: {}'.format(self.status))
finally:
return self.status
@property
def connected(self):
return self.status == Status.SUCCESS
def disconnect(self):
if not self._transport:
raise DisconnectNotConnectedClientError('Error')
try:
self._transport.close()
LOGGER.info('Client Disconnected!')
self.status = None
except TException as e:
return Status(Status.PERMISSION_DENIED, str(e))
return Status(Status.SUCCESS, 'Disconnected')
def create_table(self, param):
"""Create table
:param param: Provide table information to be created,
`Please use Prepare.table_schema generate param`
:return: Status, indicate if operation is successful
"""
if not self._client:
raise NotConnectError('Please Connect to the server first!')
try:
self._client.CreateTable(param)
except (TApplicationException, ) as e:
LOGGER.error('Unable to create table')
return Status(Status.PERMISSION_DENIED, str(e))
return Status(message='Table {} created!'.format(param.table_name))
def delete_table(self, table_name):
"""Delete table
:param table_name: Name of the table being deleted
:return: Status, indicate if operation is successful
"""
try:
self._client.DeleteTable(table_name)
except (TApplicationException, TException) as e:
LOGGER.error('Unable to delete table {}'.format(table_name))
return Status(Status.PERMISSION_DENIED, str(e))
return Status(message='Table {} deleted!'.format(table_name))
def add_vectors(self, table_name, records):
"""
Add vectors to table
:param table_name: table name been inserted
:param records: List[RowRecord], list of vectors been inserted
`Please use Prepare.row_record generate records`
:returns:
Status : indicate if vectors inserted successfully
ids :list of id, after inserted every vector is given a id
"""
try:
ids = self._client.AddVector(table_name=table_name, record_array=records)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Vectors added successfully!'), ids
def search_vectors(self, table_name, top_k, query_records, query_ranges=None):
"""
Query vectors in a table
:param table_name: str, table name been queried
:param query_records: list[QueryRecord], all vectors going to be queried
`Please use Prepare.query_record generate QueryRecord`
:param top_k: int, how many similar vectors will be searched
:param query_ranges, (Optional) list[Range], search range
:returns:
Status: indicate if query is successful
res: list[TopKQueryResult], return when operation is successful
"""
res = []
try:
top_k_query_results = self._client.SearchVector(
table_name=table_name,
query_record_array=query_records,
query_range_array=query_ranges,
topk=top_k)
if top_k_query_results:
for top_k in top_k_query_results:
res.append(TopKQueryResult([QueryResult(qr.id, qr.score)
for qr in top_k.query_result_arrays]))
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Success!'), res
def describe_table(self, table_name):
"""
Show table information
:param table_name: str, which table to be shown
:returns:
Status: indicate if query is successful
table_schema: TableSchema, return when operation is successful
"""
try:
temp = self._client.DescribeTable(table_name)
# res = TableSchema(table_name=temp.table_name, dimension=temp.dimension,
# index_type=temp.index_type, store_raw_vector=temp.store_raw_vector)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Success!'), temp
def show_tables(self):
"""
Show all tables in database
:return:
Status: indicate if this operation is successful
tables: list[str], list of table names, return when operation
is successful
"""
try:
res = self._client.ShowTables()
tables = []
if res:
tables, _ = res
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Success!'), tables
def get_table_row_count(self, table_name):
"""
Get table row count
:type table_name, str
:param table_name, target table name.
:returns:
Status: indicate if operation is successful
res: int, table row count
"""
try:
count, _ = self._client.GetTableRowCount(table_name)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.PERMISSION_DENIED, str(e)), None
return Status(message='Success'), count
def client_version(self):
"""
Provide client version
:return: Client version
"""
return __VERSION__
def server_version(self):
"""
Provide server version
:return: Server version
"""
if not self.connected:
raise NotConnectError('You have to connect first')
return self._client.Ping('version')
def server_status(self, cmd=None):
"""
Provide server status
:return: Server status
"""
if not self.connected:
raise NotConnectError('You have to connect first')
return self._client.Ping(cmd)