mirror of https://github.com/milvus-io/milvus.git
Move TopK check inside parseQueryInfo (#18892)
Signed-off-by: yangxuan <xuan.yang@zilliz.com> Signed-off-by: yangxuan <xuan.yang@zilliz.com>pull/18907/head
parent
1527aee019
commit
867ea63bdd
|
@ -8,6 +8,7 @@
|
|||
**/cmake-build-release/*
|
||||
**/cmake_build_release/*
|
||||
**/cmake_build/*
|
||||
.cache
|
||||
|
||||
internal/core/output/*
|
||||
internal/core/build/*
|
||||
|
@ -87,4 +88,4 @@ deployments/docker/*/volumes
|
|||
|
||||
# rocksdb
|
||||
cwrapper_rocksdb_build/
|
||||
internal/kv/rocksdb/cwrapper/
|
||||
internal/kv/rocksdb/cwrapper/
|
||||
|
|
|
@ -51,17 +51,19 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
MetricTypeKey = "metric_type"
|
||||
SearchParamsKey = "params"
|
||||
RoundDecimalKey = "round_decimal"
|
||||
OffsetKey = "offset"
|
||||
|
||||
InsertTaskName = "InsertTask"
|
||||
CreateCollectionTaskName = "CreateCollectionTask"
|
||||
DropCollectionTaskName = "DropCollectionTask"
|
||||
SearchTaskName = "SearchTask"
|
||||
RetrieveTaskName = "RetrieveTask"
|
||||
QueryTaskName = "QueryTask"
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
MetricTypeKey = "metric_type"
|
||||
SearchParamsKey = "params"
|
||||
RoundDecimalKey = "round_decimal"
|
||||
HasCollectionTaskName = "HasCollectionTask"
|
||||
DescribeCollectionTaskName = "DescribeCollectionTask"
|
||||
GetCollectionStatisticsTaskName = "GetCollectionStatisticsTask"
|
||||
|
|
|
@ -93,10 +93,13 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
|
|||
if err != nil {
|
||||
return nil, errors.New(TopKKey + " not found in search_params")
|
||||
}
|
||||
topK, err := strconv.Atoi(topKStr)
|
||||
topK, err := strconv.ParseInt(topKStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
}
|
||||
if err := validateTopK(topK); err != nil {
|
||||
return nil, fmt.Errorf("invalid limit, %w", err)
|
||||
}
|
||||
|
||||
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(MetricTypeKey, searchParamsPair)
|
||||
if err != nil {
|
||||
|
@ -112,7 +115,7 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
|
|||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
}
|
||||
roundDecimal, err := strconv.Atoi(roundDecimalStr)
|
||||
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
@ -122,10 +125,10 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
|
|||
}
|
||||
|
||||
return &planpb.QueryInfo{
|
||||
Topk: int64(topK),
|
||||
Topk: topK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParams,
|
||||
RoundDecimal: int64(roundDecimal),
|
||||
RoundDecimal: roundDecimal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -242,6 +245,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
plan.OutputFieldIds = outputFieldIDs
|
||||
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
|
@ -249,10 +253,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
if err := validateTopK(queryInfo.GetTopk()); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()),
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.String("plan", plan.String())) // may be very large if large term passed.
|
||||
|
@ -647,18 +647,6 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
// }
|
||||
//}
|
||||
|
||||
// func printSearchResult(partialSearchResult *internalpb.SearchResults) {
|
||||
// for i := 0; i < len(partialSearchResult.Hits); i++ {
|
||||
// testHits := milvuspb.Hits{}
|
||||
// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// fmt.Println(testHits.IDs)
|
||||
// fmt.Println(testHits.Scores)
|
||||
// }
|
||||
// }
|
||||
|
||||
func (t *searchTask) TraceCtx() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
|
|
@ -1697,6 +1697,11 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
Value: "invalid",
|
||||
})
|
||||
|
||||
spInvalidTopk65536 := append(spNoTopk, &commonpb.KeyValuePair{
|
||||
Key: TopKKey,
|
||||
Value: "65536",
|
||||
})
|
||||
|
||||
spNoMetricType := append(spNoTopk, &commonpb.KeyValuePair{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
|
@ -1727,6 +1732,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
}{
|
||||
{"No_topk", spNoTopk},
|
||||
{"Invalid_topk", spInvalidTopk},
|
||||
{"Invalid_topk_65536", spInvalidTopk65536},
|
||||
{"No_Metric_type", spNoMetricType},
|
||||
{"No_search_params", spNoSearchParams},
|
||||
{"Invalid_round_decimal", spInvalidRoundDecimal},
|
||||
|
|
Loading…
Reference in New Issue