From f9c630247d2f6406f49a17ec2fae68b9a888d322 Mon Sep 17 00:00:00 2001 From: "yihao.dai" <yihao.dai@zilliz.com> Date: Thu, 26 Oct 2023 19:30:10 +0800 Subject: [PATCH] Construct plan directly when search with vector output (#27928) Signed-off-by: bigsheeper <yihao.dai@zilliz.com> --- .../parser/planparserv2/plan_parser_v2.go | 46 ++++ internal/proxy/impl.go | 37 ++-- internal/proxy/task_query.go | 12 +- internal/proxy/task_search.go | 21 +- internal/proxy/task_search_test.go | 198 +++++++----------- internal/proxy/task_test.go | 4 +- pkg/util/typeutil/schema.go | 2 +- 7 files changed, 169 insertions(+), 151 deletions(-) diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index bd7feb3b58..2b93154687 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/antlr/antlr4/runtime/Go/antlr" + "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -158,3 +159,48 @@ func CreateSearchPlan(schemaPb *schemapb.CollectionSchema, exprStr string, vecto } return planNode, nil } + +func CreateRequeryPlan(pkField *schemapb.FieldSchema, ids *schemapb.IDs) *planpb.PlanNode { + var values []*planpb.GenericValue + switch ids.GetIdField().(type) { + case *schemapb.IDs_IntId: + values = lo.Map(ids.GetIntId().GetData(), func(id int64, _ int) *planpb.GenericValue { + return &planpb.GenericValue{ + Val: &planpb.GenericValue_Int64Val{ + Int64Val: id, + }, + } + }) + case *schemapb.IDs_StrId: + values = lo.Map(ids.GetStrId().GetData(), func(id string, _ int) *planpb.GenericValue { + return &planpb.GenericValue{ + Val: &planpb.GenericValue_StringVal{ + StringVal: id, + }, + } + }) + } + + return &planpb.PlanNode{ + Node: &planpb.PlanNode_Query{ + Query: &planpb.QueryPlanNode{ + Predicates: &planpb.Expr{ + Expr: &planpb.Expr_TermExpr{ + TermExpr: &planpb.TermExpr{ + ColumnInfo: &planpb.ColumnInfo{ + FieldId: pkField.GetFieldID(), + DataType: pkField.GetDataType(), + IsPrimaryKey: true, + IsAutoID: pkField.GetAutoID(), + IsPartitionKey: pkField.GetIsPartitionKey(), + }, + Values: values, + }, + }, + }, + IsCount: false, + Limit: int64(len(values)), + }, + }, + } +} diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index ff9e578d96..21f7f18ee5 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2774,7 +2774,8 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (* } // Query get the records by primary keys. -func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { +func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryResults, error) { + request := qt.request receiveSize := proto.Size(request) metrics.ProxyReceiveBytes.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), @@ -2800,21 +2801,6 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* defer sp.End() tr := timerecord.NewTimeRecorder("Query") - qt := &queryTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - RetrieveRequest: &internalpb.RetrieveRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - ReqID: paramtable.GetNodeID(), - }, - request: request, - qc: node.queryCoord, - lb: node.lbPolicy, - } - method := "Query" metrics.ProxyFunctionCall.WithLabelValues( @@ -2915,6 +2901,25 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* return qt.result, nil } +// Query get the records by primary keys. +func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + qt := &queryTask{ + ctx: ctx, + Condition: NewTaskCondition(ctx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + ReqID: paramtable.GetNodeID(), + }, + request: request, + qc: node.queryCoord, + lb: node.lbPolicy, + } + return node.query(ctx, qt) +} + // CreateAlias create alias for collection, then you can search the collection with alias. func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { if err := merr.CheckHealthy(node.GetStateCode()); err != nil { diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index ce13ee9885..8e6a0bec76 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -210,9 +210,12 @@ func (t *queryTask) createPlan(ctx context.Context) error { return err } - plan, err := planparserv2.CreateRetrievePlan(schema, t.request.Expr) - if err != nil { - return err + var err error + if t.plan == nil { + t.plan, err = planparserv2.CreateRetrievePlan(schema, t.request.Expr) + if err != nil { + return err + } } t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, schema, true) @@ -226,8 +229,7 @@ func (t *queryTask) createPlan(ctx context.Context) error { } outputFieldIDs = append(outputFieldIDs, common.TimeStampField) t.RetrieveRequest.OutputFieldsId = outputFieldIDs - plan.OutputFieldIds = outputFieldIDs - t.plan = plan + t.plan.OutputFieldIds = outputFieldIDs log.Ctx(ctx).Debug("translate output fields to field ids", zap.Int64s("OutputFieldsID", t.OutputFieldsId), zap.String("requestType", "query")) diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index b5f7493914..19c5498017 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -573,7 +573,7 @@ func (t *searchTask) Requery() error { return err } ids := t.result.GetResults().GetIds() - expr := IDs2Expr(pkField.GetName(), ids) + plan := planparserv2.CreateRequeryPlan(pkField, ids) queryReq := &milvuspb.QueryRequest{ Base: &commonpb.MsgBase{ @@ -581,13 +581,28 @@ func (t *searchTask) Requery() error { }, DbName: t.request.GetDbName(), CollectionName: t.request.GetCollectionName(), - Expr: expr, + Expr: "", OutputFields: t.request.GetOutputFields(), PartitionNames: t.request.GetPartitionNames(), GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(), QueryParams: t.request.GetSearchParams(), } - queryResult, err := t.node.Query(t.ctx, queryReq) + qt := &queryTask{ + ctx: t.ctx, + Condition: NewTaskCondition(t.ctx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + ReqID: paramtable.GetNodeID(), + }, + request: queryReq, + plan: plan, + qc: t.node.(*Proxy).queryCoord, + lb: t.node.(*Proxy).lbPolicy, + } + queryResult, err := t.node.(*Proxy).query(t.ctx, qt) if err != nil { return err } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index d4edc62269..c44e685db6 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -36,6 +36,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -1906,11 +1907,46 @@ func TestSearchTask_Requery(t *testing.T) { ids[i] = int64(i) } + factory := dependency.NewDefaultFactory(true) + node, err := NewProxy(ctx, factory) + assert.NoError(t, err) + node.UpdateStateCode(commonpb.StateCode_Healthy) + node.tsoAllocator = ×tampAllocator{ + tso: newMockTimestampAllocatorInterface(), + } + scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory) + assert.NoError(t, err) + node.sched = scheduler + err = node.sched.Start() + assert.NoError(t, err) + err = node.initRateCollector() + assert.NoError(t, err) + node.rootCoord = mocks.NewMockRootCoordClient(t) + node.queryCoord = mocks.NewMockQueryCoordClient(t) + + collectionName := "col" + collectionID := UniqueID(0) + cache := NewMockCache(t) + cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe() + cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(constructCollectionSchema(pkField, vecField, dim, collection), nil).Maybe() + cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe() + cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionBasicInfo{}, nil).Maybe() + cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe() + cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe() + globalMetaCache = cache + t.Run("Test normal", func(t *testing.T) { schema := constructCollectionSchema(pkField, vecField, dim, collection) - node := mocks.NewMockProxy(t) - node.EXPECT().Query(mock.Anything, mock.Anything). - Return(&milvuspb.QueryResults{ + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Query(mock.Anything, mock.Anything). + Return(&internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + }, FieldsData: []*schemapb.FieldData{ { Type: schemapb.DataType_Int64, @@ -1929,6 +1965,14 @@ func TestSearchTask_Requery(t *testing.T) { }, }, nil) + lb := NewMockLBPolicy(t) + lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) { + err = workload.exec(ctx, 0, qn) + assert.NoError(t, err) + }).Return(nil) + lb.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything).Return() + node.lbPolicy = lb + resultIDs := &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ @@ -1937,7 +1981,7 @@ func TestSearchTask_Requery(t *testing.T) { }, } - outputFields := []string{vecField} + outputFields := []string{pkField, vecField} qt := &searchTask{ ctx: ctx, SearchRequest: &internalpb.SearchRequest{ @@ -1947,7 +1991,8 @@ func TestSearchTask_Requery(t *testing.T) { }, }, request: &milvuspb.SearchRequest{ - OutputFields: outputFields, + CollectionName: collectionName, + OutputFields: outputFields, }, result: &milvuspb.SearchResults{ Results: &schemapb.SearchResultData{ @@ -1961,8 +2006,9 @@ func TestSearchTask_Requery(t *testing.T) { err := qt.Requery() assert.NoError(t, err) - assert.Len(t, qt.result.Results.FieldsData, 1) - assert.Equal(t, vecField, qt.result.Results.FieldsData[0].GetFieldName()) + assert.Len(t, qt.result.Results.FieldsData, 2) + assert.Equal(t, pkField, qt.result.Results.FieldsData[0].GetFieldName()) + assert.Equal(t, vecField, qt.result.Results.FieldsData[1].GetFieldName()) }) t.Run("Test no primary key", func(t *testing.T) { @@ -1988,41 +2034,17 @@ func TestSearchTask_Requery(t *testing.T) { assert.Error(t, err) }) - t.Run("Test requery failed 1", func(t *testing.T) { + t.Run("Test requery failed", func(t *testing.T) { schema := constructCollectionSchema(pkField, vecField, dim, collection) - node := mocks.NewMockProxy(t) - node.EXPECT().Query(mock.Anything, mock.Anything). + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Query(mock.Anything, mock.Anything). Return(nil, fmt.Errorf("mock err 1")) - qt := &searchTask{ - ctx: ctx, - SearchRequest: &internalpb.SearchRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Search, - SourceID: paramtable.GetNodeID(), - }, - }, - request: &milvuspb.SearchRequest{}, - schema: schema, - tr: timerecord.NewTimeRecorder("search"), - node: node, - } - - err := qt.Requery() - t.Logf("err = %s", err) - assert.Error(t, err) - }) - - t.Run("Test requery failed 2", func(t *testing.T) { - schema := constructCollectionSchema(pkField, vecField, dim, collection) - node := mocks.NewMockProxy(t) - node.EXPECT().Query(mock.Anything, mock.Anything). - Return(&milvuspb.QueryResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock err 2", - }, - }, nil) + lb := NewMockLBPolicy(t) + lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) { + _ = workload.exec(ctx, 0, qn) + }).Return(fmt.Errorf("mock err 1")) + node.lbPolicy = lb qt := &searchTask{ ctx: ctx, @@ -2032,88 +2054,8 @@ func TestSearchTask_Requery(t *testing.T) { SourceID: paramtable.GetNodeID(), }, }, - request: &milvuspb.SearchRequest{}, - schema: schema, - tr: timerecord.NewTimeRecorder("search"), - node: node, - } - - err := qt.Requery() - t.Logf("err = %s", err) - assert.Error(t, err) - }) - - t.Run("Test get pk field data failed", func(t *testing.T) { - schema := constructCollectionSchema(pkField, vecField, dim, collection) - node := mocks.NewMockProxy(t) - node.EXPECT().Query(mock.Anything, mock.Anything). - Return(&milvuspb.QueryResults{ - FieldsData: []*schemapb.FieldData{}, - }, nil) - - qt := &searchTask{ - ctx: ctx, - SearchRequest: &internalpb.SearchRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Search, - SourceID: paramtable.GetNodeID(), - }, - }, - request: &milvuspb.SearchRequest{}, - schema: schema, - tr: timerecord.NewTimeRecorder("search"), - node: node, - } - - err := qt.Requery() - t.Logf("err = %s", err) - assert.Error(t, err) - }) - - t.Run("Test incomplete query result", func(t *testing.T) { - schema := constructCollectionSchema(pkField, vecField, dim, collection) - node := mocks.NewMockProxy(t) - node.EXPECT().Query(mock.Anything, mock.Anything). - Return(&milvuspb.QueryResults{ - FieldsData: []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - FieldName: pkField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: ids[:len(ids)-1], - }, - }, - }, - }, - }, - newFloatVectorFieldData(vecField, rows, dim), - }, - }, nil) - - resultIDs := &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: ids, - }, - }, - } - - qt := &searchTask{ - ctx: ctx, - SearchRequest: &internalpb.SearchRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Search, - SourceID: paramtable.GetNodeID(), - }, - }, - request: &milvuspb.SearchRequest{}, - result: &milvuspb.SearchResults{ - Results: &schemapb.SearchResultData{ - Ids: resultIDs, - }, + request: &milvuspb.SearchRequest{ + CollectionName: collectionName, }, schema: schema, tr: timerecord.NewTimeRecorder("search"), @@ -2127,10 +2069,16 @@ func TestSearchTask_Requery(t *testing.T) { t.Run("Test postExecute with requery failed", func(t *testing.T) { schema := constructCollectionSchema(pkField, vecField, dim, collection) - node := mocks.NewMockProxy(t) - node.EXPECT().Query(mock.Anything, mock.Anything). + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Query(mock.Anything, mock.Anything). Return(nil, fmt.Errorf("mock err 1")) + lb := NewMockLBPolicy(t) + lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) { + _ = workload.exec(ctx, 0, qn) + }).Return(fmt.Errorf("mock err 1")) + node.lbPolicy = lb + resultIDs := &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ @@ -2147,7 +2095,9 @@ func TestSearchTask_Requery(t *testing.T) { SourceID: paramtable.GetNodeID(), }, }, - request: &milvuspb.SearchRequest{}, + request: &milvuspb.SearchRequest{ + CollectionName: collectionName, + }, result: &milvuspb.SearchResults{ Results: &schemapb.SearchResultData{ Ids: resultIDs, diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 6713dd625e..91ab82044f 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -77,7 +77,7 @@ func constructCollectionSchema( collectionName string, ) *schemapb.CollectionSchema { pk := &schemapb.FieldSchema{ - FieldID: 0, + FieldID: 100, Name: int64Field, IsPrimaryKey: true, Description: "", @@ -87,7 +87,7 @@ func constructCollectionSchema( AutoID: true, } fVec := &schemapb.FieldSchema{ - FieldID: 0, + FieldID: 101, Name: floatVecField, IsPrimaryKey: false, Description: "", diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 16f12fa40f..4dd7a80ce3 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -828,7 +828,7 @@ func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schema } if primaryFieldData == nil { - return nil, fmt.Errorf("can't find data for primary field %v", primaryFieldName) + return nil, fmt.Errorf("can't find data for primary field: %v", primaryFieldName) } return primaryFieldData, nil