mirror of https://github.com/milvus-io/milvus.git
related: #35096 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/34976/head
parent
4b077e1bd2
commit
eb23e23cd2
|
|
@ -21,9 +21,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type rankParams struct {
|
type rankParams struct {
|
||||||
limit int64
|
limit int64
|
||||||
offset int64
|
offset int64
|
||||||
roundDecimal int64
|
roundDecimal int64
|
||||||
|
groupByFieldId int64
|
||||||
|
groupSize int64
|
||||||
|
groupStrictSize bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rankParams) GetLimit() int64 {
|
func (r *rankParams) GetLimit() int64 {
|
||||||
|
|
@ -47,6 +50,27 @@ func (r *rankParams) GetRoundDecimal() int64 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *rankParams) GetGroupByFieldId() int64 {
|
||||||
|
if r != nil {
|
||||||
|
return r.groupByFieldId
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rankParams) GetGroupSize() int64 {
|
||||||
|
if r != nil {
|
||||||
|
return r.groupSize
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rankParams) GetGroupStrictSize() bool {
|
||||||
|
if r != nil {
|
||||||
|
return r.groupStrictSize
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (r *rankParams) String() string {
|
func (r *rankParams) String() string {
|
||||||
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
|
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
|
||||||
}
|
}
|
||||||
|
|
@ -137,51 +161,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. parse group by field and group by size
|
// 5. parse group by field and group by size
|
||||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
var groupByFieldId, groupSize int64
|
||||||
if err != nil {
|
|
||||||
groupByFieldName = ""
|
|
||||||
}
|
|
||||||
var groupByFieldId int64 = -1
|
|
||||||
if groupByFieldName != "" {
|
|
||||||
fields := schema.GetFields()
|
|
||||||
for _, field := range fields {
|
|
||||||
if field.GetNullable() {
|
|
||||||
return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
|
|
||||||
}
|
|
||||||
if field.Name == groupByFieldName {
|
|
||||||
groupByFieldId = field.FieldID
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if groupByFieldId == -1 {
|
|
||||||
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var groupSize int64
|
|
||||||
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
|
|
||||||
if err != nil {
|
|
||||||
groupSize = 1
|
|
||||||
} else {
|
|
||||||
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
|
||||||
if err != nil || groupSize <= 0 {
|
|
||||||
groupSize = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
|
|
||||||
return nil, 0, merr.WrapErrParameterInvalidMsg(
|
|
||||||
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
|
|
||||||
}
|
|
||||||
|
|
||||||
var groupStrictSize bool
|
var groupStrictSize bool
|
||||||
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
|
if isAdvanced {
|
||||||
if err != nil {
|
groupByFieldId, groupSize, groupStrictSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetGroupStrictSize()
|
||||||
groupStrictSize = false
|
|
||||||
} else {
|
} else {
|
||||||
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
|
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
|
||||||
if err != nil {
|
if groupByInfo.err != nil {
|
||||||
groupStrictSize = false
|
return nil, 0, groupByInfo.err
|
||||||
}
|
}
|
||||||
|
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||||
|
|
@ -298,8 +287,104 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
|
||||||
return partitionsSet.Collect(), nil
|
return partitionsSet.Collect(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type groupByInfo struct {
|
||||||
|
groupByFieldId int64
|
||||||
|
groupSize int64
|
||||||
|
groupStrictSize bool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *groupByInfo) GetGroupByFieldId() int64 {
|
||||||
|
if g != nil {
|
||||||
|
return g.groupByFieldId
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *groupByInfo) GetGroupSize() int64 {
|
||||||
|
if g != nil {
|
||||||
|
return g.groupSize
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *groupByInfo) GetGroupStrictSize() bool {
|
||||||
|
if g != nil {
|
||||||
|
return g.groupStrictSize
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *groupByInfo) GetError() error {
|
||||||
|
if g != nil {
|
||||||
|
return g.err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) *groupByInfo {
|
||||||
|
ret := &groupByInfo{}
|
||||||
|
|
||||||
|
// 1. parse group_by_field
|
||||||
|
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||||
|
if err != nil {
|
||||||
|
groupByFieldName = ""
|
||||||
|
}
|
||||||
|
var groupByFieldId int64 = -1
|
||||||
|
if groupByFieldName != "" {
|
||||||
|
fields := schema.GetFields()
|
||||||
|
for _, field := range fields {
|
||||||
|
if field.GetNullable() {
|
||||||
|
ret.err = merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
if field.Name == groupByFieldName {
|
||||||
|
groupByFieldId = field.FieldID
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if groupByFieldId == -1 {
|
||||||
|
ret.err = merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret.groupByFieldId = groupByFieldId
|
||||||
|
|
||||||
|
// 2. parse group size
|
||||||
|
var groupSize int64
|
||||||
|
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
|
||||||
|
if err != nil {
|
||||||
|
groupSize = 1
|
||||||
|
} else {
|
||||||
|
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
||||||
|
if err != nil || groupSize <= 0 {
|
||||||
|
groupSize = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
|
||||||
|
ret.err = merr.WrapErrParameterInvalidMsg(
|
||||||
|
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
ret.groupSize = groupSize
|
||||||
|
|
||||||
|
// 3. parse group strict size
|
||||||
|
var groupStrictSize bool
|
||||||
|
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
|
||||||
|
if err != nil {
|
||||||
|
groupStrictSize = false
|
||||||
|
} else {
|
||||||
|
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
|
||||||
|
if err != nil {
|
||||||
|
groupStrictSize = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret.groupStrictSize = groupStrictSize
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
// parseRankParams get limit and offset from rankParams, both are optional.
|
// parseRankParams get limit and offset from rankParams, both are optional.
|
||||||
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) {
|
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*rankParams, error) {
|
||||||
var (
|
var (
|
||||||
limit int64
|
limit int64
|
||||||
offset int64
|
offset int64
|
||||||
|
|
@ -343,23 +428,30 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
|
||||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse group_by parameters from main request body for hybrid search
|
||||||
|
groupByInfo := parseGroupByInfo(rankParamsPair, schema)
|
||||||
|
if groupByInfo.err != nil {
|
||||||
|
return nil, groupByInfo.err
|
||||||
|
}
|
||||||
|
|
||||||
return &rankParams{
|
return &rankParams{
|
||||||
limit: limit,
|
limit: limit,
|
||||||
offset: offset,
|
offset: offset,
|
||||||
roundDecimal: roundDecimal,
|
roundDecimal: roundDecimal,
|
||||||
|
groupByFieldId: groupByInfo.GetGroupByFieldId(),
|
||||||
|
groupSize: groupByInfo.GetGroupSize(),
|
||||||
|
groupStrictSize: groupByInfo.GetGroupStrictSize(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
||||||
searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams()))
|
|
||||||
copy(searchParams, req.GetRankParams())
|
|
||||||
ret := &milvuspb.SearchRequest{
|
ret := &milvuspb.SearchRequest{
|
||||||
Base: req.GetBase(),
|
Base: req.GetBase(),
|
||||||
DbName: req.GetDbName(),
|
DbName: req.GetDbName(),
|
||||||
CollectionName: req.GetCollectionName(),
|
CollectionName: req.GetCollectionName(),
|
||||||
PartitionNames: req.GetPartitionNames(),
|
PartitionNames: req.GetPartitionNames(),
|
||||||
OutputFields: req.GetOutputFields(),
|
OutputFields: req.GetOutputFields(),
|
||||||
SearchParams: searchParams,
|
SearchParams: req.GetRankParams(),
|
||||||
TravelTimestamp: req.GetTravelTimestamp(),
|
TravelTimestamp: req.GetTravelTimestamp(),
|
||||||
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
||||||
Nq: 0,
|
Nq: 0,
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if t.SearchRequest.GetIsAdvanced() {
|
if t.SearchRequest.GetIsAdvanced() {
|
||||||
t.rankParams, err = parseRankParams(t.request.GetSearchParams())
|
t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Info("parseRankParams failed", zap.Error(err))
|
log.Info("parseRankParams failed", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
|
|
@ -366,8 +366,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||||
Topk: queryInfo.GetTopk(),
|
Topk: queryInfo.GetTopk(),
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
MetricType: queryInfo.GetMetricType(),
|
MetricType: queryInfo.GetMetricType(),
|
||||||
GroupByFieldId: queryInfo.GetGroupByFieldId(),
|
GroupByFieldId: t.rankParams.GetGroupByFieldId(),
|
||||||
GroupSize: queryInfo.GetGroupSize(),
|
GroupSize: t.rankParams.GetGroupSize(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// set PartitionIDs for sub search
|
// set PartitionIDs for sub search
|
||||||
|
|
@ -407,10 +407,9 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||||
}
|
}
|
||||||
if len(t.queryInfos) > 0 {
|
|
||||||
t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId()
|
t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
|
||||||
t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize()
|
t.SearchRequest.GroupSize = t.rankParams.GetGroupSize()
|
||||||
}
|
|
||||||
|
|
||||||
// used for requery
|
// used for requery
|
||||||
if t.partitionKeyMode {
|
if t.partitionKeyMode {
|
||||||
|
|
|
||||||
|
|
@ -2263,6 +2263,63 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
||||||
assert.Equal(t, int64(0), offset)
|
assert.Equal(t, int64(0), offset)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) {
|
||||||
|
schema := &schemapb.CollectionSchema{
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{FieldID: 101, Name: "c1"},
|
||||||
|
{FieldID: 102, Name: "c2"},
|
||||||
|
{FieldID: 103, Name: "c3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// 1. first parse rank params
|
||||||
|
// outer params require to group by field 101 and groupSize=3 and groupStrictSize=false
|
||||||
|
testRankParamsPairs := getValidSearchParams()
|
||||||
|
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||||
|
Key: GroupByFieldKey,
|
||||||
|
Value: "c1",
|
||||||
|
})
|
||||||
|
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||||
|
Key: GroupSizeKey,
|
||||||
|
Value: strconv.FormatInt(3, 10),
|
||||||
|
})
|
||||||
|
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||||
|
Key: GroupStrictSize,
|
||||||
|
Value: "false",
|
||||||
|
})
|
||||||
|
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||||
|
Key: LimitKey,
|
||||||
|
Value: "100",
|
||||||
|
})
|
||||||
|
testRankParams, err := parseRankParams(testRankParamsPairs, schema)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// 2. parse search params for sub request in hybridsearch
|
||||||
|
params := getValidSearchParams()
|
||||||
|
// inner params require to group by field 103 and groupSize=10 and groupStrictSize=true
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: GroupByFieldKey,
|
||||||
|
Value: "c3",
|
||||||
|
})
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: GroupSizeKey,
|
||||||
|
Value: strconv.FormatInt(10, 10),
|
||||||
|
})
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: GroupStrictSize,
|
||||||
|
Value: "true",
|
||||||
|
})
|
||||||
|
|
||||||
|
info, _, err := parseSearchInfo(params, schema, testRankParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, info)
|
||||||
|
|
||||||
|
// all group_by related parameters should be aligned to parameters
|
||||||
|
// set by main request rather than inner sub request
|
||||||
|
assert.Equal(t, int64(101), info.GetGroupByFieldId())
|
||||||
|
assert.Equal(t, int64(3), info.GetGroupSize())
|
||||||
|
assert.False(t, info.GetGroupStrictSize())
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("parseSearchInfo error", func(t *testing.T) {
|
t.Run("parseSearchInfo error", func(t *testing.T) {
|
||||||
spNoTopk := []*commonpb.KeyValuePair{{
|
spNoTopk := []*commonpb.KeyValuePair{{
|
||||||
Key: AnnsFieldKey,
|
Key: AnnsFieldKey,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue