mirror of https://github.com/milvus-io/milvus.git
related: #39045 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>pull/38799/head
parent
e1f5cb7427
commit
ed31a5a4bf
|
@ -77,10 +77,11 @@ func (r *rankParams) String() string {
|
|||
}
|
||||
|
||||
type SearchInfo struct {
|
||||
planInfo *planpb.QueryInfo
|
||||
offset int64
|
||||
parseError error
|
||||
isIterator bool
|
||||
planInfo *planpb.QueryInfo
|
||||
offset int64
|
||||
parseError error
|
||||
isIterator bool
|
||||
collectionID int64
|
||||
}
|
||||
|
||||
func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupByFieldId int64, isIterator bool, offset int64, queryTopK *int64) (*planpb.SearchIteratorV2Info, error) {
|
||||
|
@ -184,6 +185,9 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
isIteratorStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
||||
isIterator := (isIteratorStr == "True") || (isIteratorStr == "true")
|
||||
|
||||
collectionIDStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, searchParamsPair)
|
||||
collectionId, _ := strconv.ParseInt(collectionIDStr, 0, 64)
|
||||
|
||||
if err := validateLimit(topK); err != nil {
|
||||
if isIterator {
|
||||
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
|
||||
|
@ -289,9 +293,10 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||
Hints: hints,
|
||||
SearchIteratorV2Info: planSearchIteratorV2Info,
|
||||
},
|
||||
offset: offset,
|
||||
isIterator: isIterator,
|
||||
parseError: nil,
|
||||
offset: offset,
|
||||
isIterator: isIterator,
|
||||
parseError: nil,
|
||||
collectionID: collectionId,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ const (
|
|||
IgnoreGrowingKey = "ignore_growing"
|
||||
ReduceStopForBestKey = "reduce_stop_for_best"
|
||||
IteratorField = "iterator"
|
||||
CollectionID = "collection_id"
|
||||
GroupByFieldKey = "group_by_field"
|
||||
GroupSizeKey = "group_size"
|
||||
StrictGroupSize = "strict_group_size"
|
||||
|
|
|
@ -76,10 +76,11 @@ type queryTask struct {
|
|||
}
|
||||
|
||||
type queryParams struct {
|
||||
limit int64
|
||||
offset int64
|
||||
reduceType reduce.IReduceType
|
||||
isIterator bool
|
||||
limit int64
|
||||
offset int64
|
||||
reduceType reduce.IReduceType
|
||||
isIterator bool
|
||||
collectionID int64
|
||||
}
|
||||
|
||||
// translateToOutputFieldIDs translates output fields name to output fields id.
|
||||
|
@ -146,6 +147,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
|
|||
reduceStopForBest bool
|
||||
isIterator bool
|
||||
err error
|
||||
collectionID int64
|
||||
)
|
||||
reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair)
|
||||
// if reduce_stop_for_best is provided
|
||||
|
@ -167,6 +169,15 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
|
|||
}
|
||||
}
|
||||
|
||||
collectionIdStr, err := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, queryParamsPair)
|
||||
if err == nil {
|
||||
collectionID, err = strconv.ParseInt(collectionIdStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrParameterInvalid("int value for collection_id", CollectionID,
|
||||
"value for collection id is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
reduceType := reduce.IReduceNoOrder
|
||||
if isIterator {
|
||||
if reduceStopForBest {
|
||||
|
@ -201,10 +212,11 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
|
|||
}
|
||||
|
||||
return &queryParams{
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
reduceType: reduceType,
|
||||
isIterator: isIterator,
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
reduceType: reduceType,
|
||||
isIterator: isIterator,
|
||||
collectionID: collectionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -364,6 +376,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryParams.collectionID > 0 && queryParams.collectionID != t.GetCollectionID() {
|
||||
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("Input collection id is not consistent to collectionID in the context," +
|
||||
"alias or database may have changed"))
|
||||
}
|
||||
if queryParams.reduceType == reduce.IReduceInOrderForBest {
|
||||
t.RetrieveRequest.ReduceStopForBest = true
|
||||
}
|
||||
|
|
|
@ -193,6 +193,29 @@ func TestQueryTask_all(t *testing.T) {
|
|||
Value: "trxxxx",
|
||||
})
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.request.QueryParams = task.request.QueryParams[0 : len(task.request.QueryParams)-1]
|
||||
|
||||
// check parse collection id
|
||||
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
||||
Key: CollectionID,
|
||||
Value: "trxxxx",
|
||||
})
|
||||
err := task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
task.request.QueryParams = task.request.QueryParams[0 : len(task.request.QueryParams)-1]
|
||||
|
||||
// check collection id consistency
|
||||
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
||||
Key: LimitKey,
|
||||
Value: "11",
|
||||
})
|
||||
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
||||
Key: CollectionID,
|
||||
Value: "8080",
|
||||
})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
task.request.QueryParams = make([]*commonpb.KeyValuePair, 0)
|
||||
|
||||
result1 := &internalpb.RetrieveResults{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
|
||||
|
|
|
@ -521,6 +521,11 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
|
|||
if searchInfo.parseError != nil {
|
||||
return nil, nil, 0, false, searchInfo.parseError
|
||||
}
|
||||
if searchInfo.collectionID > 0 && searchInfo.collectionID != t.GetCollectionID() {
|
||||
return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+
|
||||
"alias or database may have been changed: %d", searchInfo.collectionID, t.GetCollectionID())
|
||||
}
|
||||
|
||||
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
|
||||
if searchInfo.planInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
|
||||
return nil, nil, 0, false, errors.New("not support search_group_by operation based on binary vector column")
|
||||
|
|
|
@ -473,6 +473,30 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
|||
st.PostExecute(context.TODO())
|
||||
assert.Equal(t, st.result.GetSessionTs(), enqueueTs)
|
||||
})
|
||||
|
||||
t.Run("search inconsistent collection_id", func(t *testing.T) {
|
||||
collName := "search_inconsistent_collection" + funcutil.GenRandomStr()
|
||||
createColl(t, collName, rc)
|
||||
|
||||
st := getSearchTask(t, collName)
|
||||
st.request.SearchParams = getValidSearchParams()
|
||||
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
})
|
||||
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
|
||||
Key: CollectionID,
|
||||
Value: "8080",
|
||||
})
|
||||
st.request.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
_, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
|
||||
enqueueTs := uint64(100000)
|
||||
st.SetTs(enqueueTs)
|
||||
assert.Error(t, st.PreExecute(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func getQueryCoord() *mocks.MockQueryCoord {
|
||||
|
|
Loading…
Reference in New Issue