mirror of https://github.com/milvus-io/milvus.git
Use id to tell search result validation in proxy reduce stage (#6905)
* use id to tell search result validation Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * enable test_search_binary_hamming_flat_index Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * code optimize Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix merge retrieve result issue Signed-off-by: yudong.cai <yudong.cai@zilliz.com>pull/6807/head^2
parent
fb5ca43621
commit
fe50f97726
|
@ -1687,39 +1687,36 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat
|
|||
|
||||
j := 0
|
||||
for ; j < topk; j++ {
|
||||
valid := false
|
||||
valid := true
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= topk {
|
||||
continue
|
||||
}
|
||||
distance := searchResultData[q].Scores[idx*topk+loc]
|
||||
// https://github.com/milvus-io/milvus/issues/6781
|
||||
if math.IsNaN(float64(distance)) {
|
||||
continue
|
||||
}
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
curIdx := idx*topk + loc
|
||||
id := searchResultData[q].Ids.GetIntId().Data[curIdx]
|
||||
if id == -1 {
|
||||
valid = false
|
||||
} else {
|
||||
distance := searchResultData[q].Scores[curIdx]
|
||||
if distance > maxDistance {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
}
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
// https://github.com/milvus-io/milvus/issues/6781
|
||||
// tanimoto distance between two binary vectors maybe -inf, so -inf distance shouldn't be filtered,
|
||||
// otherwise it will cause that the number of hit records is less than needed (topk).
|
||||
// in the above process, we have already filtered NaN distance.
|
||||
distance := searchResultData[choice].Scores[idx*topk+choiceOffset]
|
||||
if distance < minFloat32 {
|
||||
break
|
||||
}
|
||||
curIdx := idx*topk + choiceOffset
|
||||
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, searchResultData[choice].Ids.GetIntId().Data[curIdx])
|
||||
|
||||
// ignore invalid search result
|
||||
id := searchResultData[choice].Ids.GetIntId().Data[curIdx]
|
||||
if id == -1 {
|
||||
continue
|
||||
}
|
||||
ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id)
|
||||
// TODO(yukun): Process searchResultData.FieldsData
|
||||
for k, fieldData := range searchResultData[choice].FieldsData {
|
||||
switch fieldType := fieldData.Field.(type) {
|
||||
|
|
|
@ -1141,7 +1141,6 @@ func (q *queryCollection) fillVectorFieldsData(segment *Segment, result *segcore
|
|||
resultLen := dim
|
||||
copy(x.FloatVector.Data[i*int(resultLen):(i+1)*int(resultLen)], floatResult)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -1294,7 +1293,8 @@ func (q *queryCollection) retrieve(msg queryMsg) error {
|
|||
func mergeRetrieveResults(dataArr []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
|
||||
var final *segcorepb.RetrieveResults
|
||||
for _, data := range dataArr {
|
||||
if data == nil {
|
||||
// skip empty result, it will break merge result
|
||||
if data == nil || len(data.Offset) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -1429,7 +1429,6 @@ class TestCollectionSearch(TestcaseBase):
|
|||
assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.xfail(reason="issue 6469")
|
||||
def test_search_binary_hamming_flat_index(self, nq, dim, auto_id, _async):
|
||||
"""
|
||||
target: search binary_collection, and check the result: distance
|
||||
|
|
Loading…
Reference in New Issue