[v0.10.2] (mishards) Fix mishards search bug (#3169)

* [skip ci]Reverse query result if metric is IP

Signed-off-by: yinghao.zou <yinghao.zou@zilliz.com>

* [skip ci] Update version check

Signed-off-by: yinghao.zou <yinghao.zou@zilliz.com>
pull/3232/head
BossZou 2020-08-07 19:33:49 +08:00 committed by GitHub
parent 058cdf03bb
commit f9f597b257
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 8 deletions

View File

@ -9,6 +9,7 @@ Please mark all change in change log and use the issue from GitHub
- \#2952 Fix the result merging of IVF_PQ IP
- \#2975 Fix config UT failed
- \#3012 If the cache is too small, queries using multiple GPUs will cause to crash
- \#3133 Reverse query result in mishards if metric type is IP
## Feature

View File

@ -224,6 +224,16 @@ logger = logging.getLogger(__name__)
# connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs)
# return connection
def version_supported(version):
version_pattern = lambda v : ".".join(v.split(".")[:2])
sv_patterns = set()
for supported_version in settings.SERVER_VERSIONS:
sv_patterns.add(version_pattern(supported_version))
v_pattern = version_pattern(version)
return v_pattern in sv_patterns
class ConnectionGroup(topology.TopoGroup):
def __init__(self, name):
@ -243,7 +253,7 @@ class ConnectionGroup(topology.TopoGroup):
if not status.OK():
logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name))
return False
if version not in settings.SERVER_VERSIONS:
if not version_supported(version):
logger.error('Cannot connect to server of version: {}. Only {} supported'.format(version,
settings.SERVER_VERSIONS))
return False

View File

@ -27,14 +27,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
self.max_workers = max_workers
def _reduce(self, source_ids, ids, source_diss, diss, k, reverse):
if source_diss[k - 1] <= diss[0]:
sort_f = lambda x, y: x >= y if reverse else lambda x, y: x <= y
if sort_f(source_diss[k - 1], diss[0]):
return source_ids, source_diss
if diss[k - 1] <= source_diss[0]:
if sort_f(diss[k - 1], source_diss[0]):
return ids, diss
source_diss.extend(diss)
diss_t = enumerate(source_diss)
diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k]
diss_m_rst = sorted(diss_t, key=lambda x: x[1], reverse=reverse)[:k]
diss_m_out = [id_ for _, id_ in diss_m_rst]
source_ids.extend(ids)
@ -149,9 +150,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
params=search_params, _async=True)
futures.append(future)
for f in futures:
ret = f.result(raw=True)
all_topk_results.append(ret)
for f in futures:
ret = f.result(raw=True)
all_topk_results.append(ret)
reverse = collection_meta.metric_type == Types.MetricType.IP
with self.tracer.start_span('do_merge', child_of=p_span):

View File

@ -12,7 +12,7 @@ else:
env.read_env()
SERVER_VERSIONS = ['0.9.0', '0.9.1', '0.10.0', '0.10.1']
SERVER_VERSIONS = ['0.9.x', '0.10.x']
DEBUG = env.bool('DEBUG', False)
MAX_RETRY = env.int('MAX_RETRY', 3)