milvus/internal/querynodev2/segments/search_reduce_test.go

259 lines
8.9 KiB
Go

package segments
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type SearchReduceSuite struct {
suite.Suite
}
func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
const (
nq = 1
topk = 4
)
suite.Run("case1", func() {
ids := []int64{1, 2, 3, 4}
scores := []float32{-1.0, -2.0, -3.0, -4.0}
topks := []int64{int64(len(ids))}
data1 := genSearchResultData(nq, topk, ids, scores, topks)
data2 := genSearchResultData(nq, topk, ids, scores, topks)
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
suite.Equal(ids, res.Ids.GetIntId().Data)
suite.Equal(scores, res.Scores)
})
suite.Run("case2", func() {
ids1 := []int64{1, 2, 3, 4}
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
suite.ElementsMatch([]int64{1, 5, 2, 3}, res.Ids.GetIntId().Data)
})
}
func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
const (
nq = 1
topk = 4
)
suite.Run("reduce_group_by_int", func() {
ids1 := []int64{1, 2, 3, 4}
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Int8,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{2, 3, 4, 5},
},
},
},
},
}
data2.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Int8,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{2, 3, 4, 5},
},
},
},
},
}
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
suite.ElementsMatch([]int32{2, 3, 4, 5}, res.GroupByFieldValue.GetScalars().GetIntData().Data)
})
suite.Run("reduce_group_by_bool", func() {
ids1 := []int64{1, 2}
scores1 := []float32{-1.0, -2.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{3, 4}
scores2 := []float32{-1.0, -1.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Bool,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: []bool{true, false},
},
},
},
},
}
data2.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Bool,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: []bool{true, false},
},
},
},
},
}
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
suite.ElementsMatch([]int64{1, 4}, res.Ids.GetIntId().Data)
suite.ElementsMatch([]float32{-1.0, -1.0}, res.Scores)
suite.ElementsMatch([]bool{true, false}, res.GroupByFieldValue.GetScalars().GetBoolData().Data)
})
suite.Run("reduce_group_by_string", func() {
ids1 := []int64{1, 2, 3, 4}
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"1", "2", "3", "4"},
},
},
},
},
}
data2.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"1", "2", "3", "4"},
},
},
},
},
}
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
suite.ElementsMatch([]string{"1", "2", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data)
})
suite.Run("reduce_group_by_string_with_group_size", func() {
ids1 := []int64{1, 2, 3, 4}
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{4, 5, 6, 7}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"1", "2", "3", "4"},
},
},
},
},
}
data2.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"1", "2", "3", "4"},
},
},
},
},
}
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
suite.ElementsMatch([]int64{1, 4, 2, 5, 3, 6, 7}, res.Ids.GetIntId().Data)
suite.ElementsMatch([]float32{-1.0, -1.0, -1.0, -2.0, -3.0, -3.0, -4.0}, res.Scores)
suite.ElementsMatch([]string{"1", "1", "2", "2", "3", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data)
})
suite.Run("reduce_group_by_empty_input", func() {
dataArray := make([]*schemapb.SearchResultData, 0)
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
searchReduce := InitSearchReducer(reduceInfo)
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
suite.Nil(err)
suite.Nil(res.GetIds().GetIdField())
suite.Equal(0, len(res.GetTopks()))
suite.Equal(0, len(res.GetScores()))
suite.Equal(int64(nq), res.GetNumQueries())
suite.Equal(int64(topk), res.GetTopK())
suite.Equal(0, len(res.GetFieldsData()))
})
}
func TestSearchReduce(t *testing.T) {
paramtable.Init()
suite.Run(t, new(SearchReduceSuite))
}