mirror of https://github.com/milvus-io/milvus.git
				
				
				
			related: #36179 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/36463/head
							parent
							
								
									4779c6cb8f
								
							
						
					
					
						commit
						df7ae08851
					
				| 
						 | 
				
			
			@ -184,11 +184,11 @@ message RetrieveRequest {
 | 
			
		|||
  bool is_count = 13;
 | 
			
		||||
  int64 iteration_extension_reduce_rate = 14;
 | 
			
		||||
  string username = 15;
 | 
			
		||||
  bool reduce_stop_for_best = 16;
 | 
			
		||||
  bool reduce_stop_for_best = 16; //deprecated
 | 
			
		||||
  int32 reduce_type = 17;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
message RetrieveResults {
 | 
			
		||||
  common.MsgBase base = 1;
 | 
			
		||||
  common.Status status = 2;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,6 +20,7 @@ import (
 | 
			
		|||
	"github.com/milvus-io/milvus/internal/proto/querypb"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/types"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/util/exprutil"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/util/reduce"
 | 
			
		||||
	typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
 | 
			
		||||
	"github.com/milvus-io/milvus/pkg/common"
 | 
			
		||||
	"github.com/milvus-io/milvus/pkg/log"
 | 
			
		||||
| 
						 | 
				
			
			@ -75,9 +76,9 @@ type queryTask struct {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
type queryParams struct {
 | 
			
		||||
	limit             int64
 | 
			
		||||
	offset            int64
 | 
			
		||||
	reduceStopForBest bool
 | 
			
		||||
	limit      int64
 | 
			
		||||
	offset     int64
 | 
			
		||||
	reduceType reduce.IReduceType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// translateToOutputFieldIDs translates output fields name to output fields id.
 | 
			
		||||
| 
						 | 
				
			
			@ -142,6 +143,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
 | 
			
		|||
		limit             int64
 | 
			
		||||
		offset            int64
 | 
			
		||||
		reduceStopForBest bool
 | 
			
		||||
		isIterator        bool
 | 
			
		||||
		err               error
 | 
			
		||||
	)
 | 
			
		||||
	reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair)
 | 
			
		||||
| 
						 | 
				
			
			@ -154,10 +156,29 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
 | 
			
		|||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isIteratorStr, err := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, queryParamsPair)
 | 
			
		||||
	// if reduce_stop_for_best is provided
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		isIterator, err = strconv.ParseBool(isIteratorStr)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, merr.WrapErrParameterInvalid("true or false", isIteratorStr,
 | 
			
		||||
				"value for iterator field is invalid")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reduceType := reduce.IReduceNoOrder
 | 
			
		||||
	if isIterator {
 | 
			
		||||
		if reduceStopForBest {
 | 
			
		||||
			reduceType = reduce.IReduceInOrderForBest
 | 
			
		||||
		} else {
 | 
			
		||||
			reduceType = reduce.IReduceInOrder
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
 | 
			
		||||
	// if limit is not provided
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &queryParams{limit: typeutil.Unlimited, reduceStopForBest: reduceStopForBest}, nil
 | 
			
		||||
		return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType}, nil
 | 
			
		||||
	}
 | 
			
		||||
	limit, err = strconv.ParseInt(limitStr, 0, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -179,9 +200,9 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	return &queryParams{
 | 
			
		||||
		limit:             limit,
 | 
			
		||||
		offset:            offset,
 | 
			
		||||
		reduceStopForBest: reduceStopForBest,
 | 
			
		||||
		limit:      limit,
 | 
			
		||||
		offset:     offset,
 | 
			
		||||
		reduceType: reduceType,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -343,7 +364,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	t.RetrieveRequest.ReduceStopForBest = queryParams.reduceStopForBest
 | 
			
		||||
	if queryParams.reduceType == reduce.IReduceInOrderForBest {
 | 
			
		||||
		t.RetrieveRequest.ReduceStopForBest = true
 | 
			
		||||
	}
 | 
			
		||||
	t.RetrieveRequest.ReduceType = int32(queryParams.reduceType)
 | 
			
		||||
 | 
			
		||||
	t.queryParams = queryParams
 | 
			
		||||
	t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset
 | 
			
		||||
| 
						 | 
				
			
			@ -612,9 +636,10 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
 | 
			
		|||
	cursors := make([]int64, len(validRetrieveResults))
 | 
			
		||||
 | 
			
		||||
	if queryParams != nil && queryParams.limit != typeutil.Unlimited {
 | 
			
		||||
		// reduceStopForBest will try to get as many results as possible
 | 
			
		||||
		// IReduceInOrderForBest will try to get as many results as possible
 | 
			
		||||
		// so loopEnd in this case will be set to the sum of all results' size
 | 
			
		||||
		if !queryParams.reduceStopForBest {
 | 
			
		||||
		// to get as many qualified results as possible
 | 
			
		||||
		if reduce.ShouldUseInputLimit(queryParams.reduceType) {
 | 
			
		||||
			loopEnd = int(queryParams.limit)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -623,7 +648,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
 | 
			
		|||
	if queryParams != nil && queryParams.offset > 0 {
 | 
			
		||||
		for i := int64(0); i < queryParams.offset; i++ {
 | 
			
		||||
			sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
 | 
			
		||||
			if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
 | 
			
		||||
			if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) {
 | 
			
		||||
				return ret, nil
 | 
			
		||||
			}
 | 
			
		||||
			cursors[sel]++
 | 
			
		||||
| 
						 | 
				
			
			@ -635,7 +660,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
 | 
			
		|||
	maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
 | 
			
		||||
	for j := 0; j < loopEnd; j++ {
 | 
			
		||||
		sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
 | 
			
		||||
		if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) {
 | 
			
		||||
		if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,6 +34,7 @@ import (
 | 
			
		|||
	"github.com/milvus-io/milvus/internal/mocks"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/proto/internalpb"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/proto/querypb"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/util/reduce"
 | 
			
		||||
	"github.com/milvus-io/milvus/pkg/common"
 | 
			
		||||
	"github.com/milvus-io/milvus/pkg/log"
 | 
			
		||||
	"github.com/milvus-io/milvus/pkg/util/funcutil"
 | 
			
		||||
| 
						 | 
				
			
			@ -442,6 +443,101 @@ func TestTaskQuery_functions(t *testing.T) {
 | 
			
		|||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("test parseQueryParams for reduce type", func(t *testing.T) {
 | 
			
		||||
		{
 | 
			
		||||
			var inParams []*commonpb.KeyValuePair
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   ReduceStopForBestKey,
 | 
			
		||||
				Value: "True",
 | 
			
		||||
			})
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   IteratorField,
 | 
			
		||||
				Value: "True",
 | 
			
		||||
			})
 | 
			
		||||
			ret, err := parseQueryParams(inParams)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			assert.Equal(t, reduce.IReduceInOrderForBest, ret.reduceType)
 | 
			
		||||
		}
 | 
			
		||||
		{
 | 
			
		||||
			var inParams []*commonpb.KeyValuePair
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   ReduceStopForBestKey,
 | 
			
		||||
				Value: "True",
 | 
			
		||||
			})
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   IteratorField,
 | 
			
		||||
				Value: "TrueXXXX",
 | 
			
		||||
			})
 | 
			
		||||
			ret, err := parseQueryParams(inParams)
 | 
			
		||||
			assert.Error(t, err)
 | 
			
		||||
			assert.Nil(t, ret)
 | 
			
		||||
		}
 | 
			
		||||
		{
 | 
			
		||||
			var inParams []*commonpb.KeyValuePair
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   ReduceStopForBestKey,
 | 
			
		||||
				Value: "TrueXXXXX",
 | 
			
		||||
			})
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   IteratorField,
 | 
			
		||||
				Value: "True",
 | 
			
		||||
			})
 | 
			
		||||
			ret, err := parseQueryParams(inParams)
 | 
			
		||||
			assert.Error(t, err)
 | 
			
		||||
			assert.Nil(t, ret)
 | 
			
		||||
		}
 | 
			
		||||
		{
 | 
			
		||||
			var inParams []*commonpb.KeyValuePair
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   ReduceStopForBestKey,
 | 
			
		||||
				Value: "True",
 | 
			
		||||
			})
 | 
			
		||||
			// when not setting iterator tag, ignore reduce_stop_for_best
 | 
			
		||||
			ret, err := parseQueryParams(inParams)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			assert.Equal(t, reduce.IReduceNoOrder, ret.reduceType)
 | 
			
		||||
		}
 | 
			
		||||
		{
 | 
			
		||||
			var inParams []*commonpb.KeyValuePair
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   IteratorField,
 | 
			
		||||
				Value: "True",
 | 
			
		||||
			})
 | 
			
		||||
			// when not setting reduce_stop_for_best tag, reduce by keep results in order
 | 
			
		||||
			ret, err := parseQueryParams(inParams)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			assert.Equal(t, reduce.IReduceInOrder, ret.reduceType)
 | 
			
		||||
		}
 | 
			
		||||
		{
 | 
			
		||||
			var inParams []*commonpb.KeyValuePair
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   ReduceStopForBestKey,
 | 
			
		||||
				Value: "False",
 | 
			
		||||
			})
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   IteratorField,
 | 
			
		||||
				Value: "True",
 | 
			
		||||
			})
 | 
			
		||||
			ret, err := parseQueryParams(inParams)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			assert.Equal(t, reduce.IReduceInOrder, ret.reduceType)
 | 
			
		||||
		}
 | 
			
		||||
		{
 | 
			
		||||
			var inParams []*commonpb.KeyValuePair
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   ReduceStopForBestKey,
 | 
			
		||||
				Value: "False",
 | 
			
		||||
			})
 | 
			
		||||
			inParams = append(inParams, &commonpb.KeyValuePair{
 | 
			
		||||
				Key:   IteratorField,
 | 
			
		||||
				Value: "False",
 | 
			
		||||
			})
 | 
			
		||||
			ret, err := parseQueryParams(inParams)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			assert.Equal(t, reduce.IReduceNoOrder, ret.reduceType)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("test reduceRetrieveResults", func(t *testing.T) {
 | 
			
		||||
		const (
 | 
			
		||||
			Dim                  = 8
 | 
			
		||||
| 
						 | 
				
			
			@ -572,7 +668,7 @@ func TestTaskQuery_functions(t *testing.T) {
 | 
			
		|||
				r2.HasMoreResult = false
 | 
			
		||||
				result, err := reduceRetrieveResults(context.Background(),
 | 
			
		||||
					[]*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
					&queryParams{limit: 2, reduceStopForBest: true})
 | 
			
		||||
					&queryParams{limit: 2, reduceType: reduce.IReduceInOrderForBest})
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				assert.Equal(t, 2, len(result.GetFieldsData()))
 | 
			
		||||
				assert.Equal(t, []int64{11, 11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
 | 
			
		||||
| 
						 | 
				
			
			@ -585,7 +681,7 @@ func TestTaskQuery_functions(t *testing.T) {
 | 
			
		|||
				r2.HasMoreResult = true
 | 
			
		||||
				result, err := reduceRetrieveResults(context.Background(),
 | 
			
		||||
					[]*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
					&queryParams{limit: 1, offset: 1, reduceStopForBest: true})
 | 
			
		||||
					&queryParams{limit: 1, offset: 1, reduceType: reduce.IReduceInOrderForBest})
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				assert.Equal(t, 2, len(result.GetFieldsData()))
 | 
			
		||||
				assert.Equal(t, []int64{11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
 | 
			
		||||
| 
						 | 
				
			
			@ -596,7 +692,7 @@ func TestTaskQuery_functions(t *testing.T) {
 | 
			
		|||
				r2.HasMoreResult = true
 | 
			
		||||
				result, err := reduceRetrieveResults(context.Background(),
 | 
			
		||||
					[]*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
					&queryParams{limit: 2, offset: 1, reduceStopForBest: true})
 | 
			
		||||
					&queryParams{limit: 2, offset: 1, reduceType: reduce.IReduceInOrderForBest})
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				assert.Equal(t, 2, len(result.GetFieldsData()))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -609,7 +705,7 @@ func TestTaskQuery_functions(t *testing.T) {
 | 
			
		|||
				r2.HasMoreResult = false
 | 
			
		||||
				result, err := reduceRetrieveResults(context.Background(),
 | 
			
		||||
					[]*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
					&queryParams{limit: typeutil.Unlimited, reduceStopForBest: true})
 | 
			
		||||
					&queryParams{limit: typeutil.Unlimited, reduceType: reduce.IReduceInOrderForBest})
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				assert.Equal(t, 2, len(result.GetFieldsData()))
 | 
			
		||||
				assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
 | 
			
		||||
| 
						 | 
				
			
			@ -620,11 +716,21 @@ func TestTaskQuery_functions(t *testing.T) {
 | 
			
		|||
			t.Run("test stop reduce for best for unlimited set amd offset", func(t *testing.T) {
 | 
			
		||||
				result, err := reduceRetrieveResults(context.Background(),
 | 
			
		||||
					[]*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
					&queryParams{limit: typeutil.Unlimited, offset: 3, reduceStopForBest: true})
 | 
			
		||||
					&queryParams{limit: typeutil.Unlimited, offset: 3, reduceType: reduce.IReduceInOrderForBest})
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				assert.Equal(t, 2, len(result.GetFieldsData()))
 | 
			
		||||
				assert.Equal(t, []int64{22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
 | 
			
		||||
			})
 | 
			
		||||
			t.Run("test iterator without setting reduce stop for best", func(t *testing.T) {
 | 
			
		||||
				r1.HasMoreResult = true
 | 
			
		||||
				r2.HasMoreResult = true
 | 
			
		||||
				result, err := reduceRetrieveResults(context.Background(),
 | 
			
		||||
					[]*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
					&queryParams{limit: 1, reduceType: reduce.IReduceInOrder})
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				assert.Equal(t, 2, len(result.GetFieldsData()))
 | 
			
		||||
				assert.Equal(t, []int64{11}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,6 +7,7 @@ import (
 | 
			
		|||
	"github.com/milvus-io/milvus/internal/proto/internalpb"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/proto/querypb"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/proto/segcorepb"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/util/reduce"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type defaultLimitReducer struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -15,24 +16,24 @@ type defaultLimitReducer struct {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
type mergeParam struct {
 | 
			
		||||
	limit            int64
 | 
			
		||||
	outputFieldsId   []int64
 | 
			
		||||
	schema           *schemapb.CollectionSchema
 | 
			
		||||
	mergeStopForBest bool
 | 
			
		||||
	limit          int64
 | 
			
		||||
	outputFieldsId []int64
 | 
			
		||||
	schema         *schemapb.CollectionSchema
 | 
			
		||||
	reduceType     reduce.IReduceType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceStopForBest bool) *mergeParam {
 | 
			
		||||
func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceType reduce.IReduceType) *mergeParam {
 | 
			
		||||
	return &mergeParam{
 | 
			
		||||
		limit:            limit,
 | 
			
		||||
		outputFieldsId:   outputFieldsId,
 | 
			
		||||
		schema:           schema,
 | 
			
		||||
		mergeStopForBest: reduceStopForBest,
 | 
			
		||||
		limit:          limit,
 | 
			
		||||
		outputFieldsId: outputFieldsId,
 | 
			
		||||
		schema:         schema,
 | 
			
		||||
		reduceType:     reduceType,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *defaultLimitReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
 | 
			
		||||
	reduceParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(),
 | 
			
		||||
		r.schema, r.req.GetReq().GetReduceStopForBest())
 | 
			
		||||
		r.schema, reduce.ToReduceType(r.req.GetReq().GetReduceType()))
 | 
			
		||||
	return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, reduceParam)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -50,7 +51,7 @@ type defaultLimitReducerSegcore struct {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, segments []Segment, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
 | 
			
		||||
	mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, r.req.GetReq().GetReduceStopForBest())
 | 
			
		||||
	mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, reduce.ToReduceType(r.req.GetReq().GetReduceType()))
 | 
			
		||||
	return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, mergeParam, segments, plan, r.manager)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -293,7 +293,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
 | 
			
		|||
		return ret, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if param.limit != typeutil.Unlimited && !param.mergeStopForBest {
 | 
			
		||||
	if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) {
 | 
			
		||||
		loopEnd = int(param.limit)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -305,7 +305,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
 | 
			
		|||
	maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
 | 
			
		||||
	for j := 0; j < loopEnd; {
 | 
			
		||||
		sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors)
 | 
			
		||||
		if sel == -1 || (param.mergeStopForBest && drainOneResult) {
 | 
			
		||||
		if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -416,7 +416,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	var limit int = -1
 | 
			
		||||
	if param.limit != typeutil.Unlimited && !param.mergeStopForBest {
 | 
			
		||||
	if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) {
 | 
			
		||||
		limit = int(param.limit)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -438,7 +438,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
 | 
			
		|||
 | 
			
		||||
	for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ {
 | 
			
		||||
		sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors)
 | 
			
		||||
		if sel == -1 || (param.mergeStopForBest && drainOneResult) {
 | 
			
		||||
		if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -100,7 +100,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
 | 
			
		|||
		}
 | 
			
		||||
 | 
			
		||||
		result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
		suite.NoError(err)
 | 
			
		||||
		suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
		suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -114,7 +114,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
 | 
			
		|||
 | 
			
		||||
	suite.Run("test nil results", func() {
 | 
			
		||||
		ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), nil,
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
		suite.NoError(err)
 | 
			
		||||
		suite.Empty(ret.GetIds())
 | 
			
		||||
		suite.Empty(ret.GetFieldsData())
 | 
			
		||||
| 
						 | 
				
			
			@ -133,7 +133,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
 | 
			
		|||
		}
 | 
			
		||||
 | 
			
		||||
		ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r},
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
		suite.NoError(err)
 | 
			
		||||
		suite.Empty(ret.GetIds())
 | 
			
		||||
		suite.Empty(ret.GetFieldsData())
 | 
			
		||||
| 
						 | 
				
			
			@ -185,7 +185,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
 | 
			
		|||
			for _, test := range tests {
 | 
			
		||||
				suite.Run(test.description, func() {
 | 
			
		||||
					result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
 | 
			
		||||
						NewMergeParam(test.limit, make([]int64, 0), nil, false))
 | 
			
		||||
						NewMergeParam(test.limit, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
					suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
					suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData()))
 | 
			
		||||
					suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -225,14 +225,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
 | 
			
		|||
			}
 | 
			
		||||
 | 
			
		||||
			_, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result},
 | 
			
		||||
				NewMergeParam(reqLimit, make([]int64, 0), nil, false))
 | 
			
		||||
				NewMergeParam(reqLimit, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
			suite.Error(err)
 | 
			
		||||
			paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		suite.Run("test int ID", func() {
 | 
			
		||||
			result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
			intFieldData, has := getFieldData(result, Int64FieldID)
 | 
			
		||||
| 
						 | 
				
			
			@ -262,7 +262,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
 | 
			
		|||
			}
 | 
			
		||||
 | 
			
		||||
			result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2},
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
			suite.NoError(err)
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -321,7 +321,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
 | 
			
		|||
		}
 | 
			
		||||
 | 
			
		||||
		result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
		suite.NoError(err)
 | 
			
		||||
		suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
		suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -335,7 +335,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
 | 
			
		|||
 | 
			
		||||
	suite.Run("test nil results", func() {
 | 
			
		||||
		ret, err := MergeInternalRetrieveResult(context.Background(), nil,
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
		suite.NoError(err)
 | 
			
		||||
		suite.Empty(ret.GetIds())
 | 
			
		||||
		suite.Empty(ret.GetFieldsData())
 | 
			
		||||
| 
						 | 
				
			
			@ -373,7 +373,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
 | 
			
		|||
			},
 | 
			
		||||
		}
 | 
			
		||||
		result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{ret1, ret2},
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
			NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
		suite.NoError(err)
 | 
			
		||||
		suite.Equal(2, len(result.GetFieldsData()))
 | 
			
		||||
		suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -424,7 +424,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
 | 
			
		|||
			for _, test := range tests {
 | 
			
		||||
				suite.Run(test.description, func() {
 | 
			
		||||
					result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
						NewMergeParam(test.limit, make([]int64, 0), nil, false))
 | 
			
		||||
						NewMergeParam(test.limit, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
					suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
					suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData()))
 | 
			
		||||
					suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -463,14 +463,14 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
 | 
			
		|||
			}
 | 
			
		||||
 | 
			
		||||
			_, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result, result},
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
			suite.Error(err)
 | 
			
		||||
			paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		suite.Run("test int ID", func() {
 | 
			
		||||
			result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -501,7 +501,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
 | 
			
		|||
			}
 | 
			
		||||
 | 
			
		||||
			result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2},
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false))
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder))
 | 
			
		||||
			suite.NoError(err)
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -568,7 +568,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
 | 
			
		|||
			result1.HasMoreResult = true
 | 
			
		||||
			result2.HasMoreResult = true
 | 
			
		||||
			result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
 | 
			
		||||
				NewMergeParam(3, make([]int64, 0), nil, true))
 | 
			
		||||
				NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
 | 
			
		||||
			suite.NoError(err)
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			// has more result both, stop reduce when draining one result
 | 
			
		||||
| 
						 | 
				
			
			@ -586,7 +586,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
 | 
			
		|||
			result1.HasMoreResult = false
 | 
			
		||||
			result2.HasMoreResult = false
 | 
			
		||||
			result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true))
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
 | 
			
		||||
			suite.NoError(err)
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			// as result1 and result2 don't have better results neither
 | 
			
		||||
| 
						 | 
				
			
			@ -604,7 +604,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
 | 
			
		|||
			result1.HasMoreResult = true
 | 
			
		||||
			result2.HasMoreResult = false
 | 
			
		||||
			result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2},
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true))
 | 
			
		||||
				NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
 | 
			
		||||
			suite.NoError(err)
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			// as result1 may have better results, stop reducing when draining it
 | 
			
		||||
| 
						 | 
				
			
			@ -643,7 +643,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
 | 
			
		|||
		result1.HasMoreResult = true
 | 
			
		||||
		result2.HasMoreResult = false
 | 
			
		||||
		result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
 | 
			
		||||
			NewMergeParam(3, make([]int64, 0), nil, true))
 | 
			
		||||
			NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
 | 
			
		||||
		suite.NoError(err)
 | 
			
		||||
		suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
		suite.Equal([]int64{0, 2, 4, 6, 7}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -687,7 +687,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
 | 
			
		|||
			result1.HasMoreResult = false
 | 
			
		||||
			result2.HasMoreResult = false
 | 
			
		||||
			result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
 | 
			
		||||
				NewMergeParam(3, make([]int64, 0), nil, true))
 | 
			
		||||
				NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
 | 
			
		||||
			suite.NoError(err)
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			suite.Equal([]int64{0, 2, 4, 7}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
| 
						 | 
				
			
			@ -696,11 +696,20 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
 | 
			
		|||
			result1.HasMoreResult = false
 | 
			
		||||
			result2.HasMoreResult = true
 | 
			
		||||
			result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
 | 
			
		||||
				NewMergeParam(3, make([]int64, 0), nil, true))
 | 
			
		||||
				NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest))
 | 
			
		||||
			suite.NoError(err)
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			suite.Equal([]int64{0, 2}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
		})
 | 
			
		||||
		suite.Run("test no stop reduce for best ", func() {
 | 
			
		||||
			result1.HasMoreResult = true
 | 
			
		||||
			result2.HasMoreResult = true
 | 
			
		||||
			result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2},
 | 
			
		||||
				NewMergeParam(1, make([]int64, 0), nil, reduce.IReduceInOrder))
 | 
			
		||||
			suite.NoError(err)
 | 
			
		||||
			suite.Equal(3, len(result.GetFieldsData()))
 | 
			
		||||
			suite.Equal([]int64{0}, result.GetIds().GetIntId().GetData())
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -90,3 +90,30 @@ func (r *ResultInfo) GetIsAdvance() bool {
 | 
			
		|||
func (r *ResultInfo) SetMetricType(metricType string) {
 | 
			
		||||
	r.metricType = metricType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type IReduceType int32
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	IReduceNoOrder IReduceType = iota
 | 
			
		||||
	IReduceInOrder
 | 
			
		||||
	IReduceInOrderForBest
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ShouldStopWhenDrained(reduceType IReduceType) bool {
 | 
			
		||||
	return reduceType == IReduceInOrder || reduceType == IReduceInOrderForBest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ToReduceType(val int32) IReduceType {
 | 
			
		||||
	switch val {
 | 
			
		||||
	case 1:
 | 
			
		||||
		return IReduceInOrder
 | 
			
		||||
	case 2:
 | 
			
		||||
		return IReduceInOrderForBest
 | 
			
		||||
	default:
 | 
			
		||||
		return IReduceNoOrder
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ShouldUseInputLimit(reduceType IReduceType) bool {
 | 
			
		||||
	return reduceType == IReduceNoOrder || reduceType == IReduceInOrder
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue