enhance: Utilize partition key optimization in reQuery (#30253)

See also #30250

This PR add requery flag in query task. When reQuery flag is true, query
task shall skip partition name conversion and use pre-calculated
partitionIDs passed from search task.

TODO: hybrid search does not have partition id information. we shall
apply same logic for hybrid search later.

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/30189/head
congqixia 2024-01-25 11:05:07 +08:00 committed by GitHub
parent 4f25066aa7
commit 8e8ac213aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 19 deletions

View File

@ -305,7 +305,8 @@ func (t *hybridSearchTask) Requery() error {
}
// TODO:Xige-16 refine the mvcc functionality of hybrid search
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs)
// TODO:silverxia move partitionIDs to hybrid search level
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
}
func rankSearchResultData(ctx context.Context,

View File

@ -63,6 +63,8 @@ type queryTask struct {
lb LBPolicy
channelsMvcc map[string]Timestamp
fastSkip bool
reQuery bool
}
type queryParams struct {
@ -327,23 +329,26 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
return fmt.Errorf("empty expression should be used with limit")
}
partitionNames := t.request.GetPartitionNames()
if t.partitionKeyMode {
expr, err := ParseExprFromPlan(t.plan)
if err != nil {
return err
}
partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
if err != nil {
return err
}
// convert partition names only when requery is false
if !t.reQuery {
partitionNames := t.request.GetPartitionNames()
if t.partitionKeyMode {
expr, err := ParseExprFromPlan(t.plan)
if err != nil {
return err
}
partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
if err != nil {
return err
}
partitionNames = append(partitionNames, hashedPartitionNames...)
}
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
if err != nil {
return err
partitionNames = append(partitionNames, hashedPartitionNames...)
}
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
if err != nil {
return err
}
}
// count with pagination

View File

@ -627,7 +627,7 @@ func (t *searchTask) Requery() error {
QueryParams: t.request.GetSearchParams(),
}
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs)
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
}
func (t *searchTask) fillInEmptyResult(numQueries int64) {
@ -681,6 +681,7 @@ func doRequery(ctx context.Context,
request *milvuspb.QueryRequest,
result *milvuspb.SearchResults,
queryChannelsTs map[string]Timestamp,
partitionIDs []int64,
) error {
outputFields := request.GetOutputFields()
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
@ -701,7 +702,8 @@ func doRequery(ctx context.Context,
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
ReqID: paramtable.GetNodeID(),
PartitionIDs: partitionIDs, // use search partitionIDs
},
request: request,
plan: plan,
@ -709,6 +711,7 @@ func doRequery(ctx context.Context,
lb: node.(*Proxy).lbPolicy,
channelsMvcc: channelsMvcc,
fastSkip: true,
reQuery: true,
}
queryResult, err := node.(*Proxy).query(ctx, qt)
if err != nil {