mirror of https://github.com/milvus-io/milvus.git
fix: add the range search params check in proxy (#30423)
if check in Segcore, will not do the it when not insert data. so, check "radius" and "range_filter" in proxy. related with #30365 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/30887/head
parent
6e9f3ea531
commit
a4f3e01a3a
|
|
@ -2,6 +2,7 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
|
|
@ -42,6 +43,8 @@ const (
|
|||
// a second query request will be initiated to retrieve output fields data.
|
||||
// In this case, the first search will not return any output field from QueryNodes.
|
||||
requeryThreshold = 0.5 * 1024 * 1024
|
||||
radiusKey = "radius"
|
||||
rangeFilterKey = "range_filter"
|
||||
)
|
||||
|
||||
type searchTask struct {
|
||||
|
|
@ -178,6 +181,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
searchParamStr = ""
|
||||
}
|
||||
|
||||
err = checkRangeSearchParams(searchParamStr, metricType)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 5. parse group by field
|
||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||
if err != nil {
|
||||
|
|
@ -910,6 +918,57 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
type rangeSearchParams struct {
|
||||
radius float64
|
||||
rangeFilter float64
|
||||
}
|
||||
|
||||
func checkRangeSearchParams(str string, metricType string) error {
|
||||
if len(str) == 0 {
|
||||
// no search params, no need to check
|
||||
return nil
|
||||
}
|
||||
var data map[string]*json.RawMessage
|
||||
err := json.Unmarshal([]byte(str), &data)
|
||||
if err != nil {
|
||||
log.Info("json Unmarshal fail when checkRangeSearchParams")
|
||||
return err
|
||||
}
|
||||
_, ok := data[radiusKey]
|
||||
// will not do range search, no need to check
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var params rangeSearchParams
|
||||
err = json.Unmarshal(*data[radiusKey], ¶ms.radius)
|
||||
if err != nil {
|
||||
return merr.WrapErrParameterInvalidMsg("must pass numpy type for radius")
|
||||
}
|
||||
|
||||
_, ok = data[rangeFilterKey]
|
||||
// not pass range_filter, no need to check
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
err = json.Unmarshal(*data[rangeFilterKey], ¶ms.rangeFilter)
|
||||
if err != nil {
|
||||
return merr.WrapErrParameterInvalidMsg("must pass numpy type for range_filter")
|
||||
}
|
||||
|
||||
if metric.PositivelyRelated(metricType) {
|
||||
if params.radius >= params.rangeFilter {
|
||||
msg := fmt.Sprintf("range_filter must be greater than radius for IP/COSINE, range_filter:%f, radius:%f", params.rangeFilter, params.radius)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
} else {
|
||||
if params.radius <= params.rangeFilter {
|
||||
msg := fmt.Sprintf("range_filter must be less than radius for L2/HAMMING/JACCARD, range_filter:%f, radius:%f", params.rangeFilter, params.radius)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) TraceCtx() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,6 +112,56 @@ func getBaseSearchParams() []*commonpb.KeyValuePair {
|
|||
}
|
||||
}
|
||||
|
||||
func getBaseParamsForRangeSearchL2() []*commonpb.KeyValuePair {
|
||||
return []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: testFloatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.L2,
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Value: "-1",
|
||||
},
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getBaseParamsForRangeSearchIP() []*commonpb.KeyValuePair {
|
||||
return []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: testFloatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.IP,
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Value: "-1",
|
||||
},
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getValidSearchParams() []*commonpb.KeyValuePair {
|
||||
return []*commonpb.KeyValuePair{
|
||||
{
|
||||
|
|
@ -1954,6 +2004,101 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
})
|
||||
}
|
||||
})
|
||||
t.Run("check range search params", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
|
||||
normalParamWithNoFilter := getBaseParamsForRangeSearchL2()
|
||||
normalParamWithNoFilter = append(normalParamWithNoFilter, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10}`,
|
||||
})
|
||||
|
||||
normalParamForIP := getBaseParamsForRangeSearchIP()
|
||||
normalParamForIP = append(normalParamForIP, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
|
||||
})
|
||||
|
||||
normalParamForL2 := getBaseParamsForRangeSearchL2()
|
||||
normalParamForL2 = append(normalParamForL2, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
|
||||
})
|
||||
|
||||
abnormalParamForIP := getBaseParamsForRangeSearchIP()
|
||||
abnormalParamForIP = append(abnormalParamForIP, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
|
||||
})
|
||||
|
||||
abnormalParamForL2 := getBaseParamsForRangeSearchL2()
|
||||
abnormalParamForL2 = append(abnormalParamForL2, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
|
||||
})
|
||||
|
||||
wrongTypeRadius := getBaseParamsForRangeSearchIP()
|
||||
wrongTypeRadius = append(wrongTypeRadius, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": "ab"}`,
|
||||
})
|
||||
|
||||
wrongTypeFilter := getBaseParamsForRangeSearchIP()
|
||||
wrongTypeFilter = append(wrongTypeFilter, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": "20"}`,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
validParams []*commonpb.KeyValuePair
|
||||
}{
|
||||
{"normalParam", normalParam},
|
||||
{"normalParamWithNoFilter", normalParamWithNoFilter},
|
||||
{"normalParamForIP", normalParamForIP},
|
||||
{"normalParamForL2", normalParamForL2},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
})
|
||||
}
|
||||
|
||||
tests = []struct {
|
||||
description string
|
||||
validParams []*commonpb.KeyValuePair
|
||||
}{
|
||||
{"abnormalParamForIP", abnormalParamForIP},
|
||||
{"abnormalParamForL2", abnormalParamForL2},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
}
|
||||
|
||||
tests = []struct {
|
||||
description string
|
||||
validParams []*commonpb.KeyValuePair
|
||||
}{
|
||||
{"wrongTypeRadius", wrongTypeRadius},
|
||||
{"wrongTypeFilter", wrongTypeFilter},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
||||
|
|
|
|||
|
|
@ -209,7 +209,6 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
|
|||
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
|
||||
return worker.SearchSegments(ctx, req)
|
||||
}, "Search", log)
|
||||
|
|
|
|||
Loading…
Reference in New Issue