modify shards for v0.5.3

pull/422/head
yhz 2019-11-19 17:37:13 +08:00
parent 660953afa2
commit 83d9bf6966
2 changed files with 89 additions and 28 deletions

View File

@ -2,6 +2,7 @@ import logging
import threading import threading
from functools import wraps from functools import wraps
from milvus import Milvus from milvus import Milvus
from milvus.client.hooks import BaseaSearchHook
from mishards import (settings, exceptions) from mishards import (settings, exceptions)
from utils import singleton from utils import singleton
@ -9,6 +10,12 @@ from utils import singleton
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Searchook(BaseaSearchHook):
def on_response(self, *args, **kwargs):
return True
class Connection: class Connection:
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
self.name = name self.name = name
@ -18,6 +25,9 @@ class Connection:
self.conn = Milvus() self.conn = Milvus()
self.error_handlers = [] if not error_handlers else error_handlers self.error_handlers = [] if not error_handlers else error_handlers
self.on_retry_func = kwargs.get('on_retry_func', None) self.on_retry_func = kwargs.get('on_retry_func', None)
# define search hook
self.conn._set_hook(search_in_file=Searchook())
# self._connect() # self._connect()
def __str__(self): def __str__(self):

View File

@ -29,39 +29,88 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
self.router = router self.router = router
self.max_workers = max_workers self.max_workers = max_workers
def _reduce(self, source_ids, ids, source_diss, diss, k, reverse):
if source_diss[k - 1] <= diss[0]:
return source_ids, source_diss
if diss[k - 1] <= source_diss[0]:
return ids, diss
diss_t = enumerate(source_diss.extend(diss))
diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k]
diss_m_out = [id_ for _, id_ in diss_m_rst]
id_t = source_ids.extend(ids)
id_m_out = [id_t[i] for i, _ in diss_m_rst]
return id_m_out, diss_m_out
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs): def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
status = status_pb2.Status(error_code=status_pb2.SUCCESS, status = status_pb2.Status(error_code=status_pb2.SUCCESS,
reason="Success") reason="Success")
if not files_n_topk_results: if not files_n_topk_results:
return status, [] return status, []
request_results = defaultdict(list) # request_results = defaultdict(list)
# row_num = files_n_topk_results[0].row_num
merge_id_results = []
merge_dis_results = []
calc_time = time.time() calc_time = time.time()
for files_collection in files_n_topk_results: for files_collection in files_n_topk_results:
if isinstance(files_collection, tuple): if isinstance(files_collection, tuple):
status, _ = files_collection status, _ = files_collection
return status, [] return status, []
for request_pos, each_request_results in enumerate(
files_collection.topk_query_result): row_num = files_collection.row_num
request_results[request_pos].extend( ids = files_collection.ids
each_request_results.query_result_arrays) diss = files_collection.distances # distance collections
request_results[request_pos] = sorted( batch_len = len(ids) // row_num
request_results[request_pos],
key=lambda x: x.distance, for row_index in range(row_num):
reverse=reverse)[:topk] id_batch = ids[row_index * batch_len: (row_index + 1) * batch_len]
dis_batch = diss[row_index * batch_len: (row_index + 1) * batch_len]
if len(merge_id_results) < row_index:
raise ValueError("merge error")
elif len(merge_id_results) == row_index:
# TODO: may bug here
merge_id_results.append(id_batch)
merge_dis_results.append(dis_batch)
else:
merge_id_results[row_index].extend(ids[row_index * batch_len, (row_index + 1) * batch_len])
merge_dis_results[row_index].extend(diss[row_index * batch_len, (row_index + 1) * batch_len])
# _reduce(_ids, _diss, k, reverse)
merge_id_results[row_index], merge_dis_results[row_index] = \
self._reduce(merge_id_results[row_index], id_batch,
merge_dis_results[row_index], dis_batch,
batch_len,
reverse)
# for request_pos, each_request_results in enumerate(
# files_collection.topk_query_result):
# request_results[request_pos].extend(
# each_request_results.query_result_arrays)
# request_results[request_pos] = sorted(
# request_results[request_pos],
# key=lambda x: x.distance,
# reverse=reverse)[:topk]
calc_time = time.time() - calc_time calc_time = time.time() - calc_time
logger.info('Merge takes {}'.format(calc_time)) logger.info('Merge takes {}'.format(calc_time))
results = sorted(request_results.items()) # results = sorted(request_results.items())
topk_query_result = [] id_mrege_list = []
dis_mrege_list = []
for result in results: for id_results, dis_results in zip(merge_id_results, merge_dis_results):
query_result = TopKQueryResult(query_result_arrays=result[1]) id_mrege_list.extend(id_results)
topk_query_result.append(query_result) dis_mrege_list.extend(dis_results)
return status, topk_query_result # for result in results:
# query_result = TopKQueryResult(query_result_arrays=result[1])
# topk_query_result.append(query_result)
return status, id_mrege_list, dis_mrege_list
def _do_query(self, def _do_query(self,
context, context,
@ -109,8 +158,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
file_ids=query_params['file_ids'], file_ids=query_params['file_ids'],
query_records=vectors, query_records=vectors,
top_k=topk, top_k=topk,
nprobe=nprobe, nprobe=nprobe
lazy_=True) )
end = time.time() end = time.time()
logger.info('search_vectors_in_files takes: {}'.format(end - start)) logger.info('search_vectors_in_files takes: {}'.format(end - start))
@ -241,7 +290,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('Search {}: topk={} nprobe={}'.format( logger.info('Search {}: topk={} nprobe={}'.format(
table_name, topk, nprobe)) table_name, topk, nprobe))
metadata = {'resp_class': milvus_pb2.TopKQueryResultList} metadata = {'resp_class': milvus_pb2.TopKQueryResult}
if nprobe > self.MAX_NPROBE or nprobe <= 0: if nprobe > self.MAX_NPROBE or nprobe <= 0:
raise exceptions.InvalidArgumentError( raise exceptions.InvalidArgumentError(
@ -275,22 +324,24 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
query_range_array.append( query_range_array.append(
Range(query_range.start_value, query_range.end_value)) Range(query_range.start_value, query_range.end_value))
status, results = self._do_query(context, status, id_results, dis_results = self._do_query(context,
table_name, table_name,
table_meta, table_meta,
query_record_array, query_record_array,
topk, topk,
nprobe, nprobe,
query_range_array, query_range_array,
metadata=metadata) metadata=metadata)
now = time.time() now = time.time()
logger.info('SearchVector takes: {}'.format(now - start)) logger.info('SearchVector takes: {}'.format(now - start))
topk_result_list = milvus_pb2.TopKQueryResultList( topk_result_list = milvus_pb2.TopKQueryResult(
status=status_pb2.Status(error_code=status.error_code, status=status_pb2.Status(error_code=status.error_code,
reason=status.reason), reason=status.reason),
topk_query_result=results) row_num=len(query_record_array),
ids=id_results,
distances=dis_results)
return topk_result_list return topk_result_list
@mark_grpc_method @mark_grpc_method