Move TopK check inside parseQueryInfo (#18892)

Signed-off-by: yangxuan <xuan.yang@zilliz.com>

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
pull/18907/head
XuanYang-cn 2022-08-30 10:32:56 +08:00 committed by GitHub
parent 1527aee019
commit 867ea63bdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 26 deletions

3
.gitignore vendored
View File

@ -8,6 +8,7 @@
**/cmake-build-release/*
**/cmake_build_release/*
**/cmake_build/*
.cache
internal/core/output/*
internal/core/build/*
@ -87,4 +88,4 @@ deployments/docker/*/volumes
# rocksdb
cwrapper_rocksdb_build/
internal/kv/rocksdb/cwrapper/
internal/kv/rocksdb/cwrapper/

View File

@ -51,17 +51,19 @@ import (
)
const (
AnnsFieldKey = "anns_field"
TopKKey = "topk"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
InsertTaskName = "InsertTask"
CreateCollectionTaskName = "CreateCollectionTask"
DropCollectionTaskName = "DropCollectionTask"
SearchTaskName = "SearchTask"
RetrieveTaskName = "RetrieveTask"
QueryTaskName = "QueryTask"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
HasCollectionTaskName = "HasCollectionTask"
DescribeCollectionTaskName = "DescribeCollectionTask"
GetCollectionStatisticsTaskName = "GetCollectionStatisticsTask"

View File

@ -93,10 +93,13 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
if err != nil {
return nil, errors.New(TopKKey + " not found in search_params")
}
topK, err := strconv.Atoi(topKStr)
topK, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
}
if err := validateTopK(topK); err != nil {
return nil, fmt.Errorf("invalid limit, %w", err)
}
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(MetricTypeKey, searchParamsPair)
if err != nil {
@ -112,7 +115,7 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err := strconv.Atoi(roundDecimalStr)
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
@ -122,10 +125,10 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
}
return &planpb.QueryInfo{
Topk: int64(topK),
Topk: topK,
MetricType: metricType,
SearchParams: searchParams,
RoundDecimal: int64(roundDecimal),
RoundDecimal: roundDecimal,
}, nil
}
@ -242,6 +245,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.SearchRequest.OutputFieldsId = outputFieldIDs
plan.OutputFieldIds = outputFieldIDs
t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
@ -249,10 +253,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
return err
}
t.SearchRequest.Topk = queryInfo.GetTopk()
if err := validateTopK(queryInfo.GetTopk()); err != nil {
return err
}
log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()),
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.String("plan", plan.String())) // may be very large if large term passed.
@ -647,18 +647,6 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
// }
//}
// func printSearchResult(partialSearchResult *internalpb.SearchResults) {
// for i := 0; i < len(partialSearchResult.Hits); i++ {
// testHits := milvuspb.Hits{}
// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
// if err != nil {
// panic(err)
// }
// fmt.Println(testHits.IDs)
// fmt.Println(testHits.Scores)
// }
// }
func (t *searchTask) TraceCtx() context.Context {
return t.ctx
}

View File

@ -1697,6 +1697,11 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
Value: "invalid",
})
spInvalidTopk65536 := append(spNoTopk, &commonpb.KeyValuePair{
Key: TopKKey,
Value: "65536",
})
spNoMetricType := append(spNoTopk, &commonpb.KeyValuePair{
Key: TopKKey,
Value: "10",
@ -1727,6 +1732,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
}{
{"No_topk", spNoTopk},
{"Invalid_topk", spInvalidTopk},
{"Invalid_topk_65536", spInvalidTopk65536},
{"No_Metric_type", spNoMetricType},
{"No_search_params", spNoSearchParams},
{"Invalid_round_decimal", spInvalidRoundDecimal},