mirror of https://github.com/milvus-io/milvus.git
				
				
				
			fix: add the range search params check in proxy (#30423)
if check in Segcore, will not do the it when not insert data. so, check "radius" and "range_filter" in proxy. related with #30365 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/30887/head
							parent
							
								
									6e9f3ea531
								
							
						
					
					
						commit
						a4f3e01a3a
					
				| 
						 | 
					@ -2,6 +2,7 @@ package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"math"
 | 
						"math"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
| 
						 | 
					@ -42,6 +43,8 @@ const (
 | 
				
			||||||
	// a second query request will be initiated to retrieve output fields data.
 | 
						// a second query request will be initiated to retrieve output fields data.
 | 
				
			||||||
	// In this case, the first search will not return any output field from QueryNodes.
 | 
						// In this case, the first search will not return any output field from QueryNodes.
 | 
				
			||||||
	requeryThreshold = 0.5 * 1024 * 1024
 | 
						requeryThreshold = 0.5 * 1024 * 1024
 | 
				
			||||||
 | 
						radiusKey        = "radius"
 | 
				
			||||||
 | 
						rangeFilterKey   = "range_filter"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type searchTask struct {
 | 
					type searchTask struct {
 | 
				
			||||||
| 
						 | 
					@ -178,6 +181,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
 | 
				
			||||||
		searchParamStr = ""
 | 
							searchParamStr = ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = checkRangeSearchParams(searchParamStr, metricType)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, 0, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 5. parse group by field
 | 
						// 5. parse group by field
 | 
				
			||||||
	groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
 | 
						groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -910,6 +918,57 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
 | 
				
			||||||
	return ret, nil
 | 
						return ret, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type rangeSearchParams struct {
 | 
				
			||||||
 | 
						radius      float64
 | 
				
			||||||
 | 
						rangeFilter float64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func checkRangeSearchParams(str string, metricType string) error {
 | 
				
			||||||
 | 
						if len(str) == 0 {
 | 
				
			||||||
 | 
							// no search params, no need to check
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var data map[string]*json.RawMessage
 | 
				
			||||||
 | 
						err := json.Unmarshal([]byte(str), &data)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Info("json Unmarshal fail when checkRangeSearchParams")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, ok := data[radiusKey]
 | 
				
			||||||
 | 
						// will not do range search, no need to check
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var params rangeSearchParams
 | 
				
			||||||
 | 
						err = json.Unmarshal(*data[radiusKey], ¶ms.radius)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return merr.WrapErrParameterInvalidMsg("must pass numpy type for radius")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, ok = data[rangeFilterKey]
 | 
				
			||||||
 | 
						// not pass range_filter, no need to check
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = json.Unmarshal(*data[rangeFilterKey], ¶ms.rangeFilter)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return merr.WrapErrParameterInvalidMsg("must pass numpy type for range_filter")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if metric.PositivelyRelated(metricType) {
 | 
				
			||||||
 | 
							if params.radius >= params.rangeFilter {
 | 
				
			||||||
 | 
								msg := fmt.Sprintf("range_filter must be greater than radius for IP/COSINE, range_filter:%f, radius:%f", params.rangeFilter, params.radius)
 | 
				
			||||||
 | 
								return merr.WrapErrParameterInvalidMsg(msg)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							if params.radius <= params.rangeFilter {
 | 
				
			||||||
 | 
								msg := fmt.Sprintf("range_filter must be less than radius for L2/HAMMING/JACCARD, range_filter:%f, radius:%f", params.rangeFilter, params.radius)
 | 
				
			||||||
 | 
								return merr.WrapErrParameterInvalidMsg(msg)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *searchTask) TraceCtx() context.Context {
 | 
					func (t *searchTask) TraceCtx() context.Context {
 | 
				
			||||||
	return t.ctx
 | 
						return t.ctx
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -112,6 +112,56 @@ func getBaseSearchParams() []*commonpb.KeyValuePair {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getBaseParamsForRangeSearchL2() []*commonpb.KeyValuePair {
 | 
				
			||||||
 | 
						return []*commonpb.KeyValuePair{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   AnnsFieldKey,
 | 
				
			||||||
 | 
								Value: testFloatVecField,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   TopKKey,
 | 
				
			||||||
 | 
								Value: "10",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   common.MetricTypeKey,
 | 
				
			||||||
 | 
								Value: metric.L2,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   RoundDecimalKey,
 | 
				
			||||||
 | 
								Value: "-1",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   IgnoreGrowingKey,
 | 
				
			||||||
 | 
								Value: "false",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getBaseParamsForRangeSearchIP() []*commonpb.KeyValuePair {
 | 
				
			||||||
 | 
						return []*commonpb.KeyValuePair{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   AnnsFieldKey,
 | 
				
			||||||
 | 
								Value: testFloatVecField,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   TopKKey,
 | 
				
			||||||
 | 
								Value: "10",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   common.MetricTypeKey,
 | 
				
			||||||
 | 
								Value: metric.IP,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   RoundDecimalKey,
 | 
				
			||||||
 | 
								Value: "-1",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Key:   IgnoreGrowingKey,
 | 
				
			||||||
 | 
								Value: "false",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getValidSearchParams() []*commonpb.KeyValuePair {
 | 
					func getValidSearchParams() []*commonpb.KeyValuePair {
 | 
				
			||||||
	return []*commonpb.KeyValuePair{
 | 
						return []*commonpb.KeyValuePair{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
| 
						 | 
					@ -1954,6 +2004,101 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						t.Run("check range search params", func(t *testing.T) {
 | 
				
			||||||
 | 
							normalParam := getValidSearchParams()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							normalParamWithNoFilter := getBaseParamsForRangeSearchL2()
 | 
				
			||||||
 | 
							normalParamWithNoFilter = append(normalParamWithNoFilter, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   SearchParamsKey,
 | 
				
			||||||
 | 
								Value: `{"nprobe": 10, "radius": 10}`,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							normalParamForIP := getBaseParamsForRangeSearchIP()
 | 
				
			||||||
 | 
							normalParamForIP = append(normalParamForIP, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   SearchParamsKey,
 | 
				
			||||||
 | 
								Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							normalParamForL2 := getBaseParamsForRangeSearchL2()
 | 
				
			||||||
 | 
							normalParamForL2 = append(normalParamForL2, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   SearchParamsKey,
 | 
				
			||||||
 | 
								Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							abnormalParamForIP := getBaseParamsForRangeSearchIP()
 | 
				
			||||||
 | 
							abnormalParamForIP = append(abnormalParamForIP, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   SearchParamsKey,
 | 
				
			||||||
 | 
								Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							abnormalParamForL2 := getBaseParamsForRangeSearchL2()
 | 
				
			||||||
 | 
							abnormalParamForL2 = append(abnormalParamForL2, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   SearchParamsKey,
 | 
				
			||||||
 | 
								Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							wrongTypeRadius := getBaseParamsForRangeSearchIP()
 | 
				
			||||||
 | 
							wrongTypeRadius = append(wrongTypeRadius, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   SearchParamsKey,
 | 
				
			||||||
 | 
								Value: `{"nprobe": 10, "radius": "ab"}`,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							wrongTypeFilter := getBaseParamsForRangeSearchIP()
 | 
				
			||||||
 | 
							wrongTypeFilter = append(wrongTypeFilter, &commonpb.KeyValuePair{
 | 
				
			||||||
 | 
								Key:   SearchParamsKey,
 | 
				
			||||||
 | 
								Value: `{"nprobe": 10, "radius": 10, "range_filter": "20"}`,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							tests := []struct {
 | 
				
			||||||
 | 
								description string
 | 
				
			||||||
 | 
								validParams []*commonpb.KeyValuePair
 | 
				
			||||||
 | 
							}{
 | 
				
			||||||
 | 
								{"normalParam", normalParam},
 | 
				
			||||||
 | 
								{"normalParamWithNoFilter", normalParamWithNoFilter},
 | 
				
			||||||
 | 
								{"normalParamForIP", normalParamForIP},
 | 
				
			||||||
 | 
								{"normalParamForL2", normalParamForL2},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, test := range tests {
 | 
				
			||||||
 | 
								t.Run(test.description, func(t *testing.T) {
 | 
				
			||||||
 | 
									info, _, err := parseSearchInfo(test.validParams, nil)
 | 
				
			||||||
 | 
									assert.NoError(t, err)
 | 
				
			||||||
 | 
									assert.NotNil(t, info)
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							tests = []struct {
 | 
				
			||||||
 | 
								description string
 | 
				
			||||||
 | 
								validParams []*commonpb.KeyValuePair
 | 
				
			||||||
 | 
							}{
 | 
				
			||||||
 | 
								{"abnormalParamForIP", abnormalParamForIP},
 | 
				
			||||||
 | 
								{"abnormalParamForL2", abnormalParamForL2},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, test := range tests {
 | 
				
			||||||
 | 
								t.Run(test.description, func(t *testing.T) {
 | 
				
			||||||
 | 
									info, _, err := parseSearchInfo(test.validParams, nil)
 | 
				
			||||||
 | 
									assert.ErrorIs(t, err, merr.ErrParameterInvalid)
 | 
				
			||||||
 | 
									assert.Nil(t, info)
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							tests = []struct {
 | 
				
			||||||
 | 
								description string
 | 
				
			||||||
 | 
								validParams []*commonpb.KeyValuePair
 | 
				
			||||||
 | 
							}{
 | 
				
			||||||
 | 
								{"wrongTypeRadius", wrongTypeRadius},
 | 
				
			||||||
 | 
								{"wrongTypeFilter", wrongTypeFilter},
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, test := range tests {
 | 
				
			||||||
 | 
								t.Run(test.description, func(t *testing.T) {
 | 
				
			||||||
 | 
									info, _, err := parseSearchInfo(test.validParams, nil)
 | 
				
			||||||
 | 
									assert.ErrorIs(t, err, merr.ErrParameterInvalid)
 | 
				
			||||||
 | 
									assert.Nil(t, info)
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
 | 
					func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -209,7 +209,6 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
 | 
				
			||||||
		log.Warn("Search organizeSubTask failed", zap.Error(err))
 | 
							log.Warn("Search organizeSubTask failed", zap.Error(err))
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
 | 
						results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
 | 
				
			||||||
		return worker.SearchSegments(ctx, req)
 | 
							return worker.SearchSegments(ctx, req)
 | 
				
			||||||
	}, "Search", log)
 | 
						}, "Search", log)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue