mirror of https://github.com/milvus-io/milvus.git
Move AddQueryChannel to taskScheduler in query node (#12294)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/12432/head
parent
001752eb91
commit
496f3d0009
|
@ -15,9 +15,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
@ -28,7 +25,6 @@ import (
|
||||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||||
"github.com/milvus-io/milvus/internal/util/mqclient"
|
|
||||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -96,92 +92,43 @@ func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQuery
|
||||||
}
|
}
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
collectionID := in.CollectionID
|
dct := &addQueryChannelTask{
|
||||||
if node.queryService == nil {
|
baseTask: baseTask{
|
||||||
errMsg := "null query service, collectionID = " + fmt.Sprintln(collectionID)
|
ctx: ctx,
|
||||||
status := &commonpb.Status{
|
done: make(chan error),
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
},
|
||||||
Reason: errMsg,
|
req: in,
|
||||||
}
|
node: node,
|
||||||
return status, errors.New(errMsg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.queryService.hasQueryCollection(collectionID) {
|
err := node.scheduler.queue.Enqueue(dct)
|
||||||
log.Debug("queryCollection has been existed when addQueryChannel",
|
|
||||||
zap.Any("collectionID", collectionID),
|
|
||||||
)
|
|
||||||
status := &commonpb.Status{
|
|
||||||
ErrorCode: commonpb.ErrorCode_Success,
|
|
||||||
}
|
|
||||||
return status, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// add search collection
|
|
||||||
err := node.queryService.addQueryCollection(collectionID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
status := &commonpb.Status{
|
status := &commonpb.Status{
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
log.Debug("add query collection", zap.Any("collectionID", collectionID))
|
log.Debug("addQueryChannelTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
|
||||||
// add request channel
|
waitFunc := func() (*commonpb.Status, error) {
|
||||||
sc, err := node.queryService.getQueryCollection(in.CollectionID)
|
err = dct.WaitToFinish()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
status := &commonpb.Status{
|
status := &commonpb.Status{
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
|
||||||
return status, err
|
|
||||||
}
|
|
||||||
consumeChannels := []string{in.RequestChannelID}
|
|
||||||
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())
|
|
||||||
|
|
||||||
if Params.skipQueryChannelRecovery {
|
|
||||||
log.Debug("Skip query channel seek back ", zap.Strings("channels", consumeChannels),
|
|
||||||
zap.String("seek position", string(in.SeekPosition.MsgID)),
|
|
||||||
zap.Uint64("ts", in.SeekPosition.Timestamp))
|
|
||||||
sc.queryMsgStream.AsConsumerWithPosition(consumeChannels, consumeSubName, mqclient.SubscriptionPositionLatest)
|
|
||||||
} else {
|
|
||||||
sc.queryMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
|
||||||
if in.SeekPosition == nil || len(in.SeekPosition.MsgID) == 0 {
|
|
||||||
// as consumer
|
|
||||||
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
|
||||||
} else {
|
|
||||||
// seek query channel
|
|
||||||
err = sc.queryMsgStream.Seek([]*internalpb.MsgPosition{in.SeekPosition})
|
|
||||||
if err != nil {
|
|
||||||
status := &commonpb.Status{
|
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
||||||
Reason: err.Error(),
|
|
||||||
}
|
|
||||||
return status, err
|
|
||||||
}
|
}
|
||||||
log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels),
|
log.Error(err.Error())
|
||||||
zap.String("seek position", string(in.SeekPosition.MsgID)))
|
return status, err
|
||||||
}
|
}
|
||||||
|
log.Debug("addQueryChannelTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// add result channel
|
return waitFunc()
|
||||||
producerChannels := []string{in.ResultChannelID}
|
|
||||||
sc.queryResultMsgStream.AsProducer(producerChannels)
|
|
||||||
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
|
||||||
|
|
||||||
// init global sealed segments
|
|
||||||
for _, segment := range in.GlobalSealedSegments {
|
|
||||||
sc.globalSegmentManager.addGlobalSegmentInfo(segment)
|
|
||||||
}
|
|
||||||
|
|
||||||
// start queryCollection, message stream need to asConsumer before start
|
|
||||||
sc.start()
|
|
||||||
log.Debug("start query collection", zap.Any("collectionID", collectionID))
|
|
||||||
|
|
||||||
status := &commonpb.Status{
|
|
||||||
ErrorCode: commonpb.ErrorCode_Success,
|
|
||||||
}
|
|
||||||
return status, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveQueryChannel remove queryChannel of the collection to stop receiving query message
|
// RemoveQueryChannel remove queryChannel of the collection to stop receiving query message
|
||||||
|
@ -261,7 +208,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmC
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
log.Debug("watchDmChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
log.Debug("watchDmChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
@ -273,7 +220,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmC
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
log.Debug("watchDmChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
log.Debug("watchDmChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
@ -311,7 +258,7 @@ func (node *QueryNode) WatchDeltaChannels(ctx context.Context, in *queryPb.Watch
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
log.Debug("watchDeltaChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
log.Debug("watchDeltaChannelsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
@ -323,7 +270,7 @@ func (node *QueryNode) WatchDeltaChannels(ctx context.Context, in *queryPb.Watch
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
log.Debug("watchDeltaChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
log.Debug("watchDeltaChannelsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
@ -361,7 +308,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegment
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
segmentIDs := make([]UniqueID, 0)
|
segmentIDs := make([]UniqueID, 0)
|
||||||
|
@ -377,7 +324,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegment
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
log.Debug("loadSegmentsTask WaitToFinish done", zap.Int64s("segmentIDs", segmentIDs))
|
log.Debug("loadSegmentsTask WaitToFinish done", zap.Int64s("segmentIDs", segmentIDs))
|
||||||
|
@ -415,7 +362,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.Releas
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
log.Debug("releaseCollectionTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
log.Debug("releaseCollectionTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
@ -423,7 +370,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.Releas
|
||||||
func() {
|
func() {
|
||||||
err = dct.WaitToFinish()
|
err = dct.WaitToFinish()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
@ -461,7 +408,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.Releas
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return status, err
|
return status, err
|
||||||
}
|
}
|
||||||
log.Debug("releasePartitionsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
log.Debug("releasePartitionsTask Enqueue done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
@ -469,7 +416,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.Releas
|
||||||
func() {
|
func() {
|
||||||
err = dct.WaitToFinish()
|
err = dct.WaitToFinish()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err.Error())
|
log.Error(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID))
|
||||||
|
|
|
@ -77,158 +77,28 @@ func TestImpl_GetStatisticsChannel(t *testing.T) {
|
||||||
func TestImpl_AddQueryChannel(t *testing.T) {
|
func TestImpl_AddQueryChannel(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
t.Run("test addQueryChannel", func(t *testing.T) {
|
node, err := genSimpleQueryNode(ctx)
|
||||||
node, err := genSimpleQueryNode(ctx)
|
assert.NoError(t, err)
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := &queryPb.AddQueryChannelRequest{
|
req := &queryPb.AddQueryChannelRequest{
|
||||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
Base: &commonpb.MsgBase{
|
||||||
NodeID: 0,
|
MsgType: commonpb.MsgType_LoadCollection,
|
||||||
CollectionID: defaultCollectionID,
|
MsgID: rand.Int63(),
|
||||||
RequestChannelID: genQueryChannel(),
|
},
|
||||||
ResultChannelID: genQueryResultChannel(),
|
NodeID: 0,
|
||||||
}
|
CollectionID: defaultCollectionID,
|
||||||
|
RequestChannelID: genQueryChannel(),
|
||||||
|
ResultChannelID: genQueryResultChannel(),
|
||||||
|
}
|
||||||
|
|
||||||
status, err := node.AddQueryChannel(ctx, req)
|
status, err := node.AddQueryChannel(ctx, req)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test addQueryChannel has queryCollection", func(t *testing.T) {
|
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||||
node, err := genSimpleQueryNode(ctx)
|
status, err = node.AddQueryChannel(ctx, req)
|
||||||
assert.NoError(t, err)
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||||
err = node.queryService.addQueryCollection(defaultCollectionID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := &queryPb.AddQueryChannelRequest{
|
|
||||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
|
||||||
NodeID: 0,
|
|
||||||
CollectionID: defaultCollectionID,
|
|
||||||
RequestChannelID: genQueryChannel(),
|
|
||||||
ResultChannelID: genQueryResultChannel(),
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := node.AddQueryChannel(ctx, req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test node is abnormal", func(t *testing.T) {
|
|
||||||
node, err := genSimpleQueryNode(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
|
||||||
status, err := node.AddQueryChannel(ctx, nil)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test nil query service", func(t *testing.T) {
|
|
||||||
node, err := genSimpleQueryNode(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := &queryPb.AddQueryChannelRequest{
|
|
||||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
|
||||||
CollectionID: defaultCollectionID,
|
|
||||||
}
|
|
||||||
|
|
||||||
node.queryService = nil
|
|
||||||
status, err := node.AddQueryChannel(ctx, req)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test add query collection failed", func(t *testing.T) {
|
|
||||||
node, err := genSimpleQueryNode(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = node.streaming.replica.removeCollection(defaultCollectionID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = node.historical.replica.removeCollection(defaultCollectionID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := &queryPb.AddQueryChannelRequest{
|
|
||||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
|
||||||
NodeID: 0,
|
|
||||||
CollectionID: defaultCollectionID,
|
|
||||||
RequestChannelID: genQueryChannel(),
|
|
||||||
ResultChannelID: genQueryResultChannel(),
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := node.AddQueryChannel(ctx, req)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test init global sealed segments", func(t *testing.T) {
|
|
||||||
node, err := genSimpleQueryNode(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := &queryPb.AddQueryChannelRequest{
|
|
||||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
|
||||||
NodeID: 0,
|
|
||||||
CollectionID: defaultCollectionID,
|
|
||||||
RequestChannelID: genQueryChannel(),
|
|
||||||
ResultChannelID: genQueryResultChannel(),
|
|
||||||
GlobalSealedSegments: []*queryPb.SegmentInfo{{
|
|
||||||
SegmentID: defaultSegmentID,
|
|
||||||
CollectionID: defaultCollectionID,
|
|
||||||
PartitionID: defaultPartitionID,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := node.AddQueryChannel(ctx, req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test not init global sealed segments", func(t *testing.T) {
|
|
||||||
node, err := genSimpleQueryNode(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := &queryPb.AddQueryChannelRequest{
|
|
||||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
|
||||||
NodeID: 0,
|
|
||||||
CollectionID: defaultCollectionID,
|
|
||||||
RequestChannelID: genQueryChannel(),
|
|
||||||
ResultChannelID: genQueryResultChannel(),
|
|
||||||
GlobalSealedSegments: []*queryPb.SegmentInfo{{
|
|
||||||
SegmentID: defaultSegmentID,
|
|
||||||
CollectionID: 1000,
|
|
||||||
PartitionID: defaultPartitionID,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := node.AddQueryChannel(ctx, req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test seek error", func(t *testing.T) {
|
|
||||||
node, err := genSimpleQueryNode(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
position := &internalpb.MsgPosition{
|
|
||||||
ChannelName: genQueryChannel(),
|
|
||||||
MsgID: []byte{1, 2, 3},
|
|
||||||
MsgGroup: defaultSubName,
|
|
||||||
Timestamp: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &queryPb.AddQueryChannelRequest{
|
|
||||||
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels),
|
|
||||||
NodeID: 0,
|
|
||||||
CollectionID: defaultCollectionID,
|
|
||||||
RequestChannelID: genQueryChannel(),
|
|
||||||
ResultChannelID: genQueryResultChannel(),
|
|
||||||
SeekPosition: position,
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := node.AddQueryChannel(ctx, req)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestImpl_RemoveQueryChannel(t *testing.T) {
|
func TestImpl_RemoveQueryChannel(t *testing.T) {
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
@ -26,6 +27,7 @@ import (
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/rootcoord"
|
"github.com/milvus-io/milvus/internal/rootcoord"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/mqclient"
|
||||||
)
|
)
|
||||||
|
|
||||||
type task interface {
|
type task interface {
|
||||||
|
@ -46,6 +48,12 @@ type baseTask struct {
|
||||||
id UniqueID
|
id UniqueID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type addQueryChannelTask struct {
|
||||||
|
baseTask
|
||||||
|
req *queryPb.AddQueryChannelRequest
|
||||||
|
node *QueryNode
|
||||||
|
}
|
||||||
|
|
||||||
type watchDmChannelsTask struct {
|
type watchDmChannelsTask struct {
|
||||||
baseTask
|
baseTask
|
||||||
req *queryPb.WatchDmChannelsRequest
|
req *queryPb.WatchDmChannelsRequest
|
||||||
|
@ -93,6 +101,105 @@ func (b *baseTask) Notify(err error) {
|
||||||
b.done <- err
|
b.done <- err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addQueryChannel
|
||||||
|
func (r *addQueryChannelTask) Timestamp() Timestamp {
|
||||||
|
if r.req.Base == nil {
|
||||||
|
log.Warn("nil base req in addQueryChannelTask", zap.Any("collectionID", r.req.CollectionID))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return r.req.Base.Timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *addQueryChannelTask) OnEnqueue() error {
|
||||||
|
if r.req == nil || r.req.Base == nil {
|
||||||
|
r.SetID(rand.Int63n(100000000000))
|
||||||
|
} else {
|
||||||
|
r.SetID(r.req.Base.MsgID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *addQueryChannelTask) PreExecute(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *addQueryChannelTask) Execute(ctx context.Context) error {
|
||||||
|
log.Debug("Execute addQueryChannelTask",
|
||||||
|
zap.Any("collectionID", r.req.CollectionID))
|
||||||
|
|
||||||
|
collectionID := r.req.CollectionID
|
||||||
|
if r.node.queryService == nil {
|
||||||
|
errMsg := "null query service, collectionID = " + fmt.Sprintln(collectionID)
|
||||||
|
return errors.New(errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.node.queryService.hasQueryCollection(collectionID) {
|
||||||
|
log.Debug("queryCollection has been existed when addQueryChannel",
|
||||||
|
zap.Any("collectionID", collectionID),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// add search collection
|
||||||
|
err := r.node.queryService.addQueryCollection(collectionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug("add query collection", zap.Any("collectionID", collectionID))
|
||||||
|
|
||||||
|
// add request channel
|
||||||
|
sc, err := r.node.queryService.getQueryCollection(collectionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
consumeChannels := []string{r.req.RequestChannelID}
|
||||||
|
consumeSubName := Params.MsgChannelSubName + "-" + strconv.FormatInt(collectionID, 10) + "-" + strconv.Itoa(rand.Int())
|
||||||
|
|
||||||
|
if Params.skipQueryChannelRecovery {
|
||||||
|
log.Debug("Skip query channel seek back ", zap.Strings("channels", consumeChannels),
|
||||||
|
zap.String("seek position", string(r.req.SeekPosition.MsgID)),
|
||||||
|
zap.Uint64("ts", r.req.SeekPosition.Timestamp))
|
||||||
|
sc.queryMsgStream.AsConsumerWithPosition(consumeChannels, consumeSubName, mqclient.SubscriptionPositionLatest)
|
||||||
|
} else {
|
||||||
|
sc.queryMsgStream.AsConsumer(consumeChannels, consumeSubName)
|
||||||
|
if r.req.SeekPosition == nil || len(r.req.SeekPosition.MsgID) == 0 {
|
||||||
|
// as consumer
|
||||||
|
log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName)
|
||||||
|
} else {
|
||||||
|
// seek query channel
|
||||||
|
err = sc.queryMsgStream.Seek([]*internalpb.MsgPosition{r.req.SeekPosition})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug("querynode seek query channel: ", zap.Any("consumeChannels", consumeChannels),
|
||||||
|
zap.String("seek position", string(r.req.SeekPosition.MsgID)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add result channel
|
||||||
|
producerChannels := []string{r.req.ResultChannelID}
|
||||||
|
sc.queryResultMsgStream.AsProducer(producerChannels)
|
||||||
|
log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", "))
|
||||||
|
|
||||||
|
// init global sealed segments
|
||||||
|
for _, segment := range r.req.GlobalSealedSegments {
|
||||||
|
sc.globalSegmentManager.addGlobalSegmentInfo(segment)
|
||||||
|
}
|
||||||
|
|
||||||
|
// start queryCollection, message stream need to asConsumer before start
|
||||||
|
sc.start()
|
||||||
|
log.Debug("start query collection", zap.Any("collectionID", collectionID))
|
||||||
|
|
||||||
|
log.Debug("addQueryChannelTask done",
|
||||||
|
zap.Any("collectionID", r.req.CollectionID),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *addQueryChannelTask) PostExecute(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// watchDmChannelsTask
|
// watchDmChannelsTask
|
||||||
func (w *watchDmChannelsTask) Timestamp() Timestamp {
|
func (w *watchDmChannelsTask) Timestamp() Timestamp {
|
||||||
if w.req.Base == nil {
|
if w.req.Base == nil {
|
||||||
|
|
|
@ -47,7 +47,7 @@ type baseTaskQueue struct {
|
||||||
scheduler *taskScheduler
|
scheduler *taskScheduler
|
||||||
}
|
}
|
||||||
|
|
||||||
type loadAndReleaseTaskQueue struct {
|
type queryNodeTaskQueue struct {
|
||||||
baseTaskQueue
|
baseTaskQueue
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
@ -148,15 +148,15 @@ func (queue *baseTaskQueue) Enqueue(t task) error {
|
||||||
return queue.addUnissuedTask(t)
|
return queue.addUnissuedTask(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadAndReleaseTaskQueue
|
// queryNodeTaskQueue
|
||||||
func (queue *loadAndReleaseTaskQueue) Enqueue(t task) error {
|
func (queue *queryNodeTaskQueue) Enqueue(t task) error {
|
||||||
queue.mu.Lock()
|
queue.mu.Lock()
|
||||||
defer queue.mu.Unlock()
|
defer queue.mu.Unlock()
|
||||||
return queue.baseTaskQueue.Enqueue(t)
|
return queue.baseTaskQueue.Enqueue(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLoadAndReleaseTaskQueue(scheduler *taskScheduler) *loadAndReleaseTaskQueue {
|
func newQueryNodeTaskQueue(scheduler *taskScheduler) *queryNodeTaskQueue {
|
||||||
return &loadAndReleaseTaskQueue{
|
return &queryNodeTaskQueue{
|
||||||
baseTaskQueue: baseTaskQueue{
|
baseTaskQueue: baseTaskQueue{
|
||||||
unissuedTasks: list.New(),
|
unissuedTasks: list.New(),
|
||||||
activeTasks: make(map[UniqueID]task),
|
activeTasks: make(map[UniqueID]task),
|
||||||
|
|
|
@ -29,7 +29,7 @@ func TestBaseTaskQueue_addUnissuedTask(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("test full", func(t *testing.T) {
|
t.Run("test full", func(t *testing.T) {
|
||||||
taskQueue := newLoadAndReleaseTaskQueue(s)
|
taskQueue := newQueryNodeTaskQueue(s)
|
||||||
task := &mockTask{}
|
task := &mockTask{}
|
||||||
for i := 0; i < maxTaskNum; i++ {
|
for i := 0; i < maxTaskNum; i++ {
|
||||||
err := taskQueue.addUnissuedTask(task)
|
err := taskQueue.addUnissuedTask(task)
|
||||||
|
@ -40,7 +40,7 @@ func TestBaseTaskQueue_addUnissuedTask(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add task to front", func(t *testing.T) {
|
t.Run("add task to front", func(t *testing.T) {
|
||||||
taskQueue := newLoadAndReleaseTaskQueue(s)
|
taskQueue := newQueryNodeTaskQueue(s)
|
||||||
mt := &mockTask{
|
mt := &mockTask{
|
||||||
timestamp: 1000,
|
timestamp: 1000,
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ func newTaskScheduler(ctx context.Context) *taskScheduler {
|
||||||
ctx: ctx1,
|
ctx: ctx1,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
s.queue = newLoadAndReleaseTaskQueue(s)
|
s.queue = newQueryNodeTaskQueue(s)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ func (s *taskScheduler) processTask(t task, q taskQueue) {
|
||||||
err = t.PostExecute(s.ctx)
|
err = t.PostExecute(s.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *taskScheduler) loadAndReleaseLoop() {
|
func (s *taskScheduler) taskLoop() {
|
||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -78,7 +78,7 @@ func (s *taskScheduler) loadAndReleaseLoop() {
|
||||||
|
|
||||||
func (s *taskScheduler) Start() {
|
func (s *taskScheduler) Start() {
|
||||||
s.wg.Add(1)
|
s.wg.Add(1)
|
||||||
go s.loadAndReleaseLoop()
|
go s.taskLoop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *taskScheduler) Close() {
|
func (s *taskScheduler) Close() {
|
||||||
|
|
|
@ -25,6 +25,191 @@ import (
|
||||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestTask_AddQueryChannel(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
genAddQueryChanelRequest := func() *querypb.AddQueryChannelRequest {
|
||||||
|
return &querypb.AddQueryChannelRequest{
|
||||||
|
Base: genCommonMsgBase(commonpb.MsgType_LoadCollection),
|
||||||
|
NodeID: 0,
|
||||||
|
CollectionID: defaultCollectionID,
|
||||||
|
RequestChannelID: genQueryChannel(),
|
||||||
|
ResultChannelID: genQueryResultChannel(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("test timestamp", func(t *testing.T) {
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
}
|
||||||
|
timestamp := Timestamp(1000)
|
||||||
|
task.req.Base.Timestamp = timestamp
|
||||||
|
resT := task.Timestamp()
|
||||||
|
assert.Equal(t, timestamp, resT)
|
||||||
|
task.req.Base = nil
|
||||||
|
resT = task.Timestamp()
|
||||||
|
assert.Equal(t, Timestamp(0), resT)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test OnEnqueue", func(t *testing.T) {
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
}
|
||||||
|
err := task.OnEnqueue()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
task.req.Base = nil
|
||||||
|
err = task.OnEnqueue()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test execute", func(t *testing.T) {
|
||||||
|
node, err := genSimpleQueryNode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = task.Execute(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test execute has queryCollection", func(t *testing.T) {
|
||||||
|
node, err := genSimpleQueryNode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = node.queryService.addQueryCollection(defaultCollectionID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = task.Execute(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test execute nil query service", func(t *testing.T) {
|
||||||
|
node, err := genSimpleQueryNode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
node.queryService = nil
|
||||||
|
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = task.Execute(ctx)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test execute add query collection failed", func(t *testing.T) {
|
||||||
|
node, err := genSimpleQueryNode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = node.streaming.replica.removeCollection(defaultCollectionID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = node.historical.replica.removeCollection(defaultCollectionID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = task.Execute(ctx)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test execute init global sealed segments", func(t *testing.T) {
|
||||||
|
node, err := genSimpleQueryNode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
task.req.GlobalSealedSegments = []*querypb.SegmentInfo{{
|
||||||
|
SegmentID: defaultSegmentID,
|
||||||
|
CollectionID: defaultCollectionID,
|
||||||
|
PartitionID: defaultPartitionID,
|
||||||
|
}}
|
||||||
|
|
||||||
|
err = task.Execute(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test execute not init global sealed segments", func(t *testing.T) {
|
||||||
|
node, err := genSimpleQueryNode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
task.req.GlobalSealedSegments = []*querypb.SegmentInfo{{
|
||||||
|
SegmentID: defaultSegmentID,
|
||||||
|
CollectionID: 1000,
|
||||||
|
PartitionID: defaultPartitionID,
|
||||||
|
}}
|
||||||
|
|
||||||
|
err = task.Execute(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test execute seek error", func(t *testing.T) {
|
||||||
|
node, err := genSimpleQueryNode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
position := &internalpb.MsgPosition{
|
||||||
|
ChannelName: genQueryChannel(),
|
||||||
|
MsgID: []byte{1, 2, 3},
|
||||||
|
MsgGroup: defaultSubName,
|
||||||
|
Timestamp: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
task.req.SeekPosition = position
|
||||||
|
|
||||||
|
err = task.Execute(ctx)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test execute skipQueryChannelRecovery", func(t *testing.T) {
|
||||||
|
node, err := genSimpleQueryNode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
position := &internalpb.MsgPosition{
|
||||||
|
ChannelName: genQueryChannel(),
|
||||||
|
MsgID: []byte{1, 2, 3},
|
||||||
|
MsgGroup: defaultSubName,
|
||||||
|
Timestamp: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
task := addQueryChannelTask{
|
||||||
|
req: genAddQueryChanelRequest(),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
task.req.SeekPosition = position
|
||||||
|
|
||||||
|
Params.skipQueryChannelRecovery = true
|
||||||
|
|
||||||
|
err = task.Execute(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestTask_watchDmChannelsTask(t *testing.T) {
|
func TestTask_watchDmChannelsTask(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
Loading…
Reference in New Issue