mirror of https://github.com/milvus-io/milvus.git
related: #36146 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/36964/head
parent
20750c061b
commit
eccc326e8b
|
@ -964,9 +964,13 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
|
|||
searchParams := generateSearchParams(ctx, c, httpReq.SearchParams)
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)})
|
||||
if httpReq.GroupByField != "" {
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
|
||||
}
|
||||
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)})
|
||||
}
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField)
|
||||
|
@ -1064,9 +1068,13 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
|
|||
{Key: proxy.RankParamsKey, Value: string(bs)},
|
||||
{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
|
||||
{Key: ParamRoundDecimal, Value: "-1"},
|
||||
{Key: ParamGroupByField, Value: httpReq.GroupByField},
|
||||
{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)},
|
||||
{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)},
|
||||
}
|
||||
if httpReq.GroupByField != "" {
|
||||
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
|
||||
}
|
||||
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
|
||||
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
|
||||
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)})
|
||||
}
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
|
||||
|
|
|
@ -370,8 +370,15 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
|
|||
groupSize = 1
|
||||
} else {
|
||||
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
||||
if err != nil || groupSize <= 0 {
|
||||
groupSize = 1
|
||||
if err != nil {
|
||||
ret.err = merr.WrapErrParameterInvalidMsg(
|
||||
fmt.Sprintf("failed to parse input group size:%s", groupSizeStr))
|
||||
return ret
|
||||
}
|
||||
if groupSize <= 0 {
|
||||
ret.err = merr.WrapErrParameterInvalidMsg(
|
||||
fmt.Sprintf("input group size:%d is negative, failed to do search_groupby", groupSize))
|
||||
return ret
|
||||
}
|
||||
}
|
||||
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
|
||||
|
|
|
@ -2538,7 +2538,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), searchInfo.planInfo.GetTopk())
|
||||
})
|
||||
|
||||
t.Run("check max group size", func(t *testing.T) {
|
||||
t.Run("check correctness of group size", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||
Key: GroupSizeKey,
|
||||
|
@ -2553,14 +2553,26 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
Fields: fields,
|
||||
}
|
||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, searchInfo.planInfo)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "exceeds configured max group size"))
|
||||
|
||||
resetSearchParamsValue(normalParam, GroupSizeKey, `10`)
|
||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
{
|
||||
resetSearchParamsValue(normalParam, GroupSizeKey, `10`)
|
||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
assert.Equal(t, int64(10), searchInfo.planInfo.GroupSize)
|
||||
}
|
||||
{
|
||||
resetSearchParamsValue(normalParam, GroupSizeKey, `-1`)
|
||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "is negative"))
|
||||
}
|
||||
{
|
||||
resetSearchParamsValue(normalParam, GroupSizeKey, `xxx`)
|
||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue