enhance: fix inconsistenty of alias and db for query iterator(#39045) (#39216)

related: #39045

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
pull/38799/head
Chun Han 2025-01-15 09:48:59 +08:00 committed by GitHub
parent e1f5cb7427
commit ed31a5a4bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 89 additions and 15 deletions

View File

@ -77,10 +77,11 @@ func (r *rankParams) String() string {
}
type SearchInfo struct {
planInfo *planpb.QueryInfo
offset int64
parseError error
isIterator bool
planInfo *planpb.QueryInfo
offset int64
parseError error
isIterator bool
collectionID int64
}
func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupByFieldId int64, isIterator bool, offset int64, queryTopK *int64) (*planpb.SearchIteratorV2Info, error) {
@ -184,6 +185,9 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
isIteratorStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
isIterator := (isIteratorStr == "True") || (isIteratorStr == "true")
collectionIDStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, searchParamsPair)
collectionId, _ := strconv.ParseInt(collectionIDStr, 0, 64)
if err := validateLimit(topK); err != nil {
if isIterator {
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
@ -289,9 +293,10 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
Hints: hints,
SearchIteratorV2Info: planSearchIteratorV2Info,
},
offset: offset,
isIterator: isIterator,
parseError: nil,
offset: offset,
isIterator: isIterator,
parseError: nil,
collectionID: collectionId,
}
}

View File

@ -55,6 +55,7 @@ const (
IgnoreGrowingKey = "ignore_growing"
ReduceStopForBestKey = "reduce_stop_for_best"
IteratorField = "iterator"
CollectionID = "collection_id"
GroupByFieldKey = "group_by_field"
GroupSizeKey = "group_size"
StrictGroupSize = "strict_group_size"

View File

@ -76,10 +76,11 @@ type queryTask struct {
}
type queryParams struct {
limit int64
offset int64
reduceType reduce.IReduceType
isIterator bool
limit int64
offset int64
reduceType reduce.IReduceType
isIterator bool
collectionID int64
}
// translateToOutputFieldIDs translates output fields name to output fields id.
@ -146,6 +147,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
reduceStopForBest bool
isIterator bool
err error
collectionID int64
)
reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair)
// if reduce_stop_for_best is provided
@ -167,6 +169,15 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
}
}
collectionIdStr, err := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, queryParamsPair)
if err == nil {
collectionID, err = strconv.ParseInt(collectionIdStr, 0, 64)
if err != nil {
return nil, merr.WrapErrParameterInvalid("int value for collection_id", CollectionID,
"value for collection id is invalid")
}
}
reduceType := reduce.IReduceNoOrder
if isIterator {
if reduceStopForBest {
@ -201,10 +212,11 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
}
return &queryParams{
limit: limit,
offset: offset,
reduceType: reduceType,
isIterator: isIterator,
limit: limit,
offset: offset,
reduceType: reduceType,
isIterator: isIterator,
collectionID: collectionID,
}, nil
}
@ -364,6 +376,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
if err != nil {
return err
}
if queryParams.collectionID > 0 && queryParams.collectionID != t.GetCollectionID() {
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("Input collection id is not consistent to collectionID in the context," +
"alias or database may have changed"))
}
if queryParams.reduceType == reduce.IReduceInOrderForBest {
t.RetrieveRequest.ReduceStopForBest = true
}

View File

@ -193,6 +193,29 @@ func TestQueryTask_all(t *testing.T) {
Value: "trxxxx",
})
assert.Error(t, task.PreExecute(ctx))
task.request.QueryParams = task.request.QueryParams[0 : len(task.request.QueryParams)-1]
// check parse collection id
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
Key: CollectionID,
Value: "trxxxx",
})
err := task.PreExecute(ctx)
assert.Error(t, err)
task.request.QueryParams = task.request.QueryParams[0 : len(task.request.QueryParams)-1]
// check collection id consistency
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
Key: LimitKey,
Value: "11",
})
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
Key: CollectionID,
Value: "8080",
})
err = task.PreExecute(ctx)
assert.Error(t, err)
task.request.QueryParams = make([]*commonpb.KeyValuePair, 0)
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},

View File

@ -521,6 +521,11 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
if searchInfo.parseError != nil {
return nil, nil, 0, false, searchInfo.parseError
}
if searchInfo.collectionID > 0 && searchInfo.collectionID != t.GetCollectionID() {
return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+
"alias or database may have been changed: %d", searchInfo.collectionID, t.GetCollectionID())
}
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
if searchInfo.planInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, false, errors.New("not support search_group_by operation based on binary vector column")

View File

@ -473,6 +473,30 @@ func TestSearchTask_PreExecute(t *testing.T) {
st.PostExecute(context.TODO())
assert.Equal(t, st.result.GetSessionTs(), enqueueTs)
})
t.Run("search inconsistent collection_id", func(t *testing.T) {
collName := "search_inconsistent_collection" + funcutil.GenRandomStr()
createColl(t, collName, rc)
st := getSearchTask(t, collName)
st.request.SearchParams = getValidSearchParams()
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
Key: CollectionID,
Value: "8080",
})
st.request.DslType = commonpb.DslType_BoolExprV1
_, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
enqueueTs := uint64(100000)
st.SetTs(enqueueTs)
assert.Error(t, st.PreExecute(ctx))
})
}
func getQueryCoord() *mocks.MockQueryCoord {