From fe50f977266e417ee2cf28695aaa5a5c9331b918 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Mon, 2 Aug 2021 10:25:49 +0800 Subject: [PATCH] Use id to tell search result validation in proxy reduce stage (#6905) * use id to tell search result validation Signed-off-by: yudong.cai * enable test_search_binary_hamming_flat_index Signed-off-by: yudong.cai * code optimize Signed-off-by: yudong.cai * fix merge retrieve result issue Signed-off-by: yudong.cai --- internal/proxy/task.go | 39 +++++++++---------- internal/querynode/query_collection.go | 4 +- .../python_client/testcases/test_search.py | 1 - 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 3769516186..5a99c0e9d9 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1687,39 +1687,36 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat j := 0 for ; j < topk; j++ { - valid := false + valid := true choice, maxDistance := 0, minFloat32 for q, loc := range locs { // query num, the number of ways to merge if loc >= topk { continue } - distance := searchResultData[q].Scores[idx*topk+loc] - // https://github.com/milvus-io/milvus/issues/6781 - if math.IsNaN(float64(distance)) { - continue - } - if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) { - choice = q - maxDistance = distance - valid = true + curIdx := idx*topk + loc + id := searchResultData[q].Ids.GetIntId().Data[curIdx] + if id == -1 { + valid = false + } else { + distance := searchResultData[q].Scores[curIdx] + if distance > maxDistance { + choice = q + maxDistance = distance + } } } if !valid { break } choiceOffset := locs[choice] - // check if distance is valid, `invalid` here means very very big, - // in this process, distance here is the smallest, so the rest of distance are all invalid - // https://github.com/milvus-io/milvus/issues/6781 - // tanimoto distance between two binary vectors maybe -inf, so -inf distance shouldn't be filtered, - // otherwise it will cause that the number of hit records is less than needed (topk). - // in the above process, we have already filtered NaN distance. - distance := searchResultData[choice].Scores[idx*topk+choiceOffset] - if distance < minFloat32 { - break - } curIdx := idx*topk + choiceOffset - ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, searchResultData[choice].Ids.GetIntId().Data[curIdx]) + + // ignore invalid search result + id := searchResultData[choice].Ids.GetIntId().Data[curIdx] + if id == -1 { + continue + } + ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id) // TODO(yukun): Process searchResultData.FieldsData for k, fieldData := range searchResultData[choice].FieldsData { switch fieldType := fieldData.Field.(type) { diff --git a/internal/querynode/query_collection.go b/internal/querynode/query_collection.go index 34f551cf6f..aa059757f7 100644 --- a/internal/querynode/query_collection.go +++ b/internal/querynode/query_collection.go @@ -1141,7 +1141,6 @@ func (q *queryCollection) fillVectorFieldsData(segment *Segment, result *segcore resultLen := dim copy(x.FloatVector.Data[i*int(resultLen):(i+1)*int(resultLen)], floatResult) } - } } return nil @@ -1294,7 +1293,8 @@ func (q *queryCollection) retrieve(msg queryMsg) error { func mergeRetrieveResults(dataArr []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { var final *segcorepb.RetrieveResults for _, data := range dataArr { - if data == nil { + // skip empty result, it will break merge result + if data == nil || len(data.Offset) == 0 { continue } diff --git a/tests20/python_client/testcases/test_search.py b/tests20/python_client/testcases/test_search.py index 8eefa16414..41600a8cc7 100644 --- a/tests20/python_client/testcases/test_search.py +++ b/tests20/python_client/testcases/test_search.py @@ -1429,7 +1429,6 @@ class TestCollectionSearch(TestcaseBase): assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon @pytest.mark.tags(CaseLabel.L2) - @pytest.mark.xfail(reason="issue 6469") def test_search_binary_hamming_flat_index(self, nq, dim, auto_id, _async): """ target: search binary_collection, and check the result: distance