2019-06-06 06:49:57 +00:00
import logging
import pytest
import mock
import faker
import random
2019-06-10 11:13:10 +00:00
import struct
2019-06-06 06:49:57 +00:00
from faker.providers import BaseProvider
2019-06-10 11:13:10 +00:00
from client.Client import MegaSearch, Prepare
from client.Abstract import IndexType, TableSchema
2019-06-06 06:49:57 +00:00
from client.Status import Status
from client.Exceptions import (
2019-06-10 11:13:10 +00:00
from megasearch.thrift import ttypes, MegasearchService
2019-06-06 06:49:57 +00:00
from thrift.transport.TSocket import TSocket
2019-06-10 11:13:10 +00:00
from thrift.transport import TTransport
2019-06-06 10:14:23 +00:00
from thrift.transport.TTransport import TTransportException
2019-06-06 06:49:57 +00:00
LOGGER = logging.getLogger(__name__)
class FakerProvider(BaseProvider):
def table_name(self):
return 'table_name' + str(random.randint(1000, 9999))
def name(self):
return 'name' + str(random.randint(1000, 9999))
def dim(self):
return random.randint(0, 999)
fake = faker.Faker()
def range_factory():
2019-06-10 11:13:10 +00:00
param = {
2019-06-06 06:49:57 +00:00
'start': str(random.randint(1, 10)),
'end': str(random.randint(11, 20)),
2019-06-10 11:13:10 +00:00
return Prepare.range(**param)
def ranges_factory():
return [range_factory() for _ in range(5)]
2019-06-06 06:49:57 +00:00
def table_schema_factory():
param = {
'table_name': fake.table_name(),
2019-06-10 11:13:10 +00:00
'dimension': random.randint(0, 999),
'index_type': IndexType.IDMAP,
'store_raw_vector': False
2019-06-06 06:49:57 +00:00
return Prepare.table_schema(**param)
2019-06-10 11:13:10 +00:00
def row_record_factory(dimension):
vec = [random.random() + random.randint(0,9) for _ in range(dimension)]
bin_vec = struct.pack(str(dimension) + "d", *vec)
2019-06-06 06:49:57 +00:00
2019-06-10 11:13:10 +00:00
return Prepare.row_record(vector_data=bin_vec)
2019-06-06 06:49:57 +00:00
2019-06-10 11:13:10 +00:00
def row_records_factory(dimension):
return [row_record_factory(dimension) for _ in range(20)]
2019-06-06 06:49:57 +00:00
class TestConnection:
param = {'host':'localhost', 'port': '5000'}
@mock.patch.object(TSocket, 'open')
def test_true_connect(self, open):
open.return_value = None
cnn = MegaSearch()
2019-06-10 11:13:10 +00:00
assert cnn.status == Status.SUCCESS
2019-06-06 06:49:57 +00:00
assert cnn.connected
with pytest.raises(RepeatingConnectError):
def test_false_connect(self):
cnn = MegaSearch()
2019-06-10 11:13:10 +00:00
assert cnn.status != Status.SUCCESS
@mock.patch.object(TTransport.TBufferedTransport, 'close')
@mock.patch.object(TSocket, 'open')
def test_disconnected(self, close, open):
close.return_value = None
open.return_value = None
cnn = MegaSearch()
assert cnn.disconnect() == Status.SUCCESS
2019-06-06 06:49:57 +00:00
def test_disconnected_error(self):
cnn = MegaSearch()
2019-06-10 11:13:10 +00:00
cnn.connect_status = Status(Status.PERMISSION_DENIED)
2019-06-06 06:49:57 +00:00
with pytest.raises(DisconnectNotConnectedClientError):
class TestTable:
@mock.patch.object(TSocket, 'open')
def client(self, open):
param = {'host': 'localhost', 'port': '5000'}
open.return_value = None
cnn = MegaSearch()
return cnn
@mock.patch.object(MegasearchService.Client, 'CreateTable')
def test_create_table(self, CreateTable, client):
CreateTable.return_value = None
param = table_schema_factory()
res = client.create_table(param)
2019-06-10 11:13:10 +00:00
assert res == Status.SUCCESS
2019-06-06 06:49:57 +00:00
def test_false_create_table(self, client):
param = table_schema_factory()
2019-06-06 10:14:23 +00:00
with pytest.raises(TTransportException):
res = client.create_table(param)
2019-06-10 11:13:10 +00:00
assert res != Status.SUCCESS
2019-06-06 06:49:57 +00:00
@mock.patch.object(MegasearchService.Client, 'DeleteTable')
def test_delete_table(self, DeleteTable, client):
DeleteTable.return_value = None
table_name = 'fake_table_name'
res = client.delete_table(table_name)
2019-06-10 11:13:10 +00:00
assert res == Status.SUCCESS
2019-06-06 06:49:57 +00:00
def test_false_delete_table(self, client):
table_name = 'fake_table_name'
res = client.delete_table(table_name)
2019-06-10 11:13:10 +00:00
assert res != Status.SUCCESS
2019-06-06 06:49:57 +00:00
class TestVector:
@mock.patch.object(TSocket, 'open')
def client(self, open):
param = {'host': 'localhost', 'port': '5000'}
open.return_value = None
cnn = MegaSearch()
return cnn
@mock.patch.object(MegasearchService.Client, 'AddVector')
def test_add_vector(self, AddVector, client):
AddVector.return_value = None
param ={
'table_name': fake.table_name(),
2019-06-10 11:13:10 +00:00
'records': row_records_factory(256)
2019-06-06 06:49:57 +00:00
2019-06-10 11:13:10 +00:00
res, ids = client.add_vectors(**param)
assert res == Status.SUCCESS
2019-06-06 06:49:57 +00:00
def test_false_add_vector(self, client):
param ={
'table_name': fake.table_name(),
2019-06-10 11:13:10 +00:00
'records': row_records_factory(256)
2019-06-06 06:49:57 +00:00
2019-06-10 11:13:10 +00:00
res, ids = client.add_vectors(**param)
assert res != Status.SUCCESS
2019-06-06 06:49:57 +00:00
@mock.patch.object(MegasearchService.Client, 'SearchVector')
def test_search_vector(self, SearchVector, client):
2019-06-10 11:13:10 +00:00
SearchVector.return_value = None, None
2019-06-06 06:49:57 +00:00
param = {
'table_name': fake.table_name(),
2019-06-10 11:13:10 +00:00
'query_records': row_records_factory(256),
'query_ranges': ranges_factory(),
'top_k': random.randint(0, 10)
2019-06-06 06:49:57 +00:00
2019-06-10 11:13:10 +00:00
res, results = client.search_vectors(**param)
assert res == Status.SUCCESS
2019-06-06 06:49:57 +00:00
def test_false_vector(self, client):
param = {
'table_name': fake.table_name(),
2019-06-10 11:13:10 +00:00
'query_records': row_records_factory(256),
'query_ranges': ranges_factory(),
'top_k': random.randint(0, 10)
2019-06-06 06:49:57 +00:00
2019-06-10 11:13:10 +00:00
res, results = client.search_vectors(**param)
assert res != Status.SUCCESS
2019-06-06 06:49:57 +00:00
@mock.patch.object(MegasearchService.Client, 'DescribeTable')
def test_describe_table(self, DescribeTable, client):
DescribeTable.return_value = table_schema_factory()
table_name = fake.table_name()
res, table_schema = client.describe_table(table_name)
2019-06-10 11:13:10 +00:00
assert res == Status.SUCCESS
assert isinstance(table_schema, TableSchema)
2019-06-06 06:49:57 +00:00
def test_false_decribe_table(self, client):
table_name = fake.table_name()
res, table_schema = client.describe_table(table_name)
2019-06-10 11:13:10 +00:00
assert res != Status.SUCCESS
2019-06-06 06:49:57 +00:00
assert not table_schema
@mock.patch.object(MegasearchService.Client, 'ShowTables')
def test_show_tables(self, ShowTables, client):
2019-06-10 11:13:10 +00:00
ShowTables.return_value = [fake.table_name() for _ in range(10)], None
2019-06-06 06:49:57 +00:00
res, tables = client.show_tables()
2019-06-10 11:13:10 +00:00
assert res == Status.SUCCESS
2019-06-06 06:49:57 +00:00
assert isinstance(tables, list)
def test_false_show_tables(self, client):
res, tables = client.show_tables()
2019-06-10 11:13:10 +00:00
assert res != Status.SUCCESS
2019-06-06 06:49:57 +00:00
assert not tables
2019-06-10 11:13:10 +00:00
@mock.patch.object(MegasearchService.Client, 'GetTableRowCount')
def test_get_table_row_count(self, GetTableRowCount, client):
GetTableRowCount.return_value = 22, None
res, count = client.get_table_row_count('fake_table')
assert res == Status.SUCCESS
def test_false_get_table_row_count(self, client):
res,count = client.get_table_row_count('fake_table')
assert res != Status.SUCCESS
assert not count
2019-06-06 06:49:57 +00:00
def test_client_version(self, client):
res = client.client_version()
assert res == '0.0.1'
class TestPrepare:
def test_table_schema(self):
param = {
2019-06-10 11:13:10 +00:00
'table_name': fake.table_name(),
'dimension': random.randint(0, 999),
'index_type': IndexType.IDMAP,
'store_raw_vector': False
2019-06-06 06:49:57 +00:00
res = Prepare.table_schema(**param)
assert isinstance(res, ttypes.TableSchema)
def test_range(self):
param = {
'start': '200',
'end': '1000'
res = Prepare.range(**param)
assert isinstance(res, ttypes.Range)
assert res.start_value == '200'
assert res.end_value == '1000'
def test_row_record(self):
2019-06-10 11:13:10 +00:00
vec = [random.random() + random.randint(0, 9) for _ in range(256)]
bin_vec = struct.pack(str(256) + "d", *vec)
res = Prepare.row_record(bin_vec)
2019-06-06 06:49:57 +00:00
assert isinstance(res, ttypes.RowRecord)
2019-06-10 11:13:10 +00:00
assert isinstance(bin_vec, bytes)
2019-06-06 06:49:57 +00:00