mirror of https://github.com/milvus-io/milvus.git
related: #35096 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/34976/head
parent
4b077e1bd2
commit
eb23e23cd2
|
|
@ -21,9 +21,12 @@ import (
|
|||
)
|
||||
|
||||
type rankParams struct {
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
limit int64
|
||||
offset int64
|
||||
roundDecimal int64
|
||||
groupByFieldId int64
|
||||
groupSize int64
|
||||
groupStrictSize bool
|
||||
}
|
||||
|
||||
func (r *rankParams) GetLimit() int64 {
|
||||
|
|
@ -47,6 +50,27 @@ func (r *rankParams) GetRoundDecimal() int64 {
|
|||
return 0
|
||||
}
|
||||
|
||||
func (r *rankParams) GetGroupByFieldId() int64 {
|
||||
if r != nil {
|
||||
return r.groupByFieldId
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (r *rankParams) GetGroupSize() int64 {
|
||||
if r != nil {
|
||||
return r.groupSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (r *rankParams) GetGroupStrictSize() bool {
|
||||
if r != nil {
|
||||
return r.groupStrictSize
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *rankParams) String() string {
|
||||
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
|
||||
}
|
||||
|
|
@ -137,51 +161,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
}
|
||||
|
||||
// 5. parse group by field and group by size
|
||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupByFieldName = ""
|
||||
}
|
||||
var groupByFieldId int64 = -1
|
||||
if groupByFieldName != "" {
|
||||
fields := schema.GetFields()
|
||||
for _, field := range fields {
|
||||
if field.GetNullable() {
|
||||
return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
|
||||
}
|
||||
if field.Name == groupByFieldName {
|
||||
groupByFieldId = field.FieldID
|
||||
break
|
||||
}
|
||||
}
|
||||
if groupByFieldId == -1 {
|
||||
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
||||
}
|
||||
}
|
||||
|
||||
var groupSize int64
|
||||
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupSize = 1
|
||||
} else {
|
||||
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
||||
if err != nil || groupSize <= 0 {
|
||||
groupSize = 1
|
||||
}
|
||||
}
|
||||
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
|
||||
return nil, 0, merr.WrapErrParameterInvalidMsg(
|
||||
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
|
||||
}
|
||||
|
||||
var groupByFieldId, groupSize int64
|
||||
var groupStrictSize bool
|
||||
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
|
||||
if err != nil {
|
||||
groupStrictSize = false
|
||||
if isAdvanced {
|
||||
groupByFieldId, groupSize, groupStrictSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetGroupStrictSize()
|
||||
} else {
|
||||
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
|
||||
if err != nil {
|
||||
groupStrictSize = false
|
||||
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
|
||||
if groupByInfo.err != nil {
|
||||
return nil, 0, groupByInfo.err
|
||||
}
|
||||
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
|
||||
}
|
||||
|
||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||
|
|
@ -298,8 +287,104 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
|
|||
return partitionsSet.Collect(), nil
|
||||
}
|
||||
|
||||
type groupByInfo struct {
|
||||
groupByFieldId int64
|
||||
groupSize int64
|
||||
groupStrictSize bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *groupByInfo) GetGroupByFieldId() int64 {
|
||||
if g != nil {
|
||||
return g.groupByFieldId
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (g *groupByInfo) GetGroupSize() int64 {
|
||||
if g != nil {
|
||||
return g.groupSize
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (g *groupByInfo) GetGroupStrictSize() bool {
|
||||
if g != nil {
|
||||
return g.groupStrictSize
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *groupByInfo) GetError() error {
|
||||
if g != nil {
|
||||
return g.err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) *groupByInfo {
|
||||
ret := &groupByInfo{}
|
||||
|
||||
// 1. parse group_by_field
|
||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupByFieldName = ""
|
||||
}
|
||||
var groupByFieldId int64 = -1
|
||||
if groupByFieldName != "" {
|
||||
fields := schema.GetFields()
|
||||
for _, field := range fields {
|
||||
if field.GetNullable() {
|
||||
ret.err = merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
|
||||
return ret
|
||||
}
|
||||
if field.Name == groupByFieldName {
|
||||
groupByFieldId = field.FieldID
|
||||
break
|
||||
}
|
||||
}
|
||||
if groupByFieldId == -1 {
|
||||
ret.err = merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
||||
return ret
|
||||
}
|
||||
}
|
||||
ret.groupByFieldId = groupByFieldId
|
||||
|
||||
// 2. parse group size
|
||||
var groupSize int64
|
||||
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupSize = 1
|
||||
} else {
|
||||
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
||||
if err != nil || groupSize <= 0 {
|
||||
groupSize = 1
|
||||
}
|
||||
}
|
||||
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
|
||||
ret.err = merr.WrapErrParameterInvalidMsg(
|
||||
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
|
||||
return ret
|
||||
}
|
||||
ret.groupSize = groupSize
|
||||
|
||||
// 3. parse group strict size
|
||||
var groupStrictSize bool
|
||||
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
|
||||
if err != nil {
|
||||
groupStrictSize = false
|
||||
} else {
|
||||
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
|
||||
if err != nil {
|
||||
groupStrictSize = false
|
||||
}
|
||||
}
|
||||
ret.groupStrictSize = groupStrictSize
|
||||
return ret
|
||||
}
|
||||
|
||||
// parseRankParams get limit and offset from rankParams, both are optional.
|
||||
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
|
||||
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*rankParams, error) {
|
||||
var (
|
||||
limit int64
|
||||
offset int64
|
||||
|
|
@ -343,23 +428,30 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
|
|||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
// parse group_by parameters from main request body for hybrid search
|
||||
groupByInfo := parseGroupByInfo(rankParamsPair, schema)
|
||||
if groupByInfo.err != nil {
|
||||
return nil, groupByInfo.err
|
||||
}
|
||||
|
||||
return &rankParams{
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
roundDecimal: roundDecimal,
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
roundDecimal: roundDecimal,
|
||||
groupByFieldId: groupByInfo.GetGroupByFieldId(),
|
||||
groupSize: groupByInfo.GetGroupSize(),
|
||||
groupStrictSize: groupByInfo.GetGroupStrictSize(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
||||
searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams()))
|
||||
copy(searchParams, req.GetRankParams())
|
||||
ret := &milvuspb.SearchRequest{
|
||||
Base: req.GetBase(),
|
||||
DbName: req.GetDbName(),
|
||||
CollectionName: req.GetCollectionName(),
|
||||
PartitionNames: req.GetPartitionNames(),
|
||||
OutputFields: req.GetOutputFields(),
|
||||
SearchParams: searchParams,
|
||||
SearchParams: req.GetRankParams(),
|
||||
TravelTimestamp: req.GetTravelTimestamp(),
|
||||
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
||||
Nq: 0,
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
if t.SearchRequest.GetIsAdvanced() {
|
||||
t.rankParams, err = parseRankParams(t.request.GetSearchParams())
|
||||
t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
log.Info("parseRankParams failed", zap.Error(err))
|
||||
return err
|
||||
|
|
@ -366,8 +366,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
Topk: queryInfo.GetTopk(),
|
||||
Offset: offset,
|
||||
MetricType: queryInfo.GetMetricType(),
|
||||
GroupByFieldId: queryInfo.GetGroupByFieldId(),
|
||||
GroupSize: queryInfo.GetGroupSize(),
|
||||
GroupByFieldId: t.rankParams.GetGroupByFieldId(),
|
||||
GroupSize: t.rankParams.GetGroupSize(),
|
||||
}
|
||||
|
||||
// set PartitionIDs for sub search
|
||||
|
|
@ -407,10 +407,9 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
}
|
||||
if len(t.queryInfos) > 0 {
|
||||
t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId()
|
||||
t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize()
|
||||
}
|
||||
|
||||
t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
|
||||
t.SearchRequest.GroupSize = t.rankParams.GetGroupSize()
|
||||
|
||||
// used for requery
|
||||
if t.partitionKeyMode {
|
||||
|
|
|
|||
|
|
@ -2263,6 +2263,63 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
assert.Equal(t, int64(0), offset)
|
||||
})
|
||||
|
||||
t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 101, Name: "c1"},
|
||||
{FieldID: 102, Name: "c2"},
|
||||
{FieldID: 103, Name: "c3"},
|
||||
},
|
||||
}
|
||||
// 1. first parse rank params
|
||||
// outer params require to group by field 101 and groupSize=3 and groupStrictSize=false
|
||||
testRankParamsPairs := getValidSearchParams()
|
||||
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "c1",
|
||||
})
|
||||
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||
Key: GroupSizeKey,
|
||||
Value: strconv.FormatInt(3, 10),
|
||||
})
|
||||
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||
Key: GroupStrictSize,
|
||||
Value: "false",
|
||||
})
|
||||
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||
Key: LimitKey,
|
||||
Value: "100",
|
||||
})
|
||||
testRankParams, err := parseRankParams(testRankParamsPairs, schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 2. parse search params for sub request in hybridsearch
|
||||
params := getValidSearchParams()
|
||||
// inner params require to group by field 103 and groupSize=10 and groupStrictSize=true
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "c3",
|
||||
})
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: GroupSizeKey,
|
||||
Value: strconv.FormatInt(10, 10),
|
||||
})
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: GroupStrictSize,
|
||||
Value: "true",
|
||||
})
|
||||
|
||||
info, _, err := parseSearchInfo(params, schema, testRankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
|
||||
// all group_by related parameters should be aligned to parameters
|
||||
// set by main request rather than inner sub request
|
||||
assert.Equal(t, int64(101), info.GetGroupByFieldId())
|
||||
assert.Equal(t, int64(3), info.GetGroupSize())
|
||||
assert.False(t, info.GetGroupStrictSize())
|
||||
})
|
||||
|
||||
t.Run("parseSearchInfo error", func(t *testing.T) {
|
||||
spNoTopk := []*commonpb.KeyValuePair{{
|
||||
Key: AnnsFieldKey,
|
||||
|
|
|
|||
Loading…
Reference in New Issue