Construct plan directly when search with vector output (#27928)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/27968/head
yihao.dai 2023-10-26 19:30:10 +08:00 committed by GitHub
parent 13877a07ff
commit f9c630247d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 169 additions and 151 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -158,3 +159,48 @@ func CreateSearchPlan(schemaPb *schemapb.CollectionSchema, exprStr string, vecto
}
return planNode, nil
}
func CreateRequeryPlan(pkField *schemapb.FieldSchema, ids *schemapb.IDs) *planpb.PlanNode {
var values []*planpb.GenericValue
switch ids.GetIdField().(type) {
case *schemapb.IDs_IntId:
values = lo.Map(ids.GetIntId().GetData(), func(id int64, _ int) *planpb.GenericValue {
return &planpb.GenericValue{
Val: &planpb.GenericValue_Int64Val{
Int64Val: id,
},
}
})
case *schemapb.IDs_StrId:
values = lo.Map(ids.GetStrId().GetData(), func(id string, _ int) *planpb.GenericValue {
return &planpb.GenericValue{
Val: &planpb.GenericValue_StringVal{
StringVal: id,
},
}
})
}
return &planpb.PlanNode{
Node: &planpb.PlanNode_Query{
Query: &planpb.QueryPlanNode{
Predicates: &planpb.Expr{
Expr: &planpb.Expr_TermExpr{
TermExpr: &planpb.TermExpr{
ColumnInfo: &planpb.ColumnInfo{
FieldId: pkField.GetFieldID(),
DataType: pkField.GetDataType(),
IsPrimaryKey: true,
IsAutoID: pkField.GetAutoID(),
IsPartitionKey: pkField.GetIsPartitionKey(),
},
Values: values,
},
},
},
IsCount: false,
Limit: int64(len(values)),
},
},
}
}

View File

@ -2774,7 +2774,8 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
}
// Query get the records by primary keys.
func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryResults, error) {
request := qt.request
receiveSize := proto.Size(request)
metrics.ProxyReceiveBytes.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
@ -2800,21 +2801,6 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
defer sp.End()
tr := timerecord.NewTimeRecorder("Query")
qt := &queryTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: request,
qc: node.queryCoord,
lb: node.lbPolicy,
}
method := "Query"
metrics.ProxyFunctionCall.WithLabelValues(
@ -2915,6 +2901,25 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
return qt.result, nil
}
// Query get the records by primary keys.
func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
qt := &queryTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: request,
qc: node.queryCoord,
lb: node.lbPolicy,
}
return node.query(ctx, qt)
}
// CreateAlias create alias for collection, then you can search the collection with alias.
func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) {
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {

View File

@ -210,9 +210,12 @@ func (t *queryTask) createPlan(ctx context.Context) error {
return err
}
plan, err := planparserv2.CreateRetrievePlan(schema, t.request.Expr)
if err != nil {
return err
var err error
if t.plan == nil {
t.plan, err = planparserv2.CreateRetrievePlan(schema, t.request.Expr)
if err != nil {
return err
}
}
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, schema, true)
@ -226,8 +229,7 @@ func (t *queryTask) createPlan(ctx context.Context) error {
}
outputFieldIDs = append(outputFieldIDs, common.TimeStampField)
t.RetrieveRequest.OutputFieldsId = outputFieldIDs
plan.OutputFieldIds = outputFieldIDs
t.plan = plan
t.plan.OutputFieldIds = outputFieldIDs
log.Ctx(ctx).Debug("translate output fields to field ids",
zap.Int64s("OutputFieldsID", t.OutputFieldsId),
zap.String("requestType", "query"))

View File

@ -573,7 +573,7 @@ func (t *searchTask) Requery() error {
return err
}
ids := t.result.GetResults().GetIds()
expr := IDs2Expr(pkField.GetName(), ids)
plan := planparserv2.CreateRequeryPlan(pkField, ids)
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
@ -581,13 +581,28 @@ func (t *searchTask) Requery() error {
},
DbName: t.request.GetDbName(),
CollectionName: t.request.GetCollectionName(),
Expr: expr,
Expr: "",
OutputFields: t.request.GetOutputFields(),
PartitionNames: t.request.GetPartitionNames(),
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
QueryParams: t.request.GetSearchParams(),
}
queryResult, err := t.node.Query(t.ctx, queryReq)
qt := &queryTask{
ctx: t.ctx,
Condition: NewTaskCondition(t.ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: queryReq,
plan: plan,
qc: t.node.(*Proxy).queryCoord,
lb: t.node.(*Proxy).lbPolicy,
}
queryResult, err := t.node.(*Proxy).query(t.ctx, qt)
if err != nil {
return err
}

View File

@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
@ -1906,11 +1907,46 @@ func TestSearchTask_Requery(t *testing.T) {
ids[i] = int64(i)
}
factory := dependency.NewDefaultFactory(true)
node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
assert.NoError(t, err)
node.sched = scheduler
err = node.sched.Start()
assert.NoError(t, err)
err = node.initRateCollector()
assert.NoError(t, err)
node.rootCoord = mocks.NewMockRootCoordClient(t)
node.queryCoord = mocks.NewMockQueryCoordClient(t)
collectionName := "col"
collectionID := UniqueID(0)
cache := NewMockCache(t)
cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe()
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(constructCollectionSchema(pkField, vecField, dim, collection), nil).Maybe()
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionBasicInfo{}, nil).Maybe()
cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe()
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
globalMetaCache = cache
t.Run("Test normal", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewMockProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Query(mock.Anything, mock.Anything).
Return(&internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
@ -1929,6 +1965,14 @@ func TestSearchTask_Requery(t *testing.T) {
},
}, nil)
lb := NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
err = workload.exec(ctx, 0, qn)
assert.NoError(t, err)
}).Return(nil)
lb.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything).Return()
node.lbPolicy = lb
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
@ -1937,7 +1981,7 @@ func TestSearchTask_Requery(t *testing.T) {
},
}
outputFields := []string{vecField}
outputFields := []string{pkField, vecField}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
@ -1947,7 +1991,8 @@ func TestSearchTask_Requery(t *testing.T) {
},
},
request: &milvuspb.SearchRequest{
OutputFields: outputFields,
CollectionName: collectionName,
OutputFields: outputFields,
},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
@ -1961,8 +2006,9 @@ func TestSearchTask_Requery(t *testing.T) {
err := qt.Requery()
assert.NoError(t, err)
assert.Len(t, qt.result.Results.FieldsData, 1)
assert.Equal(t, vecField, qt.result.Results.FieldsData[0].GetFieldName())
assert.Len(t, qt.result.Results.FieldsData, 2)
assert.Equal(t, pkField, qt.result.Results.FieldsData[0].GetFieldName())
assert.Equal(t, vecField, qt.result.Results.FieldsData[1].GetFieldName())
})
t.Run("Test no primary key", func(t *testing.T) {
@ -1988,41 +2034,17 @@ func TestSearchTask_Requery(t *testing.T) {
assert.Error(t, err)
})
t.Run("Test requery failed 1", func(t *testing.T) {
t.Run("Test requery failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewMockProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock err 1"))
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test requery failed 2", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewMockProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock err 2",
},
}, nil)
lb := NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
_ = workload.exec(ctx, 0, qn)
}).Return(fmt.Errorf("mock err 1"))
node.lbPolicy = lb
qt := &searchTask{
ctx: ctx,
@ -2032,88 +2054,8 @@ func TestSearchTask_Requery(t *testing.T) {
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test get pk field data failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewMockProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{},
}, nil)
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test incomplete query result", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewMockProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: pkField,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: ids[:len(ids)-1],
},
},
},
},
},
newFloatVectorFieldData(vecField, rows, dim),
},
}, nil)
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: resultIDs,
},
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
@ -2127,10 +2069,16 @@ func TestSearchTask_Requery(t *testing.T) {
t.Run("Test postExecute with requery failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewMockProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock err 1"))
lb := NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
_ = workload.exec(ctx, 0, qn)
}).Return(fmt.Errorf("mock err 1"))
node.lbPolicy = lb
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
@ -2147,7 +2095,9 @@ func TestSearchTask_Requery(t *testing.T) {
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: resultIDs,

View File

@ -77,7 +77,7 @@ func constructCollectionSchema(
collectionName string,
) *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
FieldID: 100,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
@ -87,7 +87,7 @@ func constructCollectionSchema(
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
FieldID: 101,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",

View File

@ -828,7 +828,7 @@ func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schema
}
if primaryFieldData == nil {
return nil, fmt.Errorf("can't find data for primary field %v", primaryFieldName)
return nil, fmt.Errorf("can't find data for primary field: %v", primaryFieldName)
}
return primaryFieldData, nil