From f9c630247d2f6406f49a17ec2fae68b9a888d322 Mon Sep 17 00:00:00 2001
From: "yihao.dai" <yihao.dai@zilliz.com>
Date: Thu, 26 Oct 2023 19:30:10 +0800
Subject: [PATCH] Construct plan directly when search with vector output
 (#27928)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
---
 .../parser/planparserv2/plan_parser_v2.go     |  46 ++++
 internal/proxy/impl.go                        |  37 ++--
 internal/proxy/task_query.go                  |  12 +-
 internal/proxy/task_search.go                 |  21 +-
 internal/proxy/task_search_test.go            | 198 +++++++-----------
 internal/proxy/task_test.go                   |   4 +-
 pkg/util/typeutil/schema.go                   |   2 +-
 7 files changed, 169 insertions(+), 151 deletions(-)

diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go
index bd7feb3b58..2b93154687 100644
--- a/internal/parser/planparserv2/plan_parser_v2.go
+++ b/internal/parser/planparserv2/plan_parser_v2.go
@@ -4,6 +4,7 @@ import (
 	"fmt"
 
 	"github.com/antlr/antlr4/runtime/Go/antlr"
+	"github.com/samber/lo"
 	"go.uber.org/zap"
 
 	"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@@ -158,3 +159,48 @@ func CreateSearchPlan(schemaPb *schemapb.CollectionSchema, exprStr string, vecto
 	}
 	return planNode, nil
 }
+
+func CreateRequeryPlan(pkField *schemapb.FieldSchema, ids *schemapb.IDs) *planpb.PlanNode {
+	var values []*planpb.GenericValue
+	switch ids.GetIdField().(type) {
+	case *schemapb.IDs_IntId:
+		values = lo.Map(ids.GetIntId().GetData(), func(id int64, _ int) *planpb.GenericValue {
+			return &planpb.GenericValue{
+				Val: &planpb.GenericValue_Int64Val{
+					Int64Val: id,
+				},
+			}
+		})
+	case *schemapb.IDs_StrId:
+		values = lo.Map(ids.GetStrId().GetData(), func(id string, _ int) *planpb.GenericValue {
+			return &planpb.GenericValue{
+				Val: &planpb.GenericValue_StringVal{
+					StringVal: id,
+				},
+			}
+		})
+	}
+
+	return &planpb.PlanNode{
+		Node: &planpb.PlanNode_Query{
+			Query: &planpb.QueryPlanNode{
+				Predicates: &planpb.Expr{
+					Expr: &planpb.Expr_TermExpr{
+						TermExpr: &planpb.TermExpr{
+							ColumnInfo: &planpb.ColumnInfo{
+								FieldId:        pkField.GetFieldID(),
+								DataType:       pkField.GetDataType(),
+								IsPrimaryKey:   true,
+								IsAutoID:       pkField.GetAutoID(),
+								IsPartitionKey: pkField.GetIsPartitionKey(),
+							},
+							Values: values,
+						},
+					},
+				},
+				IsCount: false,
+				Limit:   int64(len(values)),
+			},
+		},
+	}
+}
diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go
index ff9e578d96..21f7f18ee5 100644
--- a/internal/proxy/impl.go
+++ b/internal/proxy/impl.go
@@ -2774,7 +2774,8 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
 }
 
 // Query get the records by primary keys.
-func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
+func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryResults, error) {
+	request := qt.request
 	receiveSize := proto.Size(request)
 	metrics.ProxyReceiveBytes.WithLabelValues(
 		strconv.FormatInt(paramtable.GetNodeID(), 10),
@@ -2800,21 +2801,6 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
 	defer sp.End()
 	tr := timerecord.NewTimeRecorder("Query")
 
-	qt := &queryTask{
-		ctx:       ctx,
-		Condition: NewTaskCondition(ctx),
-		RetrieveRequest: &internalpb.RetrieveRequest{
-			Base: commonpbutil.NewMsgBase(
-				commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
-				commonpbutil.WithSourceID(paramtable.GetNodeID()),
-			),
-			ReqID: paramtable.GetNodeID(),
-		},
-		request: request,
-		qc:      node.queryCoord,
-		lb:      node.lbPolicy,
-	}
-
 	method := "Query"
 
 	metrics.ProxyFunctionCall.WithLabelValues(
@@ -2915,6 +2901,25 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
 	return qt.result, nil
 }
 
+// Query get the records by primary keys.
+func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
+	qt := &queryTask{
+		ctx:       ctx,
+		Condition: NewTaskCondition(ctx),
+		RetrieveRequest: &internalpb.RetrieveRequest{
+			Base: commonpbutil.NewMsgBase(
+				commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
+				commonpbutil.WithSourceID(paramtable.GetNodeID()),
+			),
+			ReqID: paramtable.GetNodeID(),
+		},
+		request: request,
+		qc:      node.queryCoord,
+		lb:      node.lbPolicy,
+	}
+	return node.query(ctx, qt)
+}
+
 // CreateAlias create alias for collection, then you can search the collection with alias.
 func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) {
 	if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go
index ce13ee9885..8e6a0bec76 100644
--- a/internal/proxy/task_query.go
+++ b/internal/proxy/task_query.go
@@ -210,9 +210,12 @@ func (t *queryTask) createPlan(ctx context.Context) error {
 		return err
 	}
 
-	plan, err := planparserv2.CreateRetrievePlan(schema, t.request.Expr)
-	if err != nil {
-		return err
+	var err error
+	if t.plan == nil {
+		t.plan, err = planparserv2.CreateRetrievePlan(schema, t.request.Expr)
+		if err != nil {
+			return err
+		}
 	}
 
 	t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, schema, true)
@@ -226,8 +229,7 @@ func (t *queryTask) createPlan(ctx context.Context) error {
 	}
 	outputFieldIDs = append(outputFieldIDs, common.TimeStampField)
 	t.RetrieveRequest.OutputFieldsId = outputFieldIDs
-	plan.OutputFieldIds = outputFieldIDs
-	t.plan = plan
+	t.plan.OutputFieldIds = outputFieldIDs
 	log.Ctx(ctx).Debug("translate output fields to field ids",
 		zap.Int64s("OutputFieldsID", t.OutputFieldsId),
 		zap.String("requestType", "query"))
diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go
index b5f7493914..19c5498017 100644
--- a/internal/proxy/task_search.go
+++ b/internal/proxy/task_search.go
@@ -573,7 +573,7 @@ func (t *searchTask) Requery() error {
 		return err
 	}
 	ids := t.result.GetResults().GetIds()
-	expr := IDs2Expr(pkField.GetName(), ids)
+	plan := planparserv2.CreateRequeryPlan(pkField, ids)
 
 	queryReq := &milvuspb.QueryRequest{
 		Base: &commonpb.MsgBase{
@@ -581,13 +581,28 @@ func (t *searchTask) Requery() error {
 		},
 		DbName:             t.request.GetDbName(),
 		CollectionName:     t.request.GetCollectionName(),
-		Expr:               expr,
+		Expr:               "",
 		OutputFields:       t.request.GetOutputFields(),
 		PartitionNames:     t.request.GetPartitionNames(),
 		GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
 		QueryParams:        t.request.GetSearchParams(),
 	}
-	queryResult, err := t.node.Query(t.ctx, queryReq)
+	qt := &queryTask{
+		ctx:       t.ctx,
+		Condition: NewTaskCondition(t.ctx),
+		RetrieveRequest: &internalpb.RetrieveRequest{
+			Base: commonpbutil.NewMsgBase(
+				commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
+				commonpbutil.WithSourceID(paramtable.GetNodeID()),
+			),
+			ReqID: paramtable.GetNodeID(),
+		},
+		request: queryReq,
+		plan:    plan,
+		qc:      t.node.(*Proxy).queryCoord,
+		lb:      t.node.(*Proxy).lbPolicy,
+	}
+	queryResult, err := t.node.(*Proxy).query(t.ctx, qt)
 	if err != nil {
 		return err
 	}
diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go
index d4edc62269..c44e685db6 100644
--- a/internal/proxy/task_search_test.go
+++ b/internal/proxy/task_search_test.go
@@ -36,6 +36,7 @@ import (
 	"github.com/milvus-io/milvus/internal/proto/internalpb"
 	"github.com/milvus-io/milvus/internal/proto/querypb"
 	"github.com/milvus-io/milvus/internal/types"
+	"github.com/milvus-io/milvus/internal/util/dependency"
 	"github.com/milvus-io/milvus/pkg/common"
 	"github.com/milvus-io/milvus/pkg/util/funcutil"
 	"github.com/milvus-io/milvus/pkg/util/merr"
@@ -1906,11 +1907,46 @@ func TestSearchTask_Requery(t *testing.T) {
 		ids[i] = int64(i)
 	}
 
+	factory := dependency.NewDefaultFactory(true)
+	node, err := NewProxy(ctx, factory)
+	assert.NoError(t, err)
+	node.UpdateStateCode(commonpb.StateCode_Healthy)
+	node.tsoAllocator = &timestampAllocator{
+		tso: newMockTimestampAllocatorInterface(),
+	}
+	scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
+	assert.NoError(t, err)
+	node.sched = scheduler
+	err = node.sched.Start()
+	assert.NoError(t, err)
+	err = node.initRateCollector()
+	assert.NoError(t, err)
+	node.rootCoord = mocks.NewMockRootCoordClient(t)
+	node.queryCoord = mocks.NewMockQueryCoordClient(t)
+
+	collectionName := "col"
+	collectionID := UniqueID(0)
+	cache := NewMockCache(t)
+	cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe()
+	cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(constructCollectionSchema(pkField, vecField, dim, collection), nil).Maybe()
+	cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
+	cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionBasicInfo{}, nil).Maybe()
+	cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe()
+	cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
+	globalMetaCache = cache
+
 	t.Run("Test normal", func(t *testing.T) {
 		schema := constructCollectionSchema(pkField, vecField, dim, collection)
-		node := mocks.NewMockProxy(t)
-		node.EXPECT().Query(mock.Anything, mock.Anything).
-			Return(&milvuspb.QueryResults{
+		qn := mocks.NewMockQueryNodeClient(t)
+		qn.EXPECT().Query(mock.Anything, mock.Anything).
+			Return(&internalpb.RetrieveResults{
+				Ids: &schemapb.IDs{
+					IdField: &schemapb.IDs_IntId{
+						IntId: &schemapb.LongArray{
+							Data: ids,
+						},
+					},
+				},
 				FieldsData: []*schemapb.FieldData{
 					{
 						Type:      schemapb.DataType_Int64,
@@ -1929,6 +1965,14 @@ func TestSearchTask_Requery(t *testing.T) {
 				},
 			}, nil)
 
+		lb := NewMockLBPolicy(t)
+		lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
+			err = workload.exec(ctx, 0, qn)
+			assert.NoError(t, err)
+		}).Return(nil)
+		lb.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything).Return()
+		node.lbPolicy = lb
+
 		resultIDs := &schemapb.IDs{
 			IdField: &schemapb.IDs_IntId{
 				IntId: &schemapb.LongArray{
@@ -1937,7 +1981,7 @@ func TestSearchTask_Requery(t *testing.T) {
 			},
 		}
 
-		outputFields := []string{vecField}
+		outputFields := []string{pkField, vecField}
 		qt := &searchTask{
 			ctx: ctx,
 			SearchRequest: &internalpb.SearchRequest{
@@ -1947,7 +1991,8 @@ func TestSearchTask_Requery(t *testing.T) {
 				},
 			},
 			request: &milvuspb.SearchRequest{
-				OutputFields: outputFields,
+				CollectionName: collectionName,
+				OutputFields:   outputFields,
 			},
 			result: &milvuspb.SearchResults{
 				Results: &schemapb.SearchResultData{
@@ -1961,8 +2006,9 @@ func TestSearchTask_Requery(t *testing.T) {
 
 		err := qt.Requery()
 		assert.NoError(t, err)
-		assert.Len(t, qt.result.Results.FieldsData, 1)
-		assert.Equal(t, vecField, qt.result.Results.FieldsData[0].GetFieldName())
+		assert.Len(t, qt.result.Results.FieldsData, 2)
+		assert.Equal(t, pkField, qt.result.Results.FieldsData[0].GetFieldName())
+		assert.Equal(t, vecField, qt.result.Results.FieldsData[1].GetFieldName())
 	})
 
 	t.Run("Test no primary key", func(t *testing.T) {
@@ -1988,41 +2034,17 @@ func TestSearchTask_Requery(t *testing.T) {
 		assert.Error(t, err)
 	})
 
-	t.Run("Test requery failed 1", func(t *testing.T) {
+	t.Run("Test requery failed", func(t *testing.T) {
 		schema := constructCollectionSchema(pkField, vecField, dim, collection)
-		node := mocks.NewMockProxy(t)
-		node.EXPECT().Query(mock.Anything, mock.Anything).
+		qn := mocks.NewMockQueryNodeClient(t)
+		qn.EXPECT().Query(mock.Anything, mock.Anything).
 			Return(nil, fmt.Errorf("mock err 1"))
 
-		qt := &searchTask{
-			ctx: ctx,
-			SearchRequest: &internalpb.SearchRequest{
-				Base: &commonpb.MsgBase{
-					MsgType:  commonpb.MsgType_Search,
-					SourceID: paramtable.GetNodeID(),
-				},
-			},
-			request: &milvuspb.SearchRequest{},
-			schema:  schema,
-			tr:      timerecord.NewTimeRecorder("search"),
-			node:    node,
-		}
-
-		err := qt.Requery()
-		t.Logf("err = %s", err)
-		assert.Error(t, err)
-	})
-
-	t.Run("Test requery failed 2", func(t *testing.T) {
-		schema := constructCollectionSchema(pkField, vecField, dim, collection)
-		node := mocks.NewMockProxy(t)
-		node.EXPECT().Query(mock.Anything, mock.Anything).
-			Return(&milvuspb.QueryResults{
-				Status: &commonpb.Status{
-					ErrorCode: commonpb.ErrorCode_UnexpectedError,
-					Reason:    "mock err 2",
-				},
-			}, nil)
+		lb := NewMockLBPolicy(t)
+		lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
+			_ = workload.exec(ctx, 0, qn)
+		}).Return(fmt.Errorf("mock err 1"))
+		node.lbPolicy = lb
 
 		qt := &searchTask{
 			ctx: ctx,
@@ -2032,88 +2054,8 @@ func TestSearchTask_Requery(t *testing.T) {
 					SourceID: paramtable.GetNodeID(),
 				},
 			},
-			request: &milvuspb.SearchRequest{},
-			schema:  schema,
-			tr:      timerecord.NewTimeRecorder("search"),
-			node:    node,
-		}
-
-		err := qt.Requery()
-		t.Logf("err = %s", err)
-		assert.Error(t, err)
-	})
-
-	t.Run("Test get pk field data failed", func(t *testing.T) {
-		schema := constructCollectionSchema(pkField, vecField, dim, collection)
-		node := mocks.NewMockProxy(t)
-		node.EXPECT().Query(mock.Anything, mock.Anything).
-			Return(&milvuspb.QueryResults{
-				FieldsData: []*schemapb.FieldData{},
-			}, nil)
-
-		qt := &searchTask{
-			ctx: ctx,
-			SearchRequest: &internalpb.SearchRequest{
-				Base: &commonpb.MsgBase{
-					MsgType:  commonpb.MsgType_Search,
-					SourceID: paramtable.GetNodeID(),
-				},
-			},
-			request: &milvuspb.SearchRequest{},
-			schema:  schema,
-			tr:      timerecord.NewTimeRecorder("search"),
-			node:    node,
-		}
-
-		err := qt.Requery()
-		t.Logf("err = %s", err)
-		assert.Error(t, err)
-	})
-
-	t.Run("Test incomplete query result", func(t *testing.T) {
-		schema := constructCollectionSchema(pkField, vecField, dim, collection)
-		node := mocks.NewMockProxy(t)
-		node.EXPECT().Query(mock.Anything, mock.Anything).
-			Return(&milvuspb.QueryResults{
-				FieldsData: []*schemapb.FieldData{
-					{
-						Type:      schemapb.DataType_Int64,
-						FieldName: pkField,
-						Field: &schemapb.FieldData_Scalars{
-							Scalars: &schemapb.ScalarField{
-								Data: &schemapb.ScalarField_LongData{
-									LongData: &schemapb.LongArray{
-										Data: ids[:len(ids)-1],
-									},
-								},
-							},
-						},
-					},
-					newFloatVectorFieldData(vecField, rows, dim),
-				},
-			}, nil)
-
-		resultIDs := &schemapb.IDs{
-			IdField: &schemapb.IDs_IntId{
-				IntId: &schemapb.LongArray{
-					Data: ids,
-				},
-			},
-		}
-
-		qt := &searchTask{
-			ctx: ctx,
-			SearchRequest: &internalpb.SearchRequest{
-				Base: &commonpb.MsgBase{
-					MsgType:  commonpb.MsgType_Search,
-					SourceID: paramtable.GetNodeID(),
-				},
-			},
-			request: &milvuspb.SearchRequest{},
-			result: &milvuspb.SearchResults{
-				Results: &schemapb.SearchResultData{
-					Ids: resultIDs,
-				},
+			request: &milvuspb.SearchRequest{
+				CollectionName: collectionName,
 			},
 			schema: schema,
 			tr:     timerecord.NewTimeRecorder("search"),
@@ -2127,10 +2069,16 @@ func TestSearchTask_Requery(t *testing.T) {
 
 	t.Run("Test postExecute with requery failed", func(t *testing.T) {
 		schema := constructCollectionSchema(pkField, vecField, dim, collection)
-		node := mocks.NewMockProxy(t)
-		node.EXPECT().Query(mock.Anything, mock.Anything).
+		qn := mocks.NewMockQueryNodeClient(t)
+		qn.EXPECT().Query(mock.Anything, mock.Anything).
 			Return(nil, fmt.Errorf("mock err 1"))
 
+		lb := NewMockLBPolicy(t)
+		lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
+			_ = workload.exec(ctx, 0, qn)
+		}).Return(fmt.Errorf("mock err 1"))
+		node.lbPolicy = lb
+
 		resultIDs := &schemapb.IDs{
 			IdField: &schemapb.IDs_IntId{
 				IntId: &schemapb.LongArray{
@@ -2147,7 +2095,9 @@ func TestSearchTask_Requery(t *testing.T) {
 					SourceID: paramtable.GetNodeID(),
 				},
 			},
-			request: &milvuspb.SearchRequest{},
+			request: &milvuspb.SearchRequest{
+				CollectionName: collectionName,
+			},
 			result: &milvuspb.SearchResults{
 				Results: &schemapb.SearchResultData{
 					Ids: resultIDs,
diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go
index 6713dd625e..91ab82044f 100644
--- a/internal/proxy/task_test.go
+++ b/internal/proxy/task_test.go
@@ -77,7 +77,7 @@ func constructCollectionSchema(
 	collectionName string,
 ) *schemapb.CollectionSchema {
 	pk := &schemapb.FieldSchema{
-		FieldID:      0,
+		FieldID:      100,
 		Name:         int64Field,
 		IsPrimaryKey: true,
 		Description:  "",
@@ -87,7 +87,7 @@ func constructCollectionSchema(
 		AutoID:       true,
 	}
 	fVec := &schemapb.FieldSchema{
-		FieldID:      0,
+		FieldID:      101,
 		Name:         floatVecField,
 		IsPrimaryKey: false,
 		Description:  "",
diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go
index 16f12fa40f..4dd7a80ce3 100644
--- a/pkg/util/typeutil/schema.go
+++ b/pkg/util/typeutil/schema.go
@@ -828,7 +828,7 @@ func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schema
 	}
 
 	if primaryFieldData == nil {
-		return nil, fmt.Errorf("can't find data for primary field %v", primaryFieldName)
+		return nil, fmt.Errorf("can't find data for primary field: %v", primaryFieldName)
 	}
 
 	return primaryFieldData, nil