diff --git a/internal/proxy/task_hybrid_search.go b/internal/proxy/task_hybrid_search.go index 05f1fd76e2..662a0fc07a 100644 --- a/internal/proxy/task_hybrid_search.go +++ b/internal/proxy/task_hybrid_search.go @@ -305,7 +305,8 @@ func (t *hybridSearchTask) Requery() error { } // TODO:Xige-16 refine the mvcc functionality of hybrid search - return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs) + // TODO:silverxia move partitionIDs to hybrid search level + return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{}) } func rankSearchResultData(ctx context.Context, diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 00b85ae437..699bd1ea1b 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -63,6 +63,8 @@ type queryTask struct { lb LBPolicy channelsMvcc map[string]Timestamp fastSkip bool + + reQuery bool } type queryParams struct { @@ -327,23 +329,26 @@ func (t *queryTask) PreExecute(ctx context.Context) error { return fmt.Errorf("empty expression should be used with limit") } - partitionNames := t.request.GetPartitionNames() - if t.partitionKeyMode { - expr, err := ParseExprFromPlan(t.plan) - if err != nil { - return err - } - partitionKeys := ParsePartitionKeys(expr) - hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys) - if err != nil { - return err - } + // convert partition names only when requery is false + if !t.reQuery { + partitionNames := t.request.GetPartitionNames() + if t.partitionKeyMode { + expr, err := ParseExprFromPlan(t.plan) + if err != nil { + return err + } + partitionKeys := ParsePartitionKeys(expr) + hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys) + if err != nil { + return err + } - partitionNames = append(partitionNames, hashedPartitionNames...) - } - t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames) - if err != nil { - return err + partitionNames = append(partitionNames, hashedPartitionNames...) + } + t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames) + if err != nil { + return err + } } // count with pagination diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 88d1019882..60af81cbed 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -627,7 +627,7 @@ func (t *searchTask) Requery() error { QueryParams: t.request.GetSearchParams(), } - return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs) + return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs()) } func (t *searchTask) fillInEmptyResult(numQueries int64) { @@ -681,6 +681,7 @@ func doRequery(ctx context.Context, request *milvuspb.QueryRequest, result *milvuspb.SearchResults, queryChannelsTs map[string]Timestamp, + partitionIDs []int64, ) error { outputFields := request.GetOutputFields() pkField, err := typeutil.GetPrimaryFieldSchema(schema) @@ -701,7 +702,8 @@ func doRequery(ctx context.Context, commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), commonpbutil.WithSourceID(paramtable.GetNodeID()), ), - ReqID: paramtable.GetNodeID(), + ReqID: paramtable.GetNodeID(), + PartitionIDs: partitionIDs, // use search partitionIDs }, request: request, plan: plan, @@ -709,6 +711,7 @@ func doRequery(ctx context.Context, lb: node.(*Proxy).lbPolicy, channelsMvcc: channelsMvcc, fastSkip: true, + reQuery: true, } queryResult, err := node.(*Proxy).query(ctx, qt) if err != nil {