mirror of https://github.com/milvus-io/milvus.git
Fix querynode panics when watch/unsub runs concurrently (#20606)
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com> Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/20572/head
parent
d8f8296b03
commit
ac9a993a39
|
@ -26,6 +26,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/samber/lo"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
@ -34,6 +35,7 @@ import (
|
||||||
"github.com/milvus-io/milvus/internal/common"
|
"github.com/milvus-io/milvus/internal/common"
|
||||||
"github.com/milvus-io/milvus/internal/log"
|
"github.com/milvus-io/milvus/internal/log"
|
||||||
"github.com/milvus-io/milvus/internal/metrics"
|
"github.com/milvus-io/milvus/internal/metrics"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||||
|
@ -303,6 +305,14 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
||||||
return status, nil
|
return status, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log := log.With(
|
||||||
|
zap.Int64("collectionID", in.GetCollectionID()),
|
||||||
|
zap.Int64("nodeID", paramtable.GetNodeID()),
|
||||||
|
zap.Strings("channels", lo.Map(in.GetInfos(), func(info *datapb.VchannelInfo, _ int) string {
|
||||||
|
return info.GetChannelName()
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
|
||||||
task := &watchDmChannelsTask{
|
task := &watchDmChannelsTask{
|
||||||
baseTask: baseTask{
|
baseTask: baseTask{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -313,13 +323,10 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
||||||
}
|
}
|
||||||
|
|
||||||
startTs := time.Now()
|
startTs := time.Now()
|
||||||
log.Info("watchDmChannels init", zap.Int64("collectionID", in.CollectionID),
|
log.Info("watchDmChannels init")
|
||||||
zap.String("channelName", in.Infos[0].GetChannelName()),
|
|
||||||
zap.Int64("nodeID", paramtable.GetNodeID()))
|
|
||||||
// currently we only support load one channel as a time
|
// currently we only support load one channel as a time
|
||||||
future := node.taskPool.Submit(func() (interface{}, error) {
|
future := node.taskPool.Submit(func() (interface{}, error) {
|
||||||
log.Info("watchDmChannels start ", zap.Int64("collectionID", in.CollectionID),
|
log.Info("watchDmChannels start ",
|
||||||
zap.String("channelName", in.Infos[0].GetChannelName()),
|
|
||||||
zap.Duration("timeInQueue", time.Since(startTs)))
|
zap.Duration("timeInQueue", time.Since(startTs)))
|
||||||
err := task.PreExecute(ctx)
|
err := task.PreExecute(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -337,7 +344,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
Reason: err.Error(),
|
Reason: err.Error(),
|
||||||
}
|
}
|
||||||
log.Warn("failed to subscribe channel ", zap.Error(err))
|
log.Warn("failed to subscribe channel", zap.Error(err))
|
||||||
return status, nil
|
return status, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -351,10 +358,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
||||||
return status, nil
|
return status, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sc, _ := node.ShardClusterService.getShardCluster(in.Infos[0].GetChannelName())
|
log.Info("successfully watchDmChannelsTask")
|
||||||
sc.SetupFirstVersion()
|
|
||||||
log.Info("successfully watchDmChannelsTask", zap.Int64("collectionID", in.CollectionID),
|
|
||||||
zap.String("channelName", in.Infos[0].GetChannelName()), zap.Int64("nodeID", paramtable.GetNodeID()))
|
|
||||||
return &commonpb.Status{
|
return &commonpb.Status{
|
||||||
ErrorCode: commonpb.ErrorCode_Success,
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
}, nil
|
}, nil
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
|
@ -137,10 +138,47 @@ func TestImpl_WatchDmChannels(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||||
|
defer node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||||
status, err := node.WatchDmChannels(ctx, req)
|
status, err := node.WatchDmChannels(ctx, req)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("mock release after loaded", func(t *testing.T) {
|
||||||
|
|
||||||
|
mockTSReplica := &MockTSafeReplicaInterface{}
|
||||||
|
|
||||||
|
oldTSReplica := node.tSafeReplica
|
||||||
|
defer func() {
|
||||||
|
node.tSafeReplica = oldTSReplica
|
||||||
|
}()
|
||||||
|
node.tSafeReplica = mockTSReplica
|
||||||
|
mockTSReplica.On("addTSafe", mock.Anything).Run(func(_ mock.Arguments) {
|
||||||
|
node.ShardClusterService.releaseShardCluster("1001-dmc0")
|
||||||
|
})
|
||||||
|
schema := genTestCollectionSchema()
|
||||||
|
req := &queryPb.WatchDmChannelsRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_WatchDmChannels,
|
||||||
|
MsgID: rand.Int63(),
|
||||||
|
TargetID: node.session.ServerID,
|
||||||
|
},
|
||||||
|
CollectionID: defaultCollectionID,
|
||||||
|
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||||
|
Schema: schema,
|
||||||
|
Infos: []*datapb.VchannelInfo{
|
||||||
|
{
|
||||||
|
CollectionID: 1001,
|
||||||
|
ChannelName: "1001-dmc0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := node.WatchDmChannels(ctx, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||||
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestImpl_UnsubDmChannel(t *testing.T) {
|
func TestImpl_UnsubDmChannel(t *testing.T) {
|
||||||
|
|
|
@ -65,16 +65,19 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
||||||
VPChannels[v] = p
|
VPChannels[v] = p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log := log.With(
|
||||||
|
zap.Int64("collectionID", w.req.GetCollectionID()),
|
||||||
|
zap.Strings("vChannels", vChannels),
|
||||||
|
zap.Int64("replicaID", w.req.GetReplicaID()),
|
||||||
|
)
|
||||||
|
|
||||||
if len(VPChannels) != len(vChannels) {
|
if len(VPChannels) != len(vChannels) {
|
||||||
return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID))
|
return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Starting WatchDmChannels ...",
|
log.Info("Starting WatchDmChannels ...",
|
||||||
zap.String("collectionName", w.req.Schema.Name),
|
zap.String("loadType", lType.String()),
|
||||||
zap.Int64("collectionID", collectionID),
|
zap.String("collectionName", w.req.GetSchema().GetName()),
|
||||||
zap.Int64("replicaID", w.req.GetReplicaID()),
|
|
||||||
zap.String("load type", lType.String()),
|
|
||||||
zap.Strings("vChannels", vChannels),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// init collection meta
|
// init collection meta
|
||||||
|
@ -126,7 +129,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
||||||
|
|
||||||
coll.setLoadType(lType)
|
coll.setLoadType(lType)
|
||||||
|
|
||||||
log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
log.Info("watchDMChannel, init replica done")
|
||||||
|
|
||||||
// create tSafe
|
// create tSafe
|
||||||
for _, channel := range vChannels {
|
for _, channel := range vChannels {
|
||||||
|
@ -143,7 +146,30 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
||||||
fg.flowGraph.Start()
|
fg.flowGraph.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
log.Info("WatchDmChannels done")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostExecute setup ShardCluster first version and without do gc if failed.
|
||||||
|
func (w *watchDmChannelsTask) PostExecute(ctx context.Context) error {
|
||||||
|
// setup shard cluster version
|
||||||
|
var releasedChannels []string
|
||||||
|
for _, info := range w.req.GetInfos() {
|
||||||
|
sc, ok := w.node.ShardClusterService.getShardCluster(info.GetChannelName())
|
||||||
|
// shard cluster may be released by a release task
|
||||||
|
if !ok {
|
||||||
|
releasedChannels = append(releasedChannels, info.GetChannelName())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sc.SetupFirstVersion()
|
||||||
|
}
|
||||||
|
if len(releasedChannels) > 0 {
|
||||||
|
// no clean up needed, release shall do the job
|
||||||
|
log.Warn("WatchDmChannels failed, shard cluster may be released",
|
||||||
|
zap.Strings("releasedChannels", releasedChannels),
|
||||||
|
)
|
||||||
|
return fmt.Errorf("failed to watch %v, shard cluster may be released", releasedChannels)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue