mirror of https://github.com/milvus-io/milvus.git
enhance: Utilize partition key optimization in reQuery (#30253)
See also #30250 This PR add requery flag in query task. When reQuery flag is true, query task shall skip partition name conversion and use pre-calculated partitionIDs passed from search task. TODO: hybrid search does not have partition id information. we shall apply same logic for hybrid search later. Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/30189/head
parent
4f25066aa7
commit
8e8ac213aa
|
@ -305,7 +305,8 @@ func (t *hybridSearchTask) Requery() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:Xige-16 refine the mvcc functionality of hybrid search
|
// TODO:Xige-16 refine the mvcc functionality of hybrid search
|
||||||
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs)
|
// TODO:silverxia move partitionIDs to hybrid search level
|
||||||
|
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func rankSearchResultData(ctx context.Context,
|
func rankSearchResultData(ctx context.Context,
|
||||||
|
|
|
@ -63,6 +63,8 @@ type queryTask struct {
|
||||||
lb LBPolicy
|
lb LBPolicy
|
||||||
channelsMvcc map[string]Timestamp
|
channelsMvcc map[string]Timestamp
|
||||||
fastSkip bool
|
fastSkip bool
|
||||||
|
|
||||||
|
reQuery bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type queryParams struct {
|
type queryParams struct {
|
||||||
|
@ -327,23 +329,26 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
||||||
return fmt.Errorf("empty expression should be used with limit")
|
return fmt.Errorf("empty expression should be used with limit")
|
||||||
}
|
}
|
||||||
|
|
||||||
partitionNames := t.request.GetPartitionNames()
|
// convert partition names only when requery is false
|
||||||
if t.partitionKeyMode {
|
if !t.reQuery {
|
||||||
expr, err := ParseExprFromPlan(t.plan)
|
partitionNames := t.request.GetPartitionNames()
|
||||||
if err != nil {
|
if t.partitionKeyMode {
|
||||||
return err
|
expr, err := ParseExprFromPlan(t.plan)
|
||||||
}
|
if err != nil {
|
||||||
partitionKeys := ParsePartitionKeys(expr)
|
return err
|
||||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
|
}
|
||||||
if err != nil {
|
partitionKeys := ParsePartitionKeys(expr)
|
||||||
return err
|
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
|
||||||
}
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
partitionNames = append(partitionNames, hashedPartitionNames...)
|
partitionNames = append(partitionNames, hashedPartitionNames...)
|
||||||
}
|
}
|
||||||
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
|
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// count with pagination
|
// count with pagination
|
||||||
|
|
|
@ -627,7 +627,7 @@ func (t *searchTask) Requery() error {
|
||||||
QueryParams: t.request.GetSearchParams(),
|
QueryParams: t.request.GetSearchParams(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs)
|
return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
||||||
|
@ -681,6 +681,7 @@ func doRequery(ctx context.Context,
|
||||||
request *milvuspb.QueryRequest,
|
request *milvuspb.QueryRequest,
|
||||||
result *milvuspb.SearchResults,
|
result *milvuspb.SearchResults,
|
||||||
queryChannelsTs map[string]Timestamp,
|
queryChannelsTs map[string]Timestamp,
|
||||||
|
partitionIDs []int64,
|
||||||
) error {
|
) error {
|
||||||
outputFields := request.GetOutputFields()
|
outputFields := request.GetOutputFields()
|
||||||
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
|
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
|
||||||
|
@ -701,7 +702,8 @@ func doRequery(ctx context.Context,
|
||||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||||
),
|
),
|
||||||
ReqID: paramtable.GetNodeID(),
|
ReqID: paramtable.GetNodeID(),
|
||||||
|
PartitionIDs: partitionIDs, // use search partitionIDs
|
||||||
},
|
},
|
||||||
request: request,
|
request: request,
|
||||||
plan: plan,
|
plan: plan,
|
||||||
|
@ -709,6 +711,7 @@ func doRequery(ctx context.Context,
|
||||||
lb: node.(*Proxy).lbPolicy,
|
lb: node.(*Proxy).lbPolicy,
|
||||||
channelsMvcc: channelsMvcc,
|
channelsMvcc: channelsMvcc,
|
||||||
fastSkip: true,
|
fastSkip: true,
|
||||||
|
reQuery: true,
|
||||||
}
|
}
|
||||||
queryResult, err := node.(*Proxy).query(ctx, qt)
|
queryResult, err := node.(*Proxy).query(ctx, qt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue