mirror of https://github.com/milvus-io/milvus.git
modify shards for v0.5.3
parent
660953afa2
commit
83d9bf6966
|
@ -2,6 +2,7 @@ import logging
|
||||||
import threading
|
import threading
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from milvus import Milvus
|
from milvus import Milvus
|
||||||
|
from milvus.client.hooks import BaseaSearchHook
|
||||||
|
|
||||||
from mishards import (settings, exceptions)
|
from mishards import (settings, exceptions)
|
||||||
from utils import singleton
|
from utils import singleton
|
||||||
|
@ -9,6 +10,12 @@ from utils import singleton
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Searchook(BaseaSearchHook):
|
||||||
|
|
||||||
|
def on_response(self, *args, **kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Connection:
|
class Connection:
|
||||||
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
|
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -18,6 +25,9 @@ class Connection:
|
||||||
self.conn = Milvus()
|
self.conn = Milvus()
|
||||||
self.error_handlers = [] if not error_handlers else error_handlers
|
self.error_handlers = [] if not error_handlers else error_handlers
|
||||||
self.on_retry_func = kwargs.get('on_retry_func', None)
|
self.on_retry_func = kwargs.get('on_retry_func', None)
|
||||||
|
|
||||||
|
# define search hook
|
||||||
|
self.conn._set_hook(search_in_file=Searchook())
|
||||||
# self._connect()
|
# self._connect()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
|
@ -29,39 +29,88 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||||
self.router = router
|
self.router = router
|
||||||
self.max_workers = max_workers
|
self.max_workers = max_workers
|
||||||
|
|
||||||
|
def _reduce(self, source_ids, ids, source_diss, diss, k, reverse):
|
||||||
|
if source_diss[k - 1] <= diss[0]:
|
||||||
|
return source_ids, source_diss
|
||||||
|
if diss[k - 1] <= source_diss[0]:
|
||||||
|
return ids, diss
|
||||||
|
|
||||||
|
diss_t = enumerate(source_diss.extend(diss))
|
||||||
|
diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k]
|
||||||
|
diss_m_out = [id_ for _, id_ in diss_m_rst]
|
||||||
|
|
||||||
|
id_t = source_ids.extend(ids)
|
||||||
|
id_m_out = [id_t[i] for i, _ in diss_m_rst]
|
||||||
|
|
||||||
|
return id_m_out, diss_m_out
|
||||||
|
|
||||||
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
|
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
|
||||||
status = status_pb2.Status(error_code=status_pb2.SUCCESS,
|
status = status_pb2.Status(error_code=status_pb2.SUCCESS,
|
||||||
reason="Success")
|
reason="Success")
|
||||||
if not files_n_topk_results:
|
if not files_n_topk_results:
|
||||||
return status, []
|
return status, []
|
||||||
|
|
||||||
request_results = defaultdict(list)
|
# request_results = defaultdict(list)
|
||||||
|
# row_num = files_n_topk_results[0].row_num
|
||||||
|
merge_id_results = []
|
||||||
|
merge_dis_results = []
|
||||||
|
|
||||||
calc_time = time.time()
|
calc_time = time.time()
|
||||||
for files_collection in files_n_topk_results:
|
for files_collection in files_n_topk_results:
|
||||||
if isinstance(files_collection, tuple):
|
if isinstance(files_collection, tuple):
|
||||||
status, _ = files_collection
|
status, _ = files_collection
|
||||||
return status, []
|
return status, []
|
||||||
for request_pos, each_request_results in enumerate(
|
|
||||||
files_collection.topk_query_result):
|
row_num = files_collection.row_num
|
||||||
request_results[request_pos].extend(
|
ids = files_collection.ids
|
||||||
each_request_results.query_result_arrays)
|
diss = files_collection.distances # distance collections
|
||||||
request_results[request_pos] = sorted(
|
batch_len = len(ids) // row_num
|
||||||
request_results[request_pos],
|
|
||||||
key=lambda x: x.distance,
|
for row_index in range(row_num):
|
||||||
reverse=reverse)[:topk]
|
id_batch = ids[row_index * batch_len: (row_index + 1) * batch_len]
|
||||||
|
dis_batch = diss[row_index * batch_len: (row_index + 1) * batch_len]
|
||||||
|
|
||||||
|
if len(merge_id_results) < row_index:
|
||||||
|
raise ValueError("merge error")
|
||||||
|
elif len(merge_id_results) == row_index:
|
||||||
|
# TODO: may bug here
|
||||||
|
merge_id_results.append(id_batch)
|
||||||
|
merge_dis_results.append(dis_batch)
|
||||||
|
else:
|
||||||
|
merge_id_results[row_index].extend(ids[row_index * batch_len, (row_index + 1) * batch_len])
|
||||||
|
merge_dis_results[row_index].extend(diss[row_index * batch_len, (row_index + 1) * batch_len])
|
||||||
|
# _reduce(_ids, _diss, k, reverse)
|
||||||
|
merge_id_results[row_index], merge_dis_results[row_index] = \
|
||||||
|
self._reduce(merge_id_results[row_index], id_batch,
|
||||||
|
merge_dis_results[row_index], dis_batch,
|
||||||
|
batch_len,
|
||||||
|
reverse)
|
||||||
|
|
||||||
|
# for request_pos, each_request_results in enumerate(
|
||||||
|
# files_collection.topk_query_result):
|
||||||
|
# request_results[request_pos].extend(
|
||||||
|
# each_request_results.query_result_arrays)
|
||||||
|
# request_results[request_pos] = sorted(
|
||||||
|
# request_results[request_pos],
|
||||||
|
# key=lambda x: x.distance,
|
||||||
|
# reverse=reverse)[:topk]
|
||||||
|
|
||||||
calc_time = time.time() - calc_time
|
calc_time = time.time() - calc_time
|
||||||
logger.info('Merge takes {}'.format(calc_time))
|
logger.info('Merge takes {}'.format(calc_time))
|
||||||
|
|
||||||
results = sorted(request_results.items())
|
# results = sorted(request_results.items())
|
||||||
topk_query_result = []
|
id_mrege_list = []
|
||||||
|
dis_mrege_list = []
|
||||||
|
|
||||||
for result in results:
|
for id_results, dis_results in zip(merge_id_results, merge_dis_results):
|
||||||
query_result = TopKQueryResult(query_result_arrays=result[1])
|
id_mrege_list.extend(id_results)
|
||||||
topk_query_result.append(query_result)
|
dis_mrege_list.extend(dis_results)
|
||||||
|
|
||||||
return status, topk_query_result
|
# for result in results:
|
||||||
|
# query_result = TopKQueryResult(query_result_arrays=result[1])
|
||||||
|
# topk_query_result.append(query_result)
|
||||||
|
|
||||||
|
return status, id_mrege_list, dis_mrege_list
|
||||||
|
|
||||||
def _do_query(self,
|
def _do_query(self,
|
||||||
context,
|
context,
|
||||||
|
@ -109,8 +158,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||||
file_ids=query_params['file_ids'],
|
file_ids=query_params['file_ids'],
|
||||||
query_records=vectors,
|
query_records=vectors,
|
||||||
top_k=topk,
|
top_k=topk,
|
||||||
nprobe=nprobe,
|
nprobe=nprobe
|
||||||
lazy_=True)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.info('search_vectors_in_files takes: {}'.format(end - start))
|
logger.info('search_vectors_in_files takes: {}'.format(end - start))
|
||||||
|
|
||||||
|
@ -241,7 +290,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||||
logger.info('Search {}: topk={} nprobe={}'.format(
|
logger.info('Search {}: topk={} nprobe={}'.format(
|
||||||
table_name, topk, nprobe))
|
table_name, topk, nprobe))
|
||||||
|
|
||||||
metadata = {'resp_class': milvus_pb2.TopKQueryResultList}
|
metadata = {'resp_class': milvus_pb2.TopKQueryResult}
|
||||||
|
|
||||||
if nprobe > self.MAX_NPROBE or nprobe <= 0:
|
if nprobe > self.MAX_NPROBE or nprobe <= 0:
|
||||||
raise exceptions.InvalidArgumentError(
|
raise exceptions.InvalidArgumentError(
|
||||||
|
@ -275,22 +324,24 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||||
query_range_array.append(
|
query_range_array.append(
|
||||||
Range(query_range.start_value, query_range.end_value))
|
Range(query_range.start_value, query_range.end_value))
|
||||||
|
|
||||||
status, results = self._do_query(context,
|
status, id_results, dis_results = self._do_query(context,
|
||||||
table_name,
|
table_name,
|
||||||
table_meta,
|
table_meta,
|
||||||
query_record_array,
|
query_record_array,
|
||||||
topk,
|
topk,
|
||||||
nprobe,
|
nprobe,
|
||||||
query_range_array,
|
query_range_array,
|
||||||
metadata=metadata)
|
metadata=metadata)
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
logger.info('SearchVector takes: {}'.format(now - start))
|
logger.info('SearchVector takes: {}'.format(now - start))
|
||||||
|
|
||||||
topk_result_list = milvus_pb2.TopKQueryResultList(
|
topk_result_list = milvus_pb2.TopKQueryResult(
|
||||||
status=status_pb2.Status(error_code=status.error_code,
|
status=status_pb2.Status(error_code=status.error_code,
|
||||||
reason=status.reason),
|
reason=status.reason),
|
||||||
topk_query_result=results)
|
row_num=len(query_record_array),
|
||||||
|
ids=id_results,
|
||||||
|
distances=dis_results)
|
||||||
return topk_result_list
|
return topk_result_list
|
||||||
|
|
||||||
@mark_grpc_method
|
@mark_grpc_method
|
||||||
|
|
Loading…
Reference in New Issue