fix: add the range search params check in proxy (#30423)

if check in Segcore, will not do the it when not insert data.
so, check "radius" and "range_filter" in proxy.
related with #30365

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
pull/30887/head
smellthemoon 2024-02-28 11:24:58 +08:00 committed by GitHub
parent 6e9f3ea531
commit a4f3e01a3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 204 additions and 1 deletions

View File

@ -2,6 +2,7 @@ package proxy
import (
"context"
"encoding/json"
"fmt"
"math"
"regexp"
@ -42,6 +43,8 @@ const (
// a second query request will be initiated to retrieve output fields data.
// In this case, the first search will not return any output field from QueryNodes.
requeryThreshold = 0.5 * 1024 * 1024
radiusKey = "radius"
rangeFilterKey = "range_filter"
)
type searchTask struct {
@ -178,6 +181,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
searchParamStr = ""
}
err = checkRangeSearchParams(searchParamStr, metricType)
if err != nil {
return nil, 0, err
}
// 5. parse group by field
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
if err != nil {
@ -910,6 +918,57 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
return ret, nil
}
type rangeSearchParams struct {
radius float64
rangeFilter float64
}
func checkRangeSearchParams(str string, metricType string) error {
if len(str) == 0 {
// no search params, no need to check
return nil
}
var data map[string]*json.RawMessage
err := json.Unmarshal([]byte(str), &data)
if err != nil {
log.Info("json Unmarshal fail when checkRangeSearchParams")
return err
}
_, ok := data[radiusKey]
// will not do range search, no need to check
if !ok {
return nil
}
var params rangeSearchParams
err = json.Unmarshal(*data[radiusKey], &params.radius)
if err != nil {
return merr.WrapErrParameterInvalidMsg("must pass numpy type for radius")
}
_, ok = data[rangeFilterKey]
// not pass range_filter, no need to check
if !ok {
return nil
}
err = json.Unmarshal(*data[rangeFilterKey], &params.rangeFilter)
if err != nil {
return merr.WrapErrParameterInvalidMsg("must pass numpy type for range_filter")
}
if metric.PositivelyRelated(metricType) {
if params.radius >= params.rangeFilter {
msg := fmt.Sprintf("range_filter must be greater than radius for IP/COSINE, range_filter:%f, radius:%f", params.rangeFilter, params.radius)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if params.radius <= params.rangeFilter {
msg := fmt.Sprintf("range_filter must be less than radius for L2/HAMMING/JACCARD, range_filter:%f, radius:%f", params.rangeFilter, params.radius)
return merr.WrapErrParameterInvalidMsg(msg)
}
}
return nil
}
func (t *searchTask) TraceCtx() context.Context {
return t.ctx
}

View File

@ -112,6 +112,56 @@ func getBaseSearchParams() []*commonpb.KeyValuePair {
}
}
func getBaseParamsForRangeSearchL2() []*commonpb.KeyValuePair {
return []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: testFloatVecField,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
{
Key: RoundDecimalKey,
Value: "-1",
},
{
Key: IgnoreGrowingKey,
Value: "false",
},
}
}
func getBaseParamsForRangeSearchIP() []*commonpb.KeyValuePair {
return []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: testFloatVecField,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: common.MetricTypeKey,
Value: metric.IP,
},
{
Key: RoundDecimalKey,
Value: "-1",
},
{
Key: IgnoreGrowingKey,
Value: "false",
},
}
}
func getValidSearchParams() []*commonpb.KeyValuePair {
return []*commonpb.KeyValuePair{
{
@ -1954,6 +2004,101 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
})
}
})
t.Run("check range search params", func(t *testing.T) {
normalParam := getValidSearchParams()
normalParamWithNoFilter := getBaseParamsForRangeSearchL2()
normalParamWithNoFilter = append(normalParamWithNoFilter, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Value: `{"nprobe": 10, "radius": 10}`,
})
normalParamForIP := getBaseParamsForRangeSearchIP()
normalParamForIP = append(normalParamForIP, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
})
normalParamForL2 := getBaseParamsForRangeSearchL2()
normalParamForL2 = append(normalParamForL2, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
})
abnormalParamForIP := getBaseParamsForRangeSearchIP()
abnormalParamForIP = append(abnormalParamForIP, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
})
abnormalParamForL2 := getBaseParamsForRangeSearchL2()
abnormalParamForL2 = append(abnormalParamForL2, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
})
wrongTypeRadius := getBaseParamsForRangeSearchIP()
wrongTypeRadius = append(wrongTypeRadius, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Value: `{"nprobe": 10, "radius": "ab"}`,
})
wrongTypeFilter := getBaseParamsForRangeSearchIP()
wrongTypeFilter = append(wrongTypeFilter, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Value: `{"nprobe": 10, "radius": 10, "range_filter": "20"}`,
})
tests := []struct {
description string
validParams []*commonpb.KeyValuePair
}{
{"normalParam", normalParam},
{"normalParamWithNoFilter", normalParamWithNoFilter},
{"normalParamForIP", normalParamForIP},
{"normalParamForL2", normalParamForL2},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, _, err := parseSearchInfo(test.validParams, nil)
assert.NoError(t, err)
assert.NotNil(t, info)
})
}
tests = []struct {
description string
validParams []*commonpb.KeyValuePair
}{
{"abnormalParamForIP", abnormalParamForIP},
{"abnormalParamForL2", abnormalParamForL2},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, _, err := parseSearchInfo(test.validParams, nil)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.Nil(t, info)
})
}
tests = []struct {
description string
validParams []*commonpb.KeyValuePair
}{
{"wrongTypeRadius", wrongTypeRadius},
{"wrongTypeFilter", wrongTypeFilter},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, _, err := parseSearchInfo(test.validParams, nil)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.Nil(t, info)
})
}
})
}
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {

View File

@ -209,7 +209,6 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
log.Warn("Search organizeSubTask failed", zap.Error(err))
return nil, err
}
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
return worker.SearchSegments(ctx, req)
}, "Search", log)