mirror of https://github.com/milvus-io/milvus.git
fix: add the range search params check in proxy (#30423)
if check in Segcore, will not do the it when not insert data. so, check "radius" and "range_filter" in proxy. related with #30365 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/30887/head
parent
6e9f3ea531
commit
a4f3e01a3a
|
|
@ -2,6 +2,7 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
@ -42,6 +43,8 @@ const (
|
||||||
// a second query request will be initiated to retrieve output fields data.
|
// a second query request will be initiated to retrieve output fields data.
|
||||||
// In this case, the first search will not return any output field from QueryNodes.
|
// In this case, the first search will not return any output field from QueryNodes.
|
||||||
requeryThreshold = 0.5 * 1024 * 1024
|
requeryThreshold = 0.5 * 1024 * 1024
|
||||||
|
radiusKey = "radius"
|
||||||
|
rangeFilterKey = "range_filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
type searchTask struct {
|
type searchTask struct {
|
||||||
|
|
@ -178,6 +181,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
||||||
searchParamStr = ""
|
searchParamStr = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = checkRangeSearchParams(searchParamStr, metricType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
// 5. parse group by field
|
// 5. parse group by field
|
||||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -910,6 +918,57 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type rangeSearchParams struct {
|
||||||
|
radius float64
|
||||||
|
rangeFilter float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkRangeSearchParams(str string, metricType string) error {
|
||||||
|
if len(str) == 0 {
|
||||||
|
// no search params, no need to check
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var data map[string]*json.RawMessage
|
||||||
|
err := json.Unmarshal([]byte(str), &data)
|
||||||
|
if err != nil {
|
||||||
|
log.Info("json Unmarshal fail when checkRangeSearchParams")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, ok := data[radiusKey]
|
||||||
|
// will not do range search, no need to check
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var params rangeSearchParams
|
||||||
|
err = json.Unmarshal(*data[radiusKey], ¶ms.radius)
|
||||||
|
if err != nil {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("must pass numpy type for radius")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok = data[rangeFilterKey]
|
||||||
|
// not pass range_filter, no need to check
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(*data[rangeFilterKey], ¶ms.rangeFilter)
|
||||||
|
if err != nil {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("must pass numpy type for range_filter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if metric.PositivelyRelated(metricType) {
|
||||||
|
if params.radius >= params.rangeFilter {
|
||||||
|
msg := fmt.Sprintf("range_filter must be greater than radius for IP/COSINE, range_filter:%f, radius:%f", params.rangeFilter, params.radius)
|
||||||
|
return merr.WrapErrParameterInvalidMsg(msg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if params.radius <= params.rangeFilter {
|
||||||
|
msg := fmt.Sprintf("range_filter must be less than radius for L2/HAMMING/JACCARD, range_filter:%f, radius:%f", params.rangeFilter, params.radius)
|
||||||
|
return merr.WrapErrParameterInvalidMsg(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *searchTask) TraceCtx() context.Context {
|
func (t *searchTask) TraceCtx() context.Context {
|
||||||
return t.ctx
|
return t.ctx
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -112,6 +112,56 @@ func getBaseSearchParams() []*commonpb.KeyValuePair {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getBaseParamsForRangeSearchL2() []*commonpb.KeyValuePair {
|
||||||
|
return []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: testFloatVecField,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: TopKKey,
|
||||||
|
Value: "10",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: common.MetricTypeKey,
|
||||||
|
Value: metric.L2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: RoundDecimalKey,
|
||||||
|
Value: "-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: IgnoreGrowingKey,
|
||||||
|
Value: "false",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBaseParamsForRangeSearchIP() []*commonpb.KeyValuePair {
|
||||||
|
return []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: testFloatVecField,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: TopKKey,
|
||||||
|
Value: "10",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: common.MetricTypeKey,
|
||||||
|
Value: metric.IP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: RoundDecimalKey,
|
||||||
|
Value: "-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: IgnoreGrowingKey,
|
||||||
|
Value: "false",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getValidSearchParams() []*commonpb.KeyValuePair {
|
func getValidSearchParams() []*commonpb.KeyValuePair {
|
||||||
return []*commonpb.KeyValuePair{
|
return []*commonpb.KeyValuePair{
|
||||||
{
|
{
|
||||||
|
|
@ -1954,6 +2004,101 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
t.Run("check range search params", func(t *testing.T) {
|
||||||
|
normalParam := getValidSearchParams()
|
||||||
|
|
||||||
|
normalParamWithNoFilter := getBaseParamsForRangeSearchL2()
|
||||||
|
normalParamWithNoFilter = append(normalParamWithNoFilter, &commonpb.KeyValuePair{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10, "radius": 10}`,
|
||||||
|
})
|
||||||
|
|
||||||
|
normalParamForIP := getBaseParamsForRangeSearchIP()
|
||||||
|
normalParamForIP = append(normalParamForIP, &commonpb.KeyValuePair{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
|
||||||
|
})
|
||||||
|
|
||||||
|
normalParamForL2 := getBaseParamsForRangeSearchL2()
|
||||||
|
normalParamForL2 = append(normalParamForL2, &commonpb.KeyValuePair{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
|
||||||
|
})
|
||||||
|
|
||||||
|
abnormalParamForIP := getBaseParamsForRangeSearchIP()
|
||||||
|
abnormalParamForIP = append(abnormalParamForIP, &commonpb.KeyValuePair{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
|
||||||
|
})
|
||||||
|
|
||||||
|
abnormalParamForL2 := getBaseParamsForRangeSearchL2()
|
||||||
|
abnormalParamForL2 = append(abnormalParamForL2, &commonpb.KeyValuePair{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
|
||||||
|
})
|
||||||
|
|
||||||
|
wrongTypeRadius := getBaseParamsForRangeSearchIP()
|
||||||
|
wrongTypeRadius = append(wrongTypeRadius, &commonpb.KeyValuePair{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10, "radius": "ab"}`,
|
||||||
|
})
|
||||||
|
|
||||||
|
wrongTypeFilter := getBaseParamsForRangeSearchIP()
|
||||||
|
wrongTypeFilter = append(wrongTypeFilter, &commonpb.KeyValuePair{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10, "radius": 10, "range_filter": "20"}`,
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
description string
|
||||||
|
validParams []*commonpb.KeyValuePair
|
||||||
|
}{
|
||||||
|
{"normalParam", normalParam},
|
||||||
|
{"normalParamWithNoFilter", normalParamWithNoFilter},
|
||||||
|
{"normalParamForIP", normalParamForIP},
|
||||||
|
{"normalParamForL2", normalParamForL2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.description, func(t *testing.T) {
|
||||||
|
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, info)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = []struct {
|
||||||
|
description string
|
||||||
|
validParams []*commonpb.KeyValuePair
|
||||||
|
}{
|
||||||
|
{"abnormalParamForIP", abnormalParamForIP},
|
||||||
|
{"abnormalParamForL2", abnormalParamForL2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.description, func(t *testing.T) {
|
||||||
|
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||||
|
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||||
|
assert.Nil(t, info)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = []struct {
|
||||||
|
description string
|
||||||
|
validParams []*commonpb.KeyValuePair
|
||||||
|
}{
|
||||||
|
{"wrongTypeRadius", wrongTypeRadius},
|
||||||
|
{"wrongTypeFilter", wrongTypeFilter},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.description, func(t *testing.T) {
|
||||||
|
info, _, err := parseSearchInfo(test.validParams, nil)
|
||||||
|
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||||
|
assert.Nil(t, info)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
||||||
|
|
|
||||||
|
|
@ -209,7 +209,6 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
|
||||||
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
|
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
|
||||||
return worker.SearchSegments(ctx, req)
|
return worker.SearchSegments(ctx, req)
|
||||||
}, "Search", log)
|
}, "Search", log)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue