feat: support to replicate collection when the services contains the system tt msg (#37559)

- issue: #37105

---------

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/37836/head
SimFG 2024-12-17 09:08:46 +08:00 committed by GitHub
parent 5d014c76c7
commit 2afe2eaf3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 2383 additions and 263 deletions

View File

@ -873,6 +873,7 @@ common:
bloomFilterType: BlockedBloomFilter # bloom filter type, support BasicBloomFilter and BlockedBloomFilter
maxBloomFalsePositive: 0.001 # max false positive rate for bloom filter
bloomFilterApplyBatchSize: 1000 # batch size when to apply pk to bloom filter
collectionReplicateEnable: false # Whether to enable collection replication.
usePartitionKeyAsClusteringKey: false # if true, do clustering compaction and segment prune on partition key field
useVectorAsClusteringKey: false # if true, do clustering compaction and segment prune on vector field
enableVectorClusteringKey: false # if true, enable vector clustering key and vector clustering compaction

View File

@ -36,6 +36,7 @@ type ROChannel interface {
GetSchema() *schemapb.CollectionSchema
GetCreateTimestamp() Timestamp
GetWatchInfo() *datapb.ChannelWatchInfo
GetDBProperties() []*commonpb.KeyValuePair
}
type RWChannel interface {
@ -48,6 +49,7 @@ func NewRWChannel(name string,
startPos []*commonpb.KeyDataPair,
schema *schemapb.CollectionSchema,
createTs uint64,
dbProperties []*commonpb.KeyValuePair,
) RWChannel {
return &StateChannel{
Name: name,
@ -55,9 +57,11 @@ func NewRWChannel(name string,
StartPositions: startPos,
Schema: schema,
CreateTimestamp: createTs,
DBProperties: dbProperties,
}
}
// TODO fubang same as StateChannel
type channelMeta struct {
Name string
CollectionID UniqueID
@ -109,6 +113,10 @@ func (ch *channelMeta) String() string {
return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions)
}
func (ch *channelMeta) GetDBProperties() []*commonpb.KeyValuePair {
return nil
}
type ChannelState string
const (
@ -126,6 +134,7 @@ type StateChannel struct {
CollectionID UniqueID
StartPositions []*commonpb.KeyDataPair
Schema *schemapb.CollectionSchema
DBProperties []*commonpb.KeyValuePair
CreateTimestamp uint64
Info *datapb.ChannelWatchInfo
@ -143,6 +152,7 @@ func NewStateChannel(ch RWChannel) *StateChannel {
Schema: ch.GetSchema(),
CreateTimestamp: ch.GetCreateTimestamp(),
Info: ch.GetWatchInfo(),
DBProperties: ch.GetDBProperties(),
assignedNode: bufferID,
}
@ -156,6 +166,7 @@ func NewStateChannelByWatchInfo(nodeID int64, info *datapb.ChannelWatchInfo) *St
Name: info.GetVchan().GetChannelName(),
CollectionID: info.GetVchan().GetCollectionID(),
Schema: info.GetSchema(),
DBProperties: info.GetDbProperties(),
Info: info,
assignedNode: nodeID,
}
@ -277,3 +288,7 @@ func (c *StateChannel) Assign(nodeID int64) {
func (c *StateChannel) setState(state ChannelState) {
c.currentState = state
}
func (c *StateChannel) GetDBProperties() []*commonpb.KeyValuePair {
return c.DBProperties
}

View File

@ -736,20 +736,22 @@ func (m *ChannelManagerImpl) fillChannelWatchInfo(op *ChannelOp) error {
schema := ch.GetSchema()
if schema == nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
collInfo, err := m.h.GetCollection(ctx, ch.GetCollectionID())
if err != nil {
cancel()
return err
}
cancel()
schema = collInfo.Schema
}
info := &datapb.ChannelWatchInfo{
Vchan: reduceVChanSize(vcInfo),
StartTs: startTs,
State: inferStateByOpType(op.Type),
Schema: schema,
OpID: opID,
Vchan: reduceVChanSize(vcInfo),
StartTs: startTs,
State: inferStateByOpType(op.Type),
Schema: schema,
OpID: opID,
DbProperties: ch.GetDBProperties(),
}
ch.UpdateWatchInfo(info)
}

View File

@ -138,6 +138,12 @@ type collectionInfo struct {
VChannelNames []string
}
type dbInfo struct {
ID int64
Name string
Properties []*commonpb.KeyValuePair
}
// NewMeta creates meta from provided `kv.TxnKV`
func newMeta(ctx context.Context, catalog metastore.DataCoordCatalog, chunkManager storage.ChunkManager) (*meta, error) {
im, err := newIndexMeta(ctx, catalog)
@ -244,12 +250,12 @@ func (m *meta) reloadCollectionsFromRootcoord(ctx context.Context, broker broker
return err
}
for _, dbName := range resp.GetDbNames() {
resp, err := broker.ShowCollections(ctx, dbName)
collectionsResp, err := broker.ShowCollections(ctx, dbName)
if err != nil {
return err
}
for _, collectionID := range resp.GetCollectionIds() {
resp, err := broker.DescribeCollectionInternal(ctx, collectionID)
for _, collectionID := range collectionsResp.GetCollectionIds() {
descResp, err := broker.DescribeCollectionInternal(ctx, collectionID)
if err != nil {
return err
}
@ -259,14 +265,14 @@ func (m *meta) reloadCollectionsFromRootcoord(ctx context.Context, broker broker
}
collection := &collectionInfo{
ID: collectionID,
Schema: resp.GetSchema(),
Schema: descResp.GetSchema(),
Partitions: partitionIDs,
StartPositions: resp.GetStartPositions(),
Properties: funcutil.KeyValuePair2Map(resp.GetProperties()),
CreatedAt: resp.GetCreatedTimestamp(),
DatabaseName: resp.GetDbName(),
DatabaseID: resp.GetDbId(),
VChannelNames: resp.GetVirtualChannelNames(),
StartPositions: descResp.GetStartPositions(),
Properties: funcutil.KeyValuePair2Map(descResp.GetProperties()),
CreatedAt: descResp.GetCreatedTimestamp(),
DatabaseName: descResp.GetDbName(),
DatabaseID: descResp.GetDbId(),
VChannelNames: descResp.GetVirtualChannelNames(),
}
m.AddCollection(collection)
}

View File

@ -951,7 +951,7 @@ func (s *Server) GetChannelRecoveryInfo(ctx context.Context, req *datapb.GetChan
return resp, nil
}
channel := NewRWChannel(req.GetVchannel(), collectionID, nil, collection.Schema, 0) // TODO: remove RWChannel, just use vchannel + collectionID
channel := NewRWChannel(req.GetVchannel(), collectionID, nil, collection.Schema, 0, nil) // TODO: remove RWChannel, just use vchannel + collectionID
channelInfo := s.handler.GetDataVChanPositions(channel, allPartitionID)
if channelInfo.SeekPosition == nil {
log.Warn("channel recovery start position is not found, may collection is on creating")
@ -1230,6 +1230,7 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Strings("channels", req.GetChannelNames()),
zap.Any("dbProperties", req.GetDbProperties()),
)
log.Info("receive watch channels request")
resp := &datapb.WatchChannelsResponse{
@ -1242,7 +1243,7 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq
}, nil
}
for _, channelName := range req.GetChannelNames() {
ch := NewRWChannel(channelName, req.GetCollectionID(), req.GetStartPositions(), req.GetSchema(), req.GetCreateTimestamp())
ch := NewRWChannel(channelName, req.GetCollectionID(), req.GetStartPositions(), req.GetSchema(), req.GetCreateTimestamp(), req.GetDbProperties())
err := s.channelManager.Watch(ctx, ch)
if err != nil {
log.Warn("fail to watch channelName", zap.Error(err))
@ -1562,6 +1563,7 @@ func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.Alt
StartPositions: req.GetStartPositions(),
Properties: properties,
DatabaseID: req.GetDbID(),
DatabaseName: req.GetSchema().GetDbName(),
VChannelNames: req.GetVChannels(),
}
s.meta.AddCollection(collInfo)

View File

@ -70,7 +70,7 @@ func (s *OpRunnerSuite) SetupTest() {
Return(nil).Maybe()
dispClient := msgdispatcher.NewMockClient(s.T())
dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
dispClient.EXPECT().Register(mock.Anything, mock.Anything).
Return(make(chan *msgstream.MsgPack), nil).Maybe()
dispClient.EXPECT().Deregister(mock.Anything).Maybe()

View File

@ -350,7 +350,13 @@ func NewDataSyncService(initCtx context.Context, pipelineParams *util.PipelinePa
return nil, err
}
input, err := createNewInputFromDispatcher(initCtx, pipelineParams.DispClient, info.GetVchan().GetChannelName(), info.GetVchan().GetSeekPosition())
input, err := createNewInputFromDispatcher(initCtx,
pipelineParams.DispClient,
info.GetVchan().GetChannelName(),
info.GetVchan().GetSeekPosition(),
info.GetSchema(),
info.GetDbProperties(),
)
if err != nil {
return nil, err
}

View File

@ -23,8 +23,11 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/flowgraph"
pkgcommon "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/mq/common"
@ -57,11 +60,29 @@ func newDmInputNode(dmNodeConfig *nodeConfig, input <-chan *msgstream.MsgPack) *
return node
}
func createNewInputFromDispatcher(initCtx context.Context, dispatcherClient msgdispatcher.Client, vchannel string, seekPos *msgpb.MsgPosition) (<-chan *msgstream.MsgPack, error) {
func createNewInputFromDispatcher(initCtx context.Context,
dispatcherClient msgdispatcher.Client,
vchannel string,
seekPos *msgpb.MsgPosition,
schema *schemapb.CollectionSchema,
dbProperties []*commonpb.KeyValuePair,
) (<-chan *msgstream.MsgPack, error) {
log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()),
zap.String("vchannel", vchannel))
replicateID, _ := pkgcommon.GetReplicateID(schema.GetProperties())
if replicateID == "" {
log.Info("datanode consume without replicateID, try to get replicateID from dbProperties", zap.Any("dbProperties", dbProperties))
replicateID, _ = pkgcommon.GetReplicateID(dbProperties)
}
replicateConfig := msgstream.GetReplicateConfig(replicateID, schema.GetDbName(), schema.GetName())
if seekPos != nil && len(seekPos.MsgID) != 0 {
input, err := dispatcherClient.Register(initCtx, vchannel, seekPos, common.SubscriptionPositionUnknown)
input, err := dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{
VChannel: vchannel,
Pos: seekPos,
SubPos: common.SubscriptionPositionUnknown,
ReplicateConfig: replicateConfig,
})
if err != nil {
return nil, err
}
@ -71,7 +92,12 @@ func createNewInputFromDispatcher(initCtx context.Context, dispatcherClient msgd
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp()))))
return input, err
}
input, err := dispatcherClient.Register(initCtx, vchannel, nil, common.SubscriptionPositionEarliest)
input, err := dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{
VChannel: vchannel,
Pos: nil,
SubPos: common.SubscriptionPositionEarliest,
ReplicateConfig: replicateConfig,
})
if err != nil {
return nil, err
}

View File

@ -62,6 +62,9 @@ func (mm *mockMsgStreamFactory) NewMsgStreamDisposer(ctx context.Context) func([
type mockTtMsgStream struct{}
func (mtm *mockTtMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
}
func (mtm *mockTtMsgStream) Close() {}
func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {

View File

@ -65,7 +65,7 @@ func TestFlowGraphManager(t *testing.T) {
wbm.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
dispClient := msgdispatcher.NewMockClient(t)
dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil)
dispClient.EXPECT().Register(mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil)
dispClient.EXPECT().Deregister(mock.Anything)
pipelineParams := &util.PipelineParams{
@ -151,7 +151,7 @@ func newFlowGraphManager(t *testing.T) (string, FlowgraphManager) {
wbm.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
dispClient := msgdispatcher.NewMockClient(t)
dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil)
dispClient.EXPECT().Register(mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil)
pipelineParams := &util.PipelineParams{
Ctx: context.TODO(),

View File

@ -15,6 +15,7 @@ type Collection struct {
CollectionID int64
Partitions []*Partition
Name string
DBName string
Description string
AutoID bool
Fields []*Field
@ -41,6 +42,7 @@ func (c *Collection) Clone() *Collection {
DBID: c.DBID,
CollectionID: c.CollectionID,
Name: c.Name,
DBName: c.DBName,
Description: c.Description,
AutoID: c.AutoID,
Fields: CloneFields(c.Fields),
@ -99,6 +101,7 @@ func UnmarshalCollectionModel(coll *pb.CollectionInfo) *Collection {
CollectionID: coll.ID,
DBID: coll.DbId,
Name: coll.Schema.Name,
DBName: coll.Schema.DbName,
Description: coll.Schema.Description,
AutoID: coll.Schema.AutoID,
Fields: UnmarshalFieldModels(coll.GetSchema().GetFields()),
@ -154,6 +157,7 @@ func marshalCollectionModelWithConfig(coll *Collection, c *config) *pb.Collectio
Description: coll.Description,
AutoID: coll.AutoID,
EnableDynamicField: coll.EnableDynamicField,
DbName: coll.DBName,
}
if c.withFields {

View File

@ -538,6 +538,7 @@ message ChannelWatchInfo {
// watch progress, deprecated
int32 progress = 6;
int64 opID = 7;
repeated common.KeyValuePair dbProperties = 8;
}
enum CompactionType {
@ -655,6 +656,7 @@ message WatchChannelsRequest {
repeated common.KeyDataPair start_positions = 3;
schema.CollectionSchema schema = 4;
uint64 create_timestamp = 5;
repeated common.KeyValuePair db_properties = 6;
}
message WatchChannelsResponse {

View File

@ -323,6 +323,7 @@ message LoadMetaInfo {
string db_name = 5; // Only used for metrics label.
string resource_group = 6; // Only used for metrics label.
repeated int64 load_fields = 7;
repeated common.KeyValuePair db_properties = 8;
}
message WatchDmChannelsRequest {

View File

@ -146,6 +146,7 @@ service RootCoord {
message AllocTimestampRequest {
common.MsgBase base = 1;
uint32 count = 3;
uint64 blockTimestamp = 4;
}
message AllocTimestampResponse {

View File

@ -160,7 +160,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
}
globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName)
log.Info("complete to invalidate collection meta cache", zap.String("type", request.GetBase().GetMsgType().String()))
case commonpb.MsgType_DropDatabase:
case commonpb.MsgType_DropDatabase, commonpb.MsgType_AlterDatabase:
globalMetaCache.RemoveDatabase(ctx, request.GetDbName())
case commonpb.MsgType_AlterCollection, commonpb.MsgType_AlterCollectionField:
if request.CollectionID != UniqueID(0) {
@ -6325,13 +6325,19 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
if paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() {
collectionReplicateEnable := paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool()
ttMsgEnabled := paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool()
// replicate message can be use in two ways, otherwise return error
// 1. collectionReplicateEnable is false and ttMsgEnabled is false, active/standby mode
// 2. collectionReplicateEnable is true and ttMsgEnabled is true, data migration mode
if (!collectionReplicateEnable && ttMsgEnabled) || (collectionReplicateEnable && !ttMsgEnabled) {
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.ErrDenyReplicateMessage),
}, nil
}
var err error
var err error
if req.GetChannelName() == "" {
log.Ctx(ctx).Warn("channel name is empty")
return &milvuspb.ReplicateMessageResponse{
@ -6369,6 +6375,18 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
StartPositions: req.StartPositions,
EndPositions: req.EndPositions,
}
checkCollectionReplicateProperty := func(dbName, collectionName string) bool {
if !collectionReplicateEnable {
return true
}
replicateID, err := GetReplicateID(ctx, dbName, collectionName)
if err != nil {
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
return false
}
return replicateID != ""
}
// getTsMsgFromConsumerMsg
for i, msgBytes := range req.Msgs {
header := commonpb.MsgHeader{}
@ -6388,6 +6406,9 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
}
switch realMsg := tsMsg.(type) {
case *msgstream.InsertMsg:
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
}
assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(),
realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs)
if err != nil {
@ -6402,6 +6423,10 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
realMsg.SegmentID = assignSegmentID
break
}
case *msgstream.DeleteMsg:
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
}
}
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
}

View File

@ -1279,6 +1279,9 @@ func TestProxy_Delete(t *testing.T) {
},
}
schema := newSchemaInfo(collSchema)
basicInfo := &collectionInfo{
collID: collectionID,
}
paramtable.Init()
t.Run("delete run failed", func(t *testing.T) {
@ -1311,6 +1314,7 @@ func TestProxy_Delete(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(partitionID, nil)
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(basicInfo, nil)
chMgr.On("getVChannels", mock.Anything).Return(channels, nil)
chMgr.On("getChannels", mock.Anything).Return(nil, fmt.Errorf("mock error"))
globalMetaCache = cache
@ -1863,3 +1867,330 @@ func TestRegisterRestRouter(t *testing.T) {
})
}
}
func TestReplicateMessageForCollectionMode(t *testing.T) {
paramtable.Init()
ctx := context.Background()
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 10,
EndTimestamp: 10,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 10001,
Timestamp: 10,
SourceID: -1,
},
ShardName: "foo_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
SegmentID: 33,
Timestamps: []uint64{10},
RowIDs: []int64{66},
NumRows: 1,
},
}
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 20,
EndTimestamp: 20,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
DeleteRequest: &msgpb.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Delete,
MsgID: 10002,
Timestamp: 20,
SourceID: -1,
},
ShardName: "foo_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
},
}
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
t.Run("replicate message in the replicate collection mode", func(t *testing.T) {
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
{
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "false")
p := &Proxy{}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
})
assert.NoError(t, err)
assert.Error(t, merr.Error(r.Status))
}
{
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
p := &Proxy{}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
})
assert.NoError(t, err)
assert.Error(t, merr.Error(r.Status))
}
})
t.Run("replicate message for the replicate collection mode", func(t *testing.T) {
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
mockCache := NewMockCache(t)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Twice()
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Twice()
globalMetaCache = mockCache
{
p := &Proxy{
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
Msgs: [][]byte{insertMsgBytes.([]byte)},
})
assert.NoError(t, err)
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
}
{
p := &Proxy{
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
Msgs: [][]byte{deleteMsgBytes.([]byte)},
})
assert.NoError(t, err)
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
}
})
}
func TestAlterCollectionReplicateProperty(t *testing.T) {
paramtable.Init()
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
replicateID: "local-milvus",
}, nil).Maybe()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe()
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil)
globalMetaCache = mockCache
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().Close().Return().Maybe()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe()
msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"alter_property": {mockMsgID1, mockMsgID2},
}, nil).Maybe()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return msgStreamObj, nil
}
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
ctx := context.Background()
var startTt uint64 = 10
startTime := time.Now()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(1000)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, func() Timestamp {
return Timestamp(time.Since(startTime).Seconds()) + startTt
})
assert.NoError(t, err)
segAllocator.Start()
mockRootcoord := mocks.NewMockRootCoordClient(t)
mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *rootcoordpb.AllocTimestampRequest, option ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
return &rootcoordpb.AllocTimestampResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Timestamp: Timestamp(time.Since(startTime).Seconds()) + startTt,
}, nil
})
mockRootcoord.EXPECT().AlterCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil)
p := &Proxy{
ctx: ctx,
replicateStreamManager: manager,
segAssigner: segAllocator,
rootCoord: mockRootcoord,
}
tsoAllocatorIns := newMockTsoAllocator()
p.sched, err = newTaskScheduler(p.ctx, tsoAllocatorIns, p.factory)
assert.NoError(t, err)
p.sched.Start()
defer p.sched.Close()
p.UpdateStateCode(commonpb.StateCode_Healthy)
getInsertMsgBytes := func(channel string, ts uint64) []byte {
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: channel,
MsgID: []byte("mock message id 2"),
},
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 10001,
Timestamp: ts,
SourceID: -1,
},
ShardName: channel + "_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
SegmentID: 33,
Timestamps: []uint64{ts},
RowIDs: []int64{66},
NumRows: 1,
},
}
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
return insertMsgBytes.([]byte)
}
getDeleteMsgBytes := func(channel string, ts uint64) []byte {
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
DeleteRequest: &msgpb.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Delete,
MsgID: 10002,
Timestamp: ts,
SourceID: -1,
},
ShardName: channel + "_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
},
}
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
return deleteMsgBytes.([]byte)
}
go func() {
// replicate message
var replicateResp *milvuspb.ReplicateMessageResponse
var err error
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_1",
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+5)},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_2",
Msgs: [][]byte{getDeleteMsgBytes("alter_property_2", startTt+5)},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
time.Sleep(time.Second)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_1",
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+10)},
})
assert.NoError(t, err)
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_2",
Msgs: [][]byte{getInsertMsgBytes("alter_property_2", startTt+10)},
})
assert.NoError(t, err)
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
}()
time.Sleep(200 * time.Millisecond)
// alter collection property
statusResp, err := p.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: "default",
CollectionName: "foo_collection",
Properties: []*commonpb.KeyValuePair{
{
Key: "replicate.endTS",
Value: "1",
},
},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(statusResp))
}

View File

@ -102,10 +102,12 @@ type collectionInfo struct {
createdUtcTimestamp uint64
consistencyLevel commonpb.ConsistencyLevel
partitionKeyIsolation bool
replicateID string
}
type databaseInfo struct {
dbID typeutil.UniqueID
properties []*commonpb.KeyValuePair
createdTimestamp uint64
}
@ -478,6 +480,7 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string,
m.collInfo[database] = make(map[string]*collectionInfo)
}
replicateID, _ := common.GetReplicateID(collection.Properties)
m.collInfo[database][collectionName] = &collectionInfo{
collID: collection.CollectionID,
schema: schemaInfo,
@ -486,6 +489,7 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string,
createdUtcTimestamp: collection.CreatedUtcTimestamp,
consistencyLevel: collection.ConsistencyLevel,
partitionKeyIsolation: isolation,
replicateID: replicateID,
}
log.Ctx(ctx).Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName),
@ -571,10 +575,19 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll
method := "GetCollectionInfo"
// if collInfo.collID != collectionID, means that the cache is not trustable
// try to get collection according to collectionID
if !ok || collInfo.collID != collectionID {
// Why use collectionID? Because the collectionID is not always provided in the proxy.
if !ok || (collectionID != 0 && collInfo.collID != collectionID) {
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
if collectionID == 0 {
collInfo, err := m.UpdateByName(ctx, database, collectionName)
if err != nil {
return nil, err
}
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return collInfo, nil
}
collInfo, err := m.UpdateByID(ctx, database, collectionID)
if err != nil {
return nil, err
@ -1225,6 +1238,7 @@ func (m *MetaCache) GetDatabaseInfo(ctx context.Context, database string) (*data
defer m.mu.Unlock()
dbInfo := &databaseInfo{
dbID: resp.GetDbID(),
properties: resp.Properties,
createdTimestamp: resp.GetCreatedTimestamp(),
}
m.dbInfo[database] = dbInfo

View File

@ -304,6 +304,87 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
wg.Wait()
}
func TestMetaCacheGetCollectionWithUpdate(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
ctx := context.Background()
rootCoord := mocks.NewMockRootCoordClient(t)
queryCoord := mocks.NewMockQueryCoordClient(t)
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err)
t.Run("update with name", func(t *testing.T) {
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
CollectionID: 1,
Schema: &schemapb.CollectionSchema{
Name: "bar",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1,
Name: "p",
},
{
FieldID: 100,
Name: "pk",
},
},
},
ShardsNum: 1,
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"},
}, nil).Once()
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []typeutil.UniqueID{11},
PartitionNames: []string{"p1"},
CreatedTimestamps: []uint64{11},
CreatedUtcTimestamps: []uint64{11},
}, nil).Once()
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once()
c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "bar", 1)
assert.NoError(t, err)
assert.Equal(t, c.collID, int64(1))
assert.Equal(t, c.schema.Name, "bar")
})
t.Run("update with name", func(t *testing.T) {
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
CollectionID: 1,
Schema: &schemapb.CollectionSchema{
Name: "bar",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1,
Name: "p",
},
{
FieldID: 100,
Name: "pk",
},
},
},
ShardsNum: 1,
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"},
}, nil).Once()
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []typeutil.UniqueID{11},
PartitionNames: []string{"p1"},
CreatedTimestamps: []uint64{11},
CreatedUtcTimestamps: []uint64{11},
}, nil).Once()
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once()
c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "hoo", 0)
assert.NoError(t, err)
assert.Equal(t, c.collID, int64(1))
assert.Equal(t, c.schema.Name, "bar")
})
}
func TestMetaCache_GetCollectionName(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}

View File

@ -40,6 +40,9 @@ func (m *mockMsgStream) ForceEnableProduce(enabled bool) {
}
}
func (m *mockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
}
func newMockMsgStream() *mockMsgStream {
return &mockMsgStream{}
}

View File

@ -314,6 +314,9 @@ func (ms *simpleMockMsgStream) CheckTopicValid(topic string) error {
func (ms *simpleMockMsgStream) ForceEnableProduce(enabled bool) {
}
func (ms *simpleMockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
}
func newSimpleMockMsgStream() *simpleMockMsgStream {
return &simpleMockMsgStream{
msgChan: make(chan *msgstream.MsgPack, 1024),

View File

@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/ctokenizer"
"github.com/milvus-io/milvus/pkg/common"
@ -1081,6 +1082,25 @@ func (t *alterCollectionTask) PreExecute(ctx context.Context) error {
}
}
_, ok := common.IsReplicateEnabled(t.Properties)
if ok {
return merr.WrapErrParameterInvalidMsg("can't set the replicate.id property")
}
endTS, ok := common.GetReplicateEndTS(t.Properties)
if ok && collBasicInfo.replicateID != "" {
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: 1,
BlockTimestamp: endTS,
})
if err = merr.CheckRPCCall(allocResp, err); err != nil {
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
}
if allocResp.GetTimestamp() <= endTS {
return merr.WrapErrServiceInternal("alter collection: alloc timestamp failed, timestamp is not greater than endTS",
fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS))
}
}
return nil
}

View File

@ -2,6 +2,7 @@ package proxy
import (
"context"
"fmt"
"go.uber.org/zap"
@ -9,6 +10,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
@ -274,6 +276,34 @@ func (t *alterDatabaseTask) OnEnqueue() error {
}
func (t *alterDatabaseTask) PreExecute(ctx context.Context) error {
_, ok := common.GetReplicateID(t.Properties)
if ok {
return merr.WrapErrParameterInvalidMsg("can't set the replicate id property in alter database request")
}
endTS, ok := common.GetReplicateEndTS(t.Properties)
if !ok { // not exist replicate end ts property
return nil
}
cacheInfo, err := globalMetaCache.GetDatabaseInfo(ctx, t.DbName)
if err != nil {
return err
}
oldReplicateEnable, _ := common.IsReplicateEnabled(cacheInfo.properties)
if !oldReplicateEnable { // old replicate enable is false
return merr.WrapErrParameterInvalidMsg("can't set the replicate end ts property in alter database request when db replicate is disabled")
}
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: 1,
BlockTimestamp: endTS,
})
if err = merr.CheckRPCCall(allocResp, err); err != nil {
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
}
if allocResp.GetTimestamp() <= endTS {
return merr.WrapErrServiceInternal("alter database: alloc timestamp failed, timestamp is not greater than endTS",
fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS))
}
return nil
}

View File

@ -5,6 +5,7 @@ import (
"strings"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/metadata"
@ -201,6 +202,163 @@ func TestAlterDatabase(t *testing.T) {
assert.Nil(t, err1)
}
func TestAlterDatabaseTaskForReplicateProperty(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
t.Run("replicate id", func(t *testing.T) {
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.MmapEnabledKey,
Value: "true",
},
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("fail to get database info", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("not enable replicate", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
properties: []*commonpb.KeyValuePair{},
}, nil).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("fail to alloc ts", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
}, nil).Once()
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("alloc wrong ts", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
}, nil).Once()
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
Status: merr.Success(),
Timestamp: 999,
}, nil).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("alloc wrong ts", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
}, nil).Once()
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
Status: merr.Success(),
Timestamp: 1001,
}, nil).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.NoError(t, err)
})
}
func TestDescribeDatabaseTask(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)

View File

@ -335,6 +335,15 @@ func (dr *deleteRunner) Init(ctx context.Context) error {
return ErrWithLog(log, "Failed to get collection id", merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound))
}
replicateID, err := GetReplicateID(ctx, dr.req.GetDbName(), collName)
if err != nil {
log.Warn("get replicate info failed", zap.String("collectionName", collName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
if replicateID != "" {
return merr.WrapErrCollectionReplicateMode("delete")
}
dr.schema, err = globalMetaCache.GetCollectionSchema(ctx, dr.req.GetDbName(), collName)
if err != nil {
return ErrWithLog(log, "Failed to get collection schema", err)

View File

@ -297,6 +297,45 @@ func TestDeleteRunner_Init(t *testing.T) {
assert.Error(t, dr.Init(context.Background()))
})
t.Run("fail to get collection info", func(t *testing.T) {
dr := deleteRunner{req: &milvuspb.DeleteRequest{
CollectionName: collectionName,
DbName: dbName,
}}
cache := NewMockCache(t)
cache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil,
errors.New("mock get collection info"))
globalMetaCache = cache
assert.Error(t, dr.Init(context.Background()))
})
t.Run("deny delete in the replicate mode", func(t *testing.T) {
dr := deleteRunner{req: &milvuspb.DeleteRequest{
CollectionName: collectionName,
DbName: dbName,
}}
cache := NewMockCache(t)
cache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
replicateID: "local-mac",
}, nil)
globalMetaCache = cache
assert.Error(t, dr.Init(context.Background()))
})
t.Run("fail get collection schema", func(t *testing.T) {
dr := deleteRunner{req: &milvuspb.DeleteRequest{
CollectionName: collectionName,
@ -309,6 +348,7 @@ func TestDeleteRunner_Init(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
@ -332,6 +372,7 @@ func TestDeleteRunner_Init(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
@ -376,6 +417,7 @@ func TestDeleteRunner_Init(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(schema, nil)
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
globalMetaCache = cache
assert.Error(t, dr.Init(context.Background()))
@ -402,6 +444,7 @@ func TestDeleteRunner_Init(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(schema, nil)
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
cache.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
@ -431,6 +474,7 @@ func TestDeleteRunner_Init(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),

View File

@ -125,6 +125,15 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return merr.WrapErrAsInputError(merr.WrapErrParameterTooLarge("insert request size exceeds maxInsertSize"))
}
replicateID, err := GetReplicateID(it.ctx, it.insertMsg.GetDbName(), collectionName)
if err != nil {
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
return merr.WrapErrAsInputError(err)
}
if replicateID != "" {
return merr.WrapErrCollectionReplicateMode("insert")
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, it.insertMsg.GetDbName(), collectionName)
if err != nil {
log.Ctx(ctx).Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err))

View File

@ -41,6 +41,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
@ -1708,8 +1709,8 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestTask_all"
dbName := ""
prefix := "TestTask_int64pk"
dbName := "int64PK"
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
@ -1726,45 +1727,43 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
}
nb := 10
t.Run("create collection", func(t *testing.T) {
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreatePartition,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreatePartition,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
@ -1957,7 +1956,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
shardsNum := int32(2)
prefix := "TestTask_all"
dbName := ""
dbName := "testvarchar"
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
@ -1975,45 +1974,43 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
}
nb := 10
t.Run("create collection", func(t *testing.T) {
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testVarCharField, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testVarCharField, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreatePartition,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreatePartition,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
@ -3444,30 +3441,28 @@ func TestPartitionKey(t *testing.T) {
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
t.Run("create collection", func(t *testing.T) {
createCollectionTask := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: "",
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
NumPartitions: common.DefaultPartitionsWithPartitionKey,
createCollectionTask := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = createCollectionTask.PreExecute(ctx)
assert.NoError(t, err)
err = createCollectionTask.Execute(ctx)
assert.NoError(t, err)
})
DbName: "",
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
NumPartitions: common.DefaultPartitionsWithPartitionKey,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = createCollectionTask.PreExecute(ctx)
assert.NoError(t, err)
err = createCollectionTask.Execute(ctx)
assert.NoError(t, err)
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
@ -3500,7 +3495,7 @@ func TestPartitionKey(t *testing.T) {
_ = segAllocator.Start()
defer segAllocator.Close()
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", collectionName)
assert.NoError(t, err)
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames)))
@ -4269,3 +4264,136 @@ func TestTaskPartitionKeyIsolation(t *testing.T) {
"can not alter partition key isolation mode if the collection already has a vector index. Please drop the index first")
})
}
func TestAlterCollectionForReplicateProperty(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
replicateID: "local-mac-1",
}, nil).Maybe()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe()
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil).Maybe()
globalMetaCache = mockCache
ctx := context.Background()
mockRootcoord := mocks.NewMockRootCoordClient(t)
t.Run("invalid replicate id", func(t *testing.T) {
task := &alterCollectionTask{
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "xxxxx",
},
},
},
rootCoord: mockRootcoord,
}
err := task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("empty replicate id", func(t *testing.T) {
task := &alterCollectionTask{
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
CollectionName: "test",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "",
},
},
},
rootCoord: mockRootcoord,
}
err := task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("fail to alloc ts", func(t *testing.T) {
task := &alterCollectionTask{
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
CollectionName: "test",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "100",
},
},
},
rootCoord: mockRootcoord,
}
mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
err := task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("alloc wrong ts", func(t *testing.T) {
task := &alterCollectionTask{
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
CollectionName: "test",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "100",
},
},
},
rootCoord: mockRootcoord,
}
mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
Status: merr.Success(),
Timestamp: 99,
}, nil).Once()
err := task.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestInsertForReplicate(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
t.Run("get replicate id fail", func(t *testing.T) {
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
task := &insertTask{
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: "foo",
},
},
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("insert with replicate id", func(t *testing.T) {
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
schema: &schemaInfo{
CollectionSchema: &schemapb.CollectionSchema{
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-mac",
},
},
},
},
replicateID: "local-mac",
}, nil).Once()
task := &insertTask{
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: "foo",
},
},
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
}

View File

@ -292,6 +292,15 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
Timestamp: it.EndTs(),
}
replicateID, err := GetReplicateID(ctx, it.req.GetDbName(), collectionName)
if err != nil {
log.Warn("get replicate info failed", zap.String("collectionName", collectionName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
if replicateID != "" {
return merr.WrapErrCollectionReplicateMode("upsert")
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName)
if err != nil {
log.Warn("Failed to get collection schema",

View File

@ -19,6 +19,7 @@ import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -325,3 +326,37 @@ func TestUpsertTask(t *testing.T) {
assert.ElementsMatch(t, channels, ut.pChannels)
})
}
func TestUpsertTaskForReplicate(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
ctx := context.Background()
t.Run("fail to get collection info", func(t *testing.T) {
ut := upsertTask{
ctx: ctx,
req: &milvuspb.UpsertRequest{
CollectionName: "col-0",
},
}
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("foo")).Once()
err := ut.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("replicate mode", func(t *testing.T) {
ut := upsertTask{
ctx: ctx,
req: &milvuspb.UpsertRequest{
CollectionName: "col-0",
},
}
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
replicateID: "local-mac",
}, nil).Once()
err := ut.PreExecute(ctx)
assert.Error(t, err)
})
}

View File

@ -2212,3 +2212,22 @@ func GetFailedResponse(req any, err error) any {
}
return nil
}
func GetReplicateID(ctx context.Context, database, collectionName string) (string, error) {
if globalMetaCache == nil {
return "", merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
}
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, database, collectionName, 0)
if err != nil {
return "", err
}
if colInfo.replicateID != "" {
return colInfo.replicateID, nil
}
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, database)
if err != nil {
return "", err
}
replicateID, _ := common.GetReplicateID(dbInfo.properties)
return replicateID, nil
}

View File

@ -36,6 +36,7 @@ import (
coordMocks "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/querycoordv2/checkers"
"github.com/milvus-io/milvus/internal/querycoordv2/dist"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
@ -614,6 +615,7 @@ func (suite *ServerSuite) hackServer() {
)
suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Schema: &schemapb.CollectionSchema{}}, nil).Maybe()
suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil).Maybe()
suite.broker.EXPECT().ListIndexes(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()

View File

@ -56,9 +56,7 @@ var (
)
func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
log := log.Ctx(ctx).With(zap.Int64s("collections", req.GetCollectionIDs()))
log.Info("show collections request received")
log.Ctx(ctx).Debug("show collections request received", zap.Int64s("collections", req.GetCollectionIDs()))
if err := merr.CheckHealthy(s.State()); err != nil {
msg := "failed to show collections"
log.Warn(msg, zap.Error(err))

View File

@ -341,18 +341,23 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
collectionInfo, err := ex.broker.DescribeCollection(ctx, task.CollectionID())
if err != nil {
log.Warn("failed to get collection info")
log.Warn("failed to get collection info", zap.Error(err))
return err
}
loadFields := ex.meta.GetLoadFields(ctx, task.CollectionID())
partitions, err := utils.GetPartitions(ctx, ex.targetMgr, task.CollectionID())
if err != nil {
log.Warn("failed to get partitions of collection")
log.Warn("failed to get partitions of collection", zap.Error(err))
return err
}
indexInfo, err := ex.broker.ListIndexes(ctx, task.CollectionID())
if err != nil {
log.Warn("fail to get index meta of collection")
log.Warn("fail to get index meta of collection", zap.Error(err))
return err
}
dbResp, err := ex.broker.DescribeDatabase(ctx, collectionInfo.GetDbName())
if err != nil {
log.Warn("failed to get database info", zap.Error(err))
return err
}
loadMeta := packLoadMeta(
@ -363,6 +368,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
loadFields,
partitions...,
)
loadMeta.DbProperties = dbResp.GetProperties()
dmChannel := ex.targetMgr.GetDmChannel(ctx, task.CollectionID(), action.ChannelName(), meta.NextTarget)
if dmChannel == nil {

View File

@ -38,6 +38,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
. "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
@ -230,6 +231,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
},
}, nil
})
suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil)
for channel, segment := range suite.growingSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).
Return([]*datapb.SegmentInfo{

View File

@ -22,6 +22,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
@ -46,7 +47,7 @@ func TestGetPipelineJSON(t *testing.T) {
collectionManager := segments.NewMockCollectionManager(t)
segmentManager := segments.NewMockSegmentManager(t)
collectionManager.EXPECT().Get(mock.Anything).Return(&segments.Collection{})
collectionManager.EXPECT().Get(mock.Anything).Return(segments.NewTestCollection(1, querypb.LoadType_UnKnownType, &schemapb.CollectionSchema{}))
manager := &segments.Manager{
Collection: collectionManager,
Segment: segmentManager,

View File

@ -72,7 +72,7 @@ func (suite *FilterNodeSuite) TestWithLoadCollection() {
suite.validSegmentIDs = []int64{2, 3, 4, 5, 6}
// mock
collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadCollection)
collection := segments.NewTestCollection(suite.collectionID, querypb.LoadType_LoadCollection, nil)
for _, partitionID := range suite.partitionIDs {
collection.AddPartition(partitionID)
}
@ -111,7 +111,7 @@ func (suite *FilterNodeSuite) TestWithLoadPartation() {
suite.validSegmentIDs = []int64{2, 3, 4, 5, 6}
// mock
collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadPartition)
collection := segments.NewTestCollection(suite.collectionID, querypb.LoadType_LoadPartition, nil)
collection.AddPartition(suite.partitionIDs[0])
mockCollectionManager := segments.NewMockCollectionManager(suite.T())

View File

@ -85,7 +85,7 @@ func (m *manager) Add(collectionID UniqueID, channel string) (Pipeline, error) {
return nil, merr.WrapErrChannelNotFound(channel, "delegator not found")
}
newPipeLine, err := NewPipeLine(collectionID, channel, m.dataManager, m.dispatcher, delegator)
newPipeLine, err := NewPipeLine(collection, channel, m.dataManager, m.dispatcher, delegator)
if err != nil {
return nil, merr.WrapErrServiceUnavailable(err.Error(), "failed to create new pipeline")
}

View File

@ -24,9 +24,10 @@ import (
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -73,9 +74,9 @@ func (suite *PipelineManagerTestSuite) SetupTest() {
func (suite *PipelineManagerTestSuite) TestBasic() {
// init mock
// mock collection manager
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(&segments.Collection{})
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(segments.NewTestCollection(suite.collectionID, querypb.LoadType_UnKnownType, &schemapb.CollectionSchema{}))
// mock mq factory
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.msgChan, nil)
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
// build manager

View File

@ -19,7 +19,9 @@ package pipeline
import (
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
base "github.com/milvus-io/milvus/internal/util/pipeline"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -45,17 +47,23 @@ func (p *pipeline) Close() {
}
func NewPipeLine(
collectionID UniqueID,
collection *Collection,
channel string,
manager *DataManager,
dispatcher msgdispatcher.Client,
delegator delegator.ShardDelegator,
) (Pipeline, error) {
collectionID := collection.ID()
replicateID, _ := common.GetReplicateID(collection.Schema().GetProperties())
if replicateID == "" {
replicateID, _ = common.GetReplicateID(collection.GetDBProperties())
}
replicateConfig := msgstream.GetReplicateConfig(replicateID, collection.GetDBName(), collection.Schema().Name)
pipelineQueueLength := paramtable.Get().QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()
p := &pipeline{
collectionID: collectionID,
StreamPipeline: base.NewPipelineWithStream(dispatcher, nodeCtxTtInterval, enableTtChecker, channel),
StreamPipeline: base.NewPipelineWithStream(dispatcher, nodeCtxTtInterval, enableTtChecker, channel, replicateConfig),
}
filterNode := newFilterNode(collectionID, channel, manager, delegator, pipelineQueueLength)

View File

@ -24,13 +24,14 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -103,11 +104,17 @@ func (suite *PipelineTestSuite) TestBasic() {
schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
DbProperties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
})
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection)
// mock mq factory
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.msgChan, nil)
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
// mock delegator
@ -136,16 +143,16 @@ func (suite *PipelineTestSuite) TestBasic() {
Collection: suite.collectionManager,
Segment: suite.segmentManager,
}
pipeline, err := NewPipeLine(suite.collectionID, suite.channel, manager, suite.msgDispatcher, suite.delegator)
pipelineObj, err := NewPipeLine(collection, suite.channel, manager, suite.msgDispatcher, suite.delegator)
suite.NoError(err)
// Init Consumer
err = pipeline.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{})
err = pipelineObj.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{})
suite.NoError(err)
err = pipeline.Start()
err = pipelineObj.Start()
suite.NoError(err)
defer pipeline.Close()
defer pipelineObj.Close()
// build input msg
in := suite.buildMsgPack(schema)

View File

@ -148,6 +148,7 @@ type Collection struct {
partitions *typeutil.ConcurrentSet[int64]
loadType querypb.LoadType
dbName string
dbProperties []*commonpb.KeyValuePair
resourceGroup string
// resource group of node may be changed if node transfer,
// but Collection in Manager will be released before assign new replica of new resource group on these node.
@ -166,6 +167,10 @@ func (c *Collection) GetDBName() string {
return c.dbName
}
func (c *Collection) GetDBProperties() []*commonpb.KeyValuePair {
return c.dbProperties
}
// GetResourceGroup returns the resource group of collection.
func (c *Collection) GetResourceGroup() string {
return c.resourceGroup
@ -284,6 +289,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
partitions: typeutil.NewConcurrentSet[int64](),
loadType: loadMetaInfo.GetLoadType(),
dbName: loadMetaInfo.GetDbName(),
dbProperties: loadMetaInfo.GetDbProperties(),
resourceGroup: loadMetaInfo.GetResourceGroup(),
refCount: atomic.NewUint32(0),
isGpuIndex: isGpuIndex,
@ -297,13 +303,16 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
return coll
}
func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *Collection {
return &Collection{
// Only for test
func NewTestCollection(collectionID int64, loadType querypb.LoadType, schema *schemapb.CollectionSchema) *Collection {
col := &Collection{
id: collectionID,
partitions: typeutil.NewConcurrentSet[int64](),
loadType: loadType,
refCount: atomic.NewUint32(0),
}
col.schema.Store(schema)
return col
}
// new collection without segcore prepare

View File

@ -26,11 +26,13 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -130,6 +132,43 @@ func (a *alterCollectionTask) Execute(ctx context.Context) error {
}))
}
oldReplicateEnable, _ := common.IsReplicateEnabled(oldColl.Properties)
replicateEnable, ok := common.IsReplicateEnabled(newColl.Properties)
if ok && !replicateEnable && oldReplicateEnable {
replicateID, _ := common.GetReplicateID(oldColl.Properties)
redoTask.AddAsyncStep(NewSimpleStep("send replicate end msg for collection", func(ctx context.Context) ([]nestedStep, error) {
msgPack := &msgstream.MsgPack{}
msg := &msgstream.ReplicateMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
},
ReplicateMsg: &msgpb.ReplicateMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Replicate,
Timestamp: ts,
ReplicateInfo: &commonpb.ReplicateInfo{
IsReplicate: true,
ReplicateID: replicateID,
},
},
IsEnd: true,
Database: newColl.DBName,
Collection: newColl.Name,
},
}
msgPack.Msgs = append(msgPack.Msgs, msg)
log.Info("send replicate end msg",
zap.String("collection", newColl.Name),
zap.String("database", newColl.DBName),
zap.String("replicateID", replicateID),
)
return nil, a.core.chanTimeTick.broadcastDmlChannels(newColl.PhysicalChannelNames, msgPack)
}))
}
return redoTask.Execute(ctx)
}

View File

@ -19,6 +19,7 @@ package rootcoord
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
@ -29,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
)
func Test_alterCollectionTask_Prepare(t *testing.T) {
@ -217,14 +219,25 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
assert.NoError(t, err)
})
t.Run("alter successfully", func(t *testing.T) {
t.Run("alter successfully2", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{CollectionID: int64(1)}, nil)
).Return(&model.Collection{
CollectionID: int64(1),
Name: "cn",
DBName: "foo",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
@ -237,19 +250,37 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return nil
}
packChan := make(chan *msgstream.MsgPack, 10)
ticker := newChanTimeTickSync(packChan)
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker))
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withTtSynchronizer(ticker))
newPros := append(properties, &commonpb.KeyValuePair{
Key: common.ReplicateEndTSKey,
Value: "10000",
})
task := &alterCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
Properties: properties,
Properties: newPros,
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
time.Sleep(time.Second)
select {
case pack := <-packChan:
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase())
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection())
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
default:
assert.Fail(t, "no message sent")
}
})
t.Run("test update collection props", func(t *testing.T) {

View File

@ -25,11 +25,14 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -43,6 +46,19 @@ func (a *alterDatabaseTask) Prepare(ctx context.Context) error {
return fmt.Errorf("alter database failed, database name does not exists")
}
// TODO SimFG maybe it will support to alter the replica.id properties in the future when the database has no collections
// now it can't be because the latest database properties can't be notified to the querycoord and datacoord
replicateID, _ := common.GetReplicateID(a.Req.Properties)
if replicateID != "" {
colls, err := a.core.meta.ListCollections(ctx, a.Req.DbName, a.ts, true)
if err != nil {
return err
}
if len(colls) > 0 {
return errors.New("can't set replicate id on database with collections")
}
}
return nil
}
@ -85,6 +101,18 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error {
ts: ts,
})
redoTask.AddSyncStep(&expireCacheStep{
baseStep: baseStep{core: a.core},
dbName: newDB.Name,
ts: ts,
// make sure to send the "expire cache" request
// because it won't send this request when the length of collection names array is zero
collectionNames: []string{""},
opts: []proxyutil.ExpireCacheOpt{
proxyutil.SetMsgType(commonpb.MsgType_AlterDatabase),
},
})
oldReplicaNumber, _ := common.DatabaseLevelReplicaNumber(oldDB.Properties)
oldResourceGroups, _ := common.DatabaseLevelResourceGroups(oldDB.Properties)
newReplicaNumber, _ := common.DatabaseLevelReplicaNumber(newDB.Properties)
@ -123,6 +151,39 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error {
}))
}
oldReplicateEnable, _ := common.IsReplicateEnabled(oldDB.Properties)
newReplicateEnable, ok := common.IsReplicateEnabled(newDB.Properties)
if ok && !newReplicateEnable && oldReplicateEnable {
replicateID, _ := common.GetReplicateID(oldDB.Properties)
redoTask.AddAsyncStep(NewSimpleStep("send replicate end msg for db", func(ctx context.Context) ([]nestedStep, error) {
msgPack := &msgstream.MsgPack{}
msg := &msgstream.ReplicateMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
},
ReplicateMsg: &msgpb.ReplicateMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Replicate,
Timestamp: ts,
ReplicateInfo: &commonpb.ReplicateInfo{
IsReplicate: true,
ReplicateID: replicateID,
},
},
IsEnd: true,
Database: newDB.Name,
Collection: "",
},
}
msgPack.Msgs = append(msgPack.Msgs, msg)
log.Info("send replicate end msg for db", zap.String("db", newDB.Name), zap.String("replicateID", replicateID))
return nil, a.core.chanTimeTick.broadcastDmlChannels(a.core.chanTimeTick.listDmlChannels(), msgPack)
}))
}
return redoTask.Execute(ctx)
}
@ -134,6 +195,14 @@ func (a *alterDatabaseTask) GetLockerKey() LockerKey {
}
func MergeProperties(oldProps []*commonpb.KeyValuePair, updatedProps []*commonpb.KeyValuePair) []*commonpb.KeyValuePair {
_, existEndTS := common.GetReplicateEndTS(updatedProps)
if existEndTS {
updatedProps = append(updatedProps, &commonpb.KeyValuePair{
Key: common.ReplicateIDKey,
Value: "",
})
}
props := make(map[string]string)
for _, prop := range oldProps {
props[prop.Key] = prop.Value

View File

@ -19,6 +19,7 @@ package rootcoord
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
@ -29,6 +30,8 @@ import (
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/funcutil"
)
func Test_alterDatabaseTask_Prepare(t *testing.T) {
@ -47,6 +50,76 @@ func Test_alterDatabaseTask_Prepare(t *testing.T) {
err := task.Prepare(context.Background())
assert.NoError(t, err)
})
t.Run("replicate id", func(t *testing.T) {
{
// no collections
meta := mockrootcoord.NewIMetaTable(t)
core := newTestCore(withMeta(meta))
meta.EXPECT().
ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return([]*model.Collection{}, nil).
Once()
task := &alterDatabaseTask{
baseTask: newBaseTask(context.Background(), core),
Req: &rootcoordpb.AlterDatabaseRequest{
DbName: "cn",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
}
{
meta := mockrootcoord.NewIMetaTable(t)
core := newTestCore(withMeta(meta))
meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{
{
Name: "foo",
},
}, nil).Once()
task := &alterDatabaseTask{
baseTask: newBaseTask(context.Background(), core),
Req: &rootcoordpb.AlterDatabaseRequest{
DbName: "cn",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
}
{
meta := mockrootcoord.NewIMetaTable(t)
core := newTestCore(withMeta(meta))
meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil, errors.New("err")).
Once()
task := &alterDatabaseTask{
baseTask: newBaseTask(context.Background(), core),
Req: &rootcoordpb.AlterDatabaseRequest{
DbName: "cn",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
}
})
}
func Test_alterDatabaseTask_Execute(t *testing.T) {
@ -146,25 +219,51 @@ func Test_alterDatabaseTask_Execute(t *testing.T) {
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Database{ID: int64(1)}, nil)
).Return(&model.Database{
ID: int64(1),
Name: "cn",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
}, nil)
meta.On("AlterDatabase",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
// the chan length should larger than 4, because newChanTimeTickSync will send 4 ts messages when execute the `broadcast` step
packChan := make(chan *msgstream.MsgPack, 10)
ticker := newChanTimeTickSync(packChan)
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
core := newTestCore(withMeta(meta))
core := newTestCore(withMeta(meta), withValidProxyManager(), withTtSynchronizer(ticker))
newPros := append(properties,
&commonpb.KeyValuePair{Key: common.ReplicateEndTSKey, Value: "1000"},
)
task := &alterDatabaseTask{
baseTask: newBaseTask(context.Background(), core),
Req: &rootcoordpb.AlterDatabaseRequest{
DbName: "cn",
Properties: properties,
Properties: newPros,
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
time.Sleep(time.Second)
select {
case pack := <-packChan:
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetDatabase())
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
default:
assert.Fail(t, "no message sent")
}
})
t.Run("test update collection props", func(t *testing.T) {
@ -248,3 +347,26 @@ func Test_alterDatabaseTask_Execute(t *testing.T) {
assert.Empty(t, ret2)
})
}
func TestMergeProperties(t *testing.T) {
p := MergeProperties([]*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
{
Key: "foo",
Value: "xxx",
},
}, []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1001",
},
})
assert.Len(t, p, 3)
m := funcutil.KeyValuePair2Map(p)
assert.Equal(t, "", m[common.ReplicateIDKey])
assert.Equal(t, "1001", m[common.ReplicateEndTSKey])
assert.Equal(t, "xxx", m["foo"])
}

View File

@ -43,6 +43,7 @@ type watchInfo struct {
vChannels []string
startPositions []*commonpb.KeyDataPair
schema *schemapb.CollectionSchema
dbProperties []*commonpb.KeyValuePair
}
// Broker communicates with other components.
@ -165,6 +166,7 @@ func (b *ServerBroker) WatchChannels(ctx context.Context, info *watchInfo) error
StartPositions: info.startPositions,
Schema: info.schema,
CreateTimestamp: info.ts,
DbProperties: info.dbProperties,
})
if err != nil {
return err

View File

@ -61,6 +61,7 @@ type createCollectionTask struct {
channels collectionChannels
dbID UniqueID
partitionNames []string
dbProperties []*commonpb.KeyValuePair
}
func (t *createCollectionTask) validate(ctx context.Context) error {
@ -424,6 +425,18 @@ func (t *createCollectionTask) Prepare(ctx context.Context) error {
return err
}
t.dbID = db.ID
dbReplicateID, _ := common.GetReplicateID(db.Properties)
if dbReplicateID != "" {
reqProperties := make([]*commonpb.KeyValuePair, 0, len(t.Req.Properties))
for _, prop := range t.Req.Properties {
if prop.Key == common.ReplicateIDKey {
continue
}
reqProperties = append(reqProperties, prop)
}
t.Req.Properties = reqProperties
}
t.dbProperties = db.Properties
if err := t.validate(ctx); err != nil {
return err
@ -565,6 +578,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
CollectionID: collID,
DBID: t.dbID,
Name: t.schema.Name,
DBName: t.Req.GetDbName(),
Description: t.schema.Description,
AutoID: t.schema.AutoID,
Fields: model.UnmarshalFieldModels(t.schema.Fields),
@ -644,11 +658,14 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
startPositions: toKeyDataPairs(startPositions),
schema: &schemapb.CollectionSchema{
Name: collInfo.Name,
DbName: collInfo.DBName,
Description: collInfo.Description,
AutoID: collInfo.AutoID,
Fields: model.MarshalFieldModels(collInfo.Fields),
Properties: collInfo.Properties,
Functions: model.MarshalFunctionModels(collInfo.Functions),
},
dbProperties: t.dbProperties,
},
}, &nullStep{})
undoTask.AddStep(&changeCollectionStateStep{

View File

@ -823,6 +823,70 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
})
}
func TestCreateCollectionTask_Prepare_WithProperty(t *testing.T) {
paramtable.Init()
meta := mockrootcoord.NewIMetaTable(t)
t.Run("with db properties", func(t *testing.T) {
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{
Name: "foo",
ID: 1,
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
}, nil).Twice()
meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{
util.DefaultDBID: {1, 2},
}).Once()
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0).Once()
defer cleanTestEnv()
collectionName := funcutil.GenRandomStr()
field1 := funcutil.GenRandomStr()
ticker := newRocksMqTtSynchronizer()
core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta))
schema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: field1,
DataType: schemapb.DataType_Int64,
},
},
}
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task := createCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
CollectionName: collectionName,
Schema: marshaledSchema,
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "hoo",
},
},
},
dbID: 1,
}
task.Req.ShardsNum = common.DefaultShardsNum
err = task.Prepare(context.Background())
assert.Len(t, task.dbProperties, 1)
assert.Len(t, task.Req.Properties, 0)
assert.NoError(t, err)
})
}
func Test_createCollectionTask_Execute(t *testing.T) {
t.Run("add same collection with different parameters", func(t *testing.T) {
defer cleanTestEnv()

View File

@ -195,6 +195,9 @@ func (mt *MetaTable) reload() error {
return err
}
for _, collection := range collections {
if collection.DBName == "" {
collection.DBName = dbName
}
mt.collID2Meta[collection.CollectionID] = collection
mt.generalCnt += len(collection.Partitions) * int(collection.ShardsNum)
if collection.Available() {
@ -559,12 +562,14 @@ func filterUnavailable(coll *model.Collection) *model.Collection {
func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowUnavailable bool) (*model.Collection, error) {
coll, ok := mt.collID2Meta[collectionID]
if !ok || coll == nil {
log.Warn("not found collection", zap.Int64("collectionID", collectionID))
return nil, merr.WrapErrCollectionNotFound(collectionID)
}
if allowUnavailable {
return coll.Clone(), nil
}
if !coll.Available() {
log.Warn("collection not available", zap.Int64("collectionID", collectionID), zap.Any("state", coll.State))
return nil, merr.WrapErrCollectionNotFound(collectionID)
}
return filterUnavailable(coll), nil

View File

@ -1058,6 +1058,31 @@ func newTickerWithFactory(factory msgstream.Factory) *timetickSync {
return ticker
}
func newChanTimeTickSync(packChan chan *msgstream.MsgPack) *timetickSync {
f := msgstream.NewMockMqFactory()
f.NewMsgStreamFunc = func(ctx context.Context) (msgstream.MsgStream, error) {
stream := msgstream.NewWastedMockMsgStream()
stream.BroadcastFunc = func(pack *msgstream.MsgPack) error {
log.Info("mock Broadcast")
packChan <- pack
return nil
}
stream.BroadcastMarkFunc = func(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
log.Info("mock BroadcastMark")
packChan <- pack
return map[string][]msgstream.MessageID{}, nil
}
stream.AsProducerFunc = func(channels []string) {
}
stream.ChanFunc = func() <-chan *msgstream.MsgPack {
return packChan
}
return stream, nil
}
return newTickerWithFactory(f)
}
type mockDdlTsLockManager struct {
DdlTsLockManager
GetMinDdlTsFunc func() Timestamp

View File

@ -1226,6 +1226,7 @@ func convertModelToDesc(collInfo *model.Collection, aliases []string, dbName str
Fields: model.MarshalFieldModels(collInfo.Fields),
Functions: model.MarshalFunctionModels(collInfo.Functions),
EnableDynamicField: collInfo.EnableDynamicField,
Properties: collInfo.Properties,
}
resp.CollectionID = collInfo.CollectionID
resp.VirtualChannelNames = collInfo.VirtualChannelNames
@ -1745,6 +1746,19 @@ func (c *Core) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestam
}, nil
}
if in.BlockTimestamp > 0 {
blockTime, _ := tsoutil.ParseTS(in.BlockTimestamp)
lastTime := c.tsoAllocator.GetLastSavedTime()
deltaDuration := blockTime.Sub(lastTime)
if deltaDuration > 0 {
log.Info("wait for block timestamp",
zap.Time("blockTime", blockTime),
zap.Time("lastTime", lastTime),
zap.Duration("delta", deltaDuration))
time.Sleep(deltaDuration + time.Millisecond*200)
}
}
ts, err := c.tsoAllocator.GenerateTSO(in.GetCount())
if err != nil {
log.Ctx(ctx).Error("failed to allocate timestamp", zap.String("role", typeutil.RootCoordRole),

View File

@ -856,6 +856,32 @@ func TestRootCoord_AllocTimestamp(t *testing.T) {
assert.Equal(t, ts-uint64(count)+1, resp.GetTimestamp())
assert.Equal(t, count, resp.GetCount())
})
t.Run("block timestamp", func(t *testing.T) {
alloc := newMockTsoAllocator()
count := uint32(10)
current := time.Now()
ts := tsoutil.ComposeTSByTime(current.Add(time.Second), 1)
alloc.GenerateTSOF = func(count uint32) (uint64, error) {
// end ts
return ts, nil
}
alloc.GetLastSavedTimeF = func() time.Time {
return current
}
ctx := context.Background()
c := newTestCore(withHealthyCode(),
withTsoAllocator(alloc))
resp, err := c.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: count,
BlockTimestamp: tsoutil.ComposeTSByTime(current.Add(time.Second), 0),
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
// begin ts
assert.Equal(t, ts-uint64(count)+1, resp.GetTimestamp())
assert.Equal(t, count, resp.GetCount())
})
}
func TestRootCoord_AllocID(t *testing.T) {

View File

@ -18,6 +18,7 @@ package rootcoord
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
@ -173,7 +174,6 @@ func NewCollectionLockerKey(collection string, rw bool) LockerKey {
}
func NewLockerKeyChain(lockerKeys ...LockerKey) LockerKey {
log.Info("NewLockerKeyChain", zap.Any("lockerKeys", len(lockerKeys)))
if len(lockerKeys) == 0 {
return nil
}
@ -191,3 +191,16 @@ func NewLockerKeyChain(lockerKeys ...LockerKey) LockerKey {
}
return lockerKeys[0]
}
func GetLockerKeyString(k LockerKey) string {
if k == nil {
return "nil"
}
key := k.LockKey()
level := k.Level()
wLock := k.IsWLock()
if k.Next() == nil {
return fmt.Sprintf("%s-%d-%t", key, level, wLock)
}
return fmt.Sprintf("%s-%d-%t|%s", key, level, wLock, GetLockerKeyString(k.Next()))
}

View File

@ -20,7 +20,6 @@ package rootcoord
import (
"context"
"fmt"
"testing"
"github.com/cockroachdb/errors"
@ -72,16 +71,6 @@ func TestLockerKey(t *testing.T) {
}
}
func GetLockerKeyString(k LockerKey) string {
key := k.LockKey()
level := k.Level()
wLock := k.IsWLock()
if k.Next() == nil {
return fmt.Sprintf("%s-%d-%t", key, level, wLock)
}
return fmt.Sprintf("%s-%d-%t|%s", key, level, wLock, GetLockerKeyString(k.Next()))
}
func TestGetLockerKey(t *testing.T) {
t.Run("alter alias task locker key", func(t *testing.T) {
tt := &alterAliasTask{

View File

@ -116,6 +116,7 @@ func (c *channelLifetime) Run() error {
// Build and add pipeline.
ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, c.f.pipelineParams,
// TODO fubang add the db properties
&datapb.ChannelWatchInfo{Vchan: resp.GetInfo(), Schema: resp.GetSchema()}, handler.Chan(), func(t syncmgr.Task, err error) {
if err != nil || t == nil {
return

View File

@ -45,12 +45,13 @@ type StreamPipeline interface {
}
type streamPipeline struct {
pipeline *pipeline
input <-chan *msgstream.MsgPack
scanner streaming.Scanner
dispatcher msgdispatcher.Client
startOnce sync.Once
vChannel string
pipeline *pipeline
input <-chan *msgstream.MsgPack
scanner streaming.Scanner
dispatcher msgdispatcher.Client
startOnce sync.Once
vChannel string
replicateConfig *msgstream.ReplicateConfig
closeCh chan struct{} // notify work to exit
closeWg sync.WaitGroup
@ -118,7 +119,12 @@ func (p *streamPipeline) ConsumeMsgStream(ctx context.Context, position *msgpb.M
}
start := time.Now()
p.input, err = p.dispatcher.Register(ctx, p.vChannel, position, common.SubscriptionPositionUnknown)
p.input, err = p.dispatcher.Register(ctx, &msgdispatcher.StreamConfig{
VChannel: p.vChannel,
Pos: position,
SubPos: common.SubscriptionPositionUnknown,
ReplicateConfig: p.replicateConfig,
})
if err != nil {
log.Error("dispatcher register failed", zap.String("channel", position.ChannelName))
return WrapErrRegDispather(err)
@ -160,18 +166,24 @@ func (p *streamPipeline) Close() {
})
}
func NewPipelineWithStream(dispatcher msgdispatcher.Client, nodeTtInterval time.Duration, enableTtChecker bool, vChannel string) StreamPipeline {
func NewPipelineWithStream(dispatcher msgdispatcher.Client,
nodeTtInterval time.Duration,
enableTtChecker bool,
vChannel string,
replicateConfig *msgstream.ReplicateConfig,
) StreamPipeline {
pipeline := &streamPipeline{
pipeline: &pipeline{
nodes: []*nodeCtx{},
nodeTtInterval: nodeTtInterval,
enableTtChecker: enableTtChecker,
},
dispatcher: dispatcher,
vChannel: vChannel,
closeCh: make(chan struct{}),
closeWg: sync.WaitGroup{},
lastAccessTime: atomic.NewTime(time.Now()),
dispatcher: dispatcher,
vChannel: vChannel,
replicateConfig: replicateConfig,
closeCh: make(chan struct{}),
closeWg: sync.WaitGroup{},
lastAccessTime: atomic.NewTime(time.Now()),
}
return pipeline

View File

@ -25,7 +25,6 @@ import (
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
)
@ -47,9 +46,9 @@ func (suite *StreamPipelineSuite) SetupTest() {
suite.inChannel = make(chan *msgstream.MsgPack, 1)
suite.outChannel = make(chan msgstream.Timestamp)
suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T())
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.inChannel, nil)
suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.inChannel, nil)
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
suite.pipeline = NewPipelineWithStream(suite.msgDispatcher, 0, false, suite.channel)
suite.pipeline = NewPipelineWithStream(suite.msgDispatcher, 0, false, suite.channel, nil)
suite.length = 4
}

View File

@ -191,6 +191,8 @@ const (
PartitionKeyIsolationKey = "partitionkey.isolation"
FieldSkipLoadKey = "field.skipLoad"
IndexOffsetCacheEnabledKey = "indexoffsetcache.enabled"
ReplicateIDKey = "replicate.id"
ReplicateEndTSKey = "replicate.endTS"
)
const (
@ -395,3 +397,31 @@ func ShouldFieldBeLoaded(kvs []*commonpb.KeyValuePair) (bool, error) {
}
return true, nil
}
func IsReplicateEnabled(kvs []*commonpb.KeyValuePair) (bool, bool) {
replicateID, ok := GetReplicateID(kvs)
return replicateID != "", ok
}
func GetReplicateID(kvs []*commonpb.KeyValuePair) (string, bool) {
for _, kv := range kvs {
if kv.GetKey() == ReplicateIDKey {
return kv.GetValue(), true
}
}
return "", false
}
func GetReplicateEndTS(kvs []*commonpb.KeyValuePair) (uint64, bool) {
for _, kv := range kvs {
if kv.GetKey() == ReplicateEndTSKey {
ts, err := strconv.ParseUint(kv.GetValue(), 10, 64)
if err != nil {
log.Warn("parse replicate end ts failed", zap.Error(err), zap.Stack("stack"))
return 0, false
}
return ts, true
}
}
return 0, false
}

View File

@ -177,3 +177,84 @@ func TestShouldFieldBeLoaded(t *testing.T) {
})
}
}
func TestReplicateProperty(t *testing.T) {
t.Run("ReplicateID", func(t *testing.T) {
{
p := []*commonpb.KeyValuePair{
{
Key: ReplicateIDKey,
Value: "1001",
},
}
e, ok := IsReplicateEnabled(p)
assert.True(t, e)
assert.True(t, ok)
i, ok := GetReplicateID(p)
assert.True(t, ok)
assert.Equal(t, "1001", i)
}
{
p := []*commonpb.KeyValuePair{
{
Key: ReplicateIDKey,
Value: "",
},
}
e, ok := IsReplicateEnabled(p)
assert.False(t, e)
assert.True(t, ok)
}
{
p := []*commonpb.KeyValuePair{
{
Key: "foo",
Value: "1001",
},
}
e, ok := IsReplicateEnabled(p)
assert.False(t, e)
assert.False(t, ok)
}
})
t.Run("ReplicateTS", func(t *testing.T) {
{
p := []*commonpb.KeyValuePair{
{
Key: ReplicateEndTSKey,
Value: "1001",
},
}
ts, ok := GetReplicateEndTS(p)
assert.True(t, ok)
assert.EqualValues(t, 1001, ts)
}
{
p := []*commonpb.KeyValuePair{
{
Key: ReplicateEndTSKey,
Value: "foo",
},
}
ts, ok := GetReplicateEndTS(p)
assert.False(t, ok)
assert.EqualValues(t, 0, ts)
}
{
p := []*commonpb.KeyValuePair{
{
Key: "foo",
Value: "1001",
},
}
ts, ok := GetReplicateEndTS(p)
assert.False(t, ok)
assert.EqualValues(t, 0, ts)
}
})
}

View File

@ -36,8 +36,23 @@ type (
SubPos = common.SubscriptionInitialPosition
)
type StreamConfig struct {
VChannel string
Pos *Pos
SubPos SubPos
ReplicateConfig *msgstream.ReplicateConfig
}
func NewStreamConfig(vchannel string, pos *Pos, subPos SubPos) *StreamConfig {
return &StreamConfig{
VChannel: vchannel,
Pos: pos,
SubPos: subPos,
}
}
type Client interface {
Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error)
Deregister(vchannel string)
Close()
}
@ -62,7 +77,8 @@ func NewClient(factory msgstream.Factory, role string, nodeID int64) Client {
}
}
func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) {
vchannel := streamConfig.VChannel
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
pchannel := funcutil.ToPhysicalChannel(vchannel)
@ -75,7 +91,7 @@ func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos
c.managers.Insert(pchannel, manager)
go manager.Run()
}
ch, err := manager.Add(ctx, vchannel, pos, subPos)
ch, err := manager.Add(ctx, streamConfig)
if err != nil {
if manager.Num() == 0 {
manager.Close()

View File

@ -34,9 +34,9 @@ import (
func TestClient(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
assert.NotNil(t, client)
_, err := client.Register(context.Background(), "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
_, err := client.Register(context.Background(), NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = client.Register(context.Background(), "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
_, err = client.Register(context.Background(), NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.NotPanics(t, func() {
client.Deregister("mock_vchannel_0")
@ -51,7 +51,7 @@ func TestClient(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.DataNodeRole, 1)
defer client.Close()
assert.NotNil(t, client)
_, err := client.Register(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
_, err := client.Register(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
})
}
@ -66,7 +66,7 @@ func TestClient_Concurrency(t *testing.T) {
vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int())
wg.Add(1)
go func() {
_, err := client1.Register(context.Background(), vchannel, nil, common.SubscriptionPositionUnknown)
_, err := client1.Register(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
for j := 0; j < rand.Intn(2); j++ {
client1.Deregister(vchannel)

View File

@ -80,10 +80,14 @@ type Dispatcher struct {
stream msgstream.MsgStream
}
func NewDispatcher(ctx context.Context,
factory msgstream.Factory, isMain bool,
pchannel string, position *Pos,
subName string, subPos SubPos,
func NewDispatcher(
ctx context.Context,
factory msgstream.Factory,
isMain bool,
pchannel string,
position *Pos,
subName string,
subPos SubPos,
lagNotifyChan chan struct{},
lagTargets *typeutil.ConcurrentMap[string, *target],
includeCurrentMsg bool,
@ -260,7 +264,8 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
// init packs for all targets, even though there's no msg in pack,
// but we still need to dispatch time ticks to the targets.
targetPacks := make(map[string]*MsgPack)
for vchannel := range d.targets {
replicateConfigs := make(map[string]*msgstream.ReplicateConfig)
for vchannel, t := range d.targets {
targetPacks[vchannel] = &MsgPack{
BeginTs: pack.BeginTs,
EndTs: pack.EndTs,
@ -268,6 +273,9 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
StartPositions: pack.StartPositions,
EndPositions: pack.EndPositions,
}
if t.replicateConfig != nil {
replicateConfigs[vchannel] = t.replicateConfig
}
}
// group messages by vchannel
for _, msg := range pack.Msgs {
@ -287,9 +295,16 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
collectionID = strconv.FormatInt(msg.(*msgstream.DropPartitionMsg).GetCollectionID(), 10)
}
if vchannel == "" {
// for non-dml msg, such as CreateCollection, DropCollection, ...
// we need to dispatch it to the vchannel of this collection
for k := range targetPacks {
if msg.Type() == commonpb.MsgType_Replicate {
config := replicateConfigs[k]
if config != nil && msgstream.MatchReplicateID(msg, config.ReplicateID) {
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
}
continue
}
if !strings.Contains(k, collectionID) {
continue
}
@ -303,9 +318,63 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, msg)
}
}
replicateEndChannels := make(map[string]struct{})
for vchannel, c := range replicateConfigs {
if len(targetPacks[vchannel].Msgs) == 0 {
delete(targetPacks, vchannel) // no replicate msg, can't send pack
continue
}
// calculate the new pack ts
beginTs := targetPacks[vchannel].Msgs[0].BeginTs()
endTs := targetPacks[vchannel].Msgs[0].EndTs()
newMsgs := make([]msgstream.TsMsg, 0)
for _, msg := range targetPacks[vchannel].Msgs {
if msg.BeginTs() < beginTs {
beginTs = msg.BeginTs()
}
if msg.EndTs() > endTs {
endTs = msg.EndTs()
}
if msg.Type() == commonpb.MsgType_Replicate {
replicateMsg := msg.(*msgstream.ReplicateMsg)
if c.CheckFunc(replicateMsg) {
replicateEndChannels[vchannel] = struct{}{}
}
continue
}
newMsgs = append(newMsgs, msg)
}
targetPacks[vchannel].Msgs = newMsgs
d.resetMsgPackTS(targetPacks[vchannel], beginTs, endTs)
}
for vchannel := range replicateEndChannels {
if t, ok := d.targets[vchannel]; ok {
t.replicateConfig = nil
log.Info("replicate end, set replicate config nil", zap.String("vchannel", vchannel))
}
}
return targetPacks
}
func (d *Dispatcher) resetMsgPackTS(pack *MsgPack, newBeginTs, newEndTs typeutil.Timestamp) {
pack.BeginTs = newBeginTs
pack.EndTs = newEndTs
startPositions := make([]*msgstream.MsgPosition, 0)
endPositions := make([]*msgstream.MsgPosition, 0)
for _, pos := range pack.StartPositions {
startPosition := typeutil.Clone(pos)
startPosition.Timestamp = newBeginTs
startPositions = append(startPositions, startPosition)
}
for _, pos := range pack.EndPositions {
endPosition := typeutil.Clone(pos)
endPosition.Timestamp = newEndTs
endPositions = append(endPositions, endPosition)
}
pack.StartPositions = startPositions
pack.EndPositions = endPositions
}
func (d *Dispatcher) nonBlockingNotify() {
select {
case d.lagNotifyChan <- struct{}{}:

View File

@ -17,6 +17,8 @@
package msgdispatcher
import (
"fmt"
"math/rand"
"sync"
"testing"
"time"
@ -26,6 +28,8 @@ import (
"github.com/stretchr/testify/mock"
"golang.org/x/net/context"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
)
@ -73,7 +77,7 @@ func TestDispatcher(t *testing.T) {
output := make(chan *msgstream.MsgPack, 1024)
getTarget := func(vchannel string, pos *Pos, ch chan *msgstream.MsgPack) *target {
target := newTarget(vchannel, pos)
target := newTarget(vchannel, pos, nil)
target.ch = ch
return target
}
@ -103,7 +107,7 @@ func TestDispatcher(t *testing.T) {
t.Run("test concurrent send and close", func(t *testing.T) {
for i := 0; i < 100; i++ {
output := make(chan *msgstream.MsgPack, 1024)
target := newTarget("mock_vchannel_0", nil)
target := newTarget("mock_vchannel_0", nil, nil)
target.ch = output
assert.Equal(t, cap(output), cap(target.ch))
wg := &sync.WaitGroup{}
@ -138,3 +142,195 @@ func BenchmarkDispatcher_handle(b *testing.B) {
// BenchmarkDispatcher_handle-12 9568 122123 ns/op
// PASS
}
func TestGroupMessage(t *testing.T) {
d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0"+fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest, nil, nil, false)
assert.NoError(t, err)
d.AddTarget(newTarget("mock_pchannel_0_1v0", nil, nil))
d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo")))
{
// no replicate msg
packs := d.groupingMsgs(&MsgPack{
BeginTs: 1,
EndTs: 10,
StartPositions: []*msgstream.MsgPosition{
{
ChannelName: "mock_pchannel_0",
MsgID: []byte("1"),
Timestamp: 1,
},
},
EndPositions: []*msgstream.MsgPosition{
{
ChannelName: "mock_pchannel_0",
MsgID: []byte("10"),
Timestamp: 10,
},
},
Msgs: []msgstream.TsMsg{
&msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 5,
EndTimestamp: 5,
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
Timestamp: 5,
},
ShardName: "mock_pchannel_0_1v0",
},
},
},
})
assert.Len(t, packs, 1)
}
{
// equal to replicateID
packs := d.groupingMsgs(&MsgPack{
BeginTs: 1,
EndTs: 10,
StartPositions: []*msgstream.MsgPosition{
{
ChannelName: "mock_pchannel_0",
MsgID: []byte("1"),
Timestamp: 1,
},
},
EndPositions: []*msgstream.MsgPosition{
{
ChannelName: "mock_pchannel_0",
MsgID: []byte("10"),
Timestamp: 10,
},
},
Msgs: []msgstream.TsMsg{
&msgstream.ReplicateMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 100,
EndTimestamp: 100,
},
ReplicateMsg: &msgpb.ReplicateMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Replicate,
Timestamp: 100,
ReplicateInfo: &commonpb.ReplicateInfo{
ReplicateID: "local-test",
},
},
},
},
},
})
assert.Len(t, packs, 2)
{
replicatePack := packs["mock_pchannel_0_2v0"]
assert.EqualValues(t, 100, replicatePack.BeginTs)
assert.EqualValues(t, 100, replicatePack.EndTs)
assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp)
assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp)
assert.Len(t, replicatePack.Msgs, 0)
}
{
replicatePack := packs["mock_pchannel_0_1v0"]
assert.EqualValues(t, 1, replicatePack.BeginTs)
assert.EqualValues(t, 10, replicatePack.EndTs)
assert.EqualValues(t, 1, replicatePack.StartPositions[0].Timestamp)
assert.EqualValues(t, 10, replicatePack.EndPositions[0].Timestamp)
assert.Len(t, replicatePack.Msgs, 0)
}
}
{
// not equal to replicateID
packs := d.groupingMsgs(&MsgPack{
BeginTs: 1,
EndTs: 10,
StartPositions: []*msgstream.MsgPosition{
{
ChannelName: "mock_pchannel_0",
MsgID: []byte("1"),
Timestamp: 1,
},
},
EndPositions: []*msgstream.MsgPosition{
{
ChannelName: "mock_pchannel_0",
MsgID: []byte("10"),
Timestamp: 10,
},
},
Msgs: []msgstream.TsMsg{
&msgstream.ReplicateMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 100,
EndTimestamp: 100,
},
ReplicateMsg: &msgpb.ReplicateMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Replicate,
Timestamp: 100,
ReplicateInfo: &commonpb.ReplicateInfo{
ReplicateID: "local-test-1", // not equal to replicateID
},
},
},
},
},
})
assert.Len(t, packs, 1)
replicatePack := packs["mock_pchannel_0_2v0"]
assert.Nil(t, replicatePack)
}
{
// replicate end
replicateTarget := d.targets["mock_pchannel_0_2v0"]
assert.NotNil(t, replicateTarget.replicateConfig)
packs := d.groupingMsgs(&MsgPack{
BeginTs: 1,
EndTs: 10,
StartPositions: []*msgstream.MsgPosition{
{
ChannelName: "mock_pchannel_0",
MsgID: []byte("1"),
Timestamp: 1,
},
},
EndPositions: []*msgstream.MsgPosition{
{
ChannelName: "mock_pchannel_0",
MsgID: []byte("10"),
Timestamp: 10,
},
},
Msgs: []msgstream.TsMsg{
&msgstream.ReplicateMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 100,
EndTimestamp: 100,
},
ReplicateMsg: &msgpb.ReplicateMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Replicate,
Timestamp: 100,
ReplicateInfo: &commonpb.ReplicateInfo{
ReplicateID: "local-test",
},
},
IsEnd: true,
Database: "foo",
},
},
},
})
assert.Len(t, packs, 2)
replicatePack := packs["mock_pchannel_0_2v0"]
assert.EqualValues(t, 100, replicatePack.BeginTs)
assert.EqualValues(t, 100, replicatePack.EndTs)
assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp)
assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp)
assert.Nil(t, replicateTarget.replicateConfig)
}
}

View File

@ -36,7 +36,7 @@ import (
)
type DispatcherManager interface {
Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error)
Remove(vchannel string)
Num() int
Run()
@ -82,7 +82,8 @@ func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) strin
return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain)
}
func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
func (c *dispatcherManager) Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) {
vchannel := streamConfig.VChannel
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
@ -102,11 +103,11 @@ func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos,
}
isMain := c.mainDispatcher == nil
d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, pos, c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets, false)
d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, streamConfig.Pos, c.constructSubName(vchannel, isMain), streamConfig.SubPos, c.lagNotifyChan, c.lagTargets, false)
if err != nil {
return nil, err
}
t := newTarget(vchannel, pos)
t := newTarget(vchannel, streamConfig.Pos, streamConfig.ReplicateConfig)
d.AddTarget(t)
if isMain {
c.mainDispatcher = d

View File

@ -48,7 +48,7 @@ func TestManager(t *testing.T) {
offset++
vchannel := fmt.Sprintf("mock-pchannel-dml_0_vchannelv%d", offset)
t.Logf("add vchannel, %s", vchannel)
_, err := c.Add(context.Background(), vchannel, nil, common.SubscriptionPositionUnknown)
_, err := c.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, offset, c.Num())
}
@ -67,11 +67,11 @@ func TestManager(t *testing.T) {
ctx := context.Background()
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
assert.NotNil(t, c)
_, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, 3, c.Num())
c.(*dispatcherManager).mainDispatcher.curTs.Store(1000)
@ -98,11 +98,11 @@ func TestManager(t *testing.T) {
ctx := context.Background()
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
assert.NotNil(t, c)
_, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, 3, c.Num())
c.(*dispatcherManager).mainDispatcher.curTs.Store(1000)
@ -134,11 +134,11 @@ func TestManager(t *testing.T) {
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
go c.Run()
assert.NotNil(t, c)
_, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
assert.Equal(t, 0, c.Num())
@ -153,18 +153,18 @@ func TestManager(t *testing.T) {
go c.Run()
assert.NotNil(t, c)
ctx := context.Background()
_, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
assert.NotPanics(t, func() {
@ -325,7 +325,7 @@ func (suite *SimulationSuite) TestDispatchToVchannels() {
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_%dv%d", suite.pchannel, collectionID, i)
output, err := suite.manager.Add(context.Background(), vchannel, nil, common.SubscriptionPositionEarliest)
output, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
@ -360,8 +360,10 @@ func (suite *SimulationSuite) TestMerge() {
for i := 0; i < vchannelNum; i++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
output, err := suite.manager.Add(context.Background(), vchannel, positions[rand.Intn(len(positions))],
common.SubscriptionPositionUnknown) // seek from random position
output, err := suite.manager.Add(context.Background(), NewStreamConfig(
vchannel, positions[rand.Intn(len(positions))],
common.SubscriptionPositionUnknown,
)) // seek from random position
assert.NoError(suite.T(), err)
suite.vchannels[vchannel] = &vchannelHelper{output: output}
}
@ -402,7 +404,7 @@ func (suite *SimulationSuite) TestSplit() {
paramtable.Get().Save(targetBufSizeK, "10")
}
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
_, err := suite.manager.Add(context.Background(), vchannel, nil, common.SubscriptionPositionEarliest)
_, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
assert.NoError(suite.T(), err)
}

View File

@ -5,13 +5,8 @@ package msgdispatcher
import (
context "context"
common "github.com/milvus-io/milvus/pkg/mq/common"
mock "github.com/stretchr/testify/mock"
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream"
mock "github.com/stretchr/testify/mock"
)
// MockClient is an autogenerated mock type for the Client type
@ -92,9 +87,9 @@ func (_c *MockClient_Deregister_Call) RunAndReturn(run func(string)) *MockClient
return _c
}
// Register provides a mock function with given fields: ctx, vchannel, pos, subPos
func (_m *MockClient) Register(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) {
ret := _m.Called(ctx, vchannel, pos, subPos)
// Register provides a mock function with given fields: ctx, streamConfig
func (_m *MockClient) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *msgstream.MsgPack, error) {
ret := _m.Called(ctx, streamConfig)
if len(ret) == 0 {
panic("no return value specified for Register")
@ -102,19 +97,19 @@ func (_m *MockClient) Register(ctx context.Context, vchannel string, pos *msgpb.
var r0 <-chan *msgstream.MsgPack
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)); ok {
return rf(ctx, vchannel, pos, subPos)
if rf, ok := ret.Get(0).(func(context.Context, *StreamConfig) (<-chan *msgstream.MsgPack, error)); ok {
return rf(ctx, streamConfig)
}
if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok {
r0 = rf(ctx, vchannel, pos, subPos)
if rf, ok := ret.Get(0).(func(context.Context, *StreamConfig) <-chan *msgstream.MsgPack); ok {
r0 = rf(ctx, streamConfig)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan *msgstream.MsgPack)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) error); ok {
r1 = rf(ctx, vchannel, pos, subPos)
if rf, ok := ret.Get(1).(func(context.Context, *StreamConfig) error); ok {
r1 = rf(ctx, streamConfig)
} else {
r1 = ret.Error(1)
}
@ -129,16 +124,14 @@ type MockClient_Register_Call struct {
// Register is a helper method to define mock.On call
// - ctx context.Context
// - vchannel string
// - pos *msgpb.MsgPosition
// - subPos common.SubscriptionInitialPosition
func (_e *MockClient_Expecter) Register(ctx interface{}, vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call {
return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, vchannel, pos, subPos)}
// - streamConfig *StreamConfig
func (_e *MockClient_Expecter) Register(ctx interface{}, streamConfig interface{}) *MockClient_Register_Call {
return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, streamConfig)}
}
func (_c *MockClient_Register_Call) Run(run func(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos common.SubscriptionInitialPosition)) *MockClient_Register_Call {
func (_c *MockClient_Register_Call) Run(run func(ctx context.Context, streamConfig *StreamConfig)) *MockClient_Register_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.MsgPosition), args[3].(common.SubscriptionInitialPosition))
run(args[0].(context.Context), args[1].(*StreamConfig))
})
return _c
}
@ -148,7 +141,7 @@ func (_c *MockClient_Register_Call) Return(_a0 <-chan *msgstream.MsgPack, _a1 er
return _c
}
func (_c *MockClient_Register_Call) RunAndReturn(run func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call {
func (_c *MockClient_Register_Call) RunAndReturn(run func(context.Context, *StreamConfig) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call {
_c.Call.Return(run)
return _c
}

View File

@ -24,6 +24,7 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -33,26 +34,33 @@ type target struct {
ch chan *MsgPack
pos *Pos
closeMu sync.Mutex
closeOnce sync.Once
closed bool
maxLag time.Duration
timer *time.Timer
closeMu sync.Mutex
closeOnce sync.Once
closed bool
maxLag time.Duration
timer *time.Timer
replicateConfig *msgstream.ReplicateConfig
cancelCh lifetime.SafeChan
}
func newTarget(vchannel string, pos *Pos) *target {
func newTarget(vchannel string, pos *Pos, replicateConfig *msgstream.ReplicateConfig) *target {
maxTolerantLag := paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second)
t := &target{
vchannel: vchannel,
ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()),
pos: pos,
cancelCh: lifetime.NewSafeChan(),
maxLag: maxTolerantLag,
timer: time.NewTimer(maxTolerantLag),
vchannel: vchannel,
ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()),
pos: pos,
cancelCh: lifetime.NewSafeChan(),
maxLag: maxTolerantLag,
timer: time.NewTimer(maxTolerantLag),
replicateConfig: replicateConfig,
}
t.closed = false
if replicateConfig != nil {
log.Info("have replicate config",
zap.String("vchannel", vchannel),
zap.String("replicateID", replicateConfig.ReplicateID))
}
return t
}

View File

@ -14,7 +14,7 @@ import (
)
func TestSendTimeout(t *testing.T) {
target := newTarget("test1", &msgpb.MsgPosition{})
target := newTarget("test1", &msgpb.MsgPosition{}, nil)
time.Sleep(paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second))

View File

@ -72,6 +72,9 @@ type mqMsgStream struct {
ttMsgEnable atomic.Value
forceEnableProduce atomic.Value
configEvent config.EventHandler
replicateID string
checkFunc CheckReplicateMsgFunc
}
// NewMqMsgStream is used to generate a new mqMsgStream object
@ -276,6 +279,23 @@ func (ms *mqMsgStream) isEnabledProduce() bool {
return ms.forceEnableProduce.Load().(bool) || ms.ttMsgEnable.Load().(bool)
}
func (ms *mqMsgStream) isSkipSystemTT() bool {
return ms.replicateID != ""
}
// checkReplicateID check the replicate id of the message, return values: isMatch, isReplicate
func (ms *mqMsgStream) checkReplicateID(msg TsMsg) (bool, bool) {
if !ms.isSkipSystemTT() {
return true, false
}
msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase })
if !ok {
log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type()))
return false, false
}
return msgBase.GetBase().GetReplicateInfo().GetReplicateID() == ms.replicateID, true
}
func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
if !ms.isEnabledProduce() {
log.Ctx(ms.ctx).Warn("can't produce the msg in the backup instance", zap.Stack("stack"))
@ -688,9 +708,9 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
startBufTime := time.Now()
var endTs uint64
var size uint64
var containsDropCollectionMsg bool
var containsEndBufferMsg bool
for ms.continueBuffering(endTs, size, startBufTime) && !containsDropCollectionMsg {
for ms.continueBuffering(endTs, size, startBufTime) && !containsEndBufferMsg {
ms.consumerLock.Lock()
// wait all channels get ttMsg
for _, consumer := range ms.consumers {
@ -726,15 +746,16 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
timeTickMsg = v
continue
}
if v.EndTs() <= currTs {
if v.EndTs() <= currTs ||
GetReplicateID(v) != "" {
size += uint64(v.Size())
timeTickBuf = append(timeTickBuf, v)
} else {
tempBuffer = append(tempBuffer, v)
}
// when drop collection, force to exit the buffer loop
if v.Type() == commonpb.MsgType_DropCollection {
containsDropCollectionMsg = true
if v.Type() == commonpb.MsgType_DropCollection || v.Type() == commonpb.MsgType_Replicate {
containsEndBufferMsg = true
}
}
ms.chanMsgBuf[consumer] = tempBuffer
@ -860,7 +881,7 @@ func (ms *MqTtMsgStream) allChanReachSameTtMsg(chanTtMsgSync map[mqwrapper.Consu
}
for consumer := range ms.chanTtMsgTime {
ms.chanTtMsgTimeMutex.RLock()
chanTtMsgSync[consumer] = (ms.chanTtMsgTime[consumer] == maxTime)
chanTtMsgSync[consumer] = ms.chanTtMsgTime[consumer] == maxTime
ms.chanTtMsgTimeMutex.RUnlock()
}
@ -960,6 +981,10 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition,
if err != nil {
return fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error())
}
// skip the replicate msg because it must have been consumed
if GetReplicateID(tsMsg) != "" {
continue
}
if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp {
runLoop = false
if time.Since(loopStarTime) > 30*time.Second {

View File

@ -708,6 +708,21 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
replicatePack := MsgPack{}
replicatePack.Msgs = append(replicatePack.Msgs, &ReplicateMsg{
BaseMsg: BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{100},
},
ReplicateMsg: &msgpb.ReplicateMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Replicate,
Timestamp: 100,
},
},
})
msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5))
@ -721,6 +736,9 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
err = inputStream.Produce(ctx, &msgPack1)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
err = inputStream.Produce(ctx, &replicatePack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
_, err = inputStream.Broadcast(ctx, &msgPack2)
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))

View File

@ -0,0 +1,78 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package msgstream
import (
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
)
type ReplicateMsg struct {
BaseMsg
*msgpb.ReplicateMsg
}
var _ TsMsg = (*ReplicateMsg)(nil)
func (r *ReplicateMsg) ID() UniqueID {
return r.Base.MsgID
}
func (r *ReplicateMsg) SetID(id UniqueID) {
r.Base.MsgID = id
}
func (r *ReplicateMsg) Type() MsgType {
return r.Base.MsgType
}
func (r *ReplicateMsg) SourceID() int64 {
return r.Base.SourceID
}
func (r *ReplicateMsg) Marshal(input TsMsg) (MarshalType, error) {
replicateMsg := input.(*ReplicateMsg)
mb, err := proto.Marshal(replicateMsg.ReplicateMsg)
if err != nil {
return nil, err
}
return mb, nil
}
func (r *ReplicateMsg) Unmarshal(input MarshalType) (TsMsg, error) {
replicateMsg := &msgpb.ReplicateMsg{}
in, err := convertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, replicateMsg)
if err != nil {
return nil, err
}
rr := &ReplicateMsg{ReplicateMsg: replicateMsg}
rr.BeginTimestamp = replicateMsg.GetBase().GetTimestamp()
rr.EndTimestamp = replicateMsg.GetBase().GetTimestamp()
return rr, nil
}
func (r *ReplicateMsg) Size() int {
return proto.Size(r.ReplicateMsg)
}

View File

@ -19,7 +19,11 @@ package msgstream
import (
"context"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -73,6 +77,50 @@ type MsgStream interface {
ForceEnableProduce(can bool)
}
type ReplicateConfig struct {
ReplicateID string
CheckFunc CheckReplicateMsgFunc
}
type CheckReplicateMsgFunc func(*ReplicateMsg) bool
func GetReplicateConfig(replicateID, dbName, colName string) *ReplicateConfig {
if replicateID == "" {
return nil
}
replicateConfig := &ReplicateConfig{
ReplicateID: replicateID,
CheckFunc: func(msg *ReplicateMsg) bool {
if !msg.GetIsEnd() {
return false
}
log.Info("check replicate msg",
zap.String("replicateID", replicateID),
zap.String("dbName", dbName),
zap.String("colName", colName),
zap.Any("msg", msg))
if msg.GetIsCluster() {
return true
}
return msg.GetDatabase() == dbName && (msg.GetCollection() == colName || msg.GetCollection() == "")
},
}
return replicateConfig
}
func GetReplicateID(msg TsMsg) string {
msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase })
if !ok {
log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type()))
return ""
}
return msgBase.GetBase().GetReplicateInfo().GetReplicateID()
}
func MatchReplicateID(msg TsMsg, replicateID string) bool {
return GetReplicateID(msg) == replicateID
}
type Factory interface {
NewMsgStream(ctx context.Context) (MsgStream, error)
NewTtMsgStream(ctx context.Context) (MsgStream, error)

View File

@ -24,6 +24,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/mq/common"
)
@ -80,3 +82,90 @@ func TestGetLatestMsgID(t *testing.T) {
assert.Equal(t, []byte("mock"), id)
}
}
func TestReplicateConfig(t *testing.T) {
t.Run("get replicate id", func(t *testing.T) {
{
msg := &InsertMsg{
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
ReplicateInfo: &commonpb.ReplicateInfo{
ReplicateID: "local",
},
},
},
}
assert.Equal(t, "local", GetReplicateID(msg))
assert.True(t, MatchReplicateID(msg, "local"))
}
{
msg := &InsertMsg{
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{},
},
}
assert.Equal(t, "", GetReplicateID(msg))
assert.False(t, MatchReplicateID(msg, "local"))
}
{
msg := &MarshalFailTsMsg{}
assert.Equal(t, "", GetReplicateID(msg))
}
})
t.Run("get replicate config", func(t *testing.T) {
{
assert.Nil(t, GetReplicateConfig("", "", ""))
}
{
rc := GetReplicateConfig("local", "db", "col")
assert.Equal(t, "local", rc.ReplicateID)
checkFunc := rc.CheckFunc
assert.False(t, checkFunc(&ReplicateMsg{
ReplicateMsg: &msgpb.ReplicateMsg{},
}))
assert.True(t, checkFunc(&ReplicateMsg{
ReplicateMsg: &msgpb.ReplicateMsg{
IsEnd: true,
IsCluster: true,
},
}))
assert.False(t, checkFunc(&ReplicateMsg{
ReplicateMsg: &msgpb.ReplicateMsg{
IsEnd: true,
Database: "db1",
},
}))
assert.True(t, checkFunc(&ReplicateMsg{
ReplicateMsg: &msgpb.ReplicateMsg{
IsEnd: true,
Database: "db",
},
}))
assert.False(t, checkFunc(&ReplicateMsg{
ReplicateMsg: &msgpb.ReplicateMsg{
IsEnd: true,
Database: "db",
Collection: "col1",
},
}))
}
{
rc := GetReplicateConfig("local", "db", "col")
checkFunc := rc.CheckFunc
assert.True(t, checkFunc(&ReplicateMsg{
ReplicateMsg: &msgpb.ReplicateMsg{
IsEnd: true,
Database: "db",
},
}))
assert.False(t, checkFunc(&ReplicateMsg{
ReplicateMsg: &msgpb.ReplicateMsg{
IsEnd: true,
Database: "db1",
Collection: "col1",
},
}))
}
})
}

View File

@ -84,6 +84,7 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher {
dropRoleMsg := DropRoleMsg{}
operateUserRoleMsg := OperateUserRoleMsg{}
operatePrivilegeMsg := OperatePrivilegeMsg{}
replicateMsg := ReplicateMsg{}
p := &ProtoUnmarshalDispatcher{}
p.TempMap = make(map[commonpb.MsgType]UnmarshalFunc)
@ -113,6 +114,7 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher {
p.TempMap[commonpb.MsgType_DropRole] = dropRoleMsg.Unmarshal
p.TempMap[commonpb.MsgType_OperateUserRole] = operateUserRoleMsg.Unmarshal
p.TempMap[commonpb.MsgType_OperatePrivilege] = operatePrivilegeMsg.Unmarshal
p.TempMap[commonpb.MsgType_Replicate] = replicateMsg.Unmarshal
return p
}

View File

@ -70,6 +70,7 @@ var (
ErrCollectionIllegalSchema = newMilvusError("illegal collection schema", 105, false)
ErrCollectionOnRecovering = newMilvusError("collection on recovering", 106, true)
ErrCollectionVectorClusteringKeyNotAllowed = newMilvusError("vector clustering key not allowed", 107, false)
ErrCollectionReplicateMode = newMilvusError("can't operate on the collection under standby mode", 108, false)
// Partition related
ErrPartitionNotFound = newMilvusError("partition not found", 200, false)

View File

@ -330,6 +330,10 @@ func WrapErrAsInputErrorWhen(err error, targets ...milvusError) error {
return err
}
func WrapErrCollectionReplicateMode(operation string) error {
return wrapFields(ErrCollectionReplicateMode, value("operation", operation))
}
func GetErrorType(err error) ErrorType {
if merr, ok := err.(milvusError); ok {
return merr.errType

View File

@ -268,6 +268,7 @@ type commonConfig struct {
MaxBloomFalsePositive ParamItem `refreshable:"true"`
BloomFilterApplyBatchSize ParamItem `refreshable:"true"`
PanicWhenPluginFail ParamItem `refreshable:"false"`
CollectionReplicateEnable ParamItem `refreshable:"true"`
UsePartitionKeyAsClusteringKey ParamItem `refreshable:"true"`
UseVectorAsClusteringKey ParamItem `refreshable:"true"`
@ -784,6 +785,15 @@ This helps Milvus-CDC synchronize incremental data`,
}
p.TTMsgEnabled.Init(base.mgr)
p.CollectionReplicateEnable = ParamItem{
Key: "common.collectionReplicateEnable",
Version: "2.4.16",
DefaultValue: "false",
Doc: `Whether to enable collection replication.`,
Export: true,
}
p.CollectionReplicateEnable.Init(base.mgr)
p.TraceLogMode = ParamItem{
Key: "common.traceLogMode",
Version: "2.3.4",