enhance: refine parameter relationship for hybridsearch_group_by(#35096) (#36289)

related: #35096

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
pull/34976/head
Chun Han 2024-09-20 14:55:11 +08:00 committed by GitHub
parent 4b077e1bd2
commit eb23e23cd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 207 additions and 59 deletions

View File

@ -21,9 +21,12 @@ import (
) )
type rankParams struct { type rankParams struct {
limit int64 limit int64
offset int64 offset int64
roundDecimal int64 roundDecimal int64
groupByFieldId int64
groupSize int64
groupStrictSize bool
} }
func (r *rankParams) GetLimit() int64 { func (r *rankParams) GetLimit() int64 {
@ -47,6 +50,27 @@ func (r *rankParams) GetRoundDecimal() int64 {
return 0 return 0
} }
func (r *rankParams) GetGroupByFieldId() int64 {
if r != nil {
return r.groupByFieldId
}
return -1
}
func (r *rankParams) GetGroupSize() int64 {
if r != nil {
return r.groupSize
}
return 1
}
func (r *rankParams) GetGroupStrictSize() bool {
if r != nil {
return r.groupStrictSize
}
return false
}
func (r *rankParams) String() string { func (r *rankParams) String() string {
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal()) return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
} }
@ -137,51 +161,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
} }
// 5. parse group by field and group by size // 5. parse group by field and group by size
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair) var groupByFieldId, groupSize int64
if err != nil {
groupByFieldName = ""
}
var groupByFieldId int64 = -1
if groupByFieldName != "" {
fields := schema.GetFields()
for _, field := range fields {
if field.GetNullable() {
return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
}
if field.Name == groupByFieldName {
groupByFieldId = field.FieldID
break
}
}
if groupByFieldId == -1 {
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
}
}
var groupSize int64
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
if err != nil {
groupSize = 1
} else {
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
if err != nil || groupSize <= 0 {
groupSize = 1
}
}
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
return nil, 0, merr.WrapErrParameterInvalidMsg(
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
}
var groupStrictSize bool var groupStrictSize bool
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair) if isAdvanced {
if err != nil { groupByFieldId, groupSize, groupStrictSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetGroupStrictSize()
groupStrictSize = false
} else { } else {
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr) groupByInfo := parseGroupByInfo(searchParamsPair, schema)
if err != nil { if groupByInfo.err != nil {
groupStrictSize = false return nil, 0, groupByInfo.err
} }
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
} }
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
@ -298,8 +287,104 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
return partitionsSet.Collect(), nil return partitionsSet.Collect(), nil
} }
type groupByInfo struct {
groupByFieldId int64
groupSize int64
groupStrictSize bool
err error
}
func (g *groupByInfo) GetGroupByFieldId() int64 {
if g != nil {
return g.groupByFieldId
}
return 0
}
func (g *groupByInfo) GetGroupSize() int64 {
if g != nil {
return g.groupSize
}
return 0
}
func (g *groupByInfo) GetGroupStrictSize() bool {
if g != nil {
return g.groupStrictSize
}
return false
}
func (g *groupByInfo) GetError() error {
if g != nil {
return g.err
}
return nil
}
func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) *groupByInfo {
ret := &groupByInfo{}
// 1. parse group_by_field
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
if err != nil {
groupByFieldName = ""
}
var groupByFieldId int64 = -1
if groupByFieldName != "" {
fields := schema.GetFields()
for _, field := range fields {
if field.GetNullable() {
ret.err = merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
return ret
}
if field.Name == groupByFieldName {
groupByFieldId = field.FieldID
break
}
}
if groupByFieldId == -1 {
ret.err = merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
return ret
}
}
ret.groupByFieldId = groupByFieldId
// 2. parse group size
var groupSize int64
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
if err != nil {
groupSize = 1
} else {
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
if err != nil || groupSize <= 0 {
groupSize = 1
}
}
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
ret.err = merr.WrapErrParameterInvalidMsg(
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
return ret
}
ret.groupSize = groupSize
// 3. parse group strict size
var groupStrictSize bool
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
if err != nil {
groupStrictSize = false
} else {
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
if err != nil {
groupStrictSize = false
}
}
ret.groupStrictSize = groupStrictSize
return ret
}
// parseRankParams get limit and offset from rankParams, both are optional. // parseRankParams get limit and offset from rankParams, both are optional.
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) { func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*rankParams, error) {
var ( var (
limit int64 limit int64
offset int64 offset int64
@ -343,23 +428,30 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
} }
// parse group_by parameters from main request body for hybrid search
groupByInfo := parseGroupByInfo(rankParamsPair, schema)
if groupByInfo.err != nil {
return nil, groupByInfo.err
}
return &rankParams{ return &rankParams{
limit: limit, limit: limit,
offset: offset, offset: offset,
roundDecimal: roundDecimal, roundDecimal: roundDecimal,
groupByFieldId: groupByInfo.GetGroupByFieldId(),
groupSize: groupByInfo.GetGroupSize(),
groupStrictSize: groupByInfo.GetGroupStrictSize(),
}, nil }, nil
} }
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest { func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams()))
copy(searchParams, req.GetRankParams())
ret := &milvuspb.SearchRequest{ ret := &milvuspb.SearchRequest{
Base: req.GetBase(), Base: req.GetBase(),
DbName: req.GetDbName(), DbName: req.GetDbName(),
CollectionName: req.GetCollectionName(), CollectionName: req.GetCollectionName(),
PartitionNames: req.GetPartitionNames(), PartitionNames: req.GetPartitionNames(),
OutputFields: req.GetOutputFields(), OutputFields: req.GetOutputFields(),
SearchParams: searchParams, SearchParams: req.GetRankParams(),
TravelTimestamp: req.GetTravelTimestamp(), TravelTimestamp: req.GetTravelTimestamp(),
GuaranteeTimestamp: req.GetGuaranteeTimestamp(), GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
Nq: 0, Nq: 0,

View File

@ -170,7 +170,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
} }
} }
if t.SearchRequest.GetIsAdvanced() { if t.SearchRequest.GetIsAdvanced() {
t.rankParams, err = parseRankParams(t.request.GetSearchParams()) t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema)
if err != nil { if err != nil {
log.Info("parseRankParams failed", zap.Error(err)) log.Info("parseRankParams failed", zap.Error(err))
return err return err
@ -366,8 +366,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
Topk: queryInfo.GetTopk(), Topk: queryInfo.GetTopk(),
Offset: offset, Offset: offset,
MetricType: queryInfo.GetMetricType(), MetricType: queryInfo.GetMetricType(),
GroupByFieldId: queryInfo.GetGroupByFieldId(), GroupByFieldId: t.rankParams.GetGroupByFieldId(),
GroupSize: queryInfo.GetGroupSize(), GroupSize: t.rankParams.GetGroupSize(),
} }
// set PartitionIDs for sub search // set PartitionIDs for sub search
@ -407,10 +407,9 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed. zap.Stringer("plan", plan)) // may be very large if large term passed.
} }
if len(t.queryInfos) > 0 {
t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId() t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize() t.SearchRequest.GroupSize = t.rankParams.GetGroupSize()
}
// used for requery // used for requery
if t.partitionKeyMode { if t.partitionKeyMode {

View File

@ -2263,6 +2263,63 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
assert.Equal(t, int64(0), offset) assert.Equal(t, int64(0), offset)
}) })
t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 101, Name: "c1"},
{FieldID: 102, Name: "c2"},
{FieldID: 103, Name: "c3"},
},
}
// 1. first parse rank params
// outer params require to group by field 101 and groupSize=3 and groupStrictSize=false
testRankParamsPairs := getValidSearchParams()
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "c1",
})
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Key: GroupSizeKey,
Value: strconv.FormatInt(3, 10),
})
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Key: GroupStrictSize,
Value: "false",
})
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Key: LimitKey,
Value: "100",
})
testRankParams, err := parseRankParams(testRankParamsPairs, schema)
assert.NoError(t, err)
// 2. parse search params for sub request in hybridsearch
params := getValidSearchParams()
// inner params require to group by field 103 and groupSize=10 and groupStrictSize=true
params = append(params, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "c3",
})
params = append(params, &commonpb.KeyValuePair{
Key: GroupSizeKey,
Value: strconv.FormatInt(10, 10),
})
params = append(params, &commonpb.KeyValuePair{
Key: GroupStrictSize,
Value: "true",
})
info, _, err := parseSearchInfo(params, schema, testRankParams)
assert.NoError(t, err)
assert.NotNil(t, info)
// all group_by related parameters should be aligned to parameters
// set by main request rather than inner sub request
assert.Equal(t, int64(101), info.GetGroupByFieldId())
assert.Equal(t, int64(3), info.GetGroupSize())
assert.False(t, info.GetGroupStrictSize())
})
t.Run("parseSearchInfo error", func(t *testing.T) { t.Run("parseSearchInfo error", func(t *testing.T) {
spNoTopk := []*commonpb.KeyValuePair{{ spNoTopk := []*commonpb.KeyValuePair{{
Key: AnnsFieldKey, Key: AnnsFieldKey,