diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 793ef2083b..68791f4c57 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "encoding/json" "fmt" "math" "regexp" @@ -42,6 +43,8 @@ const ( // a second query request will be initiated to retrieve output fields data. // In this case, the first search will not return any output field from QueryNodes. requeryThreshold = 0.5 * 1024 * 1024 + radiusKey = "radius" + rangeFilterKey = "range_filter" ) type searchTask struct { @@ -178,6 +181,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb searchParamStr = "" } + err = checkRangeSearchParams(searchParamStr, metricType) + if err != nil { + return nil, 0, err + } + // 5. parse group by field groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair) if err != nil { @@ -910,6 +918,57 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb return ret, nil } +type rangeSearchParams struct { + radius float64 + rangeFilter float64 +} + +func checkRangeSearchParams(str string, metricType string) error { + if len(str) == 0 { + // no search params, no need to check + return nil + } + var data map[string]*json.RawMessage + err := json.Unmarshal([]byte(str), &data) + if err != nil { + log.Info("json Unmarshal fail when checkRangeSearchParams") + return err + } + _, ok := data[radiusKey] + // will not do range search, no need to check + if !ok { + return nil + } + var params rangeSearchParams + err = json.Unmarshal(*data[radiusKey], ¶ms.radius) + if err != nil { + return merr.WrapErrParameterInvalidMsg("must pass numpy type for radius") + } + + _, ok = data[rangeFilterKey] + // not pass range_filter, no need to check + if !ok { + return nil + } + err = json.Unmarshal(*data[rangeFilterKey], ¶ms.rangeFilter) + if err != nil { + return merr.WrapErrParameterInvalidMsg("must pass numpy type for range_filter") + } + + if metric.PositivelyRelated(metricType) { + if params.radius >= params.rangeFilter { + msg := fmt.Sprintf("range_filter must be greater than radius for IP/COSINE, range_filter:%f, radius:%f", params.rangeFilter, params.radius) + return merr.WrapErrParameterInvalidMsg(msg) + } + } else { + if params.radius <= params.rangeFilter { + msg := fmt.Sprintf("range_filter must be less than radius for L2/HAMMING/JACCARD, range_filter:%f, radius:%f", params.rangeFilter, params.radius) + return merr.WrapErrParameterInvalidMsg(msg) + } + } + return nil +} + func (t *searchTask) TraceCtx() context.Context { return t.ctx } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index d928a4eea6..f3c671b704 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -112,6 +112,56 @@ func getBaseSearchParams() []*commonpb.KeyValuePair { } } +func getBaseParamsForRangeSearchL2() []*commonpb.KeyValuePair { + return []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: testFloatVecField, + }, + { + Key: TopKKey, + Value: "10", + }, + { + Key: common.MetricTypeKey, + Value: metric.L2, + }, + { + Key: RoundDecimalKey, + Value: "-1", + }, + { + Key: IgnoreGrowingKey, + Value: "false", + }, + } +} + +func getBaseParamsForRangeSearchIP() []*commonpb.KeyValuePair { + return []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: testFloatVecField, + }, + { + Key: TopKKey, + Value: "10", + }, + { + Key: common.MetricTypeKey, + Value: metric.IP, + }, + { + Key: RoundDecimalKey, + Value: "-1", + }, + { + Key: IgnoreGrowingKey, + Value: "false", + }, + } +} + func getValidSearchParams() []*commonpb.KeyValuePair { return []*commonpb.KeyValuePair{ { @@ -1954,6 +2004,101 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { }) } }) + t.Run("check range search params", func(t *testing.T) { + normalParam := getValidSearchParams() + + normalParamWithNoFilter := getBaseParamsForRangeSearchL2() + normalParamWithNoFilter = append(normalParamWithNoFilter, &commonpb.KeyValuePair{ + Key: SearchParamsKey, + Value: `{"nprobe": 10, "radius": 10}`, + }) + + normalParamForIP := getBaseParamsForRangeSearchIP() + normalParamForIP = append(normalParamForIP, &commonpb.KeyValuePair{ + Key: SearchParamsKey, + Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`, + }) + + normalParamForL2 := getBaseParamsForRangeSearchL2() + normalParamForL2 = append(normalParamForL2, &commonpb.KeyValuePair{ + Key: SearchParamsKey, + Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`, + }) + + abnormalParamForIP := getBaseParamsForRangeSearchIP() + abnormalParamForIP = append(abnormalParamForIP, &commonpb.KeyValuePair{ + Key: SearchParamsKey, + Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`, + }) + + abnormalParamForL2 := getBaseParamsForRangeSearchL2() + abnormalParamForL2 = append(abnormalParamForL2, &commonpb.KeyValuePair{ + Key: SearchParamsKey, + Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`, + }) + + wrongTypeRadius := getBaseParamsForRangeSearchIP() + wrongTypeRadius = append(wrongTypeRadius, &commonpb.KeyValuePair{ + Key: SearchParamsKey, + Value: `{"nprobe": 10, "radius": "ab"}`, + }) + + wrongTypeFilter := getBaseParamsForRangeSearchIP() + wrongTypeFilter = append(wrongTypeFilter, &commonpb.KeyValuePair{ + Key: SearchParamsKey, + Value: `{"nprobe": 10, "radius": 10, "range_filter": "20"}`, + }) + + tests := []struct { + description string + validParams []*commonpb.KeyValuePair + }{ + {"normalParam", normalParam}, + {"normalParamWithNoFilter", normalParamWithNoFilter}, + {"normalParamForIP", normalParamForIP}, + {"normalParamForL2", normalParamForL2}, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + info, _, err := parseSearchInfo(test.validParams, nil) + assert.NoError(t, err) + assert.NotNil(t, info) + }) + } + + tests = []struct { + description string + validParams []*commonpb.KeyValuePair + }{ + {"abnormalParamForIP", abnormalParamForIP}, + {"abnormalParamForL2", abnormalParamForL2}, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + info, _, err := parseSearchInfo(test.validParams, nil) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Nil(t, info) + }) + } + + tests = []struct { + description string + validParams []*commonpb.KeyValuePair + }{ + {"wrongTypeRadius", wrongTypeRadius}, + {"wrongTypeFilter", wrongTypeFilter}, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + info, _, err := parseSearchInfo(test.validParams, nil) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Nil(t, info) + }) + } + }) } func getSearchResultData(nq, topk int64) *schemapb.SearchResultData { diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 00c35c720f..7c99b268d8 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -209,7 +209,6 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest log.Warn("Search organizeSubTask failed", zap.Error(err)) return nil, err } - results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) { return worker.SearchSegments(ctx, req) }, "Search", log)