mirror of https://github.com/milvus-io/milvus.git
enhance: Utilize partition key optimization in reQuery (#30253)
See also #30250 This PR add requery flag in query task. When reQuery flag is true, query task shall skip partition name conversion and use pre-calculated partitionIDs passed from search task. TODO: hybrid search does not have partition id information. we shall apply same logic for hybrid search later. Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/30189/head
parent
4f25066aa7
commit
8e8ac213aa
|
@ -305,7 +305,8 @@ func (t *hybridSearchTask) Requery() error {
|
|||
}
|
||||
|
||||
// TODO:Xige-16 refine the mvcc functionality of hybrid search
|
||||
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs)
|
||||
// TODO:silverxia move partitionIDs to hybrid search level
|
||||
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
|
||||
}
|
||||
|
||||
func rankSearchResultData(ctx context.Context,
|
||||
|
|
|
@ -63,6 +63,8 @@ type queryTask struct {
|
|||
lb LBPolicy
|
||||
channelsMvcc map[string]Timestamp
|
||||
fastSkip bool
|
||||
|
||||
reQuery bool
|
||||
}
|
||||
|
||||
type queryParams struct {
|
||||
|
@ -327,23 +329,26 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
|||
return fmt.Errorf("empty expression should be used with limit")
|
||||
}
|
||||
|
||||
partitionNames := t.request.GetPartitionNames()
|
||||
if t.partitionKeyMode {
|
||||
expr, err := ParseExprFromPlan(t.plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// convert partition names only when requery is false
|
||||
if !t.reQuery {
|
||||
partitionNames := t.request.GetPartitionNames()
|
||||
if t.partitionKeyMode {
|
||||
expr, err := ParseExprFromPlan(t.plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
partitionNames = append(partitionNames, hashedPartitionNames...)
|
||||
}
|
||||
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
|
||||
if err != nil {
|
||||
return err
|
||||
partitionNames = append(partitionNames, hashedPartitionNames...)
|
||||
}
|
||||
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// count with pagination
|
||||
|
|
|
@ -627,7 +627,7 @@ func (t *searchTask) Requery() error {
|
|||
QueryParams: t.request.GetSearchParams(),
|
||||
}
|
||||
|
||||
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs)
|
||||
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
||||
|
@ -681,6 +681,7 @@ func doRequery(ctx context.Context,
|
|||
request *milvuspb.QueryRequest,
|
||||
result *milvuspb.SearchResults,
|
||||
queryChannelsTs map[string]Timestamp,
|
||||
partitionIDs []int64,
|
||||
) error {
|
||||
outputFields := request.GetOutputFields()
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
|
||||
|
@ -701,7 +702,8 @@ func doRequery(ctx context.Context,
|
|||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
PartitionIDs: partitionIDs, // use search partitionIDs
|
||||
},
|
||||
request: request,
|
||||
plan: plan,
|
||||
|
@ -709,6 +711,7 @@ func doRequery(ctx context.Context,
|
|||
lb: node.(*Proxy).lbPolicy,
|
||||
channelsMvcc: channelsMvcc,
|
||||
fastSkip: true,
|
||||
reQuery: true,
|
||||
}
|
||||
queryResult, err := node.(*Proxy).query(ctx, qt)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue