mirror of https://github.com/milvus-io/milvus.git
Construct plan directly when search with vector output (#27928)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/27968/head
parent
13877a07ff
commit
f9c630247d
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/antlr/antlr4/runtime/Go/antlr"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
@ -158,3 +159,48 @@ func CreateSearchPlan(schemaPb *schemapb.CollectionSchema, exprStr string, vecto
|
|||
}
|
||||
return planNode, nil
|
||||
}
|
||||
|
||||
func CreateRequeryPlan(pkField *schemapb.FieldSchema, ids *schemapb.IDs) *planpb.PlanNode {
|
||||
var values []*planpb.GenericValue
|
||||
switch ids.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
values = lo.Map(ids.GetIntId().GetData(), func(id int64, _ int) *planpb.GenericValue {
|
||||
return &planpb.GenericValue{
|
||||
Val: &planpb.GenericValue_Int64Val{
|
||||
Int64Val: id,
|
||||
},
|
||||
}
|
||||
})
|
||||
case *schemapb.IDs_StrId:
|
||||
values = lo.Map(ids.GetStrId().GetData(), func(id string, _ int) *planpb.GenericValue {
|
||||
return &planpb.GenericValue{
|
||||
Val: &planpb.GenericValue_StringVal{
|
||||
StringVal: id,
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_Query{
|
||||
Query: &planpb.QueryPlanNode{
|
||||
Predicates: &planpb.Expr{
|
||||
Expr: &planpb.Expr_TermExpr{
|
||||
TermExpr: &planpb.TermExpr{
|
||||
ColumnInfo: &planpb.ColumnInfo{
|
||||
FieldId: pkField.GetFieldID(),
|
||||
DataType: pkField.GetDataType(),
|
||||
IsPrimaryKey: true,
|
||||
IsAutoID: pkField.GetAutoID(),
|
||||
IsPartitionKey: pkField.GetIsPartitionKey(),
|
||||
},
|
||||
Values: values,
|
||||
},
|
||||
},
|
||||
},
|
||||
IsCount: false,
|
||||
Limit: int64(len(values)),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2774,7 +2774,8 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
|
|||
}
|
||||
|
||||
// Query get the records by primary keys.
|
||||
func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
|
||||
func (node *Proxy) query(ctx context.Context, qt *queryTask) (*milvuspb.QueryResults, error) {
|
||||
request := qt.request
|
||||
receiveSize := proto.Size(request)
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
|
@ -2800,21 +2801,6 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
defer sp.End()
|
||||
tr := timerecord.NewTimeRecorder("Query")
|
||||
|
||||
qt := &queryTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
qc: node.queryCoord,
|
||||
lb: node.lbPolicy,
|
||||
}
|
||||
|
||||
method := "Query"
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
|
@ -2915,6 +2901,25 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
return qt.result, nil
|
||||
}
|
||||
|
||||
// Query get the records by primary keys.
|
||||
func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
|
||||
qt := &queryTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
qc: node.queryCoord,
|
||||
lb: node.lbPolicy,
|
||||
}
|
||||
return node.query(ctx, qt)
|
||||
}
|
||||
|
||||
// CreateAlias create alias for collection, then you can search the collection with alias.
|
||||
func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) {
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
|
|
|
@ -210,9 +210,12 @@ func (t *queryTask) createPlan(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
plan, err := planparserv2.CreateRetrievePlan(schema, t.request.Expr)
|
||||
if err != nil {
|
||||
return err
|
||||
var err error
|
||||
if t.plan == nil {
|
||||
t.plan, err = planparserv2.CreateRetrievePlan(schema, t.request.Expr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, schema, true)
|
||||
|
@ -226,8 +229,7 @@ func (t *queryTask) createPlan(ctx context.Context) error {
|
|||
}
|
||||
outputFieldIDs = append(outputFieldIDs, common.TimeStampField)
|
||||
t.RetrieveRequest.OutputFieldsId = outputFieldIDs
|
||||
plan.OutputFieldIds = outputFieldIDs
|
||||
t.plan = plan
|
||||
t.plan.OutputFieldIds = outputFieldIDs
|
||||
log.Ctx(ctx).Debug("translate output fields to field ids",
|
||||
zap.Int64s("OutputFieldsID", t.OutputFieldsId),
|
||||
zap.String("requestType", "query"))
|
||||
|
|
|
@ -573,7 +573,7 @@ func (t *searchTask) Requery() error {
|
|||
return err
|
||||
}
|
||||
ids := t.result.GetResults().GetIds()
|
||||
expr := IDs2Expr(pkField.GetName(), ids)
|
||||
plan := planparserv2.CreateRequeryPlan(pkField, ids)
|
||||
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
|
@ -581,13 +581,28 @@ func (t *searchTask) Requery() error {
|
|||
},
|
||||
DbName: t.request.GetDbName(),
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
Expr: expr,
|
||||
Expr: "",
|
||||
OutputFields: t.request.GetOutputFields(),
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
|
||||
QueryParams: t.request.GetSearchParams(),
|
||||
}
|
||||
queryResult, err := t.node.Query(t.ctx, queryReq)
|
||||
qt := &queryTask{
|
||||
ctx: t.ctx,
|
||||
Condition: NewTaskCondition(t.ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: queryReq,
|
||||
plan: plan,
|
||||
qc: t.node.(*Proxy).queryCoord,
|
||||
lb: t.node.(*Proxy).lbPolicy,
|
||||
}
|
||||
queryResult, err := t.node.(*Proxy).query(t.ctx, qt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
@ -1906,11 +1907,46 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
ids[i] = int64(i)
|
||||
}
|
||||
|
||||
factory := dependency.NewDefaultFactory(true)
|
||||
node, err := NewProxy(ctx, factory)
|
||||
assert.NoError(t, err)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
|
||||
assert.NoError(t, err)
|
||||
node.sched = scheduler
|
||||
err = node.sched.Start()
|
||||
assert.NoError(t, err)
|
||||
err = node.initRateCollector()
|
||||
assert.NoError(t, err)
|
||||
node.rootCoord = mocks.NewMockRootCoordClient(t)
|
||||
node.queryCoord = mocks.NewMockQueryCoordClient(t)
|
||||
|
||||
collectionName := "col"
|
||||
collectionID := UniqueID(0)
|
||||
cache := NewMockCache(t)
|
||||
cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe()
|
||||
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(constructCollectionSchema(pkField, vecField, dim, collection), nil).Maybe()
|
||||
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
|
||||
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionBasicInfo{}, nil).Maybe()
|
||||
cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe()
|
||||
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
||||
globalMetaCache = cache
|
||||
|
||||
t.Run("Test normal", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewMockProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&internalpb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
|
@ -1929,6 +1965,14 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
},
|
||||
}, nil)
|
||||
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
|
||||
err = workload.exec(ctx, 0, qn)
|
||||
assert.NoError(t, err)
|
||||
}).Return(nil)
|
||||
lb.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything).Return()
|
||||
node.lbPolicy = lb
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
|
@ -1937,7 +1981,7 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
outputFields := []string{vecField}
|
||||
outputFields := []string{pkField, vecField}
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
|
@ -1947,7 +1991,8 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
OutputFields: outputFields,
|
||||
CollectionName: collectionName,
|
||||
OutputFields: outputFields,
|
||||
},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
|
@ -1961,8 +2006,9 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
|
||||
err := qt.Requery()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, qt.result.Results.FieldsData, 1)
|
||||
assert.Equal(t, vecField, qt.result.Results.FieldsData[0].GetFieldName())
|
||||
assert.Len(t, qt.result.Results.FieldsData, 2)
|
||||
assert.Equal(t, pkField, qt.result.Results.FieldsData[0].GetFieldName())
|
||||
assert.Equal(t, vecField, qt.result.Results.FieldsData[1].GetFieldName())
|
||||
})
|
||||
|
||||
t.Run("Test no primary key", func(t *testing.T) {
|
||||
|
@ -1988,41 +2034,17 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test requery failed 1", func(t *testing.T) {
|
||||
t.Run("Test requery failed", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewMockProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(nil, fmt.Errorf("mock err 1"))
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test requery failed 2", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewMockProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mock err 2",
|
||||
},
|
||||
}, nil)
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
|
||||
_ = workload.exec(ctx, 0, qn)
|
||||
}).Return(fmt.Errorf("mock err 1"))
|
||||
node.lbPolicy = lb
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
|
@ -2032,88 +2054,8 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test get pk field data failed", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewMockProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
}, nil)
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test incomplete query result", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewMockProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: pkField,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: ids[:len(ids)-1],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
newFloatVectorFieldData(vecField, rows, dim),
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
|
@ -2127,10 +2069,16 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
|
||||
t.Run("Test postExecute with requery failed", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewMockProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(nil, fmt.Errorf("mock err 1"))
|
||||
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
|
||||
_ = workload.exec(ctx, 0, qn)
|
||||
}).Return(fmt.Errorf("mock err 1"))
|
||||
node.lbPolicy = lb
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
|
@ -2147,7 +2095,9 @@ func TestSearchTask_Requery(t *testing.T) {
|
|||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
|
|
|
@ -77,7 +77,7 @@ func constructCollectionSchema(
|
|||
collectionName string,
|
||||
) *schemapb.CollectionSchema {
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
FieldID: 100,
|
||||
Name: int64Field,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
|
@ -87,7 +87,7 @@ func constructCollectionSchema(
|
|||
AutoID: true,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
FieldID: 101,
|
||||
Name: floatVecField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
|
|
|
@ -828,7 +828,7 @@ func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schema
|
|||
}
|
||||
|
||||
if primaryFieldData == nil {
|
||||
return nil, fmt.Errorf("can't find data for primary field %v", primaryFieldName)
|
||||
return nil, fmt.Errorf("can't find data for primary field: %v", primaryFieldName)
|
||||
}
|
||||
|
||||
return primaryFieldData, nil
|
||||
|
|
Loading…
Reference in New Issue