mirror of https://github.com/milvus-io/milvus.git
486 lines
14 KiB
Go
486 lines
14 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
|
|
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
|
|
|
"github.com/milvus-io/milvus/internal/util/timerecord"
|
|
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
)
|
|
|
|
type queryTask struct {
|
|
Condition
|
|
*internalpb.RetrieveRequest
|
|
ctx context.Context
|
|
resultBuf chan []*internalpb.RetrieveResults
|
|
result *milvuspb.QueryResults
|
|
query *milvuspb.QueryRequest
|
|
chMgr channelsMgr
|
|
qc types.QueryCoord
|
|
ids *schemapb.IDs
|
|
collectionName string
|
|
collectionID UniqueID
|
|
}
|
|
|
|
func (qt *queryTask) PreExecute(ctx context.Context) error {
|
|
qt.Base.MsgType = commonpb.MsgType_Retrieve
|
|
qt.Base.SourceID = Params.ProxyCfg.ProxyID
|
|
|
|
collectionName := qt.query.CollectionName
|
|
|
|
if err := validateCollectionName(qt.query.CollectionName); err != nil {
|
|
log.Warn("Invalid collection name.", zap.String("collectionName", collectionName),
|
|
zap.Int64("requestID", qt.Base.MsgID), zap.String("requestType", "query"))
|
|
return err
|
|
}
|
|
log.Info("Validate collection name.", zap.Any("collectionName", collectionName),
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
|
|
info, err := globalMetaCache.GetCollectionInfo(ctx, collectionName)
|
|
if err != nil {
|
|
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
return err
|
|
}
|
|
qt.collectionName = info.schema.Name
|
|
log.Info("Get collection id by name.", zap.Any("collectionName", collectionName),
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
|
|
for _, tag := range qt.query.PartitionNames {
|
|
if err := validatePartitionTag(tag, false); err != nil {
|
|
log.Debug("Invalid partition name.", zap.Any("partitionName", tag),
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
return err
|
|
}
|
|
}
|
|
log.Info("Validate partition names.",
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
|
|
// check if collection was already loaded into query node
|
|
showResp, err := qt.qc.ShowCollections(qt.ctx, &querypb.ShowCollectionsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_ShowCollections,
|
|
MsgID: qt.Base.MsgID,
|
|
Timestamp: qt.Base.Timestamp,
|
|
SourceID: Params.ProxyCfg.ProxyID,
|
|
},
|
|
DbID: 0, // TODO(dragondriver)
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if showResp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(showResp.Status.Reason)
|
|
}
|
|
log.Debug("QueryCoord show collections",
|
|
zap.Any("collections", showResp.CollectionIDs),
|
|
zap.Any("collID", info.collID))
|
|
|
|
collectionLoaded := false
|
|
for _, collID := range showResp.CollectionIDs {
|
|
if info.collID == collID {
|
|
collectionLoaded = true
|
|
break
|
|
}
|
|
}
|
|
if !collectionLoaded {
|
|
return fmt.Errorf("collection %v was not loaded into memory", collectionName)
|
|
}
|
|
|
|
schema, _ := globalMetaCache.GetCollectionSchema(ctx, qt.query.CollectionName)
|
|
|
|
if qt.ids != nil {
|
|
pkField := ""
|
|
for _, field := range schema.Fields {
|
|
if field.IsPrimaryKey {
|
|
pkField = field.Name
|
|
}
|
|
}
|
|
qt.query.Expr = IDs2Expr(pkField, qt.ids.GetIntId().Data)
|
|
}
|
|
|
|
if qt.query.Expr == "" {
|
|
errMsg := "Query expression is empty"
|
|
return fmt.Errorf(errMsg)
|
|
}
|
|
|
|
plan, err := createExprPlan(schema, qt.query.Expr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
qt.query.OutputFields, err = translateOutputFields(qt.query.OutputFields, schema, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Debug("translate output fields", zap.Any("OutputFields", qt.query.OutputFields))
|
|
if len(qt.query.OutputFields) == 0 {
|
|
for _, field := range schema.Fields {
|
|
if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
|
|
qt.OutputFieldsId = append(qt.OutputFieldsId, field.FieldID)
|
|
}
|
|
}
|
|
} else {
|
|
addPrimaryKey := false
|
|
for _, reqField := range qt.query.OutputFields {
|
|
findField := false
|
|
for _, field := range schema.Fields {
|
|
if reqField == field.Name {
|
|
if field.IsPrimaryKey {
|
|
addPrimaryKey = true
|
|
}
|
|
findField = true
|
|
qt.OutputFieldsId = append(qt.OutputFieldsId, field.FieldID)
|
|
plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
|
|
} else {
|
|
if field.IsPrimaryKey && !addPrimaryKey {
|
|
qt.OutputFieldsId = append(qt.OutputFieldsId, field.FieldID)
|
|
plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
|
|
addPrimaryKey = true
|
|
}
|
|
}
|
|
}
|
|
if !findField {
|
|
errMsg := "Field " + reqField + " not exist"
|
|
return errors.New(errMsg)
|
|
}
|
|
}
|
|
}
|
|
log.Debug("translate output fields to field ids", zap.Any("OutputFieldsID", qt.OutputFieldsId))
|
|
|
|
qt.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
travelTimestamp := qt.query.TravelTimestamp
|
|
if travelTimestamp == 0 {
|
|
travelTimestamp = qt.BeginTs()
|
|
} else {
|
|
durationSeconds := tsoutil.CalculateDuration(qt.BeginTs(), travelTimestamp) / 1000
|
|
if durationSeconds > Params.CommonCfg.RetentionDuration {
|
|
duration := time.Second * time.Duration(durationSeconds)
|
|
return fmt.Errorf("only support to travel back to %s so far", duration.String())
|
|
}
|
|
}
|
|
guaranteeTimestamp := qt.query.GuaranteeTimestamp
|
|
if guaranteeTimestamp == 0 {
|
|
guaranteeTimestamp = qt.BeginTs()
|
|
}
|
|
qt.TravelTimestamp = travelTimestamp
|
|
qt.GuaranteeTimestamp = guaranteeTimestamp
|
|
deadline, ok := qt.TraceCtx().Deadline()
|
|
if ok {
|
|
qt.RetrieveRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
|
}
|
|
|
|
qt.ResultChannelID = Params.ProxyCfg.RetrieveResultChannelNames[0]
|
|
qt.DbID = 0 // todo(yukun)
|
|
|
|
qt.CollectionID = info.collID
|
|
qt.PartitionIDs = make([]UniqueID, 0)
|
|
|
|
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
|
|
if err != nil {
|
|
log.Debug("Failed to get partitions in collection.", zap.Any("collectionName", collectionName),
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
return err
|
|
}
|
|
log.Info("Get partitions in collection.", zap.Any("collectionName", collectionName),
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
|
|
partitionsRecord := make(map[UniqueID]bool)
|
|
for _, partitionName := range qt.query.PartitionNames {
|
|
pattern := fmt.Sprintf("^%s$", partitionName)
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
log.Debug("Failed to compile partition name regex expression.", zap.Any("partitionName", partitionName),
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
return errors.New("invalid partition names")
|
|
}
|
|
found := false
|
|
for name, pID := range partitionsMap {
|
|
if re.MatchString(name) {
|
|
if _, exist := partitionsRecord[pID]; !exist {
|
|
qt.PartitionIDs = append(qt.PartitionIDs, pID)
|
|
partitionsRecord[pID] = true
|
|
}
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
// FIXME(wxyu): undefined behavior
|
|
errMsg := fmt.Sprintf("PartitonName: %s not found", partitionName)
|
|
return errors.New(errMsg)
|
|
}
|
|
}
|
|
|
|
log.Info("Query PreExecute done.",
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
return nil
|
|
}
|
|
|
|
func (qt *queryTask) Execute(ctx context.Context) error {
|
|
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute query %d", qt.ID()))
|
|
defer tr.Elapse("done")
|
|
|
|
var tsMsg msgstream.TsMsg = &msgstream.RetrieveMsg{
|
|
RetrieveRequest: *qt.RetrieveRequest,
|
|
BaseMsg: msgstream.BaseMsg{
|
|
Ctx: ctx,
|
|
HashValues: []uint32{uint32(Params.ProxyCfg.ProxyID)},
|
|
BeginTimestamp: qt.Base.Timestamp,
|
|
EndTimestamp: qt.Base.Timestamp,
|
|
},
|
|
}
|
|
msgPack := msgstream.MsgPack{
|
|
BeginTs: qt.Base.Timestamp,
|
|
EndTs: qt.Base.Timestamp,
|
|
Msgs: make([]msgstream.TsMsg, 1),
|
|
}
|
|
msgPack.Msgs[0] = tsMsg
|
|
|
|
stream, err := qt.chMgr.getDQLStream(qt.CollectionID)
|
|
if err != nil {
|
|
err = qt.chMgr.createDQLStream(qt.CollectionID)
|
|
if err != nil {
|
|
qt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
|
qt.result.Status.Reason = err.Error()
|
|
return err
|
|
}
|
|
stream, err = qt.chMgr.getDQLStream(qt.CollectionID)
|
|
if err != nil {
|
|
qt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
|
qt.result.Status.Reason = err.Error()
|
|
return err
|
|
}
|
|
}
|
|
tr.Record("get used message stream")
|
|
|
|
err = stream.Produce(&msgPack)
|
|
if err != nil {
|
|
log.Debug("Failed to send retrieve request.",
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
}
|
|
log.Debug("proxy sent one retrieveMsg",
|
|
zap.Int64("collectionID", qt.CollectionID),
|
|
zap.Int64("msgID", tsMsg.ID()),
|
|
zap.Int("length of search msg", len(msgPack.Msgs)),
|
|
zap.Uint64("timeoutTs", qt.RetrieveRequest.TimeoutTimestamp))
|
|
tr.Record("send retrieve request to message stream")
|
|
|
|
log.Info("Query Execute done.",
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
return err
|
|
}
|
|
|
|
func (qt *queryTask) PostExecute(ctx context.Context) error {
|
|
tr := timerecord.NewTimeRecorder("queryTask PostExecute")
|
|
defer func() {
|
|
tr.Elapse("done")
|
|
}()
|
|
select {
|
|
case <-qt.TraceCtx().Done():
|
|
log.Debug("proxy", zap.Int64("Query: wait to finish failed, timeout!, taskID:", qt.ID()))
|
|
return fmt.Errorf("queryTask:wait to finish failed, timeout : %d", qt.ID())
|
|
case retrieveResults := <-qt.resultBuf:
|
|
filterRetrieveResults := make([]*internalpb.RetrieveResults, 0)
|
|
var reason string
|
|
for _, partialRetrieveResult := range retrieveResults {
|
|
if partialRetrieveResult.Status.ErrorCode == commonpb.ErrorCode_Success {
|
|
filterRetrieveResults = append(filterRetrieveResults, partialRetrieveResult)
|
|
} else {
|
|
reason += partialRetrieveResult.Status.Reason + "\n"
|
|
}
|
|
}
|
|
|
|
if len(filterRetrieveResults) == 0 {
|
|
qt.result = &milvuspb.QueryResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: reason,
|
|
},
|
|
CollectionName: qt.collectionName,
|
|
}
|
|
log.Debug("Query failed on all querynodes.",
|
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
return errors.New(reason)
|
|
}
|
|
|
|
var err error
|
|
qt.result, err = mergeRetrieveResults(filterRetrieveResults)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
qt.result.CollectionName = qt.collectionName
|
|
|
|
if len(qt.result.FieldsData) > 0 {
|
|
qt.result.Status = &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
}
|
|
} else {
|
|
log.Info("Query result is nil", zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
qt.result.Status = &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_EmptyCollection,
|
|
Reason: reason,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
schema, err := globalMetaCache.GetCollectionSchema(ctx, qt.query.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i := 0; i < len(qt.result.FieldsData); i++ {
|
|
for _, field := range schema.Fields {
|
|
if field.FieldID == qt.OutputFieldsId[i] {
|
|
qt.result.FieldsData[i].FieldName = field.Name
|
|
qt.result.FieldsData[i].FieldId = field.FieldID
|
|
qt.result.FieldsData[i].Type = field.DataType
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
log.Info("Query PostExecute done", zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
return nil
|
|
}
|
|
func (qt *queryTask) getChannels() ([]pChan, error) {
|
|
collID, err := globalMetaCache.GetCollectionID(qt.ctx, qt.query.CollectionName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var channels []pChan
|
|
channels, err = qt.chMgr.getChannels(collID)
|
|
if err != nil {
|
|
err := qt.chMgr.createDMLMsgStream(collID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return qt.chMgr.getChannels(collID)
|
|
}
|
|
|
|
return channels, nil
|
|
}
|
|
|
|
func (qt *queryTask) getVChannels() ([]vChan, error) {
|
|
collID, err := globalMetaCache.GetCollectionID(qt.ctx, qt.query.CollectionName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var channels []vChan
|
|
channels, err = qt.chMgr.getVChannels(collID)
|
|
if err != nil {
|
|
err := qt.chMgr.createDMLMsgStream(collID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return qt.chMgr.getVChannels(collID)
|
|
}
|
|
|
|
return channels, nil
|
|
}
|
|
|
|
// IDs2Expr converts ids slices to bool expresion with specified field name
|
|
func IDs2Expr(fieldName string, ids []int64) string {
|
|
idsStr := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids)), ", "), "[]")
|
|
return fieldName + " in [ " + idsStr + " ]"
|
|
}
|
|
|
|
func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) {
|
|
var ret *milvuspb.QueryResults
|
|
var skipDupCnt int64
|
|
var idSet = make(map[int64]struct{})
|
|
|
|
// merge results and remove duplicates
|
|
for _, rr := range retrieveResults {
|
|
// skip empty result, it will break merge result
|
|
if rr == nil || rr.Ids == nil || rr.Ids.GetIntId() == nil || len(rr.Ids.GetIntId().Data) == 0 {
|
|
continue
|
|
}
|
|
|
|
if ret == nil {
|
|
ret = &milvuspb.QueryResults{
|
|
FieldsData: make([]*schemapb.FieldData, len(rr.FieldsData)),
|
|
}
|
|
}
|
|
|
|
if len(ret.FieldsData) != len(rr.FieldsData) {
|
|
return nil, fmt.Errorf("mismatch FieldData in proxy RetrieveResults, expect %d get %d", len(ret.FieldsData), len(rr.FieldsData))
|
|
}
|
|
|
|
for i, id := range rr.Ids.GetIntId().GetData() {
|
|
if _, ok := idSet[id]; !ok {
|
|
typeutil.AppendFieldData(ret.FieldsData, rr.FieldsData, int64(i))
|
|
idSet[id] = struct{}{}
|
|
} else {
|
|
// primary keys duplicate
|
|
skipDupCnt++
|
|
}
|
|
}
|
|
}
|
|
log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
|
|
|
|
if ret == nil {
|
|
ret = &milvuspb.QueryResults{
|
|
FieldsData: []*schemapb.FieldData{},
|
|
}
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (qt *queryTask) TraceCtx() context.Context {
|
|
return qt.ctx
|
|
}
|
|
|
|
func (qt *queryTask) ID() UniqueID {
|
|
return qt.Base.MsgID
|
|
}
|
|
|
|
func (qt *queryTask) SetID(uid UniqueID) {
|
|
qt.Base.MsgID = uid
|
|
}
|
|
|
|
func (qt *queryTask) Name() string {
|
|
return RetrieveTaskName
|
|
}
|
|
|
|
func (qt *queryTask) Type() commonpb.MsgType {
|
|
return qt.Base.MsgType
|
|
}
|
|
|
|
func (qt *queryTask) BeginTs() Timestamp {
|
|
return qt.Base.Timestamp
|
|
}
|
|
|
|
func (qt *queryTask) EndTs() Timestamp {
|
|
return qt.Base.Timestamp
|
|
}
|
|
|
|
func (qt *queryTask) SetTs(ts Timestamp) {
|
|
qt.Base.Timestamp = ts
|
|
}
|
|
|
|
func (qt *queryTask) OnEnqueue() error {
|
|
qt.Base.MsgType = commonpb.MsgType_Retrieve
|
|
return nil
|
|
}
|