mirror of https://github.com/milvus-io/milvus.git
Add retrieve taskscheduler implementation (#5353)
See also: #5257 Signed-off-by: fishpenguin <kun.yu@zilliz.com>pull/5388/head
parent
bfc057d56d
commit
49443e8a33
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
|
@ -1170,7 +1171,62 @@ func (node *ProxyNode) Search(ctx context.Context, request *milvuspb.SearchReque
|
|||
}
|
||||
|
||||
func (node *ProxyNode) Retrieve(ctx context.Context, request *milvuspb.RetrieveRequest) (*milvuspb.RetrieveResults, error) {
|
||||
return nil, nil
|
||||
rt := &RetrieveTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
SourceID: Params.ProxyID,
|
||||
},
|
||||
ResultChannelID: strconv.FormatInt(Params.ProxyID, 10),
|
||||
},
|
||||
queryMsgStream: node.queryMsgStream,
|
||||
resultBuf: make(chan []*internalpb.RetrieveResults),
|
||||
retrieve: request,
|
||||
}
|
||||
|
||||
err := node.sched.DqQueue.Enqueue(rt)
|
||||
if err != nil {
|
||||
return &milvuspb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("Retrieve",
|
||||
zap.String("role", Params.RoleName),
|
||||
zap.Int64("msgID", rt.Base.MsgID),
|
||||
zap.Uint64("timestamp", rt.Base.Timestamp),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("partitions", request.PartitionNames),
|
||||
zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
|
||||
defer func() {
|
||||
log.Debug("Retrieve Done",
|
||||
zap.Error(err),
|
||||
zap.String("role", Params.RoleName),
|
||||
zap.Int64("msgID", rt.Base.MsgID),
|
||||
zap.Uint64("timestamp", rt.Base.Timestamp),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("partitions", request.PartitionNames),
|
||||
zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
|
||||
}()
|
||||
|
||||
err = rt.WaitToFinish()
|
||||
if err != nil {
|
||||
return &milvuspb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return rt.result, nil
|
||||
}
|
||||
|
||||
func (node *ProxyNode) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*commonpb.Status, error) {
|
||||
|
|
|
@ -401,13 +401,17 @@ func (sched *TaskScheduler) queryResultLoop() {
|
|||
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
|
||||
queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName)
|
||||
log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames))
|
||||
queryResultMsgStream.AsConsumer(Params.RetrieveResultChannelNames, Params.ProxySubName)
|
||||
log.Debug("proxynode", zap.Strings("Retrieve result channel names", Params.RetrieveResultChannelNames))
|
||||
log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName))
|
||||
|
||||
queryNodeNum := Params.QueryNodeNum
|
||||
|
||||
queryResultMsgStream.Start()
|
||||
defer queryResultMsgStream.Close()
|
||||
|
||||
queryResultBuf := make(map[UniqueID][]*internalpb.SearchResults)
|
||||
retrieveResultBuf := make(map[UniqueID][]*internalpb.RetrieveResults)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
@ -422,41 +426,75 @@ func (sched *TaskScheduler) queryResultLoop() {
|
|||
for _, tsMsg := range msgPack.Msgs {
|
||||
sp, ctx := trace.StartSpanFromContext(tsMsg.TraceCtx())
|
||||
tsMsg.SetTraceCtx(ctx)
|
||||
searchResultMsg, _ := tsMsg.(*msgstream.SearchResultMsg)
|
||||
reqID := searchResultMsg.Base.MsgID
|
||||
reqIDStr := strconv.FormatInt(reqID, 10)
|
||||
t := sched.getTaskByReqID(reqID)
|
||||
if t == nil {
|
||||
log.Debug("proxynode", zap.String("QueryResult GetTaskByReqID failed, reqID = ", reqIDStr))
|
||||
delete(queryResultBuf, reqID)
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok = queryResultBuf[reqID]
|
||||
if !ok {
|
||||
queryResultBuf[reqID] = make([]*internalpb.SearchResults, 0)
|
||||
}
|
||||
queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResults)
|
||||
|
||||
//t := sched.getTaskByReqID(reqID)
|
||||
{
|
||||
colName := t.(*SearchTask).query.CollectionName
|
||||
log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(queryResultBuf[reqID])))
|
||||
}
|
||||
if len(queryResultBuf[reqID]) == queryNodeNum {
|
||||
if searchResultMsg, srOk := tsMsg.(*msgstream.SearchResultMsg); srOk {
|
||||
reqID := searchResultMsg.Base.MsgID
|
||||
reqIDStr := strconv.FormatInt(reqID, 10)
|
||||
t := sched.getTaskByReqID(reqID)
|
||||
if t != nil {
|
||||
qt, ok := t.(*SearchTask)
|
||||
if ok {
|
||||
qt.resultBuf <- queryResultBuf[reqID]
|
||||
delete(queryResultBuf, reqID)
|
||||
}
|
||||
} else {
|
||||
|
||||
// log.Printf("task with reqID %v is nil", reqID)
|
||||
if t == nil {
|
||||
log.Debug("proxynode", zap.String("QueryResult GetTaskByReqID failed, reqID = ", reqIDStr))
|
||||
delete(queryResultBuf, reqID)
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok = queryResultBuf[reqID]
|
||||
if !ok {
|
||||
queryResultBuf[reqID] = make([]*internalpb.SearchResults, 0)
|
||||
}
|
||||
queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResults)
|
||||
|
||||
//t := sched.getTaskByReqID(reqID)
|
||||
{
|
||||
colName := t.(*SearchTask).query.CollectionName
|
||||
log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(queryResultBuf[reqID])))
|
||||
}
|
||||
if len(queryResultBuf[reqID]) == queryNodeNum {
|
||||
t := sched.getTaskByReqID(reqID)
|
||||
if t != nil {
|
||||
qt, ok := t.(*SearchTask)
|
||||
if ok {
|
||||
qt.resultBuf <- queryResultBuf[reqID]
|
||||
delete(queryResultBuf, reqID)
|
||||
}
|
||||
} else {
|
||||
|
||||
// log.Printf("task with reqID %v is nil", reqID)
|
||||
}
|
||||
}
|
||||
sp.Finish()
|
||||
}
|
||||
if retrieveResultMsg, rtOk := tsMsg.(*msgstream.RetrieveResultMsg); rtOk {
|
||||
reqID := retrieveResultMsg.Base.MsgID
|
||||
reqIDStr := strconv.FormatInt(reqID, 10)
|
||||
t := sched.getTaskByReqID(reqID)
|
||||
if t == nil {
|
||||
log.Debug("proxynode", zap.String("RetrieveResult GetTaskByReqID failed, reqID = ", reqIDStr))
|
||||
delete(retrieveResultBuf, reqID)
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok = retrieveResultBuf[reqID]
|
||||
if !ok {
|
||||
retrieveResultBuf[reqID] = make([]*internalpb.RetrieveResults, 0)
|
||||
}
|
||||
retrieveResultBuf[reqID] = append(retrieveResultBuf[reqID], &retrieveResultMsg.RetrieveResults)
|
||||
|
||||
{
|
||||
colName := t.(*RetrieveTask).retrieve.CollectionName
|
||||
log.Debug("Getcollection", zap.String("collection name", colName), zap.String("reqID", reqIDStr), zap.Int("answer cnt", len(retrieveResultBuf[reqID])))
|
||||
}
|
||||
if len(retrieveResultBuf[reqID]) == queryNodeNum {
|
||||
t := sched.getTaskByReqID(reqID)
|
||||
if t != nil {
|
||||
rt, ok := t.(*RetrieveTask)
|
||||
if ok {
|
||||
rt.resultBuf <- retrieveResultBuf[reqID]
|
||||
delete(retrieveResultBuf, reqID)
|
||||
}
|
||||
} else {
|
||||
}
|
||||
}
|
||||
sp.Finish()
|
||||
}
|
||||
sp.Finish()
|
||||
}
|
||||
case <-sched.ctx.Done():
|
||||
log.Debug("proxynode server is closed ...")
|
||||
|
|
Loading…
Reference in New Issue