Move AddQueryChannel to taskScheduler in query node (#12294)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/12432/head
bigsheeper 2021-11-30 10:29:43 +08:00 committed by GitHub
parent 001752eb91
commit 496f3d0009
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 354 additions and 245 deletions

View File

@ -15,9 +15,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"strconv"
"strings"
"go.uber.org/zap" "go.uber.org/zap"
@ -28,7 +25,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/milvuspb"
queryPb "github.com/milvus-io/milvus/internal/proto/querypb" queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/mqclient"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
) )
@ -96,92 +92,43 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
} }
return status, err return status, err
} }
collectionID := in.CollectionID dct := &addQueryChannelTask{
if node.queryService == nil { baseTask: baseTask{
errMsg := "null query service, collectionID = " + fmt.Sprintln(collectionID) ctx: ctx,
status := &commonpb.Status{ done: make(chan error),
ErrorCode: commonpb.ErrorCode_UnexpectedError, },
Reason: errMsg, req: in,
} node: node,
return status, errors.New(errMsg)
} }
if node.queryService.hasQueryCollection(collectionID) { err := node.scheduler.queue.Enqueue(dct)
log.Debug("queryCollection has been existed when addQueryChannel",
zap.Any("collectionID", collectionID),
)
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
return status, nil
}
// add search collection
err := node.queryService.addQueryCollection(collectionID)
if err != nil { if err != nil {
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Error(err.Error())
return status, err return status, err
} }
log.Debug("add query collection", zap.Any("collectionID", collectionID)) log.Debug("addQueryChannelTask Enqueue done", zap.Any("collectionID", in.CollectionID))
// add request channel waitFunc := func() (*commonpb.Status, error) {
sc, err := node.queryService.getQueryCollection(in.CollectionID) err = dct.WaitToFinish()
if err != nil { if err != nil {
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
}
return status, err
}
consumeChannels := []string{in.RequestChannelID}
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())
if Params.skipQueryChannelRecovery {
log.Debug("Skip query channel seek back ", zap.Strings("channels", consumeChannels),
zap.String("seek position", string(in.SeekPosition.MsgID)),
zap.Uint64("ts", in.SeekPosition.Timestamp))
sc.queryMsgStream.AsConsumerWithPosition(consumeChannels, consumeSubName, mqclient.SubscriptionPositionLatest)
} else {
sc.queryMsgStream.AsConsumer(consumeChannels, consumeSubName)
if in.SeekPosition == nil || len(in.SeekPosition.MsgID) == 0 {
// as consumer
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
} else {
// seek query channel
err = sc.queryMsgStream.Seek([]*internalpb.MsgPosition{in.SeekPosition})
if err != nil {
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
return status, err
} }
log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels), log.Error(err.Error())
zap.String("seek position", string(in.SeekPosition.MsgID))) return status, err
} }
log.Debug("addQueryChannelTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
} }
// add result channel return waitFunc()
producerChannels := []string{in.ResultChannelID}
sc.queryResultMsgStream.AsProducer(producerChannels)
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
// init global sealed segments
for _, segment := range in.GlobalSealedSegments {
sc.globalSegmentManager.addGlobalSegmentInfo(segment)
}
// start queryCollection, message stream need to asConsumer before start
sc.start()
log.Debug("start query collection", zap.Any("collectionID", collectionID))
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
return status, nil
} }
// RemoveQueryChannel remove queryChannel of the collection to stop receiving query message // RemoveQueryChannel remove queryChannel of the collection to stop receiving query message
@ -261,7 +208,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmC
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error()) log.Error(err.Error())
return status, err return status, err
} }
log.Debug("watchDmChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID)) log.Debug("watchDmChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
@ -273,7 +220,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmC
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error()) log.Error(err.Error())
return status, err return status, err
} }
log.Debug("watchDmChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID)) log.Debug("watchDmChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
@ -311,7 +258,7 @@ func (node *QueryNode) WatchDeltaChannels(ctx context.Context, in *queryPb.Watch
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error()) log.Error(err.Error())
return status, err return status, err
} }
log.Debug("watchDeltaChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID)) log.Debug("watchDeltaChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
@ -323,7 +270,7 @@ func (node *QueryNode) WatchDeltaChannels(ctx context.Context, in *queryPb.Watch
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error()) log.Error(err.Error())
return status, err return status, err
} }
log.Debug("watchDeltaChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID)) log.Debug("watchDeltaChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
@ -361,7 +308,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegment
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error()) log.Error(err.Error())
return status, err return status, err
} }
segmentIDs := make([]UniqueID, 0) segmentIDs := make([]UniqueID, 0)
@ -377,7 +324,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegment
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error()) log.Error(err.Error())
return status, err return status, err
} }
log.Debug("loadSegmentsTask WaitToFinish done", zap.Int64s("segmentIDs", segmentIDs)) log.Debug("loadSegmentsTask WaitToFinish done", zap.Int64s("segmentIDs", segmentIDs))
@ -415,7 +362,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.Releas
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error()) log.Error(err.Error())
return status, err return status, err
} }
log.Debug("releaseCollectionTask Enqueue done", zap.Any("collectionID", in.CollectionID)) log.Debug("releaseCollectionTask Enqueue done", zap.Any("collectionID", in.CollectionID))
@ -423,7 +370,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.Releas
func() { func() {
err = dct.WaitToFinish() err = dct.WaitToFinish()
if err != nil { if err != nil {
log.Warn(err.Error()) log.Error(err.Error())
return return
} }
log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID)) log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
@ -461,7 +408,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.Releas
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error()) log.Error(err.Error())
return status, err return status, err
} }
log.Debug("releasePartitionsTask Enqueue done", zap.Any("collectionID", in.CollectionID)) log.Debug("releasePartitionsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
@ -469,7 +416,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.Releas
func() { func() {
err = dct.WaitToFinish() err = dct.WaitToFinish()
if err != nil { if err != nil {
log.Warn(err.Error()) log.Error(err.Error())
return return
} }
log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID)) log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))

View File

@ -77,158 +77,28 @@ func TestImpl_GetStatisticsChannel(t *testing.T) {
func TestImpl_AddQueryChannel(t *testing.T) { func TestImpl_AddQueryChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
t.Run("test addQueryChannel", func(t *testing.T) { node, err := genSimpleQueryNode(ctx)
node, err := genSimpleQueryNode(ctx) assert.NoError(t, err)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{ req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels), Base: &commonpb.MsgBase{
NodeID: 0, MsgType: commonpb.MsgType_LoadCollection,
CollectionID: defaultCollectionID, MsgID: rand.Int63(),
RequestChannelID: genQueryChannel(), },
ResultChannelID: genQueryResultChannel(), NodeID: 0,
} CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
}
status, err := node.AddQueryChannel(ctx, req) status, err := node.AddQueryChannel(ctx, req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
})
t.Run("test addQueryChannel has queryCollection", func(t *testing.T) { node.UpdateStateCode(internalpb.StateCode_Abnormal)
node, err := genSimpleQueryNode(ctx) status, err = node.AddQueryChannel(ctx, req)
assert.NoError(t, err) assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
err = node.queryService.addQueryCollection(defaultCollectionID)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
}
status, err := node.AddQueryChannel(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
})
t.Run("test node is abnormal", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
status, err := node.AddQueryChannel(ctx, nil)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
t.Run("test nil query service", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
CollectionID: defaultCollectionID,
}
node.queryService = nil
status, err := node.AddQueryChannel(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
t.Run("test add query collection failed", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
err = node.streaming.replica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
err = node.historical.replica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
}
status, err := node.AddQueryChannel(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
t.Run("test init global sealed segments", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
GlobalSealedSegments: []*queryPb.SegmentInfo{{
SegmentID: defaultSegmentID,
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
}},
}
status, err := node.AddQueryChannel(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
})
t.Run("test not init global sealed segments", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
GlobalSealedSegments: []*queryPb.SegmentInfo{{
SegmentID: defaultSegmentID,
CollectionID: 1000,
PartitionID: defaultPartitionID,
}},
}
status, err := node.AddQueryChannel(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
})
t.Run("test seek error", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
position := &internalpb.MsgPosition{
ChannelName: genQueryChannel(),
MsgID: []byte{1, 2, 3},
MsgGroup: defaultSubName,
Timestamp: 0,
}
req := &queryPb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
SeekPosition: position,
}
status, err := node.AddQueryChannel(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
} }
func TestImpl_RemoveQueryChannel(t *testing.T) { func TestImpl_RemoveQueryChannel(t *testing.T) {

View File

@ -17,6 +17,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"strconv" "strconv"
"strings"
"time" "time"
"go.uber.org/zap" "go.uber.org/zap"
@ -26,6 +27,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
queryPb "github.com/milvus-io/milvus/internal/proto/querypb" queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/util/mqclient"
) )
type task interface { type task interface {
@ -46,6 +48,12 @@ type baseTask struct {
id UniqueID id UniqueID
} }
type addQueryChannelTask struct {
baseTask
req *queryPb.AddQueryChannelRequest
node *QueryNode
}
type watchDmChannelsTask struct { type watchDmChannelsTask struct {
baseTask baseTask
req *queryPb.WatchDmChannelsRequest req *queryPb.WatchDmChannelsRequest
@ -93,6 +101,105 @@ func (b *baseTask) Notify(err error) {
b.done <- err b.done <- err
} }
// addQueryChannel
func (r *addQueryChannelTask) Timestamp() Timestamp {
if r.req.Base == nil {
log.Warn("nil base req in addQueryChannelTask", zap.Any("collectionID", r.req.CollectionID))
return 0
}
return r.req.Base.Timestamp
}
func (r *addQueryChannelTask) OnEnqueue() error {
if r.req == nil || r.req.Base == nil {
r.SetID(rand.Int63n(100000000000))
} else {
r.SetID(r.req.Base.MsgID)
}
return nil
}
func (r *addQueryChannelTask) PreExecute(ctx context.Context) error {
return nil
}
func (r *addQueryChannelTask) Execute(ctx context.Context) error {
log.Debug("Execute addQueryChannelTask",
zap.Any("collectionID", r.req.CollectionID))
collectionID := r.req.CollectionID
if r.node.queryService == nil {
errMsg := "null query service, collectionID = " + fmt.Sprintln(collectionID)
return errors.New(errMsg)
}
if r.node.queryService.hasQueryCollection(collectionID) {
log.Debug("queryCollection has been existed when addQueryChannel",
zap.Any("collectionID", collectionID),
)
return nil
}
// add search collection
err := r.node.queryService.addQueryCollection(collectionID)
if err != nil {
return err
}
log.Debug("add query collection", zap.Any("collectionID", collectionID))
// add request channel
sc, err := r.node.queryService.getQueryCollection(collectionID)
if err != nil {
return err
}
consumeChannels := []string{r.req.RequestChannelID}
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())
if Params.skipQueryChannelRecovery {
log.Debug("Skip query channel seek back ", zap.Strings("channels", consumeChannels),
zap.String("seek position", string(r.req.SeekPosition.MsgID)),
zap.Uint64("ts", r.req.SeekPosition.Timestamp))
sc.queryMsgStream.AsConsumerWithPosition(consumeChannels, consumeSubName, mqclient.SubscriptionPositionLatest)
} else {
sc.queryMsgStream.AsConsumer(consumeChannels, consumeSubName)
if r.req.SeekPosition == nil || len(r.req.SeekPosition.MsgID) == 0 {
// as consumer
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
} else {
// seek query channel
err = sc.queryMsgStream.Seek([]*internalpb.MsgPosition{r.req.SeekPosition})
if err != nil {
return err
}
log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels),
zap.String("seek position", string(r.req.SeekPosition.MsgID)))
}
}
// add result channel
producerChannels := []string{r.req.ResultChannelID}
sc.queryResultMsgStream.AsProducer(producerChannels)
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
// init global sealed segments
for _, segment := range r.req.GlobalSealedSegments {
sc.globalSegmentManager.addGlobalSegmentInfo(segment)
}
// start queryCollection, message stream need to asConsumer before start
sc.start()
log.Debug("start query collection", zap.Any("collectionID", collectionID))
log.Debug("addQueryChannelTask done",
zap.Any("collectionID", r.req.CollectionID),
)
return nil
}
func (r *addQueryChannelTask) PostExecute(ctx context.Context) error {
return nil
}
// watchDmChannelsTask // watchDmChannelsTask
func (w *watchDmChannelsTask) Timestamp() Timestamp { func (w *watchDmChannelsTask) Timestamp() Timestamp {
if w.req.Base == nil { if w.req.Base == nil {

View File

@ -47,7 +47,7 @@ type baseTaskQueue struct {
scheduler *taskScheduler scheduler *taskScheduler
} }
type loadAndReleaseTaskQueue struct { type queryNodeTaskQueue struct {
baseTaskQueue baseTaskQueue
mu sync.Mutex mu sync.Mutex
} }
@ -148,15 +148,15 @@ func (queue *baseTaskQueue) Enqueue(t task) error {
return queue.addUnissuedTask(t) return queue.addUnissuedTask(t)
} }
// loadAndReleaseTaskQueue // queryNodeTaskQueue
func (queue *loadAndReleaseTaskQueue) Enqueue(t task) error { func (queue *queryNodeTaskQueue) Enqueue(t task) error {
queue.mu.Lock() queue.mu.Lock()
defer queue.mu.Unlock() defer queue.mu.Unlock()
return queue.baseTaskQueue.Enqueue(t) return queue.baseTaskQueue.Enqueue(t)
} }
func newLoadAndReleaseTaskQueue(scheduler *taskScheduler) *loadAndReleaseTaskQueue { func newQueryNodeTaskQueue(scheduler *taskScheduler) *queryNodeTaskQueue {
return &loadAndReleaseTaskQueue{ return &queryNodeTaskQueue{
baseTaskQueue: baseTaskQueue{ baseTaskQueue: baseTaskQueue{
unissuedTasks: list.New(), unissuedTasks: list.New(),
activeTasks: make(map[UniqueID]task), activeTasks: make(map[UniqueID]task),

View File

@ -29,7 +29,7 @@ func TestBaseTaskQueue_addUnissuedTask(t *testing.T) {
} }
t.Run("test full", func(t *testing.T) { t.Run("test full", func(t *testing.T) {
taskQueue := newLoadAndReleaseTaskQueue(s) taskQueue := newQueryNodeTaskQueue(s)
task := &mockTask{} task := &mockTask{}
for i := 0; i < maxTaskNum; i++ { for i := 0; i < maxTaskNum; i++ {
err := taskQueue.addUnissuedTask(task) err := taskQueue.addUnissuedTask(task)
@ -40,7 +40,7 @@ func TestBaseTaskQueue_addUnissuedTask(t *testing.T) {
}) })
t.Run("add task to front", func(t *testing.T) { t.Run("add task to front", func(t *testing.T) {
taskQueue := newLoadAndReleaseTaskQueue(s) taskQueue := newQueryNodeTaskQueue(s)
mt := &mockTask{ mt := &mockTask{
timestamp: 1000, timestamp: 1000,
} }

View File

@ -32,7 +32,7 @@ func newTaskScheduler(ctx context.Context) *taskScheduler {
ctx: ctx1, ctx: ctx1,
cancel: cancel, cancel: cancel,
} }
s.queue = newLoadAndReleaseTaskQueue(s) s.queue = newQueryNodeTaskQueue(s)
return s return s
} }
@ -61,7 +61,7 @@ func (s *taskScheduler) processTask(t task, q taskQueue) {
err = t.PostExecute(s.ctx) err = t.PostExecute(s.ctx)
} }
func (s *taskScheduler) loadAndReleaseLoop() { func (s *taskScheduler) taskLoop() {
defer s.wg.Done() defer s.wg.Done()
for { for {
select { select {
@ -78,7 +78,7 @@ func (s *taskScheduler) loadAndReleaseLoop() {
func (s *taskScheduler) Start() { func (s *taskScheduler) Start() {
s.wg.Add(1) s.wg.Add(1)
go s.loadAndReleaseLoop() go s.taskLoop()
} }
func (s *taskScheduler) Close() { func (s *taskScheduler) Close() {

View File

@ -25,6 +25,191 @@ import (
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
) )
func TestTask_AddQueryChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
genAddQueryChanelRequest := func() *querypb.AddQueryChannelRequest {
return &querypb.AddQueryChannelRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadCollection),
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
}
}
t.Run("test timestamp", func(t *testing.T) {
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
}
timestamp := Timestamp(1000)
task.req.Base.Timestamp = timestamp
resT := task.Timestamp()
assert.Equal(t, timestamp, resT)
task.req.Base = nil
resT = task.Timestamp()
assert.Equal(t, Timestamp(0), resT)
})
t.Run("test OnEnqueue", func(t *testing.T) {
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
}
err := task.OnEnqueue()
assert.NoError(t, err)
task.req.Base = nil
err = task.OnEnqueue()
assert.NoError(t, err)
})
t.Run("test execute", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
node: node,
}
err = task.Execute(ctx)
assert.NoError(t, err)
})
t.Run("test execute has queryCollection", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
err = node.queryService.addQueryCollection(defaultCollectionID)
assert.NoError(t, err)
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
node: node,
}
err = task.Execute(ctx)
assert.NoError(t, err)
})
t.Run("test execute nil query service", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
node.queryService = nil
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
node: node,
}
err = task.Execute(ctx)
assert.Error(t, err)
})
t.Run("test execute add query collection failed", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
err = node.streaming.replica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
err = node.historical.replica.removeCollection(defaultCollectionID)
assert.NoError(t, err)
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
node: node,
}
err = task.Execute(ctx)
assert.Error(t, err)
})
t.Run("test execute init global sealed segments", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
node: node,
}
task.req.GlobalSealedSegments = []*querypb.SegmentInfo{{
SegmentID: defaultSegmentID,
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
}}
err = task.Execute(ctx)
assert.NoError(t, err)
})
t.Run("test execute not init global sealed segments", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
node: node,
}
task.req.GlobalSealedSegments = []*querypb.SegmentInfo{{
SegmentID: defaultSegmentID,
CollectionID: 1000,
PartitionID: defaultPartitionID,
}}
err = task.Execute(ctx)
assert.NoError(t, err)
})
t.Run("test execute seek error", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
position := &internalpb.MsgPosition{
ChannelName: genQueryChannel(),
MsgID: []byte{1, 2, 3},
MsgGroup: defaultSubName,
Timestamp: 0,
}
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
node: node,
}
task.req.SeekPosition = position
err = task.Execute(ctx)
assert.Error(t, err)
})
t.Run("test execute skipQueryChannelRecovery", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
position := &internalpb.MsgPosition{
ChannelName: genQueryChannel(),
MsgID: []byte{1, 2, 3},
MsgGroup: defaultSubName,
Timestamp: 0,
}
task := addQueryChannelTask{
req: genAddQueryChanelRequest(),
node: node,
}
task.req.SeekPosition = position
Params.skipQueryChannelRecovery = true
err = task.Execute(ctx)
assert.NoError(t, err)
})
}
func TestTask_watchDmChannelsTask(t *testing.T) { func TestTask_watchDmChannelsTask(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()