mirror of https://github.com/milvus-io/milvus.git
enhance: ban range-search iteration for search-group-by (#30824)
related: #30033 Signed-off-by: MrPresent-Han <chun.han@zilliz.com>pull/31052/head
parent
c8efed6562
commit
3574bdf858
|
@ -203,7 +203,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
||||||
search_info.search_params_ =
|
search_info.search_params_ =
|
||||||
nlohmann::json::parse(query_info_proto.search_params());
|
nlohmann::json::parse(query_info_proto.search_params());
|
||||||
|
|
||||||
if (query_info_proto.group_by_field_id() != 0) {
|
if (query_info_proto.group_by_field_id() > 0) {
|
||||||
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
|
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
|
||||||
search_info.group_by_field_id_ = group_by_field_id;
|
search_info.group_by_field_id_ = group_by_field_id;
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,7 @@ import (
|
||||||
const (
|
const (
|
||||||
IgnoreGrowingKey = "ignore_growing"
|
IgnoreGrowingKey = "ignore_growing"
|
||||||
ReduceStopForBestKey = "reduce_stop_for_best"
|
ReduceStopForBestKey = "reduce_stop_for_best"
|
||||||
|
IteratorField = "iterator"
|
||||||
GroupByFieldKey = "group_by_field"
|
GroupByFieldKey = "group_by_field"
|
||||||
AnnsFieldKey = "anns_field"
|
AnnsFieldKey = "anns_field"
|
||||||
TopKKey = "topk"
|
TopKKey = "topk"
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
@ -192,9 +193,8 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
||||||
if err != nil {
|
if err != nil {
|
||||||
groupByFieldName = ""
|
groupByFieldName = ""
|
||||||
}
|
}
|
||||||
var groupByFieldId int64
|
var groupByFieldId int64 = -1
|
||||||
if groupByFieldName != "" {
|
if groupByFieldName != "" {
|
||||||
groupByFieldId = -1
|
|
||||||
fields := schema.GetFields()
|
fields := schema.GetFields()
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field.Name == groupByFieldName {
|
if field.Name == groupByFieldName {
|
||||||
|
@ -207,6 +207,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||||
|
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||||
|
if isIterator == "True" && groupByFieldId > 0 {
|
||||||
|
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||||
|
"Not allowed to do groupBy when doing iteration")
|
||||||
|
}
|
||||||
|
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
||||||
|
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||||
|
"Not allowed to do range-search when doing search-group-by")
|
||||||
|
}
|
||||||
|
|
||||||
return &planpb.QueryInfo{
|
return &planpb.QueryInfo{
|
||||||
Topk: queryTopK,
|
Topk: queryTopK,
|
||||||
MetricType: metricType,
|
MetricType: metricType,
|
||||||
|
|
|
@ -192,6 +192,14 @@ func getValidSearchParams() []*commonpb.KeyValuePair {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resetSearchParamsValue(kvs []*commonpb.KeyValuePair, keyName string, newVal string) {
|
||||||
|
for _, kv := range kvs {
|
||||||
|
if kv.GetKey() == keyName {
|
||||||
|
kv.Value = newVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getInvalidSearchParams(invalidName string) []*commonpb.KeyValuePair {
|
func getInvalidSearchParams(invalidName string) []*commonpb.KeyValuePair {
|
||||||
kvs := getValidSearchParams()
|
kvs := getValidSearchParams()
|
||||||
for _, kv := range kvs {
|
for _, kv := range kvs {
|
||||||
|
@ -2173,6 +2181,47 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
t.Run("check iterator and groupBy", func(t *testing.T) {
|
||||||
|
normalParam := getValidSearchParams()
|
||||||
|
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||||
|
Key: IteratorField,
|
||||||
|
Value: "True",
|
||||||
|
})
|
||||||
|
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||||
|
Key: GroupByFieldKey,
|
||||||
|
Value: "string_field",
|
||||||
|
})
|
||||||
|
fields := make([]*schemapb.FieldSchema, 0)
|
||||||
|
fields = append(fields, &schemapb.FieldSchema{
|
||||||
|
FieldID: int64(101),
|
||||||
|
Name: "string_field",
|
||||||
|
})
|
||||||
|
schema := &schemapb.CollectionSchema{
|
||||||
|
Fields: fields,
|
||||||
|
}
|
||||||
|
info, _, err := parseSearchInfo(normalParam, schema)
|
||||||
|
assert.Nil(t, info)
|
||||||
|
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||||
|
})
|
||||||
|
t.Run("check range-search and groupBy", func(t *testing.T) {
|
||||||
|
normalParam := getValidSearchParams()
|
||||||
|
resetSearchParamsValue(normalParam, SearchParamsKey, `{"nprobe": 10, "radius":0.2}`)
|
||||||
|
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||||
|
Key: GroupByFieldKey,
|
||||||
|
Value: "string_field",
|
||||||
|
})
|
||||||
|
fields := make([]*schemapb.FieldSchema, 0)
|
||||||
|
fields = append(fields, &schemapb.FieldSchema{
|
||||||
|
FieldID: int64(101),
|
||||||
|
Name: "string_field",
|
||||||
|
})
|
||||||
|
schema := &schemapb.CollectionSchema{
|
||||||
|
Fields: fields,
|
||||||
|
}
|
||||||
|
info, _, err := parseSearchInfo(normalParam, schema)
|
||||||
|
assert.Nil(t, info)
|
||||||
|
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
||||||
|
|
Loading…
Reference in New Issue