mirror of https://github.com/milvus-io/milvus.git
Fix binary results unstable (#19401)
See also: #19338, #19366, 19367 Signed-off-by: yangxuan <xuan.yang@zilliz.com> Signed-off-by: yangxuan <xuan.yang@zilliz.com>pull/19448/head
parent
7819297f68
commit
52c6a2706e
|
@ -564,11 +564,21 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
|
|||
}
|
||||
sIdx := subSearchNqOffset[i][qi] + cursors[i]
|
||||
sScore := subSearchResultData[i].Scores[sIdx]
|
||||
|
||||
// Choose the larger score idx or the smaller pk idx with the same score
|
||||
if sScore > maxScore {
|
||||
subSearchIdx = i
|
||||
resultDataIdx = sIdx
|
||||
|
||||
maxScore = sScore
|
||||
} else if sScore == maxScore {
|
||||
sID := typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx)
|
||||
tmpID := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
|
||||
|
||||
if typeutil.ComparePK(sID, tmpID) {
|
||||
subSearchIdx = i
|
||||
resultDataIdx = sIdx
|
||||
maxScore = sScore
|
||||
}
|
||||
}
|
||||
}
|
||||
return subSearchIdx, resultDataIdx
|
||||
|
|
|
@ -1218,7 +1218,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2},
|
||||
Scores: []float32{1.2, 1.0, 0.7, 0.5, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
},
|
||||
|
@ -1280,7 +1280,7 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2},
|
||||
Scores: []float32{1.2, 1.0, 0.7, 0.5, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -183,17 +183,32 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
}
|
||||
|
||||
func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int {
|
||||
sel := -1
|
||||
maxDistance := -1 * float32(math.MaxFloat32)
|
||||
var (
|
||||
sel = -1
|
||||
maxDistance = -1 * float32(math.MaxFloat32)
|
||||
resultDataIdx int64 = -1
|
||||
)
|
||||
for i, offset := range offsets { // query num, the number of ways to merge
|
||||
if offset >= dataArray[i].Topks[qi] {
|
||||
continue
|
||||
}
|
||||
|
||||
idx := resultOffsets[i][qi] + offset
|
||||
distance := dataArray[i].Scores[idx]
|
||||
|
||||
if distance > maxDistance {
|
||||
sel = i
|
||||
maxDistance = distance
|
||||
resultDataIdx = idx
|
||||
} else if distance == maxDistance {
|
||||
sID := typeutil.GetPK(dataArray[i].GetIds(), idx)
|
||||
tmpID := typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)
|
||||
|
||||
if typeutil.ComparePK(sID, tmpID) {
|
||||
sel = i
|
||||
maxDistance = distance
|
||||
resultDataIdx = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
return sel
|
||||
|
|
|
@ -35,7 +35,7 @@ func (s *byPK) Swap(i, j int) {
|
|||
}
|
||||
|
||||
func (s *byPK) Less(i, j int) bool {
|
||||
return typeutil.ComparePK(s.r.GetIds(), i, j)
|
||||
return typeutil.ComparePKInSlice(s.r.GetIds(), i, j)
|
||||
}
|
||||
|
||||
func swapFieldData(field *schemapb.FieldData, i int, j int) {
|
||||
|
|
|
@ -657,8 +657,8 @@ func SwapPK(data *schemapb.IDs, i, j int) {
|
|||
}
|
||||
}
|
||||
|
||||
// ComparePK returns if i-th PK < j-th PK
|
||||
func ComparePK(data *schemapb.IDs, i, j int) bool {
|
||||
// ComparePKInSlice returns if i-th PK < j-th PK
|
||||
func ComparePKInSlice(data *schemapb.IDs, i, j int) bool {
|
||||
switch f := data.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
return f.IntId.Data[i] < f.IntId.Data[j]
|
||||
|
@ -668,6 +668,17 @@ func ComparePK(data *schemapb.IDs, i, j int) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// ComparePK returns if i-th PK of dataA > j-th PK of dataB
|
||||
func ComparePK(pkA, pkB interface{}) bool {
|
||||
switch pkA.(type) {
|
||||
case int64:
|
||||
return pkA.(int64) < pkB.(int64)
|
||||
case string:
|
||||
return pkA.(string) < pkB.(string)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type ResultWithID interface {
|
||||
GetIds() *schemapb.IDs
|
||||
}
|
||||
|
|
|
@ -662,18 +662,18 @@ func TestComparePk(t *testing.T) {
|
|||
AppendPKs(intPks, int64(3))
|
||||
require.Equal(t, []int64{1, 2, 3}, intPks.GetIntId().GetData())
|
||||
|
||||
less := ComparePK(intPks, 0, 1)
|
||||
less := ComparePKInSlice(intPks, 0, 1)
|
||||
assert.True(t, less)
|
||||
less = ComparePK(intPks, 0, 2)
|
||||
less = ComparePKInSlice(intPks, 0, 2)
|
||||
assert.True(t, less)
|
||||
less = ComparePK(intPks, 1, 2)
|
||||
less = ComparePKInSlice(intPks, 1, 2)
|
||||
assert.True(t, less)
|
||||
|
||||
less = ComparePK(intPks, 1, 0)
|
||||
less = ComparePKInSlice(intPks, 1, 0)
|
||||
assert.False(t, less)
|
||||
less = ComparePK(intPks, 2, 0)
|
||||
less = ComparePKInSlice(intPks, 2, 0)
|
||||
assert.False(t, less)
|
||||
less = ComparePK(intPks, 2, 1)
|
||||
less = ComparePKInSlice(intPks, 2, 1)
|
||||
assert.False(t, less)
|
||||
|
||||
strPks := &schemapb.IDs{}
|
||||
|
@ -683,17 +683,17 @@ func TestComparePk(t *testing.T) {
|
|||
|
||||
require.Equal(t, []string{"1", "2", "3"}, strPks.GetStrId().GetData())
|
||||
|
||||
less = ComparePK(strPks, 0, 1)
|
||||
less = ComparePKInSlice(strPks, 0, 1)
|
||||
assert.True(t, less)
|
||||
less = ComparePK(strPks, 0, 2)
|
||||
less = ComparePKInSlice(strPks, 0, 2)
|
||||
assert.True(t, less)
|
||||
less = ComparePK(strPks, 1, 2)
|
||||
less = ComparePKInSlice(strPks, 1, 2)
|
||||
assert.True(t, less)
|
||||
|
||||
less = ComparePK(strPks, 1, 0)
|
||||
less = ComparePKInSlice(strPks, 1, 0)
|
||||
assert.False(t, less)
|
||||
less = ComparePK(strPks, 2, 0)
|
||||
less = ComparePKInSlice(strPks, 2, 0)
|
||||
assert.False(t, less)
|
||||
less = ComparePK(strPks, 2, 1)
|
||||
less = ComparePKInSlice(strPks, 2, 1)
|
||||
assert.False(t, less)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue