enhance: Utilize partition key optimization in reQuery (#30253)

See also #30250

This PR add requery flag in query task. When reQuery flag is true, query
task shall skip partition name conversion and use pre-calculated
partitionIDs passed from search task.

TODO: hybrid search does not have partition id information. we shall
apply same logic for hybrid search later.

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/30189/head
congqixia 2024-01-25 11:05:07 +08:00 committed by GitHub
parent 4f25066aa7
commit 8e8ac213aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 19 deletions

View File

@ -305,7 +305,8 @@ func (t *hybridSearchTask) Requery() error {
} }
// TODO:Xige-16 refine the mvcc functionality of hybrid search // TODO:Xige-16 refine the mvcc functionality of hybrid search
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs) // TODO:silverxia move partitionIDs to hybrid search level
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
} }
func rankSearchResultData(ctx context.Context, func rankSearchResultData(ctx context.Context,

View File

@ -63,6 +63,8 @@ type queryTask struct {
lb LBPolicy lb LBPolicy
channelsMvcc map[string]Timestamp channelsMvcc map[string]Timestamp
fastSkip bool fastSkip bool
reQuery bool
} }
type queryParams struct { type queryParams struct {
@ -327,23 +329,26 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
return fmt.Errorf("empty expression should be used with limit") return fmt.Errorf("empty expression should be used with limit")
} }
partitionNames := t.request.GetPartitionNames() // convert partition names only when requery is false
if t.partitionKeyMode { if !t.reQuery {
expr, err := ParseExprFromPlan(t.plan) partitionNames := t.request.GetPartitionNames()
if err != nil { if t.partitionKeyMode {
return err expr, err := ParseExprFromPlan(t.plan)
} if err != nil {
partitionKeys := ParsePartitionKeys(expr) return err
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys) }
if err != nil { partitionKeys := ParsePartitionKeys(expr)
return err hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
} if err != nil {
return err
}
partitionNames = append(partitionNames, hashedPartitionNames...) partitionNames = append(partitionNames, hashedPartitionNames...)
} }
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames) t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
if err != nil { if err != nil {
return err return err
}
} }
// count with pagination // count with pagination

View File

@ -627,7 +627,7 @@ func (t *searchTask) Requery() error {
QueryParams: t.request.GetSearchParams(), QueryParams: t.request.GetSearchParams(),
} }
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs) return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
} }
func (t *searchTask) fillInEmptyResult(numQueries int64) { func (t *searchTask) fillInEmptyResult(numQueries int64) {
@ -681,6 +681,7 @@ func doRequery(ctx context.Context,
request *milvuspb.QueryRequest, request *milvuspb.QueryRequest,
result *milvuspb.SearchResults, result *milvuspb.SearchResults,
queryChannelsTs map[string]Timestamp, queryChannelsTs map[string]Timestamp,
partitionIDs []int64,
) error { ) error {
outputFields := request.GetOutputFields() outputFields := request.GetOutputFields()
pkField, err := typeutil.GetPrimaryFieldSchema(schema) pkField, err := typeutil.GetPrimaryFieldSchema(schema)
@ -701,7 +702,8 @@ func doRequery(ctx context.Context,
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(paramtable.GetNodeID()),
), ),
ReqID: paramtable.GetNodeID(), ReqID: paramtable.GetNodeID(),
PartitionIDs: partitionIDs, // use search partitionIDs
}, },
request: request, request: request,
plan: plan, plan: plan,
@ -709,6 +711,7 @@ func doRequery(ctx context.Context,
lb: node.(*Proxy).lbPolicy, lb: node.(*Proxy).lbPolicy,
channelsMvcc: channelsMvcc, channelsMvcc: channelsMvcc,
fastSkip: true, fastSkip: true,
reQuery: true,
} }
queryResult, err := node.(*Proxy).query(ctx, qt) queryResult, err := node.(*Proxy).query(ctx, qt)
if err != nil { if err != nil {