enhance: add ts support for iterator(#22718) (#36572)

related: #22718

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
pull/35970/head
Chun Han 2024-10-16 18:51:23 +08:00 committed by GitHub
parent c271c21f17
commit 903450f5c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 372 additions and 191 deletions

2
go.mod
View File

@ -23,7 +23,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/klauspost/compress v1.17.7
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34
github.com/minio/minio-go/v7 v7.0.61
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
github.com/prometheus/client_golang v1.14.0

4
go.sum
View File

@ -625,8 +625,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu
github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497 h1:t4sQMbSy05p8qgMGvEGyLYYLoZ9fD1dushS1bj5X6+0=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34 h1:Fwxpg98128gfWRbQ1A3PMP9o2IfYZk7RSEy8rcoCWDA=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/pulsar-client-go v0.12.1 h1:O2JZp1tsYiO7C0MQ4hrUY/aJXnn2Gry6hpm7UodghmE=
github.com/milvus-io/pulsar-client-go v0.12.1/go.mod h1:dkutuH4oS2pXiGm+Ti7fQZ4MRjrMPZ8IJeEGAWMeckk=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=

View File

@ -75,22 +75,29 @@ func (r *rankParams) String() string {
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
}
type SearchInfo struct {
planInfo *planpb.QueryInfo
offset int64
parseError error
isIterator bool
}
// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) {
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo {
var topK int64
isAdvanced := rankParams != nil
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s is required", TopKKey)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s is required", TopKKey)}
}
topK = externalLimit
} else {
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)}
}
topK = externalLimit
} else {
@ -98,15 +105,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
}
}
isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
isIteratorStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
isIterator := (isIteratorStr == "True") || (isIteratorStr == "true")
if err := validateLimit(topK); err != nil {
if isIterator == "True" {
if isIterator {
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
} else {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)}
}
}
@ -117,12 +125,12 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)}
}
if offset != 0 {
if err := validateLimit(offset); err != nil {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)}
}
}
}
@ -130,7 +138,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
queryTopK := topK + offset
if err := validateLimit(queryTopK); err != nil {
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)}
}
// 2. parse metrics type
@ -147,11 +155,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
}
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
}
// 4. parse search param str
@ -168,30 +176,35 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
} else {
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
if groupByInfo.err != nil {
return nil, 0, groupByInfo.err
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
}
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
}
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
if isIterator == "True" && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")
if isIterator && groupByFieldId > 0 {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")}
}
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")}
}
return &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
GroupStrictSize: groupStrictSize,
}, offset, nil
return &SearchInfo{
planInfo: &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
GroupStrictSize: groupStrictSize,
},
offset: offset,
isIterator: isIterator,
parseError: nil,
}
}
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {

View File

@ -79,6 +79,7 @@ type queryParams struct {
limit int64
offset int64
reduceType reduce.IReduceType
isIterator bool
}
// translateToOutputFieldIDs translates output fields name to output fields id.
@ -178,7 +179,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
// if limit is not provided
if err != nil {
return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType}, nil
return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType, isIterator: isIterator}, nil
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
@ -203,6 +204,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
limit: limit,
offset: offset,
reduceType: reduceType,
isIterator: isIterator,
}, nil
}
@ -461,6 +463,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
}
}
t.GuaranteeTimestamp = guaranteeTs
// need modify mvccTs and guaranteeTs for iterator specially
if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() > 0 {
t.MvccTimestamp = t.request.GetGuaranteeTimestamp()
t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
}
deadline, ok := t.TraceCtx().Deadline()
if ok {
@ -542,6 +549,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
t.result.OutputFields = t.userOutputFields
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
// first page for iteration, need to set up sessionTs for iterator
t.result.SessionTs = t.BeginTs()
}
log.Debug("Query PostExecute done")
return nil
}

View File

@ -132,126 +132,223 @@ func TestQueryTask_all(t *testing.T) {
require.NoError(t, err)
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
// test begins
task := &queryTask{
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
t.Run("test query task parameters", func(t *testing.T) {
task := &queryTask{
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
},
CollectionID: collectionID,
OutputFieldsId: make([]int64, len(fieldName2Types)),
},
CollectionID: collectionID,
OutputFieldsId: make([]int64, len(fieldName2Types)),
},
ctx: ctx,
result: &milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{},
},
request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
ctx: ctx,
result: &milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{},
},
CollectionName: collectionName,
Expr: expr,
QueryParams: []*commonpb.KeyValuePair{
{
Key: IgnoreGrowingKey,
Value: "false",
request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
Expr: expr,
QueryParams: []*commonpb.KeyValuePair{
{
Key: IgnoreGrowingKey,
Value: "false",
},
},
},
},
qc: qc,
lb: lb,
}
qc: qc,
lb: lb,
}
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.OnEnqueue())
// test query task with timeout
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
defer cancel1()
// before preExecute
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
task.ctx = ctx1
assert.NoError(t, task.PreExecute(ctx))
// test query task with timeout
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
defer cancel1()
// before preExecute
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
task.ctx = ctx1
assert.NoError(t, task.PreExecute(ctx))
{
task.mustUsePartitionKey = true
err := task.PreExecute(ctx)
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
task.mustUsePartitionKey = false
}
{
task.mustUsePartitionKey = true
err := task.PreExecute(ctx)
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
task.mustUsePartitionKey = false
}
// after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
// after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
// check reduce_stop_for_best
assert.Equal(t, false, task.RetrieveRequest.GetReduceStopForBest())
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
Key: ReduceStopForBestKey,
Value: "trxxxx",
})
assert.Error(t, task.PreExecute(ctx))
// check reduce_stop_for_best
assert.Equal(t, false, task.RetrieveRequest.GetReduceStopForBest())
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
Key: ReduceStopForBestKey,
Value: "trxxxx",
})
assert.Error(t, task.PreExecute(ctx))
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
Status: merr.Success(),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: testutils.GenerateInt64Array(hitNum)},
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
Status: merr.Success(),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: testutils.GenerateInt64Array(hitNum)},
},
},
},
}
}
outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types))
for i := 0; i < len(fieldName2Types); i++ {
outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i))
}
task.RetrieveRequest.OutputFieldsId = outputFieldIDs
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum))
}
result1.FieldsData = append(result1.FieldsData, generateFieldData(schemapb.DataType_Int64, common.TimeStampFieldName, hitNum))
task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField)
task.ctx = ctx
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, task.Execute(ctx))
outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types))
for i := 0; i < len(fieldName2Types); i++ {
outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i))
}
task.RetrieveRequest.OutputFieldsId = outputFieldIDs
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum))
}
result1.FieldsData = append(result1.FieldsData, generateFieldData(schemapb.DataType_Int64, common.TimeStampFieldName, hitNum))
task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField)
task.ctx = ctx
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
Status: merr.Status(merr.ErrChannelNotAvailable),
}, nil)
err = task.Execute(ctx)
assert.ErrorIs(t, err, merr.ErrChannelNotAvailable)
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
Status: merr.Status(merr.ErrChannelNotAvailable),
}, nil)
err = task.Execute(ctx)
assert.ErrorIs(t, err, merr.ErrChannelNotAvailable)
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}, nil)
assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}, nil)
assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil)
assert.NoError(t, task.Execute(ctx))
qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil)
assert.NoError(t, task.Execute(ctx))
task.queryParams = &queryParams{
limit: 100,
offset: 100,
}
assert.NoError(t, task.PostExecute(ctx))
task.queryParams = &queryParams{
limit: 100,
offset: 100,
}
assert.NoError(t, task.PostExecute(ctx))
for i := 0; i < len(task.result.FieldsData); i++ {
assert.NotEqual(t, task.result.FieldsData[i].FieldId, common.TimeStampField)
}
for i := 0; i < len(task.result.FieldsData); i++ {
assert.NotEqual(t, task.result.FieldsData[i].FieldId, common.TimeStampField)
}
})
t.Run("test query for iterator", func(t *testing.T) {
qt := &queryTask{
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
},
CollectionID: collectionID,
OutputFieldsId: make([]int64, len(fieldName2Types)),
},
ctx: ctx,
result: &milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{},
},
request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
Expr: expr,
QueryParams: []*commonpb.KeyValuePair{
{
Key: IgnoreGrowingKey,
Value: "false",
},
{
Key: IteratorField,
Value: "True",
},
},
},
qc: qc,
lb: lb,
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
}
// simulate scheduler enqueue task
enqueTs := uint64(10000)
qt.SetTs(enqueTs)
qtErr := qt.PreExecute(context.TODO())
assert.Nil(t, qtErr)
assert.True(t, qt.queryParams.isIterator)
qt.resultBuf.Insert(&internalpb.RetrieveResults{})
qtErr = qt.PostExecute(context.TODO())
assert.Nil(t, qtErr)
// after first page, sessionTs is set
assert.True(t, qt.result.GetSessionTs() > 0)
// next page query task
qt = &queryTask{
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
},
CollectionID: collectionID,
OutputFieldsId: make([]int64, len(fieldName2Types)),
},
ctx: ctx,
result: &milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{},
},
request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
Expr: expr,
QueryParams: []*commonpb.KeyValuePair{
{
Key: IgnoreGrowingKey,
Value: "false",
},
{
Key: IteratorField,
Value: "True",
},
},
GuaranteeTimestamp: enqueTs,
},
qc: qc,
lb: lb,
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
}
qtErr = qt.PreExecute(context.TODO())
assert.Nil(t, qtErr)
assert.True(t, qt.queryParams.isIterator)
// from the second page, the mvccTs is set to the sessionTs init in the first page
assert.Equal(t, enqueTs, qt.GetMvccTimestamp())
})
}
func Test_translateToOutputFieldIDs(t *testing.T) {

View File

@ -80,6 +80,8 @@ type searchTask struct {
reScorers []reScorer
rankParams *rankParams
groupScorer func(group *Group) error
isIterator bool
}
func (t *searchTask) CanSkipAllocTimestamp() bool {
@ -249,6 +251,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
t.SearchRequest.ConsistencyLevel = consistencyLevel
if t.isIterator && t.request.GetGuaranteeTimestamp() > 0 {
t.MvccTimestamp = t.request.GetGuaranteeTimestamp()
t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
}
if deadline, ok := t.TraceCtx().Deadline(); ok {
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
@ -351,7 +357,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
for index, subReq := range t.request.GetSubReqs() {
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl())
plan, queryInfo, offset, _, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl())
if err != nil {
return err
}
@ -444,11 +450,12 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param
plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl())
plan, queryInfo, offset, isIterator, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl())
if err != nil {
return err
}
t.isIterator = isIterator
t.SearchRequest.Offset = offset
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
@ -492,40 +499,40 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
return nil
}
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) {
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
if err != nil || len(annsFieldName) == 0 {
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
if len(vecFields) == 0 {
return nil, nil, 0, errors.New(AnnsFieldKey + " not found in schema")
return nil, nil, 0, false, errors.New(AnnsFieldKey + " not found in schema")
}
if enableMultipleVectorFields && len(vecFields) > 1 {
return nil, nil, 0, errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
return nil, nil, 0, false, errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
}
annsFieldName = vecFields[0].Name
}
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
if parseErr != nil {
return nil, nil, 0, parseErr
searchInfo := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
if searchInfo.parseError != nil {
return nil, nil, 0, false, searchInfo.parseError
}
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
if searchInfo.planInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, false, errors.New("not support search_group_by operation based on binary vector column")
}
queryInfo.QueryFieldId = annField.GetFieldID()
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
searchInfo.planInfo.QueryFieldId = annField.GetFieldID()
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo)
if planErr != nil {
log.Warn("failed to create query plan", zap.Error(planErr),
zap.String("dsl", dsl), // may be very large if large term passed.
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
return nil, nil, 0, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr)
zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo))
return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr)
}
log.Debug("create query plan",
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
return plan, queryInfo, offset, nil
zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo))
return plan, searchInfo.planInfo, searchInfo.offset, searchInfo.isIterator, nil
}
func (t *searchTask) tryParsePartitionIDsFromPlan(plan *planpb.PlanNode) ([]int64, error) {
@ -718,6 +725,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
t.result.Results.OutputFields = t.userOutputFields
t.result.CollectionName = t.request.GetCollectionName()
if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
// first page for iteration, need to set up sessionTs for iterator
t.result.SessionTs = t.BeginTs()
}
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))

View File

@ -301,6 +301,55 @@ func TestSearchTask_PreExecute(t *testing.T) {
task.request.OutputFields = []string{testFloatVecField}
assert.NoError(t, task.PreExecute(ctx))
})
t.Run("search consistent iterator pre_ts", func(t *testing.T) {
collName := "search_with_timeout" + funcutil.GenRandomStr()
createColl(t, collName, rc)
st := getSearchTask(t, collName)
st.request.SearchParams = getValidSearchParams()
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
st.request.GuaranteeTimestamp = 1000
st.request.DslType = commonpb.DslType_BoolExprV1
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
st.ctx = ctxTimeout
assert.NoError(t, st.PreExecute(ctx))
assert.True(t, st.isIterator)
assert.True(t, st.GetMvccTimestamp() > 0)
assert.Equal(t, uint64(1000), st.GetGuaranteeTimestamp())
})
t.Run("search consistent iterator post_ts", func(t *testing.T) {
collName := "search_with_timeout" + funcutil.GenRandomStr()
createColl(t, collName, rc)
st := getSearchTask(t, collName)
st.request.SearchParams = getValidSearchParams()
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
st.request.DslType = commonpb.DslType_BoolExprV1
_, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
enqueueTs := uint64(100000)
st.SetTs(enqueueTs)
assert.NoError(t, st.PreExecute(ctx))
assert.True(t, st.isIterator)
assert.True(t, st.GetMvccTimestamp() == 0)
st.resultBuf.Insert(&internalpb.SearchResults{})
st.PostExecute(context.TODO())
assert.Equal(t, st.result.GetSessionTs(), enqueueTs)
})
}
func getQueryCoord() *mocks.MockQueryCoord {
@ -2235,11 +2284,11 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.validParams, nil, nil)
assert.NoError(t, err)
assert.NotNil(t, info)
searchInfo := parseSearchInfo(test.validParams, nil, nil)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
if test.description == "offsetParam" {
assert.Equal(t, targetOffset, offset)
assert.Equal(t, targetOffset, searchInfo.offset)
}
})
}
@ -2256,11 +2305,11 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
limit: externalLimit,
}
info, offset, err := parseSearchInfo(offsetParam, nil, rank)
assert.NoError(t, err)
assert.NotNil(t, info)
assert.Equal(t, int64(10), info.GetTopk())
assert.Equal(t, int64(0), offset)
searchInfo := parseSearchInfo(offsetParam, nil, rank)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
assert.Equal(t, int64(10), searchInfo.planInfo.GetTopk())
assert.Equal(t, int64(0), searchInfo.offset)
})
t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) {
@ -2309,15 +2358,15 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
Value: "true",
})
info, _, err := parseSearchInfo(params, schema, testRankParams)
assert.NoError(t, err)
assert.NotNil(t, info)
searchInfo := parseSearchInfo(params, schema, testRankParams)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
// all group_by related parameters should be aligned to parameters
// set by main request rather than inner sub request
assert.Equal(t, int64(101), info.GetGroupByFieldId())
assert.Equal(t, int64(3), info.GetGroupSize())
assert.False(t, info.GetGroupStrictSize())
assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId())
assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize())
assert.False(t, searchInfo.planInfo.GetGroupStrictSize())
})
t.Run("parseSearchInfo error", func(t *testing.T) {
@ -2399,12 +2448,12 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.invalidParams, nil, nil)
assert.Error(t, err)
assert.Nil(t, info)
assert.Zero(t, offset)
searchInfo := parseSearchInfo(test.invalidParams, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.Nil(t, searchInfo.planInfo)
assert.Zero(t, searchInfo.offset)
t.Logf("err=%s", err.Error())
t.Logf("err=%s", searchInfo.parseError)
})
}
})
@ -2426,9 +2475,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, searchInfo.planInfo)
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
})
t.Run("check range-search and groupBy", func(t *testing.T) {
normalParam := getValidSearchParams()
@ -2445,9 +2494,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, searchInfo.planInfo)
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
})
t.Run("check nullable and groupBy", func(t *testing.T) {
normalParam := getValidSearchParams()
@ -2464,9 +2513,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, searchInfo.planInfo)
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
})
t.Run("check iterator and topK", func(t *testing.T) {
normalParam := getValidSearchParams()
@ -2483,10 +2532,10 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
assert.NotNil(t, info)
assert.NoError(t, err)
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk)
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.NotNil(t, searchInfo.planInfo)
assert.NoError(t, searchInfo.parseError)
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), searchInfo.planInfo.GetTopk())
})
t.Run("check max group size", func(t *testing.T) {
@ -2503,15 +2552,15 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "exceeds configured max group size"))
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, searchInfo.planInfo)
assert.Error(t, searchInfo.parseError)
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "exceeds configured max group size"))
resetSearchParamsValue(normalParam, GroupSizeKey, `10`)
info, _, err = parseSearchInfo(normalParam, schema, nil)
assert.NotNil(t, info)
assert.NoError(t, err)
searchInfo = parseSearchInfo(normalParam, schema, nil)
assert.NotNil(t, searchInfo.planInfo)
assert.NoError(t, searchInfo.parseError)
})
}

View File

@ -439,7 +439,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp())
if err != nil {
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
return err
@ -512,7 +512,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp())
if err != nil {
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
return nil, err