Fix querynode panics when watch/unsub runs concurrently (#20606)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/20572/head
congqixia 2022-11-15 19:03:08 +08:00 committed by GitHub
parent d8f8296b03
commit ac9a993a39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 17 deletions

View File

@ -26,6 +26,7 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -34,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/metricsinfo"
@ -303,6 +305,14 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
return status, nil return status, nil
} }
log := log.With(
zap.Int64("collectionID", in.GetCollectionID()),
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Strings("channels", lo.Map(in.GetInfos(), func(info *datapb.VchannelInfo, _ int) string {
return info.GetChannelName()
})),
)
task := &watchDmChannelsTask{ task := &watchDmChannelsTask{
baseTask: baseTask{ baseTask: baseTask{
ctx: ctx, ctx: ctx,
@ -313,13 +323,10 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
} }
startTs := time.Now() startTs := time.Now()
log.Info("watchDmChannels init", zap.Int64("collectionID", in.CollectionID), log.Info("watchDmChannels init")
zap.String("channelName", in.Infos[0].GetChannelName()),
zap.Int64("nodeID", paramtable.GetNodeID()))
// currently we only support load one channel as a time // currently we only support load one channel as a time
future := node.taskPool.Submit(func() (interface{}, error) { future := node.taskPool.Submit(func() (interface{}, error) {
log.Info("watchDmChannels start ", zap.Int64("collectionID", in.CollectionID), log.Info("watchDmChannels start ",
zap.String("channelName", in.Infos[0].GetChannelName()),
zap.Duration("timeInQueue", time.Since(startTs))) zap.Duration("timeInQueue", time.Since(startTs)))
err := task.PreExecute(ctx) err := task.PreExecute(ctx)
if err != nil { if err != nil {
@ -337,7 +344,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn("failed to subscribe channel ", zap.Error(err)) log.Warn("failed to subscribe channel", zap.Error(err))
return status, nil return status, nil
} }
@ -351,10 +358,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
return status, nil return status, nil
} }
sc, _ := node.ShardClusterService.getShardCluster(in.Infos[0].GetChannelName()) log.Info("successfully watchDmChannelsTask")
sc.SetupFirstVersion()
log.Info("successfully watchDmChannelsTask", zap.Int64("collectionID", in.CollectionID),
zap.String("channelName", in.Infos[0].GetChannelName()), zap.Int64("nodeID", paramtable.GetNodeID()))
return &commonpb.Status{ return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
}, nil }, nil

View File

@ -25,6 +25,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
@ -137,10 +138,47 @@ func TestImpl_WatchDmChannels(t *testing.T) {
}, },
} }
node.UpdateStateCode(commonpb.StateCode_Abnormal) node.UpdateStateCode(commonpb.StateCode_Abnormal)
defer node.UpdateStateCode(commonpb.StateCode_Healthy)
status, err := node.WatchDmChannels(ctx, req) status, err := node.WatchDmChannels(ctx, req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
}) })
t.Run("mock release after loaded", func(t *testing.T) {
mockTSReplica := &MockTSafeReplicaInterface{}
oldTSReplica := node.tSafeReplica
defer func() {
node.tSafeReplica = oldTSReplica
}()
node.tSafeReplica = mockTSReplica
mockTSReplica.On("addTSafe", mock.Anything).Run(func(_ mock.Arguments) {
node.ShardClusterService.releaseShardCluster("1001-dmc0")
})
schema := genTestCollectionSchema()
req := &queryPb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(),
TargetID: node.session.ServerID,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
Schema: schema,
Infos: []*datapb.VchannelInfo{
{
CollectionID: 1001,
ChannelName: "1001-dmc0",
},
},
}
status, err := node.WatchDmChannels(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
} }
func TestImpl_UnsubDmChannel(t *testing.T) { func TestImpl_UnsubDmChannel(t *testing.T) {

View File

@ -65,16 +65,19 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
VPChannels[v] = p VPChannels[v] = p
} }
log := log.With(
zap.Int64("collectionID", w.req.GetCollectionID()),
zap.Strings("vChannels", vChannels),
zap.Int64("replicaID", w.req.GetReplicaID()),
)
if len(VPChannels) != len(vChannels) { if len(VPChannels) != len(vChannels) {
return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID)) return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID))
} }
log.Info("Starting WatchDmChannels ...", log.Info("Starting WatchDmChannels ...",
zap.String("collectionName", w.req.Schema.Name), zap.String("loadType", lType.String()),
zap.Int64("collectionID", collectionID), zap.String("collectionName", w.req.GetSchema().GetName()),
zap.Int64("replicaID", w.req.GetReplicaID()),
zap.String("load type", lType.String()),
zap.Strings("vChannels", vChannels),
) )
// init collection meta // init collection meta
@ -126,7 +129,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
coll.setLoadType(lType) coll.setLoadType(lType)
log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) log.Info("watchDMChannel, init replica done")
// create tSafe // create tSafe
for _, channel := range vChannels { for _, channel := range vChannels {
@ -143,7 +146,30 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
fg.flowGraph.Start() fg.flowGraph.Start()
} }
log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) log.Info("WatchDmChannels done")
return nil
}
// PostExecute setup ShardCluster first version and without do gc if failed.
func (w *watchDmChannelsTask) PostExecute(ctx context.Context) error {
// setup shard cluster version
var releasedChannels []string
for _, info := range w.req.GetInfos() {
sc, ok := w.node.ShardClusterService.getShardCluster(info.GetChannelName())
// shard cluster may be released by a release task
if !ok {
releasedChannels = append(releasedChannels, info.GetChannelName())
continue
}
sc.SetupFirstVersion()
}
if len(releasedChannels) > 0 {
// no clean up needed, release shall do the job
log.Warn("WatchDmChannels failed, shard cluster may be released",
zap.Strings("releasedChannels", releasedChannels),
)
return fmt.Errorf("failed to watch %v, shard cluster may be released", releasedChannels)
}
return nil return nil
} }