Make NQ <= 16384 limit work (#19726)

https://milvus.io/docs/limitations.md

issue: #19682

/kind bug

Signed-off-by: Yuchen Gao <yuchen.gao@zilliz.com>

Signed-off-by: Yuchen Gao <yuchen.gao@zilliz.com>
pull/19535/head
Ten Thousand Leaves 2022-10-12 18:37:23 +08:00 committed by GitHub
parent 04e1333552
commit ae373d450f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 10 deletions

View File

@ -48,6 +48,7 @@ import (
const (
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"

View File

@ -259,6 +259,12 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(t.TraceCtx(), "Proxy-Search-PreExecute")
defer sp.Finish()
// Check the nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateLimit(t.request.GetNq()); err != nil {
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, t.request.GetNq(), err)
}
if t.searchShardPolicy == nil {
t.searchShardPolicy = mergeRoundRobinPolicy
}

View File

@ -140,6 +140,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
SearchRequest: &internalpb.SearchRequest{},
request: &milvuspb.SearchRequest{
CollectionName: collName,
Nq: 1,
},
qc: qc,
tr: timerecord.NewTimeRecorder("test-search"),
@ -148,6 +149,35 @@ func TestSearchTask_PreExecute(t *testing.T) {
return task
}
getSearchTaskWithNq := func(t *testing.T, nq int64) *searchTask {
task := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{},
request: &milvuspb.SearchRequest{
CollectionName: "collection name",
Nq: nq,
},
qc: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("bad nq 0", func(t *testing.T) {
// Nq must be in range [1, 16384].
task := getSearchTaskWithNq(t, 0)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("bad nq 16385", func(t *testing.T) {
// Nq must be in range [1, 16384].
task := getSearchTaskWithNq(t, 16384+1)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("collection not exist", func(t *testing.T) {
task := getSearchTask(t, collectionName)
err = task.PreExecute(ctx)
@ -1713,6 +1743,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
SourceID: Params.ProxyCfg.GetNodeID(),
},
CollectionName: collectionName,
Nq: 2,
},
qc: qc,
shardMgr: mgr,

View File

@ -54,6 +54,12 @@ const (
// DefaultStringIndexType name of default index type for varChar/string field
DefaultStringIndexType = "Trie"
// Search limit, which applies on:
// maximum # of results to return (topK), and
// maximum # of search requests (nq).
// Check https://milvus.io/docs/limitations.md for more details.
searchCountLimit = 16384
)
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))
@ -76,8 +82,8 @@ func isNumber(c uint8) bool {
func validateLimit(limit int64) error {
// TODO make this configurable
if limit <= 0 || limit >= 16385 {
return fmt.Errorf("should be in range [1, 16385], but got %d", limit)
if limit <= 0 || limit > searchCountLimit {
return fmt.Errorf("should be in range [1, %d], but got %d", searchCountLimit, limit)
}
return nil
}
@ -274,7 +280,7 @@ func validateFieldType(schema *schemapb.CollectionSchema) error {
return nil
}
//ValidateFieldAutoID call after validatePrimaryKey
// ValidateFieldAutoID call after validatePrimaryKey
func ValidateFieldAutoID(coll *schemapb.CollectionSchema) error {
var idx = -1
for i, field := range coll.Fields {
@ -743,14 +749,17 @@ func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCach
}
// Support wildcard in output fields:
// "*" - all scalar fields
// "%" - all vector fields
//
// "*" - all scalar fields
// "%" - all vector fields
//
// For example, A and B are scalar fields, C and D are vector fields, duplicated fields will automatically be removed.
// output_fields=["*"] ==> [A,B]
// output_fields=["%"] ==> [C,D]
// output_fields=["*","%"] ==> [A,B,C,D]
// output_fields=["*",A] ==> [A,B]
// output_fields=["*",C] ==> [A,B,C]
//
// output_fields=["*"] ==> [A,B]
// output_fields=["%"] ==> [C,D]
// output_fields=["*","%"] ==> [A,B,C,D]
// output_fields=["*",A] ==> [A,B]
// output_fields=["*",C] ==> [A,B,C]
func translateOutputFields(outputFields []string, schema *schemapb.CollectionSchema, addPrimary bool) ([]string, error) {
var primaryFieldName string
scalarFieldNameMap := make(map[string]bool)