mirror of https://github.com/milvus-io/milvus.git
Guard bad distance/score from knowhere (#21034)
/kind bug issue: #21138 Signed-off-by: Yuchen Gao <yuchen.gao@zilliz.com> Signed-off-by: Yuchen Gao <yuchen.gao@zilliz.com>pull/21208/head
parent
fc10c74005
commit
b01546edd3
|
@ -648,10 +648,13 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
|
|||
resultDataIdx = sIdx
|
||||
maxScore = sScore
|
||||
} else if sScore == maxScore {
|
||||
sID := typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx)
|
||||
tmpID := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
|
||||
|
||||
if typeutil.ComparePK(sID, tmpID) {
|
||||
if subSearchIdx == -1 {
|
||||
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
|
||||
// by mistake.
|
||||
log.Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
|
||||
} else if typeutil.ComparePK(
|
||||
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
|
||||
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
|
||||
subSearchIdx = i
|
||||
resultDataIdx = sIdx
|
||||
maxScore = sScore
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -1270,6 +1271,68 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("Integer ID with bad score", func(t *testing.T) {
|
||||
type args struct {
|
||||
subSearchResultData []*schemapb.SearchResultData
|
||||
subSearchNqOffset [][]int64
|
||||
cursors []int64
|
||||
topk int64
|
||||
nq int64
|
||||
}
|
||||
tests := []struct {
|
||||
description string
|
||||
args args
|
||||
|
||||
expectedIdx []int
|
||||
expectedDataIdx []int
|
||||
}{
|
||||
{
|
||||
description: "reduce 2 subSearchResultData",
|
||||
args: args{
|
||||
subSearchResultData: []*schemapb.SearchResultData{
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 8, 5, 3, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{12, 10, 7, 6, 4, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
},
|
||||
subSearchNqOffset: [][]int64{{0, 2, 4}, {0, 2, 4}},
|
||||
cursors: []int64{0, 0},
|
||||
topk: 2,
|
||||
nq: 3,
|
||||
},
|
||||
expectedIdx: []int{-1, -1, -1},
|
||||
expectedDataIdx: []int{-1, -1, -1},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
for nqNum := int64(0); nqNum < test.args.nq; nqNum++ {
|
||||
idx, dataIdx := selectHighestScoreIndex(test.args.subSearchResultData, test.args.subSearchNqOffset, test.args.cursors, nqNum)
|
||||
assert.Equal(t, test.expectedIdx[nqNum], idx)
|
||||
assert.Equal(t, test.expectedDataIdx[nqNum], int(dataIdx))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("String ID", func(t *testing.T) {
|
||||
type args struct {
|
||||
subSearchResultData []*schemapb.SearchResultData
|
||||
|
|
|
@ -198,10 +198,13 @@ func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffset
|
|||
maxDistance = distance
|
||||
resultDataIdx = idx
|
||||
} else if distance == maxDistance {
|
||||
sID := typeutil.GetPK(dataArray[i].GetIds(), idx)
|
||||
tmpID := typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)
|
||||
|
||||
if typeutil.ComparePK(sID, tmpID) {
|
||||
if sel == -1 {
|
||||
// A bad case happens where knowhere returns distance == +/-maxFloat32
|
||||
// by mistake.
|
||||
log.Error("a bad distance is found, something is wrong here!", zap.Float32("score", distance))
|
||||
} else if typeutil.ComparePK(
|
||||
typeutil.GetPK(dataArray[i].GetIds(), idx),
|
||||
typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)) {
|
||||
sel = i
|
||||
maxDistance = distance
|
||||
resultDataIdx = idx
|
||||
|
|
|
@ -18,6 +18,7 @@ package querynode
|
|||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -424,51 +425,104 @@ func TestResult_selectSearchResultData_int(t *testing.T) {
|
|||
nq int64
|
||||
qi int64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
dataArray: []*schemapb.SearchResultData{
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 7, 5, 3, 1},
|
||||
t.Run("Integer ID", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
dataArray: []*schemapb.SearchResultData{
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 7, 5, 3, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{12, 10, 8, 6, 4, 2},
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{12, 10, 8, 6, 4, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
|
||||
offsets: []int64{0, 1},
|
||||
topk: 2,
|
||||
nq: 3,
|
||||
qi: 0,
|
||||
},
|
||||
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
|
||||
offsets: []int64{0, 1},
|
||||
topk: 2,
|
||||
nq: 3,
|
||||
qi: 0,
|
||||
want: 0,
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
|
||||
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
|
||||
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Integer ID with bad score", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
dataArray: []*schemapb.SearchResultData{
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 7, 5, 3, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{12, 10, 8, 6, 4, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
},
|
||||
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
|
||||
offsets: []int64{0, 1},
|
||||
topk: 2,
|
||||
nq: 3,
|
||||
qi: 0,
|
||||
},
|
||||
want: -1,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
|
||||
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue