mirror of https://github.com/milvus-io/milvus.git
Move AddQueryChannel to taskScheduler in query node (#12294)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/12432/head
parent
001752eb91
commit
496f3d0009
|
@ -15,9 +15,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -28,7 +25,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/mqclient"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
|
@ -96,92 +92,43 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
|
|||
}
|
||||
return status, err
|
||||
}
|
||||
collectionID := in.CollectionID
|
||||
if node.queryService == nil {
|
||||
errMsg := "null query service, collectionID = " + fmt.Sprintln(collectionID)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: errMsg,
|
||||
}
|
||||
return status, errors.New(errMsg)
|
||||
dct := &addQueryChannelTask{
|
||||
baseTask: baseTask{
|
||||
ctx: ctx,
|
||||
done: make(chan error),
|
||||
},
|
||||
req: in,
|
||||
node: node,
|
||||
}
|
||||
|
||||
if node.queryService.hasQueryCollection(collectionID) {
|
||||
log.Debug("queryCollection has been existed when addQueryChannel",
|
||||
zap.Any("collectionID", collectionID),
|
||||
)
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// add search collection
|
||||
err := node.queryService.addQueryCollection(collectionID)
|
||||
err := node.scheduler.queue.Enqueue(dct)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("add query collection", zap.Any("collectionID", collectionID))
|
||||
log.Debug("addQueryChannelTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
||||
// add request channel
|
||||
sc, err := node.queryService.getQueryCollection(in.CollectionID)
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
consumeChannels := []string{in.RequestChannelID}
|
||||
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())
|
||||
|
||||
if Params.skipQueryChannelRecovery {
|
||||
log.Debug("Skip query channel seek back ", zap.Strings("channels", consumeChannels),
|
||||
zap.String("seek position", string(in.SeekPosition.MsgID)),
|
||||
zap.Uint64("ts", in.SeekPosition.Timestamp))
|
||||
sc.queryMsgStream.AsConsumerWithPosition(consumeChannels, consumeSubName, mqclient.SubscriptionPositionLatest)
|
||||
} else {
|
||||
sc.queryMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
if in.SeekPosition == nil || len(in.SeekPosition.MsgID) == 0 {
|
||||
// as consumer
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
} else {
|
||||
// seek query channel
|
||||
err = sc.queryMsgStream.Seek([]*internalpb.MsgPosition{in.SeekPosition})
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
return status, err
|
||||
waitFunc := func() (*commonpb.Status, error) {
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels),
|
||||
zap.String("seek position", string(in.SeekPosition.MsgID)))
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("addQueryChannelTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// add result channel
|
||||
producerChannels := []string{in.ResultChannelID}
|
||||
sc.queryResultMsgStream.AsProducer(producerChannels)
|
||||
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
||||
|
||||
// init global sealed segments
|
||||
for _, segment := range in.GlobalSealedSegments {
|
||||
sc.globalSegmentManager.addGlobalSegmentInfo(segment)
|
||||
}
|
||||
|
||||
// start queryCollection, message stream need to asConsumer before start
|
||||
sc.start()
|
||||
log.Debug("start query collection", zap.Any("collectionID", collectionID))
|
||||
|
||||
status := &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}
|
||||
return status, nil
|
||||
return waitFunc()
|
||||
}
|
||||
|
||||
// RemoveQueryChannel remove queryChannel of the collection to stop receiving query message
|
||||
|
@ -261,7 +208,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmC
|
|||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("watchDmChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
@ -273,7 +220,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmC
|
|||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("watchDmChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
|
@ -311,7 +258,7 @@ func (node *QueryNode) WatchDeltaChannels(ctx context.Context, in *queryPb.Watch
|
|||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("watchDeltaChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
@ -323,7 +270,7 @@ func (node *QueryNode) WatchDeltaChannels(ctx context.Context, in *queryPb.Watch
|
|||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("watchDeltaChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
|
@ -361,7 +308,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegment
|
|||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
segmentIDs := make([]UniqueID, 0)
|
||||
|
@ -377,7 +324,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegment
|
|||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("loadSegmentsTask WaitToFinish done", zap.Int64s("segmentIDs", segmentIDs))
|
||||
|
@ -415,7 +362,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.Releas
|
|||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("releaseCollectionTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
@ -423,7 +370,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.Releas
|
|||
func() {
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
|
@ -461,7 +408,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.Releas
|
|||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return status, err
|
||||
}
|
||||
log.Debug("releasePartitionsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||
|
@ -469,7 +416,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.Releas
|
|||
func() {
|
||||
err = dct.WaitToFinish()
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||
|
|
|
@ -77,158 +77,28 @@ func TestImpl_GetStatisticsChannel(t *testing.T) {
|
|||
func TestImpl_AddQueryChannel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
t.Run("test addQueryChannel", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
}
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
MsgID: rand.Int63(),
|
||||
},
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
}
|
||||
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
})
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
t.Run("test addQueryChannel has queryCollection", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = node.queryService.addQueryCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
}
|
||||
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("test node is abnormal", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
status, err := node.AddQueryChannel(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("test nil query service", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
||||
CollectionID: defaultCollectionID,
|
||||
}
|
||||
|
||||
node.queryService = nil
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("test add query collection failed", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = node.streaming.replica.removeCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
err = node.historical.replica.removeCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
}
|
||||
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("test init global sealed segments", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
GlobalSealedSegments: []*queryPb.SegmentInfo{{
|
||||
SegmentID: defaultSegmentID,
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionID: defaultPartitionID,
|
||||
}},
|
||||
}
|
||||
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("test not init global sealed segments", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
GlobalSealedSegments: []*queryPb.SegmentInfo{{
|
||||
SegmentID: defaultSegmentID,
|
||||
CollectionID: 1000,
|
||||
PartitionID: defaultPartitionID,
|
||||
}},
|
||||
}
|
||||
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("test seek error", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
position := &internalpb.MsgPosition{
|
||||
ChannelName: genQueryChannel(),
|
||||
MsgID: []byte{1, 2, 3},
|
||||
MsgGroup: defaultSubName,
|
||||
Timestamp: 0,
|
||||
}
|
||||
|
||||
req := &queryPb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
SeekPosition: position,
|
||||
}
|
||||
|
||||
status, err := node.AddQueryChannel(ctx, req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
status, err = node.AddQueryChannel(ctx, req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
}
|
||||
|
||||
func TestImpl_RemoveQueryChannel(t *testing.T) {
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
@ -26,6 +27,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/rootcoord"
|
||||
"github.com/milvus-io/milvus/internal/util/mqclient"
|
||||
)
|
||||
|
||||
type task interface {
|
||||
|
@ -46,6 +48,12 @@ type baseTask struct {
|
|||
id UniqueID
|
||||
}
|
||||
|
||||
type addQueryChannelTask struct {
|
||||
baseTask
|
||||
req *queryPb.AddQueryChannelRequest
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
type watchDmChannelsTask struct {
|
||||
baseTask
|
||||
req *queryPb.WatchDmChannelsRequest
|
||||
|
@ -93,6 +101,105 @@ func (b *baseTask) Notify(err error) {
|
|||
b.done <- err
|
||||
}
|
||||
|
||||
// addQueryChannel
|
||||
func (r *addQueryChannelTask) Timestamp() Timestamp {
|
||||
if r.req.Base == nil {
|
||||
log.Warn("nil base req in addQueryChannelTask", zap.Any("collectionID", r.req.CollectionID))
|
||||
return 0
|
||||
}
|
||||
return r.req.Base.Timestamp
|
||||
}
|
||||
|
||||
func (r *addQueryChannelTask) OnEnqueue() error {
|
||||
if r.req == nil || r.req.Base == nil {
|
||||
r.SetID(rand.Int63n(100000000000))
|
||||
} else {
|
||||
r.SetID(r.req.Base.MsgID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *addQueryChannelTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *addQueryChannelTask) Execute(ctx context.Context) error {
|
||||
log.Debug("Execute addQueryChannelTask",
|
||||
zap.Any("collectionID", r.req.CollectionID))
|
||||
|
||||
collectionID := r.req.CollectionID
|
||||
if r.node.queryService == nil {
|
||||
errMsg := "null query service, collectionID = " + fmt.Sprintln(collectionID)
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
if r.node.queryService.hasQueryCollection(collectionID) {
|
||||
log.Debug("queryCollection has been existed when addQueryChannel",
|
||||
zap.Any("collectionID", collectionID),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// add search collection
|
||||
err := r.node.queryService.addQueryCollection(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug("add query collection", zap.Any("collectionID", collectionID))
|
||||
|
||||
// add request channel
|
||||
sc, err := r.node.queryService.getQueryCollection(collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
consumeChannels := []string{r.req.RequestChannelID}
|
||||
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())
|
||||
|
||||
if Params.skipQueryChannelRecovery {
|
||||
log.Debug("Skip query channel seek back ", zap.Strings("channels", consumeChannels),
|
||||
zap.String("seek position", string(r.req.SeekPosition.MsgID)),
|
||||
zap.Uint64("ts", r.req.SeekPosition.Timestamp))
|
||||
sc.queryMsgStream.AsConsumerWithPosition(consumeChannels, consumeSubName, mqclient.SubscriptionPositionLatest)
|
||||
} else {
|
||||
sc.queryMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||
if r.req.SeekPosition == nil || len(r.req.SeekPosition.MsgID) == 0 {
|
||||
// as consumer
|
||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||
} else {
|
||||
// seek query channel
|
||||
err = sc.queryMsgStream.Seek([]*internalpb.MsgPosition{r.req.SeekPosition})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels),
|
||||
zap.String("seek position", string(r.req.SeekPosition.MsgID)))
|
||||
}
|
||||
}
|
||||
|
||||
// add result channel
|
||||
producerChannels := []string{r.req.ResultChannelID}
|
||||
sc.queryResultMsgStream.AsProducer(producerChannels)
|
||||
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
||||
|
||||
// init global sealed segments
|
||||
for _, segment := range r.req.GlobalSealedSegments {
|
||||
sc.globalSegmentManager.addGlobalSegmentInfo(segment)
|
||||
}
|
||||
|
||||
// start queryCollection, message stream need to asConsumer before start
|
||||
sc.start()
|
||||
log.Debug("start query collection", zap.Any("collectionID", collectionID))
|
||||
|
||||
log.Debug("addQueryChannelTask done",
|
||||
zap.Any("collectionID", r.req.CollectionID),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *addQueryChannelTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// watchDmChannelsTask
|
||||
func (w *watchDmChannelsTask) Timestamp() Timestamp {
|
||||
if w.req.Base == nil {
|
||||
|
|
|
@ -47,7 +47,7 @@ type baseTaskQueue struct {
|
|||
scheduler *taskScheduler
|
||||
}
|
||||
|
||||
type loadAndReleaseTaskQueue struct {
|
||||
type queryNodeTaskQueue struct {
|
||||
baseTaskQueue
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
@ -148,15 +148,15 @@ func (queue *baseTaskQueue) Enqueue(t task) error {
|
|||
return queue.addUnissuedTask(t)
|
||||
}
|
||||
|
||||
// loadAndReleaseTaskQueue
|
||||
func (queue *loadAndReleaseTaskQueue) Enqueue(t task) error {
|
||||
// queryNodeTaskQueue
|
||||
func (queue *queryNodeTaskQueue) Enqueue(t task) error {
|
||||
queue.mu.Lock()
|
||||
defer queue.mu.Unlock()
|
||||
return queue.baseTaskQueue.Enqueue(t)
|
||||
}
|
||||
|
||||
func newLoadAndReleaseTaskQueue(scheduler *taskScheduler) *loadAndReleaseTaskQueue {
|
||||
return &loadAndReleaseTaskQueue{
|
||||
func newQueryNodeTaskQueue(scheduler *taskScheduler) *queryNodeTaskQueue {
|
||||
return &queryNodeTaskQueue{
|
||||
baseTaskQueue: baseTaskQueue{
|
||||
unissuedTasks: list.New(),
|
||||
activeTasks: make(map[UniqueID]task),
|
||||
|
|
|
@ -29,7 +29,7 @@ func TestBaseTaskQueue_addUnissuedTask(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("test full", func(t *testing.T) {
|
||||
taskQueue := newLoadAndReleaseTaskQueue(s)
|
||||
taskQueue := newQueryNodeTaskQueue(s)
|
||||
task := &mockTask{}
|
||||
for i := 0; i < maxTaskNum; i++ {
|
||||
err := taskQueue.addUnissuedTask(task)
|
||||
|
@ -40,7 +40,7 @@ func TestBaseTaskQueue_addUnissuedTask(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("add task to front", func(t *testing.T) {
|
||||
taskQueue := newLoadAndReleaseTaskQueue(s)
|
||||
taskQueue := newQueryNodeTaskQueue(s)
|
||||
mt := &mockTask{
|
||||
timestamp: 1000,
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ func newTaskScheduler(ctx context.Context) *taskScheduler {
|
|||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.queue = newLoadAndReleaseTaskQueue(s)
|
||||
s.queue = newQueryNodeTaskQueue(s)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,7 @@ func (s *taskScheduler) processTask(t task, q taskQueue) {
|
|||
err = t.PostExecute(s.ctx)
|
||||
}
|
||||
|
||||
func (s *taskScheduler) loadAndReleaseLoop() {
|
||||
func (s *taskScheduler) taskLoop() {
|
||||
defer s.wg.Done()
|
||||
for {
|
||||
select {
|
||||
|
@ -78,7 +78,7 @@ func (s *taskScheduler) loadAndReleaseLoop() {
|
|||
|
||||
func (s *taskScheduler) Start() {
|
||||
s.wg.Add(1)
|
||||
go s.loadAndReleaseLoop()
|
||||
go s.taskLoop()
|
||||
}
|
||||
|
||||
func (s *taskScheduler) Close() {
|
||||
|
|
|
@ -25,6 +25,191 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func TestTask_AddQueryChannel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
genAddQueryChanelRequest := func() *querypb.AddQueryChannelRequest {
|
||||
return &querypb.AddQueryChannelRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_LoadCollection),
|
||||
NodeID: 0,
|
||||
CollectionID: defaultCollectionID,
|
||||
RequestChannelID: genQueryChannel(),
|
||||
ResultChannelID: genQueryResultChannel(),
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("test timestamp", func(t *testing.T) {
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
}
|
||||
timestamp := Timestamp(1000)
|
||||
task.req.Base.Timestamp = timestamp
|
||||
resT := task.Timestamp()
|
||||
assert.Equal(t, timestamp, resT)
|
||||
task.req.Base = nil
|
||||
resT = task.Timestamp()
|
||||
assert.Equal(t, Timestamp(0), resT)
|
||||
})
|
||||
|
||||
t.Run("test OnEnqueue", func(t *testing.T) {
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
}
|
||||
err := task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
task.req.Base = nil
|
||||
err = task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute has queryCollection", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = node.queryService.addQueryCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute nil query service", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.queryService = nil
|
||||
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute add query collection failed", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = node.streaming.replica.removeCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
err = node.historical.replica.removeCollection(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute init global sealed segments", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
node: node,
|
||||
}
|
||||
|
||||
task.req.GlobalSealedSegments = []*querypb.SegmentInfo{{
|
||||
SegmentID: defaultSegmentID,
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionID: defaultPartitionID,
|
||||
}}
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute not init global sealed segments", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
node: node,
|
||||
}
|
||||
|
||||
task.req.GlobalSealedSegments = []*querypb.SegmentInfo{{
|
||||
SegmentID: defaultSegmentID,
|
||||
CollectionID: 1000,
|
||||
PartitionID: defaultPartitionID,
|
||||
}}
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute seek error", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
position := &internalpb.MsgPosition{
|
||||
ChannelName: genQueryChannel(),
|
||||
MsgID: []byte{1, 2, 3},
|
||||
MsgGroup: defaultSubName,
|
||||
Timestamp: 0,
|
||||
}
|
||||
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
node: node,
|
||||
}
|
||||
|
||||
task.req.SeekPosition = position
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute skipQueryChannelRecovery", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
position := &internalpb.MsgPosition{
|
||||
ChannelName: genQueryChannel(),
|
||||
MsgID: []byte{1, 2, 3},
|
||||
MsgGroup: defaultSubName,
|
||||
Timestamp: 0,
|
||||
}
|
||||
|
||||
task := addQueryChannelTask{
|
||||
req: genAddQueryChanelRequest(),
|
||||
node: node,
|
||||
}
|
||||
|
||||
task.req.SeekPosition = position
|
||||
|
||||
Params.skipQueryChannelRecovery = true
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTask_watchDmChannelsTask(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
|
Loading…
Reference in New Issue