mirror of https://github.com/milvus-io/milvus.git
				
				
				
			related: #35096 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/34976/head
							parent
							
								
									4b077e1bd2
								
							
						
					
					
						commit
						eb23e23cd2
					
				| 
						 | 
					@ -24,6 +24,9 @@ type rankParams struct {
 | 
				
			||||||
	limit           int64
 | 
						limit           int64
 | 
				
			||||||
	offset          int64
 | 
						offset          int64
 | 
				
			||||||
	roundDecimal    int64
 | 
						roundDecimal    int64
 | 
				
			||||||
 | 
						groupByFieldId  int64
 | 
				
			||||||
 | 
						groupSize       int64
 | 
				
			||||||
 | 
						groupStrictSize bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *rankParams) GetLimit() int64 {
 | 
					func (r *rankParams) GetLimit() int64 {
 | 
				
			||||||
| 
						 | 
					@ -47,6 +50,27 @@ func (r *rankParams) GetRoundDecimal() int64 {
 | 
				
			||||||
	return 0
 | 
						return 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *rankParams) GetGroupByFieldId() int64 {
 | 
				
			||||||
 | 
						if r != nil {
 | 
				
			||||||
 | 
							return r.groupByFieldId
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return -1
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *rankParams) GetGroupSize() int64 {
 | 
				
			||||||
 | 
						if r != nil {
 | 
				
			||||||
 | 
							return r.groupSize
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return 1
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *rankParams) GetGroupStrictSize() bool {
 | 
				
			||||||
 | 
						if r != nil {
 | 
				
			||||||
 | 
							return r.groupStrictSize
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *rankParams) String() string {
 | 
					func (r *rankParams) String() string {
 | 
				
			||||||
	return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
 | 
						return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -137,51 +161,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 5. parse group by field and group by size
 | 
						// 5. parse group by field and group by size
 | 
				
			||||||
	groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
 | 
						var groupByFieldId, groupSize int64
 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		groupByFieldName = ""
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var groupByFieldId int64 = -1
 | 
					 | 
				
			||||||
	if groupByFieldName != "" {
 | 
					 | 
				
			||||||
		fields := schema.GetFields()
 | 
					 | 
				
			||||||
		for _, field := range fields {
 | 
					 | 
				
			||||||
			if field.GetNullable() {
 | 
					 | 
				
			||||||
				return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if field.Name == groupByFieldName {
 | 
					 | 
				
			||||||
				groupByFieldId = field.FieldID
 | 
					 | 
				
			||||||
				break
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if groupByFieldId == -1 {
 | 
					 | 
				
			||||||
			return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var groupSize int64
 | 
					 | 
				
			||||||
	groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		groupSize = 1
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
 | 
					 | 
				
			||||||
		if err != nil || groupSize <= 0 {
 | 
					 | 
				
			||||||
			groupSize = 1
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
 | 
					 | 
				
			||||||
		return nil, 0, merr.WrapErrParameterInvalidMsg(
 | 
					 | 
				
			||||||
			fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var groupStrictSize bool
 | 
						var groupStrictSize bool
 | 
				
			||||||
	groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
 | 
						if isAdvanced {
 | 
				
			||||||
	if err != nil {
 | 
							groupByFieldId, groupSize, groupStrictSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetGroupStrictSize()
 | 
				
			||||||
		groupStrictSize = false
 | 
					 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
 | 
							groupByInfo := parseGroupByInfo(searchParamsPair, schema)
 | 
				
			||||||
		if err != nil {
 | 
							if groupByInfo.err != nil {
 | 
				
			||||||
			groupStrictSize = false
 | 
								return nil, 0, groupByInfo.err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
 | 
						// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
 | 
				
			||||||
| 
						 | 
					@ -298,8 +287,104 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
 | 
				
			||||||
	return partitionsSet.Collect(), nil
 | 
						return partitionsSet.Collect(), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type groupByInfo struct {
 | 
				
			||||||
 | 
						groupByFieldId  int64
 | 
				
			||||||
 | 
						groupSize       int64
 | 
				
			||||||
 | 
						groupStrictSize bool
 | 
				
			||||||
 | 
						err             error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (g *groupByInfo) GetGroupByFieldId() int64 {
 | 
				
			||||||
 | 
						if g != nil {
 | 
				
			||||||
 | 
							return g.groupByFieldId
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (g *groupByInfo) GetGroupSize() int64 {
 | 
				
			||||||
 | 
						if g != nil {
 | 
				
			||||||
 | 
							return g.groupSize
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (g *groupByInfo) GetGroupStrictSize() bool {
 | 
				
			||||||
 | 
						if g != nil {
 | 
				
			||||||
 | 
							return g.groupStrictSize
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (g *groupByInfo) GetError() error {
 | 
				
			||||||
 | 
						if g != nil {
 | 
				
			||||||
 | 
							return g.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) *groupByInfo {
 | 
				
			||||||
 | 
						ret := &groupByInfo{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 1. parse group_by_field
 | 
				
			||||||
 | 
						groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							groupByFieldName = ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var groupByFieldId int64 = -1
 | 
				
			||||||
 | 
						if groupByFieldName != "" {
 | 
				
			||||||
 | 
							fields := schema.GetFields()
 | 
				
			||||||
 | 
							for _, field := range fields {
 | 
				
			||||||
 | 
								if field.GetNullable() {
 | 
				
			||||||
 | 
									ret.err = merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
 | 
				
			||||||
 | 
									return ret
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if field.Name == groupByFieldName {
 | 
				
			||||||
 | 
									groupByFieldId = field.FieldID
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if groupByFieldId == -1 {
 | 
				
			||||||
 | 
								ret.err = merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
 | 
				
			||||||
 | 
								return ret
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ret.groupByFieldId = groupByFieldId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 2. parse group size
 | 
				
			||||||
 | 
						var groupSize int64
 | 
				
			||||||
 | 
						groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							groupSize = 1
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
 | 
				
			||||||
 | 
							if err != nil || groupSize <= 0 {
 | 
				
			||||||
 | 
								groupSize = 1
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
 | 
				
			||||||
 | 
							ret.err = merr.WrapErrParameterInvalidMsg(
 | 
				
			||||||
 | 
								fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
 | 
				
			||||||
 | 
							return ret
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ret.groupSize = groupSize
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 3.  parse group strict size
 | 
				
			||||||
 | 
						var groupStrictSize bool
 | 
				
			||||||
 | 
						groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							groupStrictSize = false
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								groupStrictSize = false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ret.groupStrictSize = groupStrictSize
 | 
				
			||||||
 | 
						return ret
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// parseRankParams get limit and offset from rankParams, both are optional.
 | 
					// parseRankParams get limit and offset from rankParams, both are optional.
 | 
				
			||||||
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
 | 
					func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*rankParams, error) {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		limit        int64
 | 
							limit        int64
 | 
				
			||||||
		offset       int64
 | 
							offset       int64
 | 
				
			||||||
| 
						 | 
					@ -343,23 +428,30 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
 | 
				
			||||||
		return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
 | 
							return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// parse group_by parameters from main request body for hybrid search
 | 
				
			||||||
 | 
						groupByInfo := parseGroupByInfo(rankParamsPair, schema)
 | 
				
			||||||
 | 
						if groupByInfo.err != nil {
 | 
				
			||||||
 | 
							return nil, groupByInfo.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &rankParams{
 | 
						return &rankParams{
 | 
				
			||||||
		limit:           limit,
 | 
							limit:           limit,
 | 
				
			||||||
		offset:          offset,
 | 
							offset:          offset,
 | 
				
			||||||
		roundDecimal:    roundDecimal,
 | 
							roundDecimal:    roundDecimal,
 | 
				
			||||||
 | 
							groupByFieldId:  groupByInfo.GetGroupByFieldId(),
 | 
				
			||||||
 | 
							groupSize:       groupByInfo.GetGroupSize(),
 | 
				
			||||||
 | 
							groupStrictSize: groupByInfo.GetGroupStrictSize(),
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
 | 
					func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
 | 
				
			||||||
	searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams()))
 | 
					 | 
				
			||||||
	copy(searchParams, req.GetRankParams())
 | 
					 | 
				
			||||||
	ret := &milvuspb.SearchRequest{
 | 
						ret := &milvuspb.SearchRequest{
 | 
				
			||||||
		Base:                  req.GetBase(),
 | 
							Base:                  req.GetBase(),
 | 
				
			||||||
		DbName:                req.GetDbName(),
 | 
							DbName:                req.GetDbName(),
 | 
				
			||||||
		CollectionName:        req.GetCollectionName(),
 | 
							CollectionName:        req.GetCollectionName(),
 | 
				
			||||||
		PartitionNames:        req.GetPartitionNames(),
 | 
							PartitionNames:        req.GetPartitionNames(),
 | 
				
			||||||
		OutputFields:          req.GetOutputFields(),
 | 
							OutputFields:          req.GetOutputFields(),
 | 
				
			||||||
		SearchParams:          searchParams,
 | 
							SearchParams:          req.GetRankParams(),
 | 
				
			||||||
		TravelTimestamp:       req.GetTravelTimestamp(),
 | 
							TravelTimestamp:       req.GetTravelTimestamp(),
 | 
				
			||||||
		GuaranteeTimestamp:    req.GetGuaranteeTimestamp(),
 | 
							GuaranteeTimestamp:    req.GetGuaranteeTimestamp(),
 | 
				
			||||||
		Nq:                    0,
 | 
							Nq:                    0,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -170,7 +170,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if t.SearchRequest.GetIsAdvanced() {
 | 
						if t.SearchRequest.GetIsAdvanced() {
 | 
				
			||||||
		t.rankParams, err = parseRankParams(t.request.GetSearchParams())
 | 
							t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Info("parseRankParams failed", zap.Error(err))
 | 
								log.Info("parseRankParams failed", zap.Error(err))
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
| 
						 | 
					@ -366,8 +366,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
 | 
				
			||||||
			Topk:               queryInfo.GetTopk(),
 | 
								Topk:               queryInfo.GetTopk(),
 | 
				
			||||||
			Offset:             offset,
 | 
								Offset:             offset,
 | 
				
			||||||
			MetricType:         queryInfo.GetMetricType(),
 | 
								MetricType:         queryInfo.GetMetricType(),
 | 
				
			||||||
			GroupByFieldId:     queryInfo.GetGroupByFieldId(),
 | 
								GroupByFieldId:     t.rankParams.GetGroupByFieldId(),
 | 
				
			||||||
			GroupSize:          queryInfo.GetGroupSize(),
 | 
								GroupSize:          t.rankParams.GetGroupSize(),
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// set PartitionIDs for sub search
 | 
							// set PartitionIDs for sub search
 | 
				
			||||||
| 
						 | 
					@ -407,10 +407,9 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
 | 
				
			||||||
			zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
 | 
								zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
 | 
				
			||||||
			zap.Stringer("plan", plan)) // may be very large if large term passed.
 | 
								zap.Stringer("plan", plan)) // may be very large if large term passed.
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(t.queryInfos) > 0 {
 | 
					
 | 
				
			||||||
		t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId()
 | 
						t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
 | 
				
			||||||
		t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize()
 | 
						t.SearchRequest.GroupSize = t.rankParams.GetGroupSize()
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// used for requery
 | 
						// used for requery
 | 
				
			||||||
	if t.partitionKeyMode {
 | 
						if t.partitionKeyMode {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2263,6 +2263,63 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
 | 
				
			||||||
		assert.Equal(t, int64(0), offset)
 | 
							assert.Equal(t, int64(0), offset)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) {
 | 
				
			||||||
 | 
							schema := &schemapb.CollectionSchema{
 | 
				
			||||||
 | 
								Fields: []*schemapb.FieldSchema{
 | 
				
			||||||
 | 
									{FieldID: 101, Name: "c1"},
 | 
				
			||||||
 | 
									{FieldID: 102, Name: "c2"},
 | 
				
			||||||
 | 
									{FieldID: 103, Name: "c3"},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// 1. first parse rank params
 | 
				
			||||||
 | 
							// outer params require to group by field 101 and groupSize=3 and groupStrictSize=false
 | 
				
			||||||
 | 
							testRankParamsPairs := getValidSearchParams()
 | 
				
			||||||
 | 
							testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   GroupByFieldKey,
 | 
				
			||||||
 | 
								Value: "c1",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   GroupSizeKey,
 | 
				
			||||||
 | 
								Value: strconv.FormatInt(3, 10),
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   GroupStrictSize,
 | 
				
			||||||
 | 
								Value: "false",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   LimitKey,
 | 
				
			||||||
 | 
								Value: "100",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							testRankParams, err := parseRankParams(testRankParamsPairs, schema)
 | 
				
			||||||
 | 
							assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 2. parse search params for sub request in hybridsearch
 | 
				
			||||||
 | 
							params := getValidSearchParams()
 | 
				
			||||||
 | 
							// inner params require to group by field 103 and groupSize=10 and groupStrictSize=true
 | 
				
			||||||
 | 
							params = append(params, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   GroupByFieldKey,
 | 
				
			||||||
 | 
								Value: "c3",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							params = append(params, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   GroupSizeKey,
 | 
				
			||||||
 | 
								Value: strconv.FormatInt(10, 10),
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							params = append(params, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   GroupStrictSize,
 | 
				
			||||||
 | 
								Value: "true",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							info, _, err := parseSearchInfo(params, schema, testRankParams)
 | 
				
			||||||
 | 
							assert.NoError(t, err)
 | 
				
			||||||
 | 
							assert.NotNil(t, info)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// all group_by related parameters should be aligned to parameters
 | 
				
			||||||
 | 
							// set by main request rather than inner sub request
 | 
				
			||||||
 | 
							assert.Equal(t, int64(101), info.GetGroupByFieldId())
 | 
				
			||||||
 | 
							assert.Equal(t, int64(3), info.GetGroupSize())
 | 
				
			||||||
 | 
							assert.False(t, info.GetGroupStrictSize())
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Run("parseSearchInfo error", func(t *testing.T) {
 | 
						t.Run("parseSearchInfo error", func(t *testing.T) {
 | 
				
			||||||
		spNoTopk := []*commonpb.KeyValuePair{{
 | 
							spNoTopk := []*commonpb.KeyValuePair{{
 | 
				
			||||||
			Key:   AnnsFieldKey,
 | 
								Key:   AnnsFieldKey,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue