fix: [2.5] Avoid update original search/query request (#41127)

Cherry-pick from master
pr: #41126
Related to #41034

Recent pr #40842 introduced logic to avoid requery pk column, which
updates the original request which makes the request not equavilant to
the original one.

When retry happens due to incomplete request error, this change makes
the final result set lacks the pk column even when user specifies it
explicitly.

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/41149/head
congqixia 2025-04-07 22:13:29 +08:00 committed by GitHub
parent d679195a5a
commit c5f87c1b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 20 deletions

View File

@ -59,8 +59,9 @@ type queryTask struct {
queryParams *queryParams
schema *schemaInfo
userOutputFields []string
userDynamicFields []string
translatedOutputFields []string
userOutputFields []string
userDynamicFields []string
resultBuf *typeutil.ConcurrentSet[*internalpb.RetrieveResults]
@ -271,12 +272,12 @@ func (t *queryTask) createPlan(ctx context.Context) error {
metrics.ProxyParseExpressionLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "query", metrics.SuccessLabel).Observe(float64(time.Since(start).Milliseconds()))
}
t.request.OutputFields, t.userOutputFields, t.userDynamicFields, _, err = translateOutputFields(t.request.OutputFields, t.schema, false)
t.translatedOutputFields, t.userOutputFields, t.userDynamicFields, _, err = translateOutputFields(t.request.OutputFields, t.schema, false)
if err != nil {
return err
}
outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema.CollectionSchema)
outputFieldIDs, err := translateToOutputFieldIDs(t.translatedOutputFields, schema.CollectionSchema)
if err != nil {
return err
}

View File

@ -69,8 +69,9 @@ type searchTask struct {
isTopkReduce bool
isRecallEvaluation bool
userOutputFields []string
userDynamicFields []string
translatedOutputFields []string
userOutputFields []string
userDynamicFields []string
resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults]
@ -166,13 +167,13 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}
}
t.request.OutputFields, t.userOutputFields, t.userDynamicFields, t.userRequestedPkFieldExplicitly, err = translateOutputFields(t.request.OutputFields, t.schema, true)
t.translatedOutputFields, t.userOutputFields, t.userDynamicFields, t.userRequestedPkFieldExplicitly, err = translateOutputFields(t.request.OutputFields, t.schema, true)
if err != nil {
log.Warn("translate output fields failed", zap.Error(err))
return err
}
log.Debug("translate output fields",
zap.Strings("output fields", t.request.GetOutputFields()))
zap.Strings("output fields", t.translatedOutputFields))
if t.SearchRequest.GetIsAdvanced() {
if len(t.request.GetSubReqs()) > defaultMaxSearchRequest {
@ -200,7 +201,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
return err
}
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.translatedOutputFields)
if err != nil {
log.Info("fail to get output field ids", zap.Error(err))
return err
@ -210,11 +211,11 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
// Currently, we get vectors by requery. Once we support getting vectors from search,
// searches with small result size could no longer need requery.
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType())
})
if t.SearchRequest.GetIsAdvanced() {
t.requery = len(t.request.OutputFields) > 0
t.requery = len(t.translatedOutputFields) > 0
err = t.initAdvancedSearchRequest(ctx)
} else {
t.requery = len(vectorOutputFields) > 0
@ -881,7 +882,7 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType())
})
// Currently, we get vectors by requery. Once we support getting vectors from search,
// searches with small result size could no longer need requery.
@ -892,7 +893,7 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
return 0, nil
//outputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
// return lo.Contains(t.request.GetOutputFields(), field.GetName())
// return lo.Contains(t.translatedOutputFields, field.GetName())
//})
//sizePerRecord, err := typeutil.EstimateSizePerRecord(&schemapb.CollectionSchema{Fields: outputFields})
//if err != nil {
@ -912,7 +913,7 @@ func (t *searchTask) Requery(span trace.Span) error {
ConsistencyLevel: t.SearchRequest.GetConsistencyLevel(),
NotReturnAllMeta: t.request.GetNotReturnAllMeta(),
Expr: "",
OutputFields: t.request.GetOutputFields(),
OutputFields: t.translatedOutputFields,
PartitionNames: t.request.GetPartitionNames(),
UseDefaultConsistency: false,
GuaranteeTimestamp: t.SearchRequest.GuaranteeTimestamp,
@ -992,14 +993,14 @@ func (t *searchTask) Requery(span trace.Span) error {
}
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(t.request.GetOutputFields(), fieldData.GetFieldName())
return lo.Contains(t.translatedOutputFields, fieldData.GetFieldName())
})
return nil
}
func (t *searchTask) fillInFieldInfo() {
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
for i, name := range t.request.OutputFields {
if len(t.translatedOutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
for i, name := range t.translatedOutputFields {
for _, field := range t.schema.Fields {
if t.result.Results.FieldsData[i] != nil && field.Name == name {
t.result.Results.FieldsData[i].FieldName = field.Name

View File

@ -3032,9 +3032,10 @@ func TestSearchTask_Requery(t *testing.T) {
Ids: resultIDs,
},
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
translatedOutputFields: outputFields,
}
err := qt.Requery(nil)