fix: iterator cursor progress too fast(#36179) (#36180)

related: #36179

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
pull/36463/head
Chun Han 2024-09-24 11:45:13 +08:00 committed by GitHub
parent 4779c6cb8f
commit df7ae08851
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 222 additions and 54 deletions

View File

@ -184,11 +184,11 @@ message RetrieveRequest {
bool is_count = 13;
int64 iteration_extension_reduce_rate = 14;
string username = 15;
bool reduce_stop_for_best = 16;
bool reduce_stop_for_best = 16; //deprecated
int32 reduce_type = 17;
}
message RetrieveResults {
common.MsgBase base = 1;
common.Status status = 2;

View File

@ -20,6 +20,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/reduce"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
@ -77,7 +78,7 @@ type queryTask struct {
type queryParams struct {
limit int64
offset int64
reduceStopForBest bool
reduceType reduce.IReduceType
}
// translateToOutputFieldIDs translates output fields name to output fields id.
@ -142,6 +143,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
limit int64
offset int64
reduceStopForBest bool
isIterator bool
err error
)
reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair)
@ -154,10 +156,29 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
}
}
isIteratorStr, err := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, queryParamsPair)
// if reduce_stop_for_best is provided
if err == nil {
isIterator, err = strconv.ParseBool(isIteratorStr)
if err != nil {
return nil, merr.WrapErrParameterInvalid("true or false", isIteratorStr,
"value for iterator field is invalid")
}
}
reduceType := reduce.IReduceNoOrder
if isIterator {
if reduceStopForBest {
reduceType = reduce.IReduceInOrderForBest
} else {
reduceType = reduce.IReduceInOrder
}
}
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
// if limit is not provided
if err != nil {
return &queryParams{limit: typeutil.Unlimited, reduceStopForBest: reduceStopForBest}, nil
return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType}, nil
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
@ -181,7 +202,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
return &queryParams{
limit: limit,
offset: offset,
reduceStopForBest: reduceStopForBest,
reduceType: reduceType,
}, nil
}
@ -343,7 +364,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
if err != nil {
return err
}
t.RetrieveRequest.ReduceStopForBest = queryParams.reduceStopForBest
if queryParams.reduceType == reduce.IReduceInOrderForBest {
t.RetrieveRequest.ReduceStopForBest = true
}
t.RetrieveRequest.ReduceType = int32(queryParams.reduceType)
t.queryParams = queryParams
t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset
@ -612,9 +636,10 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
cursors := make([]int64, len(validRetrieveResults))
if queryParams != nil && queryParams.limit != typeutil.Unlimited {
// reduceStopForBest will try to get as many results as possible
// IReduceInOrderForBest will try to get as many results as possible
// so loopEnd in this case will be set to the sum of all results' size
if !queryParams.reduceStopForBest {
// to get as many qualified results as possible
if reduce.ShouldUseInputLimit(queryParams.reduceType) {
loopEnd = int(queryParams.limit)
}
}
@ -623,7 +648,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
if queryParams != nil && queryParams.offset > 0 {
for i := int64(0); i < queryParams.offset; i++ {
sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) {
return ret, nil
}
cursors[sel]++
@ -635,7 +660,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
for j := 0; j < loopEnd; j++ {
sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) {
break
}
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
@ -442,6 +443,101 @@ func TestTaskQuery_functions(t *testing.T) {
}
})
t.Run("test parseQueryParams for reduce type", func(t *testing.T) {
{
var inParams []*commonpb.KeyValuePair
inParams = append(inParams, &commonpb.KeyValuePair{
Key: ReduceStopForBestKey,
Value: "True",
})
inParams = append(inParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
ret, err := parseQueryParams(inParams)
assert.NoError(t, err)
assert.Equal(t, reduce.IReduceInOrderForBest, ret.reduceType)
}
{
var inParams []*commonpb.KeyValuePair
inParams = append(inParams, &commonpb.KeyValuePair{
Key: ReduceStopForBestKey,
Value: "True",
})
inParams = append(inParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "TrueXXXX",
})
ret, err := parseQueryParams(inParams)
assert.Error(t, err)
assert.Nil(t, ret)
}
{
var inParams []*commonpb.KeyValuePair
inParams = append(inParams, &commonpb.KeyValuePair{
Key: ReduceStopForBestKey,
Value: "TrueXXXXX",
})
inParams = append(inParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
ret, err := parseQueryParams(inParams)
assert.Error(t, err)
assert.Nil(t, ret)
}
{
var inParams []*commonpb.KeyValuePair
inParams = append(inParams, &commonpb.KeyValuePair{
Key: ReduceStopForBestKey,
Value: "True",
})
// when not setting iterator tag, ignore reduce_stop_for_best
ret, err := parseQueryParams(inParams)
assert.NoError(t, err)
assert.Equal(t, reduce.IReduceNoOrder, ret.reduceType)
}
{
var inParams []*commonpb.KeyValuePair
inParams = append(inParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
// when not setting reduce_stop_for_best tag, reduce by keep results in order
ret, err := parseQueryParams(inParams)
assert.NoError(t, err)
assert.Equal(t, reduce.IReduceInOrder, ret.reduceType)
}
{
var inParams []*commonpb.KeyValuePair
inParams = append(inParams, &commonpb.KeyValuePair{
Key: ReduceStopForBestKey,
Value: "False",
})
inParams = append(inParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
ret, err := parseQueryParams(inParams)
assert.NoError(t, err)
assert.Equal(t, reduce.IReduceInOrder, ret.reduceType)
}
{
var inParams []*commonpb.KeyValuePair
inParams = append(inParams, &commonpb.KeyValuePair{
Key: ReduceStopForBestKey,
Value: "False",
})
inParams = append(inParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "False",
})
ret, err := parseQueryParams(inParams)
assert.NoError(t, err)
assert.Equal(t, reduce.IReduceNoOrder, ret.reduceType)
}
})
t.Run("test reduceRetrieveResults", func(t *testing.T) {
const (
Dim = 8
@ -572,7 +668,7 @@ func TestTaskQuery_functions(t *testing.T) {
r2.HasMoreResult = false
result, err := reduceRetrieveResults(context.Background(),
[]*internalpb.RetrieveResults{r1, r2},
&queryParams{limit: 2, reduceStopForBest: true})
&queryParams{limit: 2, reduceType: reduce.IReduceInOrderForBest})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{11, 11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
@ -585,7 +681,7 @@ func TestTaskQuery_functions(t *testing.T) {
r2.HasMoreResult = true
result, err := reduceRetrieveResults(context.Background(),
[]*internalpb.RetrieveResults{r1, r2},
&queryParams{limit: 1, offset: 1, reduceStopForBest: true})
&queryParams{limit: 1, offset: 1, reduceType: reduce.IReduceInOrderForBest})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
@ -596,7 +692,7 @@ func TestTaskQuery_functions(t *testing.T) {
r2.HasMoreResult = true
result, err := reduceRetrieveResults(context.Background(),
[]*internalpb.RetrieveResults{r1, r2},
&queryParams{limit: 2, offset: 1, reduceStopForBest: true})
&queryParams{limit: 2, offset: 1, reduceType: reduce.IReduceInOrderForBest})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
@ -609,7 +705,7 @@ func TestTaskQuery_functions(t *testing.T) {
r2.HasMoreResult = false
result, err := reduceRetrieveResults(context.Background(),
[]*internalpb.RetrieveResults{r1, r2},
&queryParams{limit: typeutil.Unlimited, reduceStopForBest: true})
&queryParams{limit: typeutil.Unlimited, reduceType: reduce.IReduceInOrderForBest})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
@ -620,11 +716,21 @@ func TestTaskQuery_functions(t *testing.T) {
t.Run("test stop reduce for best for unlimited set amd offset", func(t *testing.T) {
result, err := reduceRetrieveResults(context.Background(),
[]*internalpb.RetrieveResults{r1, r2},
&queryParams{limit: typeutil.Unlimited, offset: 3, reduceStopForBest: true})
&queryParams{limit: typeutil.Unlimited, offset: 3, reduceType: reduce.IReduceInOrderForBest})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
})
t.Run("test iterator without setting reduce stop for best", func(t *testing.T) {
r1.HasMoreResult = true
r2.HasMoreResult = true
result, err := reduceRetrieveResults(context.Background(),
[]*internalpb.RetrieveResults{r1, r2},
&queryParams{limit: 1, reduceType: reduce.IReduceInOrder})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
assert.Equal(t, []int64{11}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
})
})
})
}

View File

@ -7,6 +7,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/reduce"
)
type defaultLimitReducer struct {
@ -18,21 +19,21 @@ type mergeParam struct {
limit int64
outputFieldsId []int64
schema *schemapb.CollectionSchema
mergeStopForBest bool
reduceType reduce.IReduceType
}
func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceStopForBest bool) *mergeParam {
func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceType reduce.IReduceType) *mergeParam {
return &mergeParam{
limit: limit,
outputFieldsId: outputFieldsId,
schema: schema,
mergeStopForBest: reduceStopForBest,
reduceType: reduceType,
}
}
func (r *defaultLimitReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
reduceParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(),
r.schema, r.req.GetReq().GetReduceStopForBest())
r.schema, reduce.ToReduceType(r.req.GetReq().GetReduceType()))
return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, reduceParam)
}
@ -50,7 +51,7 @@ type defaultLimitReducerSegcore struct {
}
func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, segments []Segment, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, r.req.GetReq().GetReduceStopForBest())
mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, reduce.ToReduceType(r.req.GetReq().GetReduceType()))
return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, mergeParam, segments, plan, r.manager)
}

View File

@ -293,7 +293,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
return ret, nil
}
if param.limit != typeutil.Unlimited && !param.mergeStopForBest {
if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) {
loopEnd = int(param.limit)
}
@ -305,7 +305,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
for j := 0; j < loopEnd; {
sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors)
if sel == -1 || (param.mergeStopForBest && drainOneResult) {
if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) {
break
}
@ -416,7 +416,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
}
var limit int = -1
if param.limit != typeutil.Unlimited && !param.mergeStopForBest {
if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) {
limit = int(param.limit)
}
@ -438,7 +438,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ {
sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors)
if sel == -1 || (param.mergeStopForBest && drainOneResult) {
if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) {
break
}

View File

@ -100,7 +100,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
}
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData())
@ -114,7 +114,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
suite.Run("test nil results", func() {
ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), nil,
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.NoError(err)
suite.Empty(ret.GetIds())
suite.Empty(ret.GetFieldsData())
@ -133,7 +133,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
}
ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.NoError(err)
suite.Empty(ret.GetIds())
suite.Empty(ret.GetFieldsData())
@ -185,7 +185,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
for _, test := range tests {
suite.Run(test.description, func() {
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
NewMergeParam(test.limit, make([]int64, 0), nil, false))
NewMergeParam(test.limit, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData()))
suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData())
@ -225,14 +225,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
}
_, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result},
NewMergeParam(reqLimit, make([]int64, 0), nil, false))
NewMergeParam(reqLimit, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.Error(err)
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
})
suite.Run("test int ID", func() {
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData())
intFieldData, has := getFieldData(result, Int64FieldID)
@ -262,7 +262,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
}
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData())
@ -321,7 +321,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
}
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData())
@ -335,7 +335,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
suite.Run("test nil results", func() {
ret, err := MergeInternalRetrieveResult(context.Background(), nil,
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.NoError(err)
suite.Empty(ret.GetIds())
suite.Empty(ret.GetFieldsData())
@ -373,7 +373,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
},
}
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{ret1, ret2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.NoError(err)
suite.Equal(2, len(result.GetFieldsData()))
suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData())
@ -424,7 +424,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
for _, test := range tests {
suite.Run(test.description, func() {
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2},
NewMergeParam(test.limit, make([]int64, 0), nil, false))
NewMergeParam(test.limit, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData()))
suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData())
@ -463,14 +463,14 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
}
_, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result, result},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.Error(err)
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
})
suite.Run("test int ID", func() {
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData())
@ -501,7 +501,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
}
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData())
@ -568,7 +568,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
result1.HasMoreResult = true
result2.HasMoreResult = true
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
NewMergeParam(3, make([]int64, 0), nil, true))
NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
// has more result both, stop reduce when draining one result
@ -586,7 +586,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
result1.HasMoreResult = false
result2.HasMoreResult = false
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
// as result1 and result2 don't have better results neither
@ -604,7 +604,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
result1.HasMoreResult = true
result2.HasMoreResult = false
result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true))
NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
// as result1 may have better results, stop reducing when draining it
@ -643,7 +643,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
result1.HasMoreResult = true
result2.HasMoreResult = false
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
NewMergeParam(3, make([]int64, 0), nil, true))
NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]int64{0, 2, 4, 6, 7}, result.GetIds().GetIntId().GetData())
@ -687,7 +687,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
result1.HasMoreResult = false
result2.HasMoreResult = false
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
NewMergeParam(3, make([]int64, 0), nil, true))
NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]int64{0, 2, 4, 7}, result.GetIds().GetIntId().GetData())
@ -696,11 +696,20 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
result1.HasMoreResult = false
result2.HasMoreResult = true
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
NewMergeParam(3, make([]int64, 0), nil, true))
NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]int64{0, 2}, result.GetIds().GetIntId().GetData())
})
suite.Run("test no stop reduce for best ", func() {
result1.HasMoreResult = true
result2.HasMoreResult = true
result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
NewMergeParam(1, make([]int64, 0), nil, reduce.IReduceInOrder))
suite.NoError(err)
suite.Equal(3, len(result.GetFieldsData()))
suite.Equal([]int64{0}, result.GetIds().GetIntId().GetData())
})
})
}

View File

@ -90,3 +90,30 @@ func (r *ResultInfo) GetIsAdvance() bool {
func (r *ResultInfo) SetMetricType(metricType string) {
r.metricType = metricType
}
type IReduceType int32
const (
IReduceNoOrder IReduceType = iota
IReduceInOrder
IReduceInOrderForBest
)
func ShouldStopWhenDrained(reduceType IReduceType) bool {
return reduceType == IReduceInOrder || reduceType == IReduceInOrderForBest
}
func ToReduceType(val int32) IReduceType {
switch val {
case 1:
return IReduceInOrder
case 2:
return IReduceInOrderForBest
default:
return IReduceNoOrder
}
}
func ShouldUseInputLimit(reduceType IReduceType) bool {
return reduceType == IReduceNoOrder || reduceType == IReduceInOrder
}