mirror of https://github.com/milvus-io/milvus.git
Use asynchronous functions of load and release
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/4973/head^2
parent
bea21a823c
commit
39458697c7
|
@ -0,0 +1,388 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
queryPb "github.com/zilliztech/milvus-distributed/internal/proto/querypb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func (node *QueryNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
stats := &internalpb.ComponentStates{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
}
|
||||
code, ok := node.stateCode.Load().(internalpb.StateCode)
|
||||
if !ok {
|
||||
errMsg := "unexpected error in type assertion"
|
||||
stats.Status = &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
return stats, errors.New(errMsg)
|
||||
}
|
||||
info := &internalpb.ComponentInfo{
|
||||
NodeID: Params.QueryNodeID,
|
||||
Role: typeutil.QueryNodeRole,
|
||||
StateCode: code,
|
||||
}
|
||||
stats.State = info
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: Params.QueryTimeTickChannelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: Params.StatsChannelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQueryChannelRequest) (*commonpb.Status, error) {
|
||||
if node.searchService == nil || node.searchService.searchMsgStream == nil {
|
||||
errMsg := "null search service or null search message stream"
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
|
||||
// add request channel
|
||||
consumeChannels := []string{in.RequestChannelID}
|
||||
consumeSubName := Params.MsgChannelSubName
|
||||
node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
|
||||
// add result channel
|
||||
producerChannels := []string{in.ResultChannelID}
|
||||
node.searchService.searchResultMsgStream.AsProducer(producerChannels)
|
||||
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) RemoveQueryChannel(ctx context.Context, in *queryPb.RemoveQueryChannelRequest) (*commonpb.Status, error) {
|
||||
// if node.searchService == nil || node.searchService.searchMsgStream == nil {
|
||||
// errMsg := "null search service or null search result message stream"
|
||||
// status := &commonpb.Status{
|
||||
// ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
// Reason: errMsg,
|
||||
// }
|
||||
|
||||
// return status, errors.New(errMsg)
|
||||
// }
|
||||
|
||||
// searchStream, ok := node.searchService.searchMsgStream.(*pulsarms.PulsarMsgStream)
|
||||
// if !ok {
|
||||
// errMsg := "type assertion failed for search message stream"
|
||||
// status := &commonpb.Status{
|
||||
// ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
// Reason: errMsg,
|
||||
// }
|
||||
|
||||
// return status, errors.New(errMsg)
|
||||
// }
|
||||
|
||||
// resultStream, ok := node.searchService.searchResultMsgStream.(*pulsarms.PulsarMsgStream)
|
||||
// if !ok {
|
||||
// errMsg := "type assertion failed for search result message stream"
|
||||
// status := &commonpb.Status{
|
||||
// ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
// Reason: errMsg,
|
||||
// }
|
||||
|
||||
// return status, errors.New(errMsg)
|
||||
// }
|
||||
|
||||
// // remove request channel
|
||||
// consumeChannels := []string{in.RequestChannelID}
|
||||
// consumeSubName := Params.MsgChannelSubName
|
||||
// // TODO: searchStream.RemovePulsarConsumers(producerChannels)
|
||||
// searchStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
|
||||
// // remove result channel
|
||||
// producerChannels := []string{in.ResultChannelID}
|
||||
// // TODO: resultStream.RemovePulsarProducer(producerChannels)
|
||||
// resultStream.AsProducer(producerChannels)
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmChannelsRequest) (*commonpb.Status, error) {
|
||||
log.Debug("starting WatchDmChannels ...", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs)))
|
||||
collectionID := in.CollectionID
|
||||
ds, err := node.getDataSyncService(collectionID)
|
||||
if err != nil || ds.dmStream == nil {
|
||||
errMsg := "null data sync service or null data manipulation stream, collectionID = " + fmt.Sprintln(collectionID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
|
||||
switch t := ds.dmStream.(type) {
|
||||
case *msgstream.MqTtMsgStream:
|
||||
default:
|
||||
_ = t
|
||||
errMsg := "type assertion failed for dm message stream"
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
|
||||
getUniqueSubName := func() string {
|
||||
prefixName := Params.MsgChannelSubName
|
||||
return prefixName + "-" + strconv.FormatInt(collectionID, 10)
|
||||
}
|
||||
|
||||
// add request channel
|
||||
consumeChannels := in.ChannelIDs
|
||||
toSeekInfo := make([]*internalpb.MsgPosition, 0)
|
||||
toDirSubChannels := make([]string, 0)
|
||||
|
||||
consumeSubName := getUniqueSubName()
|
||||
|
||||
for _, info := range in.Infos {
|
||||
if len(info.Pos.MsgID) == 0 {
|
||||
toDirSubChannels = append(toDirSubChannels, info.ChannelID)
|
||||
continue
|
||||
}
|
||||
info.Pos.MsgGroup = consumeSubName
|
||||
toSeekInfo = append(toSeekInfo, info.Pos)
|
||||
|
||||
log.Debug("prevent inserting segments", zap.String("segmentIDs", fmt.Sprintln(info.ExcludedSegments)))
|
||||
err := node.replica.addExcludedSegments(collectionID, info.ExcludedSegments)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
}
|
||||
|
||||
ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName)
|
||||
for _, pos := range toSeekInfo {
|
||||
err := ds.dmStream.Seek(pos)
|
||||
if err != nil {
|
||||
errMsg := "msgStream seek error :" + err.Error()
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
}
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
log.Debug("WatchDmChannels done", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs)))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegmentsRequest) (*commonpb.Status, error) {
|
||||
dct := &loadSegmentsTask{
|
||||
baseTask: baseTask{
|
||||
ctx: ctx,
|
||||
done: make(chan error),
|
||||
},
|
||||
req: in,
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := node.scheduler.queue.Enqueue(dct)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("loadSegmentsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("loadSegmentsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
dct := &releaseCollectionTask{
|
||||
baseTask: baseTask{
|
||||
ctx: ctx,
|
||||
done: make(chan error),
|
||||
},
|
||||
req: in,
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := node.scheduler.queue.Enqueue(dct)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("releaseCollectionTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
dct := &releasePartitionsTask{
|
||||
baseTask: baseTask{
|
||||
ctx: ctx,
|
||||
done: make(chan error),
|
||||
},
|
||||
req: in,
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := node.scheduler.queue.Enqueue(dct)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("releasePartitionsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// deprecated
|
||||
func (node *QueryNode) ReleaseSegments(ctx context.Context, in *queryPb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
for _, id := range in.SegmentIDs {
|
||||
err2 := node.loadService.segLoader.replica.removeSegment(id)
|
||||
if err2 != nil {
|
||||
// not return, try to release all segments
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err2.Error()
|
||||
}
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmentInfoRequest) (*queryPb.GetSegmentInfoResponse, error) {
|
||||
infos := make([]*queryPb.SegmentInfo, 0)
|
||||
for _, id := range in.SegmentIDs {
|
||||
segment, err := node.replica.getSegmentByID(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var indexName string
|
||||
var indexID int64
|
||||
// TODO:: segment has multi vec column
|
||||
if len(segment.indexInfos) > 0 {
|
||||
for fieldID := range segment.indexInfos {
|
||||
indexName = segment.getIndexName(fieldID)
|
||||
indexID = segment.getIndexID(fieldID)
|
||||
break
|
||||
}
|
||||
}
|
||||
info := &queryPb.SegmentInfo{
|
||||
SegmentID: segment.ID(),
|
||||
CollectionID: segment.collectionID,
|
||||
PartitionID: segment.partitionID,
|
||||
MemSize: segment.getMemSize(),
|
||||
NumRows: segment.getRowCount(),
|
||||
IndexName: indexName,
|
||||
IndexID: indexID,
|
||||
}
|
||||
infos = append(infos, info)
|
||||
}
|
||||
return &queryPb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Infos: infos,
|
||||
}, nil
|
||||
}
|
|
@ -17,8 +17,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -29,10 +27,8 @@ import (
|
|||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
queryPb "github.com/zilliztech/milvus-distributed/internal/proto/querypb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/types"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type QueryNode struct {
|
||||
|
@ -59,6 +55,7 @@ type QueryNode struct {
|
|||
dataService types.DataService
|
||||
|
||||
msFactory msgstream.Factory
|
||||
scheduler *taskScheduler
|
||||
}
|
||||
|
||||
func NewQueryNode(ctx context.Context, queryNodeID UniqueID, factory msgstream.Factory) *QueryNode {
|
||||
|
@ -77,6 +74,7 @@ func NewQueryNode(ctx context.Context, queryNodeID UniqueID, factory msgstream.F
|
|||
msFactory: factory,
|
||||
}
|
||||
|
||||
node.scheduler = newTaskScheduler(ctx1)
|
||||
node.replica = newCollectionReplica()
|
||||
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
return node
|
||||
|
@ -96,6 +94,7 @@ func NewQueryNodeWithoutID(ctx context.Context, factory msgstream.Factory) *Quer
|
|||
msFactory: factory,
|
||||
}
|
||||
|
||||
node.scheduler = newTaskScheduler(ctx1)
|
||||
node.replica = newCollectionReplica()
|
||||
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
|
||||
|
@ -167,14 +166,14 @@ func (node *QueryNode) Start() error {
|
|||
|
||||
// init services and manager
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, node.msFactory)
|
||||
//node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
node.loadService = newLoadService(node.queryNodeLoopCtx, node.masterService, node.dataService, node.indexService, node.replica)
|
||||
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica, node.loadService.segLoader.indexLoader.fieldStatsChan, node.msFactory)
|
||||
|
||||
// start task scheduler
|
||||
go node.scheduler.Start()
|
||||
|
||||
// start services
|
||||
go node.searchService.start()
|
||||
//go node.metaService.start()
|
||||
go node.loadService.start()
|
||||
go node.statsService.start()
|
||||
node.UpdateStateCode(internalpb.StateCode_Healthy)
|
||||
|
@ -267,366 +266,3 @@ func (node *QueryNode) removeDataSyncService(collectionID UniqueID) {
|
|||
defer node.dsServicesMu.Unlock()
|
||||
delete(node.dataSyncServices, collectionID)
|
||||
}
|
||||
|
||||
func (node *QueryNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
stats := &internalpb.ComponentStates{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
}
|
||||
code, ok := node.stateCode.Load().(internalpb.StateCode)
|
||||
if !ok {
|
||||
errMsg := "unexpected error in type assertion"
|
||||
stats.Status = &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
return stats, errors.New(errMsg)
|
||||
}
|
||||
info := &internalpb.ComponentInfo{
|
||||
NodeID: Params.QueryNodeID,
|
||||
Role: typeutil.QueryNodeRole,
|
||||
StateCode: code,
|
||||
}
|
||||
stats.State = info
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: Params.QueryTimeTickChannelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: Params.StatsChannelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQueryChannelRequest) (*commonpb.Status, error) {
|
||||
if node.searchService == nil || node.searchService.searchMsgStream == nil {
|
||||
errMsg := "null search service or null search message stream"
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
|
||||
// add request channel
|
||||
consumeChannels := []string{in.RequestChannelID}
|
||||
consumeSubName := Params.MsgChannelSubName
|
||||
node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
|
||||
// add result channel
|
||||
producerChannels := []string{in.ResultChannelID}
|
||||
node.searchService.searchResultMsgStream.AsProducer(producerChannels)
|
||||
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) RemoveQueryChannel(ctx context.Context, in *queryPb.RemoveQueryChannelRequest) (*commonpb.Status, error) {
|
||||
// if node.searchService == nil || node.searchService.searchMsgStream == nil {
|
||||
// errMsg := "null search service or null search result message stream"
|
||||
// status := &commonpb.Status{
|
||||
// ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
// Reason: errMsg,
|
||||
// }
|
||||
|
||||
// return status, errors.New(errMsg)
|
||||
// }
|
||||
|
||||
// searchStream, ok := node.searchService.searchMsgStream.(*pulsarms.PulsarMsgStream)
|
||||
// if !ok {
|
||||
// errMsg := "type assertion failed for search message stream"
|
||||
// status := &commonpb.Status{
|
||||
// ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
// Reason: errMsg,
|
||||
// }
|
||||
|
||||
// return status, errors.New(errMsg)
|
||||
// }
|
||||
|
||||
// resultStream, ok := node.searchService.searchResultMsgStream.(*pulsarms.PulsarMsgStream)
|
||||
// if !ok {
|
||||
// errMsg := "type assertion failed for search result message stream"
|
||||
// status := &commonpb.Status{
|
||||
// ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
// Reason: errMsg,
|
||||
// }
|
||||
|
||||
// return status, errors.New(errMsg)
|
||||
// }
|
||||
|
||||
// // remove request channel
|
||||
// consumeChannels := []string{in.RequestChannelID}
|
||||
// consumeSubName := Params.MsgChannelSubName
|
||||
// // TODO: searchStream.RemovePulsarConsumers(producerChannels)
|
||||
// searchStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
|
||||
// // remove result channel
|
||||
// producerChannels := []string{in.ResultChannelID}
|
||||
// // TODO: resultStream.RemovePulsarProducer(producerChannels)
|
||||
// resultStream.AsProducer(producerChannels)
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmChannelsRequest) (*commonpb.Status, error) {
|
||||
log.Debug("starting WatchDmChannels ...", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs)))
|
||||
collectionID := in.CollectionID
|
||||
ds, err := node.getDataSyncService(collectionID)
|
||||
if err != nil || ds.dmStream == nil {
|
||||
errMsg := "null data sync service or null data manipulation stream, collectionID = " + fmt.Sprintln(collectionID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
|
||||
switch t := ds.dmStream.(type) {
|
||||
case *msgstream.MqTtMsgStream:
|
||||
default:
|
||||
_ = t
|
||||
errMsg := "type assertion failed for dm message stream"
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
|
||||
getUniqueSubName := func() string {
|
||||
prefixName := Params.MsgChannelSubName
|
||||
return prefixName + "-" + strconv.FormatInt(collectionID, 10)
|
||||
}
|
||||
|
||||
// add request channel
|
||||
consumeChannels := in.ChannelIDs
|
||||
toSeekInfo := make([]*internalpb.MsgPosition, 0)
|
||||
toDirSubChannels := make([]string, 0)
|
||||
|
||||
consumeSubName := getUniqueSubName()
|
||||
|
||||
for _, info := range in.Infos {
|
||||
if len(info.Pos.MsgID) == 0 {
|
||||
toDirSubChannels = append(toDirSubChannels, info.ChannelID)
|
||||
continue
|
||||
}
|
||||
info.Pos.MsgGroup = consumeSubName
|
||||
toSeekInfo = append(toSeekInfo, info.Pos)
|
||||
|
||||
log.Debug("prevent inserting segments", zap.String("segmentIDs", fmt.Sprintln(info.ExcludedSegments)))
|
||||
err := node.replica.addExcludedSegments(collectionID, info.ExcludedSegments)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
}
|
||||
|
||||
ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName)
|
||||
for _, pos := range toSeekInfo {
|
||||
err := ds.dmStream.Seek(pos)
|
||||
if err != nil {
|
||||
errMsg := "msgStream seek error :" + err.Error()
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
log.Error(errMsg)
|
||||
return status, errors.New(errMsg)
|
||||
}
|
||||
}
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
log.Debug("WatchDmChannels done", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs)))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegmentsRequest) (*commonpb.Status, error) {
|
||||
// TODO: support db
|
||||
collectionID := in.CollectionID
|
||||
partitionID := in.PartitionID
|
||||
segmentIDs := in.SegmentIDs
|
||||
fieldIDs := in.FieldIDs
|
||||
schema := in.Schema
|
||||
|
||||
log.Debug("query node load segment", zap.String("loadSegmentRequest", fmt.Sprintln(in)))
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
hasCollection := node.replica.hasCollection(collectionID)
|
||||
hasPartition := node.replica.hasPartition(partitionID)
|
||||
if !hasCollection {
|
||||
// loading init
|
||||
err := node.replica.addCollection(collectionID, schema)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
node.replica.initExcludedSegments(collectionID)
|
||||
newDS := newDataSyncService(node.queryNodeLoopCtx, node.replica, node.msFactory, collectionID)
|
||||
// ignore duplicated dataSyncService error
|
||||
node.addDataSyncService(collectionID, newDS)
|
||||
ds, err := node.getDataSyncService(collectionID)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
go ds.start()
|
||||
node.searchService.startSearchCollection(collectionID)
|
||||
}
|
||||
if !hasPartition {
|
||||
err := node.replica.addPartition(collectionID, partitionID)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
}
|
||||
err := node.replica.enablePartition(partitionID)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
|
||||
if len(segmentIDs) == 0 {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
err = node.loadService.loadSegmentPassively(collectionID, partitionID, segmentIDs, fieldIDs)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
return status, err
|
||||
}
|
||||
|
||||
log.Debug("LoadSegments done", zap.String("segmentIDs", fmt.Sprintln(in.SegmentIDs)))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
ds, err := node.getDataSyncService(in.CollectionID)
|
||||
if err == nil && ds != nil {
|
||||
ds.close()
|
||||
node.removeDataSyncService(in.CollectionID)
|
||||
node.replica.removeTSafe(in.CollectionID)
|
||||
node.replica.removeExcludedSegments(in.CollectionID)
|
||||
}
|
||||
|
||||
if node.searchService.hasSearchCollection(in.CollectionID) {
|
||||
node.searchService.stopSearchCollection(in.CollectionID)
|
||||
}
|
||||
|
||||
err = node.replica.removeCollection(in.CollectionID)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
|
||||
log.Debug("ReleaseCollection done", zap.Int64("collectionID", in.CollectionID))
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
for _, id := range in.PartitionIDs {
|
||||
err := node.loadService.segLoader.replica.removePartition(id)
|
||||
if err != nil {
|
||||
// not return, try to release all partitions
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err.Error()
|
||||
}
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) ReleaseSegments(ctx context.Context, in *queryPb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
for _, id := range in.SegmentIDs {
|
||||
err2 := node.loadService.segLoader.replica.removeSegment(id)
|
||||
if err2 != nil {
|
||||
// not return, try to release all segments
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = err2.Error()
|
||||
}
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmentInfoRequest) (*queryPb.GetSegmentInfoResponse, error) {
|
||||
infos := make([]*queryPb.SegmentInfo, 0)
|
||||
for _, id := range in.SegmentIDs {
|
||||
segment, err := node.replica.getSegmentByID(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var indexName string
|
||||
var indexID int64
|
||||
// TODO:: segment has multi vec column
|
||||
if len(segment.indexInfos) > 0 {
|
||||
for fieldID := range segment.indexInfos {
|
||||
indexName = segment.getIndexName(fieldID)
|
||||
indexID = segment.getIndexID(fieldID)
|
||||
break
|
||||
}
|
||||
}
|
||||
info := &queryPb.SegmentInfo{
|
||||
SegmentID: segment.ID(),
|
||||
CollectionID: segment.collectionID,
|
||||
PartitionID: segment.partitionID,
|
||||
MemSize: segment.getMemSize(),
|
||||
NumRows: segment.getRowCount(),
|
||||
IndexName: indexName,
|
||||
IndexID: indexID,
|
||||
}
|
||||
infos = append(infos, info)
|
||||
}
|
||||
return &queryPb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Infos: infos,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
queryPb "github.com/zilliztech/milvus-distributed/internal/proto/querypb"
|
||||
)
|
||||
|
||||
type task interface {
|
||||
ID() UniqueID // return ReqID
|
||||
SetID(uid UniqueID) // set ReqID
|
||||
Timestamp() Timestamp
|
||||
PreExecute(ctx context.Context) error
|
||||
Execute(ctx context.Context) error
|
||||
PostExecute(ctx context.Context) error
|
||||
WaitToFinish() error
|
||||
Notify(err error)
|
||||
OnEnqueue() error
|
||||
}
|
||||
|
||||
type baseTask struct {
|
||||
done chan error
|
||||
ctx context.Context
|
||||
id UniqueID
|
||||
}
|
||||
|
||||
type loadSegmentsTask struct {
|
||||
baseTask
|
||||
req *queryPb.LoadSegmentsRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
type releaseCollectionTask struct {
|
||||
baseTask
|
||||
req *queryPb.ReleaseCollectionRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
type releasePartitionsTask struct {
|
||||
baseTask
|
||||
req *queryPb.ReleasePartitionsRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
func (b *baseTask) ID() UniqueID {
|
||||
return b.id
|
||||
}
|
||||
|
||||
func (b *baseTask) SetID(uid UniqueID) {
|
||||
b.id = uid
|
||||
}
|
||||
|
||||
func (b *baseTask) WaitToFinish() error {
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
return errors.New("task timeout")
|
||||
case err := <-b.done:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseTask) Notify(err error) {
|
||||
b.done <- err
|
||||
}
|
||||
|
||||
// loadSegmentsTask
|
||||
func (l *loadSegmentsTask) Timestamp() Timestamp {
|
||||
return l.req.Base.Timestamp
|
||||
}
|
||||
|
||||
func (l *loadSegmentsTask) OnEnqueue() error {
|
||||
if l.req == nil || l.req.Base == nil {
|
||||
l.SetID(rand.Int63n(100000000000))
|
||||
} else {
|
||||
l.SetID(l.req.Base.MsgID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *loadSegmentsTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *loadSegmentsTask) Execute(ctx context.Context) error {
|
||||
// TODO: support db
|
||||
collectionID := l.req.CollectionID
|
||||
partitionID := l.req.PartitionID
|
||||
segmentIDs := l.req.SegmentIDs
|
||||
fieldIDs := l.req.FieldIDs
|
||||
schema := l.req.Schema
|
||||
|
||||
log.Debug("query node load segment", zap.String("loadSegmentRequest", fmt.Sprintln(l.req)))
|
||||
|
||||
hasCollection := l.node.replica.hasCollection(collectionID)
|
||||
hasPartition := l.node.replica.hasPartition(partitionID)
|
||||
if !hasCollection {
|
||||
// loading init
|
||||
err := l.node.replica.addCollection(collectionID, schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.node.replica.initExcludedSegments(collectionID)
|
||||
newDS := newDataSyncService(l.node.queryNodeLoopCtx, l.node.replica, l.node.msFactory, collectionID)
|
||||
// ignore duplicated dataSyncService error
|
||||
_ = l.node.addDataSyncService(collectionID, newDS)
|
||||
ds, err := l.node.getDataSyncService(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go ds.start()
|
||||
l.node.searchService.startSearchCollection(collectionID)
|
||||
}
|
||||
if !hasPartition {
|
||||
err := l.node.replica.addPartition(collectionID, partitionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := l.node.replica.enablePartition(partitionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(segmentIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = l.node.loadService.loadSegmentPassively(collectionID, partitionID, segmentIDs, fieldIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("LoadSegments done", zap.String("segmentIDs", fmt.Sprintln(l.req.SegmentIDs)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *loadSegmentsTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// releaseCollectionTask
|
||||
func (r *releaseCollectionTask) Timestamp() Timestamp {
|
||||
return r.req.Base.Timestamp
|
||||
}
|
||||
|
||||
func (r *releaseCollectionTask) OnEnqueue() error {
|
||||
if r.req == nil || r.req.Base == nil {
|
||||
r.SetID(rand.Int63n(100000000000))
|
||||
} else {
|
||||
r.SetID(r.req.Base.MsgID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *releaseCollectionTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *releaseCollectionTask) Execute(ctx context.Context) error {
|
||||
ds, err := r.node.getDataSyncService(r.req.CollectionID)
|
||||
if err == nil && ds != nil {
|
||||
ds.close()
|
||||
r.node.removeDataSyncService(r.req.CollectionID)
|
||||
r.node.replica.removeTSafe(r.req.CollectionID)
|
||||
r.node.replica.removeExcludedSegments(r.req.CollectionID)
|
||||
}
|
||||
|
||||
if r.node.searchService.hasSearchCollection(r.req.CollectionID) {
|
||||
r.node.searchService.stopSearchCollection(r.req.CollectionID)
|
||||
}
|
||||
|
||||
err = r.node.replica.removeCollection(r.req.CollectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("ReleaseCollection done", zap.Int64("collectionID", r.req.CollectionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *releaseCollectionTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// releasePartitionsTask
|
||||
func (r *releasePartitionsTask) Timestamp() Timestamp {
|
||||
return r.req.Base.Timestamp
|
||||
}
|
||||
|
||||
func (r *releasePartitionsTask) OnEnqueue() error {
|
||||
if r.req == nil || r.req.Base == nil {
|
||||
r.SetID(rand.Int63n(100000000000))
|
||||
} else {
|
||||
r.SetID(r.req.Base.MsgID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *releasePartitionsTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *releasePartitionsTask) Execute(ctx context.Context) error {
|
||||
for _, id := range r.req.PartitionIDs {
|
||||
err := r.node.loadService.segLoader.replica.removePartition(id)
|
||||
if err != nil {
|
||||
// not return, try to release all partitions
|
||||
log.Error(err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *releasePartitionsTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
)
|
||||
|
||||
const maxTaskNum = 1024
|
||||
|
||||
type taskQueue interface {
|
||||
utChan() <-chan int
|
||||
utEmpty() bool
|
||||
utFull() bool
|
||||
addUnissuedTask(t task) error
|
||||
PopUnissuedTask() task
|
||||
AddActiveTask(t task)
|
||||
PopActiveTask(tID UniqueID) task
|
||||
Enqueue(t task) error
|
||||
}
|
||||
|
||||
type baseTaskQueue struct {
|
||||
utMu sync.Mutex // guards unissuedTasks
|
||||
unissuedTasks *list.List
|
||||
|
||||
atMu sync.Mutex // guards activeTasks
|
||||
activeTasks map[UniqueID]task
|
||||
|
||||
maxTaskNum int64 // maxTaskNum should keep still
|
||||
utBufChan chan int // to block scheduler
|
||||
|
||||
scheduler *taskScheduler
|
||||
}
|
||||
|
||||
type loadAndReleaseTaskQueue struct {
|
||||
baseTaskQueue
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// baseTaskQueue
|
||||
func (queue *baseTaskQueue) utChan() <-chan int {
|
||||
return queue.utBufChan
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) utEmpty() bool {
|
||||
return queue.unissuedTasks.Len() == 0
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) utFull() bool {
|
||||
return int64(queue.unissuedTasks.Len()) >= queue.maxTaskNum
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) addUnissuedTask(t task) error {
|
||||
queue.utMu.Lock()
|
||||
defer queue.utMu.Unlock()
|
||||
|
||||
if queue.utFull() {
|
||||
return errors.New("task queue is full")
|
||||
}
|
||||
|
||||
if queue.unissuedTasks.Len() <= 0 {
|
||||
queue.unissuedTasks.PushBack(t)
|
||||
queue.utBufChan <- 1
|
||||
return nil
|
||||
}
|
||||
|
||||
if t.Timestamp() >= queue.unissuedTasks.Back().Value.(task).Timestamp() {
|
||||
queue.unissuedTasks.PushBack(t)
|
||||
queue.utBufChan <- 1
|
||||
return nil
|
||||
}
|
||||
|
||||
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
|
||||
if t.Timestamp() <= e.Value.(task).Timestamp() {
|
||||
queue.unissuedTasks.InsertBefore(t, e)
|
||||
queue.utBufChan <- 1
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("unexpected error in addUnissuedTask")
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) PopUnissuedTask() task {
|
||||
queue.utMu.Lock()
|
||||
defer queue.utMu.Unlock()
|
||||
|
||||
if queue.unissuedTasks.Len() <= 0 {
|
||||
log.Fatal("unissued task list is empty!")
|
||||
return nil
|
||||
}
|
||||
|
||||
ft := queue.unissuedTasks.Front()
|
||||
queue.unissuedTasks.Remove(ft)
|
||||
|
||||
return ft.Value.(task)
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) AddActiveTask(t task) {
|
||||
queue.atMu.Lock()
|
||||
defer queue.atMu.Unlock()
|
||||
|
||||
tID := t.ID()
|
||||
_, ok := queue.activeTasks[tID]
|
||||
if ok {
|
||||
log.Warn("queryNode", zap.Int64("task with ID already in active task list!", tID))
|
||||
}
|
||||
|
||||
queue.activeTasks[tID] = t
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) PopActiveTask(tID UniqueID) task {
|
||||
queue.atMu.Lock()
|
||||
defer queue.atMu.Unlock()
|
||||
|
||||
t, ok := queue.activeTasks[tID]
|
||||
if ok {
|
||||
delete(queue.activeTasks, tID)
|
||||
return t
|
||||
}
|
||||
log.Debug("queryNode", zap.Int64("cannot found ID in the active task list!", tID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) Enqueue(t task) error {
|
||||
err := t.OnEnqueue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return queue.addUnissuedTask(t)
|
||||
}
|
||||
|
||||
// loadAndReleaseTaskQueue
|
||||
func (queue *loadAndReleaseTaskQueue) Enqueue(t task) error {
|
||||
queue.mu.Lock()
|
||||
defer queue.mu.Unlock()
|
||||
return queue.baseTaskQueue.Enqueue(t)
|
||||
}
|
||||
|
||||
func newLoadAndReleaseTaskQueue(scheduler *taskScheduler) *loadAndReleaseTaskQueue {
|
||||
return &loadAndReleaseTaskQueue{
|
||||
baseTaskQueue: baseTaskQueue{
|
||||
unissuedTasks: list.New(),
|
||||
activeTasks: make(map[UniqueID]task),
|
||||
maxTaskNum: maxTaskNum,
|
||||
utBufChan: make(chan int, maxTaskNum),
|
||||
scheduler: scheduler,
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
)
|
||||
|
||||
type taskScheduler struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wg sync.WaitGroup
|
||||
queue taskQueue
|
||||
}
|
||||
|
||||
func newTaskScheduler(ctx context.Context) *taskScheduler {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
s := &taskScheduler{
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.queue = newLoadAndReleaseTaskQueue(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *taskScheduler) processTask(t task, q taskQueue) {
|
||||
// TODO: ctx?
|
||||
err := t.PreExecute(s.ctx)
|
||||
|
||||
defer func() {
|
||||
t.Notify(err)
|
||||
}()
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
q.AddActiveTask(t)
|
||||
defer func() {
|
||||
q.PopActiveTask(t.ID())
|
||||
}()
|
||||
|
||||
err = t.Execute(s.ctx)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
err = t.PostExecute(s.ctx)
|
||||
}
|
||||
|
||||
func (s *taskScheduler) loadAndReleaseLoop() {
|
||||
defer s.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-s.queue.utChan():
|
||||
if !s.queue.utEmpty() {
|
||||
t := s.queue.PopUnissuedTask()
|
||||
go s.processTask(t, s.queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *taskScheduler) Start() {
|
||||
s.wg.Add(1)
|
||||
go s.loadAndReleaseLoop()
|
||||
}
|
||||
|
||||
func (s *taskScheduler) Close() {
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
}
|
|
@ -498,6 +498,10 @@ func (qs *QueryService) LoadPartitions(ctx context.Context, req *querypb.LoadPar
|
|||
segment2Node := qs.shuffleSegmentsToQueryNode(toLoadSegmentIDs)
|
||||
for nodeID, assignedSegmentIDs := range segment2Node {
|
||||
loadSegmentRequest := &querypb.LoadSegmentsRequest{
|
||||
// TODO: use unique id allocator to assign reqID
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: rand.Int63n(10000000000),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
SegmentIDs: assignedSegmentIDs,
|
||||
|
|
Loading…
Reference in New Issue