mirror of https://github.com/milvus-io/milvus.git
enhance: simplify reduction on single search result (#36334)
See: #36122 --------- Signed-off-by: Ted Xu <ted.xu@zilliz.com>pull/36278/head
parent
89397d1e66
commit
363004fd44
|
|
@ -218,55 +218,49 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
totalResCount += subSearchNqOffset[i][nq-1]
|
||||
}
|
||||
|
||||
var (
|
||||
skipDupCnt int64
|
||||
realTopK int64 = -1
|
||||
)
|
||||
if subSearchNum == 1 && offset == 0 {
|
||||
ret.Results = subSearchResultData[0]
|
||||
} else {
|
||||
var realTopK int64 = -1
|
||||
var retSize int64
|
||||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
j int64
|
||||
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
||||
skipOffsetMap = make(map[interface{}]bool)
|
||||
groupByValList = make([]interface{}, limit)
|
||||
groupByValIdx = 0
|
||||
)
|
||||
|
||||
j int64
|
||||
pkSet = make(map[interface{}]struct{})
|
||||
groupByValMap = make(map[interface{}][]*groupReduceInfo)
|
||||
skipOffsetMap = make(map[interface{}]bool)
|
||||
groupByValList = make([]interface{}, limit)
|
||||
groupByValIdx = 0
|
||||
)
|
||||
for j = 0; j < groupBound; {
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
subSearchRes := subSearchResultData[subSearchIdx]
|
||||
|
||||
for j = 0; j < groupBound; {
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
subSearchRes := subSearchResultData[subSearchIdx]
|
||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||
score := subSearchRes.GetScores()[resultDataIdx]
|
||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||
if groupByVal == nil {
|
||||
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
||||
"there must be sth wrong on queryNode side")
|
||||
}
|
||||
|
||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||
score := subSearchRes.GetScores()[resultDataIdx]
|
||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||
if groupByVal == nil {
|
||||
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
||||
"there must be sth wrong on queryNode side")
|
||||
}
|
||||
// remove duplicates
|
||||
if _, ok := pkSet[id]; !ok {
|
||||
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
||||
skipOffsetMap[groupByVal] = true
|
||||
// the first offset's group will be ignored
|
||||
skipDupCnt++
|
||||
} else if len(groupByValMap[groupByVal]) == 0 && int64(len(groupByValMap)) >= limit {
|
||||
// skip when groupbyMap has been full and found new groupByVal
|
||||
skipDupCnt++
|
||||
} else if int64(len(groupByValMap[groupByVal])) >= groupSize {
|
||||
// skip when target group has been full
|
||||
skipDupCnt++
|
||||
} else {
|
||||
if len(groupByValMap[groupByVal]) == 0 {
|
||||
groupByValList[groupByValIdx] = groupByVal
|
||||
|
|
@ -276,55 +270,43 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||
subSearchIdx: subSearchIdx,
|
||||
resultIdx: resultDataIdx, id: id, score: score,
|
||||
})
|
||||
pkSet[id] = struct{}{}
|
||||
j++
|
||||
}
|
||||
} else {
|
||||
skipDupCnt++
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
// assemble all eligible values in group
|
||||
// values in groupByValList is sorted by the highest score in each group
|
||||
for _, groupVal := range groupByValList {
|
||||
if groupVal != nil {
|
||||
groupEntities := groupByValMap[groupVal]
|
||||
for _, groupEntity := range groupEntities {
|
||||
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
// assemble all eligible values in group
|
||||
// values in groupByValList is sorted by the highest score in each group
|
||||
for _, groupVal := range groupByValList {
|
||||
if groupVal != nil {
|
||||
groupEntities := groupByValMap[groupVal]
|
||||
for _, groupEntity := range groupEntities {
|
||||
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
}
|
||||
log.Ctx(ctx).Debug("skip duplicated search result when doing group by", zap.Int64("count", skipDupCnt))
|
||||
|
||||
if float64(skipDupCnt) >= float64(totalResCount)*0.3 {
|
||||
log.Warn("GroupBy reduce skipped too many results, "+
|
||||
"this may influence the final result seriously",
|
||||
zap.Int64("skipDupCnt", skipDupCnt),
|
||||
zap.Int64("groupBound", groupBound))
|
||||
}
|
||||
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
if !metric.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
|
|
@ -370,91 +352,79 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
|||
ret.GetResults().AllSearchCount = allSearchCount
|
||||
}
|
||||
|
||||
var (
|
||||
subSearchNum = len(subSearchResultData)
|
||||
subSearchNum := len(subSearchResultData)
|
||||
if subSearchNum == 1 && offset == 0 {
|
||||
// sorting is not needed if there is only one shard and no offset, assigning the result directly.
|
||||
// we still need to adjust the scores later.
|
||||
ret.Results = subSearchResultData[0]
|
||||
// realTopK is the topK of the nq-th query, it is used in proxy but not handled by delegator.
|
||||
topks := subSearchResultData[0].Topks
|
||||
if len(topks) > 0 {
|
||||
ret.Results.TopK = topks[len(topks)-1]
|
||||
}
|
||||
} else {
|
||||
var realTopK int64 = -1
|
||||
var retSize int64
|
||||
|
||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||
)
|
||||
for i := 0; i < subSearchNum; i++ {
|
||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||
for j := int64(1); j < nq; j++ {
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
subSearchNqOffset := make([][]int64, subSearchNum)
|
||||
for i := 0; i < subSearchNum; i++ {
|
||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||
for j := int64(1); j < nq; j++ {
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
j int64
|
||||
)
|
||||
|
||||
var (
|
||||
skipDupCnt int64
|
||||
realTopK int64 = -1
|
||||
)
|
||||
// skip offset results
|
||||
for k := int64(0); k < offset; k++ {
|
||||
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
j int64
|
||||
idSet = make(map[interface{}]struct{}, limit)
|
||||
)
|
||||
|
||||
// skip offset results
|
||||
for k := int64(0); k < offset; k++ {
|
||||
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
// keep limit results
|
||||
for j = 0; j < limit; j++ {
|
||||
// From all the sub-query result sets of the i-th query vector,
|
||||
// find the sub-query result set index of the score j-th data,
|
||||
// and the index of the data in schemapb.SearchResultData
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
||||
|
||||
// keep limit results
|
||||
for j = 0; j < limit; {
|
||||
// From all the sub-query result sets of the i-th query vector,
|
||||
// find the sub-query result set index of the score j-th data,
|
||||
// and the index of the data in schemapb.SearchResultData
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
|
||||
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
||||
|
||||
// remove duplicatessds
|
||||
if _, ok := idSet[id]; !ok {
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
||||
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
// return nil, errors.New("the length (topk) between all result of query is different")
|
||||
}
|
||||
realTopK = j
|
||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
// limit search result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
|
||||
if skipDupCnt > 0 {
|
||||
log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
}
|
||||
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
if !metric.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ type SearchReduceUtilTestSuite struct {
|
|||
suite.Suite
|
||||
}
|
||||
|
||||
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
||||
func genTestDataSearchResultsData() []*schemapb.SearchResultData {
|
||||
var searchResultData1 *schemapb.SearchResultData
|
||||
var searchResultData2 *schemapb.SearchResultData
|
||||
|
||||
|
|
@ -49,10 +49,14 @@ func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
|||
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
|
||||
}
|
||||
}
|
||||
return []*schemapb.SearchResultData{searchResultData1, searchResultData2}
|
||||
}
|
||||
|
||||
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
||||
data := genTestDataSearchResultsData()
|
||||
searchResults := []*milvuspb.SearchResults{
|
||||
{Results: searchResultData1},
|
||||
{Results: searchResultData2},
|
||||
{Results: data[0]},
|
||||
{Results: data[1]},
|
||||
}
|
||||
|
||||
nq := int64(1)
|
||||
|
|
@ -128,6 +132,16 @@ func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
|||
}
|
||||
}
|
||||
|
||||
func (struts *SearchReduceUtilTestSuite) TestReduceSearchResult() {
|
||||
data := genTestDataSearchResultsData()
|
||||
|
||||
{
|
||||
results, err := reduceSearchResultDataNoGroupBy(context.Background(), []*schemapb.SearchResultData{data[0]}, 0, 0, "L2", schemapb.DataType_Int64, 0)
|
||||
struts.NoError(err)
|
||||
struts.Equal([]string{"7", "5", "4", "2", "3", "6", "1", "9", "8"}, results.Results.GetIds().GetStrId().Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchReduceUtilTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(SearchReduceUtilTestSuite))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -591,10 +591,8 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
|
|||
func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams) (*milvuspb.QueryResults, error) {
|
||||
log.Ctx(ctx).Debug("reduceInternalRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
|
||||
var (
|
||||
ret = &milvuspb.QueryResults{}
|
||||
|
||||
skipDupCnt int64
|
||||
loopEnd int
|
||||
ret = &milvuspb.QueryResults{}
|
||||
loopEnd int
|
||||
)
|
||||
|
||||
validRetrieveResults := []*internalpb.RetrieveResults{}
|
||||
|
|
@ -611,7 +609,6 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
idSet := make(map[interface{}]struct{})
|
||||
cursors := make([]int64, len(validRetrieveResults))
|
||||
|
||||
if queryParams != nil && queryParams.limit != typeutil.Unlimited {
|
||||
|
|
@ -636,21 +633,12 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
|||
ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].GetFieldsData(), int64(loopEnd))
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
for j := 0; j < loopEnd; {
|
||||
for j := 0; j < loopEnd; j++ {
|
||||
sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
|
||||
if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
|
||||
break
|
||||
}
|
||||
|
||||
pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
|
||||
if _, ok := idSet[pk]; !ok {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
|
||||
idSet[pk] = struct{}{}
|
||||
j++
|
||||
} else {
|
||||
// primary keys duplicate
|
||||
skipDupCnt++
|
||||
}
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
|
||||
|
||||
// limit retrieve result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
|
|
@ -660,10 +648,6 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
|||
cursors[sel]++
|
||||
}
|
||||
|
||||
if skipDupCnt > 0 {
|
||||
log.Ctx(ctx).Debug("skip duplicated query result while reducing QueryResults", zap.Int64("count", skipDupCnt))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -461,34 +461,6 @@ func TestTaskQuery_functions(t *testing.T) {
|
|||
fieldDataArray2 = append(fieldDataArray2, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
|
||||
fieldDataArray2 = append(fieldDataArray2, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
|
||||
|
||||
t.Run("test skip dupPK 2", func(t *testing.T) {
|
||||
result1 := &internalpb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{0, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldsData: fieldDataArray1,
|
||||
}
|
||||
result2 := &internalpb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{0, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldsData: fieldDataArray2,
|
||||
}
|
||||
result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result1, result2}, &queryParams{limit: 2})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(result.GetFieldsData()))
|
||||
assert.Equal(t, Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
||||
assert.InDeltaSlice(t, FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
|
||||
})
|
||||
|
||||
t.Run("test nil results", func(t *testing.T) {
|
||||
ret, err := reduceRetrieveResults(context.Background(), nil, &queryParams{})
|
||||
assert.NoError(t, err)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeutil
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
type GenericSchemaSlice[T any] interface {
|
||||
Append(T)
|
||||
Get(int) T
|
||||
}
|
||||
|
||||
type int64PkSlice struct {
|
||||
data *schemapb.IDs
|
||||
}
|
||||
|
||||
func (s *int64PkSlice) Append(v int64) {
|
||||
s.data.GetIntId().Data = append(s.data.GetIntId().Data, v)
|
||||
}
|
||||
|
||||
func (s *int64PkSlice) Get(i int) int64 {
|
||||
return s.data.GetIntId().Data[i]
|
||||
}
|
||||
|
||||
type stringPkSlice struct {
|
||||
data *schemapb.IDs
|
||||
}
|
||||
|
||||
func (s *stringPkSlice) Append(v string) {
|
||||
s.data.GetStrId().Data = append(s.data.GetStrId().Data, v)
|
||||
}
|
||||
|
||||
func (s *stringPkSlice) Get(i int) string {
|
||||
return s.data.GetStrId().Data[i]
|
||||
}
|
||||
|
||||
func NewInt64PkSchemaSlice(data *schemapb.IDs) GenericSchemaSlice[int64] {
|
||||
return &int64PkSlice{
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func NewStringPkSchemaSlice(data *schemapb.IDs) GenericSchemaSlice[string] {
|
||||
return &stringPkSlice{
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPk(dst *schemapb.IDs, src *schemapb.IDs, offset int) {
|
||||
switch dst.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
v := src.GetIntId().Data[offset]
|
||||
dst.GetIntId().Data = append(dst.GetIntId().Data, v)
|
||||
case *schemapb.IDs_StrId:
|
||||
v := src.GetStrId().Data[offset]
|
||||
dst.GetStrId().Data = append(dst.GetStrId().Data, v)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func TestPKSlice(t *testing.T) {
|
||||
data1 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ints := NewInt64PkSchemaSlice(data1)
|
||||
assert.Equal(t, int64(1), ints.Get(0))
|
||||
assert.Equal(t, int64(2), ints.Get(1))
|
||||
assert.Equal(t, int64(3), ints.Get(2))
|
||||
ints.Append(4)
|
||||
assert.Equal(t, int64(4), ints.Get(3))
|
||||
|
||||
data2 := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
strs := NewStringPkSchemaSlice(data2)
|
||||
assert.Equal(t, "1", strs.Get(0))
|
||||
assert.Equal(t, "2", strs.Get(1))
|
||||
assert.Equal(t, "3", strs.Get(2))
|
||||
strs.Append("4")
|
||||
assert.Equal(t, "4", strs.Get(3))
|
||||
}
|
||||
|
||||
func TestCopyPk(t *testing.T) {
|
||||
type args struct {
|
||||
dst *schemapb.IDs
|
||||
src *schemapb.IDs
|
||||
offset int
|
||||
dstAfter *schemapb.IDs
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
name: "ints",
|
||||
args: args{
|
||||
dst: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
src: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
offset: 0,
|
||||
dstAfter: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "strs",
|
||||
args: args{
|
||||
dst: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
src: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
offset: 0,
|
||||
dstAfter: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
CopyPk(tt.args.dst, tt.args.src, tt.args.offset)
|
||||
assert.Equal(t, tt.args.dst, tt.args.dstAfter)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCopyPK(b *testing.B) {
|
||||
internal := make([]int64, 1000)
|
||||
for i := 0; i < 1000; i++ {
|
||||
internal[i] = int64(i)
|
||||
}
|
||||
src := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: internal,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
b.Run("Typed", func(b *testing.B) {
|
||||
dst := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0, 1000),
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < GetSizeOfIDs(src); j++ {
|
||||
CopyPk(dst, src, j)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.Run("Any", func(b *testing.B) {
|
||||
dst := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, 0, 1000),
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < GetSizeOfIDs(src); j++ {
|
||||
pk := GetPK(src, int64(j))
|
||||
AppendPKs(dst, pk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
Reference in New Issue