mirror of https://github.com/milvus-io/milvus.git
related: #22718 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/35970/head
parent
c271c21f17
commit
903450f5c6
2
go.mod
2
go.mod
|
@ -23,7 +23,7 @@ require (
|
|||
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
|
||||
github.com/klauspost/compress v1.17.7
|
||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34
|
||||
github.com/minio/minio-go/v7 v7.0.61
|
||||
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
|
||||
github.com/prometheus/client_golang v1.14.0
|
||||
|
|
4
go.sum
4
go.sum
|
@ -625,8 +625,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu
|
|||
github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg=
|
||||
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8=
|
||||
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497 h1:t4sQMbSy05p8qgMGvEGyLYYLoZ9fD1dushS1bj5X6+0=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34 h1:Fwxpg98128gfWRbQ1A3PMP9o2IfYZk7RSEy8rcoCWDA=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/pulsar-client-go v0.12.1 h1:O2JZp1tsYiO7C0MQ4hrUY/aJXnn2Gry6hpm7UodghmE=
|
||||
github.com/milvus-io/pulsar-client-go v0.12.1/go.mod h1:dkutuH4oS2pXiGm+Ti7fQZ4MRjrMPZ8IJeEGAWMeckk=
|
||||
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
|
||||
|
|
|
@ -75,22 +75,29 @@ func (r *rankParams) String() string {
|
|||
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
|
||||
}
|
||||
|
||||
type SearchInfo struct {
|
||||
planInfo *planpb.QueryInfo
|
||||
offset int64
|
||||
parseError error
|
||||
isIterator bool
|
||||
}
|
||||
|
||||
// parseSearchInfo returns QueryInfo and offset
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) {
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo {
|
||||
var topK int64
|
||||
isAdvanced := rankParams != nil
|
||||
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
|
||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
||||
if err != nil {
|
||||
if externalLimit <= 0 {
|
||||
return nil, 0, fmt.Errorf("%s is required", TopKKey)
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s is required", TopKKey)}
|
||||
}
|
||||
topK = externalLimit
|
||||
} else {
|
||||
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
|
||||
if err != nil {
|
||||
if externalLimit <= 0 {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)}
|
||||
}
|
||||
topK = externalLimit
|
||||
} else {
|
||||
|
@ -98,15 +105,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
}
|
||||
}
|
||||
|
||||
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||
isIteratorStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||
isIterator := (isIteratorStr == "True") || (isIteratorStr == "true")
|
||||
|
||||
if err := validateLimit(topK); err != nil {
|
||||
if isIterator == "True" {
|
||||
if isIterator {
|
||||
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
|
||||
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
|
||||
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
|
||||
} else {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,12 +125,12 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)}
|
||||
}
|
||||
|
||||
if offset != 0 {
|
||||
if err := validateLimit(offset); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -130,7 +138,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
|
||||
queryTopK := topK + offset
|
||||
if err := validateLimit(queryTopK); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)}
|
||||
}
|
||||
|
||||
// 2. parse metrics type
|
||||
|
@ -147,11 +155,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
|
||||
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
|
||||
}
|
||||
|
||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
|
||||
}
|
||||
|
||||
// 4. parse search param str
|
||||
|
@ -168,30 +176,35 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
} else {
|
||||
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
|
||||
if groupByInfo.err != nil {
|
||||
return nil, 0, groupByInfo.err
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
|
||||
}
|
||||
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
|
||||
}
|
||||
|
||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||
if isIterator == "True" && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do groupBy when doing iteration")
|
||||
if isIterator && groupByFieldId > 0 {
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do groupBy when doing iteration")}
|
||||
}
|
||||
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do range-search when doing search-group-by")
|
||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
|
||||
"Not allowed to do range-search when doing search-group-by")}
|
||||
}
|
||||
|
||||
return &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
GroupSize: groupSize,
|
||||
GroupStrictSize: groupStrictSize,
|
||||
}, offset, nil
|
||||
return &SearchInfo{
|
||||
planInfo: &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
GroupSize: groupSize,
|
||||
GroupStrictSize: groupStrictSize,
|
||||
},
|
||||
offset: offset,
|
||||
isIterator: isIterator,
|
||||
parseError: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
||||
|
|
|
@ -79,6 +79,7 @@ type queryParams struct {
|
|||
limit int64
|
||||
offset int64
|
||||
reduceType reduce.IReduceType
|
||||
isIterator bool
|
||||
}
|
||||
|
||||
// translateToOutputFieldIDs translates output fields name to output fields id.
|
||||
|
@ -178,7 +179,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
|
|||
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
|
||||
// if limit is not provided
|
||||
if err != nil {
|
||||
return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType}, nil
|
||||
return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType, isIterator: isIterator}, nil
|
||||
}
|
||||
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
||||
if err != nil {
|
||||
|
@ -203,6 +204,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
|
|||
limit: limit,
|
||||
offset: offset,
|
||||
reduceType: reduceType,
|
||||
isIterator: isIterator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -461,6 +463,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
t.GuaranteeTimestamp = guaranteeTs
|
||||
// need modify mvccTs and guaranteeTs for iterator specially
|
||||
if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() > 0 {
|
||||
t.MvccTimestamp = t.request.GetGuaranteeTimestamp()
|
||||
t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
|
||||
}
|
||||
|
||||
deadline, ok := t.TraceCtx().Deadline()
|
||||
if ok {
|
||||
|
@ -542,6 +549,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
|||
t.result.OutputFields = t.userOutputFields
|
||||
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
|
||||
if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
|
||||
// first page for iteration, need to set up sessionTs for iterator
|
||||
t.result.SessionTs = t.BeginTs()
|
||||
}
|
||||
log.Debug("Query PostExecute done")
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -132,126 +132,223 @@ func TestQueryTask_all(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
// test begins
|
||||
task := &queryTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
t.Run("test query task parameters", func(t *testing.T) {
|
||||
task := &queryTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
OutputFieldsId: make([]int64, len(fieldName2Types)),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
OutputFieldsId: make([]int64, len(fieldName2Types)),
|
||||
},
|
||||
ctx: ctx,
|
||||
result: &milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
},
|
||||
request: &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
ctx: ctx,
|
||||
result: &milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
Expr: expr,
|
||||
QueryParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
request: &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
Expr: expr,
|
||||
QueryParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
qc: qc,
|
||||
lb: lb,
|
||||
}
|
||||
qc: qc,
|
||||
lb: lb,
|
||||
}
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
|
||||
// test query task with timeout
|
||||
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel1()
|
||||
// before preExecute
|
||||
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
|
||||
task.ctx = ctx1
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
// test query task with timeout
|
||||
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel1()
|
||||
// before preExecute
|
||||
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
|
||||
task.ctx = ctx1
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
|
||||
{
|
||||
task.mustUsePartitionKey = true
|
||||
err := task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
task.mustUsePartitionKey = false
|
||||
}
|
||||
{
|
||||
task.mustUsePartitionKey = true
|
||||
err := task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
task.mustUsePartitionKey = false
|
||||
}
|
||||
|
||||
// after preExecute
|
||||
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
|
||||
// after preExecute
|
||||
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
|
||||
|
||||
// check reduce_stop_for_best
|
||||
assert.Equal(t, false, task.RetrieveRequest.GetReduceStopForBest())
|
||||
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
||||
Key: ReduceStopForBestKey,
|
||||
Value: "trxxxx",
|
||||
})
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
// check reduce_stop_for_best
|
||||
assert.Equal(t, false, task.RetrieveRequest.GetReduceStopForBest())
|
||||
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
||||
Key: ReduceStopForBestKey,
|
||||
Value: "trxxxx",
|
||||
})
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
|
||||
result1 := &internalpb.RetrieveResults{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
|
||||
Status: merr.Success(),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: testutils.GenerateInt64Array(hitNum)},
|
||||
result1 := &internalpb.RetrieveResults{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
|
||||
Status: merr.Success(),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: testutils.GenerateInt64Array(hitNum)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types))
|
||||
for i := 0; i < len(fieldName2Types); i++ {
|
||||
outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i))
|
||||
}
|
||||
task.RetrieveRequest.OutputFieldsId = outputFieldIDs
|
||||
for fieldName, dataType := range fieldName2Types {
|
||||
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum))
|
||||
}
|
||||
result1.FieldsData = append(result1.FieldsData, generateFieldData(schemapb.DataType_Int64, common.TimeStampFieldName, hitNum))
|
||||
task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField)
|
||||
task.ctx = ctx
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types))
|
||||
for i := 0; i < len(fieldName2Types); i++ {
|
||||
outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i))
|
||||
}
|
||||
task.RetrieveRequest.OutputFieldsId = outputFieldIDs
|
||||
for fieldName, dataType := range fieldName2Types {
|
||||
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum))
|
||||
}
|
||||
result1.FieldsData = append(result1.FieldsData, generateFieldData(schemapb.DataType_Int64, common.TimeStampFieldName, hitNum))
|
||||
task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField)
|
||||
task.ctx = ctx
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
|
||||
Status: merr.Status(merr.ErrChannelNotAvailable),
|
||||
}, nil)
|
||||
err = task.Execute(ctx)
|
||||
assert.ErrorIs(t, err, merr.ErrChannelNotAvailable)
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
|
||||
Status: merr.Status(merr.ErrChannelNotAvailable),
|
||||
}, nil)
|
||||
err = task.Execute(ctx)
|
||||
assert.ErrorIs(t, err, merr.ErrChannelNotAvailable)
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
},
|
||||
}, nil)
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
},
|
||||
}, nil)
|
||||
assert.Error(t, task.Execute(ctx))
|
||||
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil)
|
||||
assert.NoError(t, task.Execute(ctx))
|
||||
qn.ExpectedCalls = nil
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil)
|
||||
assert.NoError(t, task.Execute(ctx))
|
||||
|
||||
task.queryParams = &queryParams{
|
||||
limit: 100,
|
||||
offset: 100,
|
||||
}
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
task.queryParams = &queryParams{
|
||||
limit: 100,
|
||||
offset: 100,
|
||||
}
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
|
||||
for i := 0; i < len(task.result.FieldsData); i++ {
|
||||
assert.NotEqual(t, task.result.FieldsData[i].FieldId, common.TimeStampField)
|
||||
}
|
||||
for i := 0; i < len(task.result.FieldsData); i++ {
|
||||
assert.NotEqual(t, task.result.FieldsData[i].FieldId, common.TimeStampField)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test query for iterator", func(t *testing.T) {
|
||||
qt := &queryTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
OutputFieldsId: make([]int64, len(fieldName2Types)),
|
||||
},
|
||||
ctx: ctx,
|
||||
result: &milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
},
|
||||
request: &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
Expr: expr,
|
||||
QueryParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
},
|
||||
},
|
||||
},
|
||||
qc: qc,
|
||||
lb: lb,
|
||||
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
|
||||
}
|
||||
// simulate scheduler enqueue task
|
||||
enqueTs := uint64(10000)
|
||||
qt.SetTs(enqueTs)
|
||||
qtErr := qt.PreExecute(context.TODO())
|
||||
assert.Nil(t, qtErr)
|
||||
assert.True(t, qt.queryParams.isIterator)
|
||||
qt.resultBuf.Insert(&internalpb.RetrieveResults{})
|
||||
qtErr = qt.PostExecute(context.TODO())
|
||||
assert.Nil(t, qtErr)
|
||||
// after first page, sessionTs is set
|
||||
assert.True(t, qt.result.GetSessionTs() > 0)
|
||||
|
||||
// next page query task
|
||||
qt = &queryTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
OutputFieldsId: make([]int64, len(fieldName2Types)),
|
||||
},
|
||||
ctx: ctx,
|
||||
result: &milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
},
|
||||
request: &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
Expr: expr,
|
||||
QueryParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
},
|
||||
},
|
||||
GuaranteeTimestamp: enqueTs,
|
||||
},
|
||||
qc: qc,
|
||||
lb: lb,
|
||||
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
|
||||
}
|
||||
qtErr = qt.PreExecute(context.TODO())
|
||||
assert.Nil(t, qtErr)
|
||||
assert.True(t, qt.queryParams.isIterator)
|
||||
// from the second page, the mvccTs is set to the sessionTs init in the first page
|
||||
assert.Equal(t, enqueTs, qt.GetMvccTimestamp())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_translateToOutputFieldIDs(t *testing.T) {
|
||||
|
|
|
@ -80,6 +80,8 @@ type searchTask struct {
|
|||
reScorers []reScorer
|
||||
rankParams *rankParams
|
||||
groupScorer func(group *Group) error
|
||||
|
||||
isIterator bool
|
||||
}
|
||||
|
||||
func (t *searchTask) CanSkipAllocTimestamp() bool {
|
||||
|
@ -249,6 +251,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
|
||||
t.SearchRequest.ConsistencyLevel = consistencyLevel
|
||||
if t.isIterator && t.request.GetGuaranteeTimestamp() > 0 {
|
||||
t.MvccTimestamp = t.request.GetGuaranteeTimestamp()
|
||||
t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
|
||||
}
|
||||
|
||||
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
||||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||
|
@ -351,7 +357,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
|
||||
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
|
||||
for index, subReq := range t.request.GetSubReqs() {
|
||||
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl())
|
||||
plan, queryInfo, offset, _, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -444,11 +450,12 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||
// fetch search_growing from search param
|
||||
|
||||
plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl())
|
||||
plan, queryInfo, offset, isIterator, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.isIterator = isIterator
|
||||
t.SearchRequest.Offset = offset
|
||||
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
|
||||
|
||||
|
@ -492,40 +499,40 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
|
||||
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) {
|
||||
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
|
||||
if err != nil || len(annsFieldName) == 0 {
|
||||
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
|
||||
if len(vecFields) == 0 {
|
||||
return nil, nil, 0, errors.New(AnnsFieldKey + " not found in schema")
|
||||
return nil, nil, 0, false, errors.New(AnnsFieldKey + " not found in schema")
|
||||
}
|
||||
|
||||
if enableMultipleVectorFields && len(vecFields) > 1 {
|
||||
return nil, nil, 0, errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
|
||||
return nil, nil, 0, false, errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
|
||||
}
|
||||
annsFieldName = vecFields[0].Name
|
||||
}
|
||||
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
|
||||
if parseErr != nil {
|
||||
return nil, nil, 0, parseErr
|
||||
searchInfo := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
|
||||
if searchInfo.parseError != nil {
|
||||
return nil, nil, 0, false, searchInfo.parseError
|
||||
}
|
||||
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
|
||||
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
|
||||
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
|
||||
if searchInfo.planInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
|
||||
return nil, nil, 0, false, errors.New("not support search_group_by operation based on binary vector column")
|
||||
}
|
||||
|
||||
queryInfo.QueryFieldId = annField.GetFieldID()
|
||||
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
|
||||
searchInfo.planInfo.QueryFieldId = annField.GetFieldID()
|
||||
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo)
|
||||
if planErr != nil {
|
||||
log.Warn("failed to create query plan", zap.Error(planErr),
|
||||
zap.String("dsl", dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
|
||||
return nil, nil, 0, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr)
|
||||
zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo))
|
||||
return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr)
|
||||
}
|
||||
log.Debug("create query plan",
|
||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
|
||||
return plan, queryInfo, offset, nil
|
||||
zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo))
|
||||
return plan, searchInfo.planInfo, searchInfo.offset, searchInfo.isIterator, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) tryParsePartitionIDsFromPlan(plan *planpb.PlanNode) ([]int64, error) {
|
||||
|
@ -718,6 +725,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
t.result.CollectionName = t.request.GetCollectionName()
|
||||
if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
|
||||
// first page for iteration, need to set up sessionTs for iterator
|
||||
t.result.SessionTs = t.BeginTs()
|
||||
}
|
||||
|
||||
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
|
||||
|
|
|
@ -301,6 +301,55 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
|||
task.request.OutputFields = []string{testFloatVecField}
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
})
|
||||
|
||||
t.Run("search consistent iterator pre_ts", func(t *testing.T) {
|
||||
collName := "search_with_timeout" + funcutil.GenRandomStr()
|
||||
createColl(t, collName, rc)
|
||||
|
||||
st := getSearchTask(t, collName)
|
||||
st.request.SearchParams = getValidSearchParams()
|
||||
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
})
|
||||
st.request.GuaranteeTimestamp = 1000
|
||||
st.request.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
|
||||
|
||||
st.ctx = ctxTimeout
|
||||
assert.NoError(t, st.PreExecute(ctx))
|
||||
assert.True(t, st.isIterator)
|
||||
assert.True(t, st.GetMvccTimestamp() > 0)
|
||||
assert.Equal(t, uint64(1000), st.GetGuaranteeTimestamp())
|
||||
})
|
||||
|
||||
t.Run("search consistent iterator post_ts", func(t *testing.T) {
|
||||
collName := "search_with_timeout" + funcutil.GenRandomStr()
|
||||
createColl(t, collName, rc)
|
||||
|
||||
st := getSearchTask(t, collName)
|
||||
st.request.SearchParams = getValidSearchParams()
|
||||
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
})
|
||||
st.request.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
_, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
|
||||
enqueueTs := uint64(100000)
|
||||
st.SetTs(enqueueTs)
|
||||
assert.NoError(t, st.PreExecute(ctx))
|
||||
assert.True(t, st.isIterator)
|
||||
assert.True(t, st.GetMvccTimestamp() == 0)
|
||||
st.resultBuf.Insert(&internalpb.SearchResults{})
|
||||
st.PostExecute(context.TODO())
|
||||
assert.Equal(t, st.result.GetSessionTs(), enqueueTs)
|
||||
})
|
||||
}
|
||||
|
||||
func getQueryCoord() *mocks.MockQueryCoord {
|
||||
|
@ -2235,11 +2284,11 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, offset, err := parseSearchInfo(test.validParams, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
searchInfo := parseSearchInfo(test.validParams, nil, nil)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
if test.description == "offsetParam" {
|
||||
assert.Equal(t, targetOffset, offset)
|
||||
assert.Equal(t, targetOffset, searchInfo.offset)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -2256,11 +2305,11 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
limit: externalLimit,
|
||||
}
|
||||
|
||||
info, offset, err := parseSearchInfo(offsetParam, nil, rank)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
assert.Equal(t, int64(10), info.GetTopk())
|
||||
assert.Equal(t, int64(0), offset)
|
||||
searchInfo := parseSearchInfo(offsetParam, nil, rank)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
assert.Equal(t, int64(10), searchInfo.planInfo.GetTopk())
|
||||
assert.Equal(t, int64(0), searchInfo.offset)
|
||||
})
|
||||
|
||||
t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) {
|
||||
|
@ -2309,15 +2358,15 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
Value: "true",
|
||||
})
|
||||
|
||||
info, _, err := parseSearchInfo(params, schema, testRankParams)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
searchInfo := parseSearchInfo(params, schema, testRankParams)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
|
||||
// all group_by related parameters should be aligned to parameters
|
||||
// set by main request rather than inner sub request
|
||||
assert.Equal(t, int64(101), info.GetGroupByFieldId())
|
||||
assert.Equal(t, int64(3), info.GetGroupSize())
|
||||
assert.False(t, info.GetGroupStrictSize())
|
||||
assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId())
|
||||
assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize())
|
||||
assert.False(t, searchInfo.planInfo.GetGroupStrictSize())
|
||||
})
|
||||
|
||||
t.Run("parseSearchInfo error", func(t *testing.T) {
|
||||
|
@ -2399,12 +2448,12 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, offset, err := parseSearchInfo(test.invalidParams, nil, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
assert.Zero(t, offset)
|
||||
searchInfo := parseSearchInfo(test.invalidParams, nil, nil)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.Nil(t, searchInfo.planInfo)
|
||||
assert.Zero(t, searchInfo.offset)
|
||||
|
||||
t.Logf("err=%s", err.Error())
|
||||
t.Logf("err=%s", searchInfo.parseError)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
@ -2426,9 +2475,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, searchInfo.planInfo)
|
||||
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
|
||||
})
|
||||
t.Run("check range-search and groupBy", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
|
@ -2445,9 +2494,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, searchInfo.planInfo)
|
||||
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
|
||||
})
|
||||
t.Run("check nullable and groupBy", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
|
@ -2464,9 +2513,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, info)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, searchInfo.planInfo)
|
||||
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
|
||||
})
|
||||
t.Run("check iterator and topK", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
|
@ -2483,10 +2532,10 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.NotNil(t, info)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk)
|
||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), searchInfo.planInfo.GetTopk())
|
||||
})
|
||||
|
||||
t.Run("check max group size", func(t *testing.T) {
|
||||
|
@ -2503,15 +2552,15 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||
schema := &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
info, _, err := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, info)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "exceeds configured max group size"))
|
||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, searchInfo.planInfo)
|
||||
assert.Error(t, searchInfo.parseError)
|
||||
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "exceeds configured max group size"))
|
||||
|
||||
resetSearchParamsValue(normalParam, GroupSizeKey, `10`)
|
||||
info, _, err = parseSearchInfo(normalParam, schema, nil)
|
||||
assert.NotNil(t, info)
|
||||
assert.NoError(t, err)
|
||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -439,7 +439,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
|
|||
|
||||
// wait tsafe
|
||||
waitTr := timerecord.NewTimeRecorder("wait tSafe")
|
||||
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp())
|
||||
if err != nil {
|
||||
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
|
||||
return err
|
||||
|
@ -512,7 +512,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
|
|||
|
||||
// wait tsafe
|
||||
waitTr := timerecord.NewTimeRecorder("wait tSafe")
|
||||
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp())
|
||||
if err != nil {
|
||||
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
|
|
Loading…
Reference in New Issue