mirror of https://github.com/milvus-io/milvus.git
feat: support to replicate collection when the services contains the system tt msg (#37559)
- issue: #37105 --------- Signed-off-by: SimFG <bang.fu@zilliz.com>pull/37836/head
parent
5d014c76c7
commit
2afe2eaf3e
|
@ -873,6 +873,7 @@ common:
|
|||
bloomFilterType: BlockedBloomFilter # bloom filter type, support BasicBloomFilter and BlockedBloomFilter
|
||||
maxBloomFalsePositive: 0.001 # max false positive rate for bloom filter
|
||||
bloomFilterApplyBatchSize: 1000 # batch size when to apply pk to bloom filter
|
||||
collectionReplicateEnable: false # Whether to enable collection replication.
|
||||
usePartitionKeyAsClusteringKey: false # if true, do clustering compaction and segment prune on partition key field
|
||||
useVectorAsClusteringKey: false # if true, do clustering compaction and segment prune on vector field
|
||||
enableVectorClusteringKey: false # if true, enable vector clustering key and vector clustering compaction
|
||||
|
|
|
@ -36,6 +36,7 @@ type ROChannel interface {
|
|||
GetSchema() *schemapb.CollectionSchema
|
||||
GetCreateTimestamp() Timestamp
|
||||
GetWatchInfo() *datapb.ChannelWatchInfo
|
||||
GetDBProperties() []*commonpb.KeyValuePair
|
||||
}
|
||||
|
||||
type RWChannel interface {
|
||||
|
@ -48,6 +49,7 @@ func NewRWChannel(name string,
|
|||
startPos []*commonpb.KeyDataPair,
|
||||
schema *schemapb.CollectionSchema,
|
||||
createTs uint64,
|
||||
dbProperties []*commonpb.KeyValuePair,
|
||||
) RWChannel {
|
||||
return &StateChannel{
|
||||
Name: name,
|
||||
|
@ -55,9 +57,11 @@ func NewRWChannel(name string,
|
|||
StartPositions: startPos,
|
||||
Schema: schema,
|
||||
CreateTimestamp: createTs,
|
||||
DBProperties: dbProperties,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO fubang same as StateChannel
|
||||
type channelMeta struct {
|
||||
Name string
|
||||
CollectionID UniqueID
|
||||
|
@ -109,6 +113,10 @@ func (ch *channelMeta) String() string {
|
|||
return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions)
|
||||
}
|
||||
|
||||
func (ch *channelMeta) GetDBProperties() []*commonpb.KeyValuePair {
|
||||
return nil
|
||||
}
|
||||
|
||||
type ChannelState string
|
||||
|
||||
const (
|
||||
|
@ -126,6 +134,7 @@ type StateChannel struct {
|
|||
CollectionID UniqueID
|
||||
StartPositions []*commonpb.KeyDataPair
|
||||
Schema *schemapb.CollectionSchema
|
||||
DBProperties []*commonpb.KeyValuePair
|
||||
CreateTimestamp uint64
|
||||
Info *datapb.ChannelWatchInfo
|
||||
|
||||
|
@ -143,6 +152,7 @@ func NewStateChannel(ch RWChannel) *StateChannel {
|
|||
Schema: ch.GetSchema(),
|
||||
CreateTimestamp: ch.GetCreateTimestamp(),
|
||||
Info: ch.GetWatchInfo(),
|
||||
DBProperties: ch.GetDBProperties(),
|
||||
|
||||
assignedNode: bufferID,
|
||||
}
|
||||
|
@ -156,6 +166,7 @@ func NewStateChannelByWatchInfo(nodeID int64, info *datapb.ChannelWatchInfo) *St
|
|||
Name: info.GetVchan().GetChannelName(),
|
||||
CollectionID: info.GetVchan().GetCollectionID(),
|
||||
Schema: info.GetSchema(),
|
||||
DBProperties: info.GetDbProperties(),
|
||||
Info: info,
|
||||
assignedNode: nodeID,
|
||||
}
|
||||
|
@ -277,3 +288,7 @@ func (c *StateChannel) Assign(nodeID int64) {
|
|||
func (c *StateChannel) setState(state ChannelState) {
|
||||
c.currentState = state
|
||||
}
|
||||
|
||||
func (c *StateChannel) GetDBProperties() []*commonpb.KeyValuePair {
|
||||
return c.DBProperties
|
||||
}
|
||||
|
|
|
@ -736,20 +736,22 @@ func (m *ChannelManagerImpl) fillChannelWatchInfo(op *ChannelOp) error {
|
|||
schema := ch.GetSchema()
|
||||
if schema == nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
collInfo, err := m.h.GetCollection(ctx, ch.GetCollectionID())
|
||||
if err != nil {
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
cancel()
|
||||
schema = collInfo.Schema
|
||||
}
|
||||
|
||||
info := &datapb.ChannelWatchInfo{
|
||||
Vchan: reduceVChanSize(vcInfo),
|
||||
StartTs: startTs,
|
||||
State: inferStateByOpType(op.Type),
|
||||
Schema: schema,
|
||||
OpID: opID,
|
||||
Vchan: reduceVChanSize(vcInfo),
|
||||
StartTs: startTs,
|
||||
State: inferStateByOpType(op.Type),
|
||||
Schema: schema,
|
||||
OpID: opID,
|
||||
DbProperties: ch.GetDBProperties(),
|
||||
}
|
||||
ch.UpdateWatchInfo(info)
|
||||
}
|
||||
|
|
|
@ -138,6 +138,12 @@ type collectionInfo struct {
|
|||
VChannelNames []string
|
||||
}
|
||||
|
||||
type dbInfo struct {
|
||||
ID int64
|
||||
Name string
|
||||
Properties []*commonpb.KeyValuePair
|
||||
}
|
||||
|
||||
// NewMeta creates meta from provided `kv.TxnKV`
|
||||
func newMeta(ctx context.Context, catalog metastore.DataCoordCatalog, chunkManager storage.ChunkManager) (*meta, error) {
|
||||
im, err := newIndexMeta(ctx, catalog)
|
||||
|
@ -244,12 +250,12 @@ func (m *meta) reloadCollectionsFromRootcoord(ctx context.Context, broker broker
|
|||
return err
|
||||
}
|
||||
for _, dbName := range resp.GetDbNames() {
|
||||
resp, err := broker.ShowCollections(ctx, dbName)
|
||||
collectionsResp, err := broker.ShowCollections(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, collectionID := range resp.GetCollectionIds() {
|
||||
resp, err := broker.DescribeCollectionInternal(ctx, collectionID)
|
||||
for _, collectionID := range collectionsResp.GetCollectionIds() {
|
||||
descResp, err := broker.DescribeCollectionInternal(ctx, collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -259,14 +265,14 @@ func (m *meta) reloadCollectionsFromRootcoord(ctx context.Context, broker broker
|
|||
}
|
||||
collection := &collectionInfo{
|
||||
ID: collectionID,
|
||||
Schema: resp.GetSchema(),
|
||||
Schema: descResp.GetSchema(),
|
||||
Partitions: partitionIDs,
|
||||
StartPositions: resp.GetStartPositions(),
|
||||
Properties: funcutil.KeyValuePair2Map(resp.GetProperties()),
|
||||
CreatedAt: resp.GetCreatedTimestamp(),
|
||||
DatabaseName: resp.GetDbName(),
|
||||
DatabaseID: resp.GetDbId(),
|
||||
VChannelNames: resp.GetVirtualChannelNames(),
|
||||
StartPositions: descResp.GetStartPositions(),
|
||||
Properties: funcutil.KeyValuePair2Map(descResp.GetProperties()),
|
||||
CreatedAt: descResp.GetCreatedTimestamp(),
|
||||
DatabaseName: descResp.GetDbName(),
|
||||
DatabaseID: descResp.GetDbId(),
|
||||
VChannelNames: descResp.GetVirtualChannelNames(),
|
||||
}
|
||||
m.AddCollection(collection)
|
||||
}
|
||||
|
|
|
@ -951,7 +951,7 @@ func (s *Server) GetChannelRecoveryInfo(ctx context.Context, req *datapb.GetChan
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
channel := NewRWChannel(req.GetVchannel(), collectionID, nil, collection.Schema, 0) // TODO: remove RWChannel, just use vchannel + collectionID
|
||||
channel := NewRWChannel(req.GetVchannel(), collectionID, nil, collection.Schema, 0, nil) // TODO: remove RWChannel, just use vchannel + collectionID
|
||||
channelInfo := s.handler.GetDataVChanPositions(channel, allPartitionID)
|
||||
if channelInfo.SeekPosition == nil {
|
||||
log.Warn("channel recovery start position is not found, may collection is on creating")
|
||||
|
@ -1230,6 +1230,7 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq
|
|||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", req.GetCollectionID()),
|
||||
zap.Strings("channels", req.GetChannelNames()),
|
||||
zap.Any("dbProperties", req.GetDbProperties()),
|
||||
)
|
||||
log.Info("receive watch channels request")
|
||||
resp := &datapb.WatchChannelsResponse{
|
||||
|
@ -1242,7 +1243,7 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq
|
|||
}, nil
|
||||
}
|
||||
for _, channelName := range req.GetChannelNames() {
|
||||
ch := NewRWChannel(channelName, req.GetCollectionID(), req.GetStartPositions(), req.GetSchema(), req.GetCreateTimestamp())
|
||||
ch := NewRWChannel(channelName, req.GetCollectionID(), req.GetStartPositions(), req.GetSchema(), req.GetCreateTimestamp(), req.GetDbProperties())
|
||||
err := s.channelManager.Watch(ctx, ch)
|
||||
if err != nil {
|
||||
log.Warn("fail to watch channelName", zap.Error(err))
|
||||
|
@ -1562,6 +1563,7 @@ func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.Alt
|
|||
StartPositions: req.GetStartPositions(),
|
||||
Properties: properties,
|
||||
DatabaseID: req.GetDbID(),
|
||||
DatabaseName: req.GetSchema().GetDbName(),
|
||||
VChannelNames: req.GetVChannels(),
|
||||
}
|
||||
s.meta.AddCollection(collInfo)
|
||||
|
|
|
@ -70,7 +70,7 @@ func (s *OpRunnerSuite) SetupTest() {
|
|||
Return(nil).Maybe()
|
||||
|
||||
dispClient := msgdispatcher.NewMockClient(s.T())
|
||||
dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
dispClient.EXPECT().Register(mock.Anything, mock.Anything).
|
||||
Return(make(chan *msgstream.MsgPack), nil).Maybe()
|
||||
dispClient.EXPECT().Deregister(mock.Anything).Maybe()
|
||||
|
||||
|
|
|
@ -350,7 +350,13 @@ func NewDataSyncService(initCtx context.Context, pipelineParams *util.PipelinePa
|
|||
return nil, err
|
||||
}
|
||||
|
||||
input, err := createNewInputFromDispatcher(initCtx, pipelineParams.DispClient, info.GetVchan().GetChannelName(), info.GetVchan().GetSeekPosition())
|
||||
input, err := createNewInputFromDispatcher(initCtx,
|
||||
pipelineParams.DispClient,
|
||||
info.GetVchan().GetChannelName(),
|
||||
info.GetVchan().GetSeekPosition(),
|
||||
info.GetSchema(),
|
||||
info.GetDbProperties(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -23,8 +23,11 @@ import (
|
|||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/flowgraph"
|
||||
pkgcommon "github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
|
@ -57,11 +60,29 @@ func newDmInputNode(dmNodeConfig *nodeConfig, input <-chan *msgstream.MsgPack) *
|
|||
return node
|
||||
}
|
||||
|
||||
func createNewInputFromDispatcher(initCtx context.Context, dispatcherClient msgdispatcher.Client, vchannel string, seekPos *msgpb.MsgPosition) (<-chan *msgstream.MsgPack, error) {
|
||||
func createNewInputFromDispatcher(initCtx context.Context,
|
||||
dispatcherClient msgdispatcher.Client,
|
||||
vchannel string,
|
||||
seekPos *msgpb.MsgPosition,
|
||||
schema *schemapb.CollectionSchema,
|
||||
dbProperties []*commonpb.KeyValuePair,
|
||||
) (<-chan *msgstream.MsgPack, error) {
|
||||
log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()),
|
||||
zap.String("vchannel", vchannel))
|
||||
replicateID, _ := pkgcommon.GetReplicateID(schema.GetProperties())
|
||||
if replicateID == "" {
|
||||
log.Info("datanode consume without replicateID, try to get replicateID from dbProperties", zap.Any("dbProperties", dbProperties))
|
||||
replicateID, _ = pkgcommon.GetReplicateID(dbProperties)
|
||||
}
|
||||
replicateConfig := msgstream.GetReplicateConfig(replicateID, schema.GetDbName(), schema.GetName())
|
||||
|
||||
if seekPos != nil && len(seekPos.MsgID) != 0 {
|
||||
input, err := dispatcherClient.Register(initCtx, vchannel, seekPos, common.SubscriptionPositionUnknown)
|
||||
input, err := dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{
|
||||
VChannel: vchannel,
|
||||
Pos: seekPos,
|
||||
SubPos: common.SubscriptionPositionUnknown,
|
||||
ReplicateConfig: replicateConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -71,7 +92,12 @@ func createNewInputFromDispatcher(initCtx context.Context, dispatcherClient msgd
|
|||
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp()))))
|
||||
return input, err
|
||||
}
|
||||
input, err := dispatcherClient.Register(initCtx, vchannel, nil, common.SubscriptionPositionEarliest)
|
||||
input, err := dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{
|
||||
VChannel: vchannel,
|
||||
Pos: nil,
|
||||
SubPos: common.SubscriptionPositionEarliest,
|
||||
ReplicateConfig: replicateConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -62,6 +62,9 @@ func (mm *mockMsgStreamFactory) NewMsgStreamDisposer(ctx context.Context) func([
|
|||
|
||||
type mockTtMsgStream struct{}
|
||||
|
||||
func (mtm *mockTtMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) Close() {}
|
||||
|
||||
func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
|
||||
|
|
|
@ -65,7 +65,7 @@ func TestFlowGraphManager(t *testing.T) {
|
|||
wbm.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
dispClient := msgdispatcher.NewMockClient(t)
|
||||
dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil)
|
||||
dispClient.EXPECT().Register(mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil)
|
||||
dispClient.EXPECT().Deregister(mock.Anything)
|
||||
|
||||
pipelineParams := &util.PipelineParams{
|
||||
|
@ -151,7 +151,7 @@ func newFlowGraphManager(t *testing.T) (string, FlowgraphManager) {
|
|||
wbm.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
dispClient := msgdispatcher.NewMockClient(t)
|
||||
dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil)
|
||||
dispClient.EXPECT().Register(mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil)
|
||||
|
||||
pipelineParams := &util.PipelineParams{
|
||||
Ctx: context.TODO(),
|
||||
|
|
|
@ -15,6 +15,7 @@ type Collection struct {
|
|||
CollectionID int64
|
||||
Partitions []*Partition
|
||||
Name string
|
||||
DBName string
|
||||
Description string
|
||||
AutoID bool
|
||||
Fields []*Field
|
||||
|
@ -41,6 +42,7 @@ func (c *Collection) Clone() *Collection {
|
|||
DBID: c.DBID,
|
||||
CollectionID: c.CollectionID,
|
||||
Name: c.Name,
|
||||
DBName: c.DBName,
|
||||
Description: c.Description,
|
||||
AutoID: c.AutoID,
|
||||
Fields: CloneFields(c.Fields),
|
||||
|
@ -99,6 +101,7 @@ func UnmarshalCollectionModel(coll *pb.CollectionInfo) *Collection {
|
|||
CollectionID: coll.ID,
|
||||
DBID: coll.DbId,
|
||||
Name: coll.Schema.Name,
|
||||
DBName: coll.Schema.DbName,
|
||||
Description: coll.Schema.Description,
|
||||
AutoID: coll.Schema.AutoID,
|
||||
Fields: UnmarshalFieldModels(coll.GetSchema().GetFields()),
|
||||
|
@ -154,6 +157,7 @@ func marshalCollectionModelWithConfig(coll *Collection, c *config) *pb.Collectio
|
|||
Description: coll.Description,
|
||||
AutoID: coll.AutoID,
|
||||
EnableDynamicField: coll.EnableDynamicField,
|
||||
DbName: coll.DBName,
|
||||
}
|
||||
|
||||
if c.withFields {
|
||||
|
|
|
@ -538,6 +538,7 @@ message ChannelWatchInfo {
|
|||
// watch progress, deprecated
|
||||
int32 progress = 6;
|
||||
int64 opID = 7;
|
||||
repeated common.KeyValuePair dbProperties = 8;
|
||||
}
|
||||
|
||||
enum CompactionType {
|
||||
|
@ -655,6 +656,7 @@ message WatchChannelsRequest {
|
|||
repeated common.KeyDataPair start_positions = 3;
|
||||
schema.CollectionSchema schema = 4;
|
||||
uint64 create_timestamp = 5;
|
||||
repeated common.KeyValuePair db_properties = 6;
|
||||
}
|
||||
|
||||
message WatchChannelsResponse {
|
||||
|
|
|
@ -323,6 +323,7 @@ message LoadMetaInfo {
|
|||
string db_name = 5; // Only used for metrics label.
|
||||
string resource_group = 6; // Only used for metrics label.
|
||||
repeated int64 load_fields = 7;
|
||||
repeated common.KeyValuePair db_properties = 8;
|
||||
}
|
||||
|
||||
message WatchDmChannelsRequest {
|
||||
|
|
|
@ -146,6 +146,7 @@ service RootCoord {
|
|||
message AllocTimestampRequest {
|
||||
common.MsgBase base = 1;
|
||||
uint32 count = 3;
|
||||
uint64 blockTimestamp = 4;
|
||||
}
|
||||
|
||||
message AllocTimestampResponse {
|
||||
|
|
|
@ -160,7 +160,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
|
|||
}
|
||||
globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName)
|
||||
log.Info("complete to invalidate collection meta cache", zap.String("type", request.GetBase().GetMsgType().String()))
|
||||
case commonpb.MsgType_DropDatabase:
|
||||
case commonpb.MsgType_DropDatabase, commonpb.MsgType_AlterDatabase:
|
||||
globalMetaCache.RemoveDatabase(ctx, request.GetDbName())
|
||||
case commonpb.MsgType_AlterCollection, commonpb.MsgType_AlterCollectionField:
|
||||
if request.CollectionID != UniqueID(0) {
|
||||
|
@ -6325,13 +6325,19 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
|
||||
if paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() {
|
||||
collectionReplicateEnable := paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool()
|
||||
ttMsgEnabled := paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool()
|
||||
|
||||
// replicate message can be use in two ways, otherwise return error
|
||||
// 1. collectionReplicateEnable is false and ttMsgEnabled is false, active/standby mode
|
||||
// 2. collectionReplicateEnable is true and ttMsgEnabled is true, data migration mode
|
||||
if (!collectionReplicateEnable && ttMsgEnabled) || (collectionReplicateEnable && !ttMsgEnabled) {
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(merr.ErrDenyReplicateMessage),
|
||||
}, nil
|
||||
}
|
||||
var err error
|
||||
|
||||
var err error
|
||||
if req.GetChannelName() == "" {
|
||||
log.Ctx(ctx).Warn("channel name is empty")
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
|
@ -6369,6 +6375,18 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
StartPositions: req.StartPositions,
|
||||
EndPositions: req.EndPositions,
|
||||
}
|
||||
checkCollectionReplicateProperty := func(dbName, collectionName string) bool {
|
||||
if !collectionReplicateEnable {
|
||||
return true
|
||||
}
|
||||
replicateID, err := GetReplicateID(ctx, dbName, collectionName)
|
||||
if err != nil {
|
||||
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return false
|
||||
}
|
||||
return replicateID != ""
|
||||
}
|
||||
|
||||
// getTsMsgFromConsumerMsg
|
||||
for i, msgBytes := range req.Msgs {
|
||||
header := commonpb.MsgHeader{}
|
||||
|
@ -6388,6 +6406,9 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
}
|
||||
switch realMsg := tsMsg.(type) {
|
||||
case *msgstream.InsertMsg:
|
||||
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
|
||||
}
|
||||
assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(),
|
||||
realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs)
|
||||
if err != nil {
|
||||
|
@ -6402,6 +6423,10 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
|
|||
realMsg.SegmentID = assignSegmentID
|
||||
break
|
||||
}
|
||||
case *msgstream.DeleteMsg:
|
||||
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
|
||||
}
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
|
||||
}
|
||||
|
|
|
@ -1279,6 +1279,9 @@ func TestProxy_Delete(t *testing.T) {
|
|||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
basicInfo := &collectionInfo{
|
||||
collID: collectionID,
|
||||
}
|
||||
paramtable.Init()
|
||||
|
||||
t.Run("delete run failed", func(t *testing.T) {
|
||||
|
@ -1311,6 +1314,7 @@ func TestProxy_Delete(t *testing.T) {
|
|||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(partitionID, nil)
|
||||
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(basicInfo, nil)
|
||||
chMgr.On("getVChannels", mock.Anything).Return(channels, nil)
|
||||
chMgr.On("getChannels", mock.Anything).Return(nil, fmt.Errorf("mock error"))
|
||||
globalMetaCache = cache
|
||||
|
@ -1863,3 +1867,330 @@ func TestRegisterRestRouter(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateMessageForCollectionMode(t *testing.T) {
|
||||
paramtable.Init()
|
||||
ctx := context.Background()
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 10,
|
||||
EndTimestamp: 10,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: "foo",
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 10001,
|
||||
Timestamp: 10,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: "foo_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
SegmentID: 33,
|
||||
Timestamps: []uint64{10},
|
||||
RowIDs: []int64{66},
|
||||
NumRows: 1,
|
||||
},
|
||||
}
|
||||
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
|
||||
deleteMsg := &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 20,
|
||||
EndTimestamp: 20,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: "foo",
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
DeleteRequest: &msgpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Delete,
|
||||
MsgID: 10002,
|
||||
Timestamp: 20,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: "foo_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
},
|
||||
}
|
||||
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
|
||||
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
|
||||
t.Run("replicate message in the replicate collection mode", func(t *testing.T) {
|
||||
defer func() {
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
|
||||
}()
|
||||
|
||||
{
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "false")
|
||||
p := &Proxy{}
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "foo",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, merr.Error(r.Status))
|
||||
}
|
||||
|
||||
{
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
|
||||
p := &Proxy{}
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "foo",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, merr.Error(r.Status))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replicate message for the replicate collection mode", func(t *testing.T) {
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
|
||||
defer func() {
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
|
||||
}()
|
||||
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Twice()
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Twice()
|
||||
globalMetaCache = mockCache
|
||||
|
||||
{
|
||||
p := &Proxy{
|
||||
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
|
||||
}
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "foo",
|
||||
Msgs: [][]byte{insertMsgBytes.([]byte)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
|
||||
}
|
||||
|
||||
{
|
||||
p := &Proxy{
|
||||
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
|
||||
}
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "foo",
|
||||
Msgs: [][]byte{deleteMsgBytes.([]byte)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAlterCollectionReplicateProperty(t *testing.T) {
|
||||
paramtable.Init()
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
|
||||
defer func() {
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
|
||||
}()
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
||||
replicateID: "local-milvus",
|
||||
}, nil).Maybe()
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe()
|
||||
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
factory := newMockMsgStreamFactory()
|
||||
msgStreamObj := msgstream.NewMockMsgStream(t)
|
||||
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe()
|
||||
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe()
|
||||
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return().Maybe()
|
||||
msgStreamObj.EXPECT().Close().Return().Maybe()
|
||||
mockMsgID1 := mqcommon.NewMockMessageID(t)
|
||||
mockMsgID2 := mqcommon.NewMockMessageID(t)
|
||||
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe()
|
||||
msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
|
||||
"alter_property": {mockMsgID1, mockMsgID2},
|
||||
}, nil).Maybe()
|
||||
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return msgStreamObj, nil
|
||||
}
|
||||
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
|
||||
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
|
||||
|
||||
ctx := context.Background()
|
||||
var startTt uint64 = 10
|
||||
startTime := time.Now()
|
||||
dataCoord := &mockDataCoord{}
|
||||
dataCoord.expireTime = Timestamp(1000)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, func() Timestamp {
|
||||
return Timestamp(time.Since(startTime).Seconds()) + startTt
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
segAllocator.Start()
|
||||
|
||||
mockRootcoord := mocks.NewMockRootCoordClient(t)
|
||||
mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *rootcoordpb.AllocTimestampRequest, option ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
|
||||
return &rootcoordpb.AllocTimestampResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Timestamp: Timestamp(time.Since(startTime).Seconds()) + startTt,
|
||||
}, nil
|
||||
})
|
||||
mockRootcoord.EXPECT().AlterCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil)
|
||||
|
||||
p := &Proxy{
|
||||
ctx: ctx,
|
||||
replicateStreamManager: manager,
|
||||
segAssigner: segAllocator,
|
||||
rootCoord: mockRootcoord,
|
||||
}
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
p.sched, err = newTaskScheduler(p.ctx, tsoAllocatorIns, p.factory)
|
||||
assert.NoError(t, err)
|
||||
p.sched.Start()
|
||||
defer p.sched.Close()
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
getInsertMsgBytes := func(channel string, ts uint64) []byte {
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: ts,
|
||||
EndTimestamp: ts,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: channel,
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 10001,
|
||||
Timestamp: ts,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: channel + "_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
SegmentID: 33,
|
||||
Timestamps: []uint64{ts},
|
||||
RowIDs: []int64{66},
|
||||
NumRows: 1,
|
||||
},
|
||||
}
|
||||
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
|
||||
return insertMsgBytes.([]byte)
|
||||
}
|
||||
getDeleteMsgBytes := func(channel string, ts uint64) []byte {
|
||||
deleteMsg := &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: ts,
|
||||
EndTimestamp: ts,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: "foo",
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
DeleteRequest: &msgpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Delete,
|
||||
MsgID: 10002,
|
||||
Timestamp: ts,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: channel + "_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
},
|
||||
}
|
||||
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
|
||||
return deleteMsgBytes.([]byte)
|
||||
}
|
||||
|
||||
go func() {
|
||||
// replicate message
|
||||
var replicateResp *milvuspb.ReplicateMessageResponse
|
||||
var err error
|
||||
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "alter_property_1",
|
||||
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+5)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
|
||||
|
||||
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "alter_property_2",
|
||||
Msgs: [][]byte{getDeleteMsgBytes("alter_property_2", startTt+5)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "alter_property_1",
|
||||
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+10)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
|
||||
|
||||
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "alter_property_2",
|
||||
Msgs: [][]byte{getInsertMsgBytes("alter_property_2", startTt+10)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
|
||||
}()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// alter collection property
|
||||
statusResp, err := p.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "replicate.endTS",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, merr.Ok(statusResp))
|
||||
}
|
||||
|
|
|
@ -102,10 +102,12 @@ type collectionInfo struct {
|
|||
createdUtcTimestamp uint64
|
||||
consistencyLevel commonpb.ConsistencyLevel
|
||||
partitionKeyIsolation bool
|
||||
replicateID string
|
||||
}
|
||||
|
||||
type databaseInfo struct {
|
||||
dbID typeutil.UniqueID
|
||||
properties []*commonpb.KeyValuePair
|
||||
createdTimestamp uint64
|
||||
}
|
||||
|
||||
|
@ -478,6 +480,7 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string,
|
|||
m.collInfo[database] = make(map[string]*collectionInfo)
|
||||
}
|
||||
|
||||
replicateID, _ := common.GetReplicateID(collection.Properties)
|
||||
m.collInfo[database][collectionName] = &collectionInfo{
|
||||
collID: collection.CollectionID,
|
||||
schema: schemaInfo,
|
||||
|
@ -486,6 +489,7 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string,
|
|||
createdUtcTimestamp: collection.CreatedUtcTimestamp,
|
||||
consistencyLevel: collection.ConsistencyLevel,
|
||||
partitionKeyIsolation: isolation,
|
||||
replicateID: replicateID,
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName),
|
||||
|
@ -571,10 +575,19 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll
|
|||
method := "GetCollectionInfo"
|
||||
// if collInfo.collID != collectionID, means that the cache is not trustable
|
||||
// try to get collection according to collectionID
|
||||
if !ok || collInfo.collID != collectionID {
|
||||
// Why use collectionID? Because the collectionID is not always provided in the proxy.
|
||||
if !ok || (collectionID != 0 && collInfo.collID != collectionID) {
|
||||
tr := timerecord.NewTimeRecorder("UpdateCache")
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
|
||||
|
||||
if collectionID == 0 {
|
||||
collInfo, err := m.UpdateByName(ctx, database, collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return collInfo, nil
|
||||
}
|
||||
collInfo, err := m.UpdateByID(ctx, database, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -1225,6 +1238,7 @@ func (m *MetaCache) GetDatabaseInfo(ctx context.Context, database string) (*data
|
|||
defer m.mu.Unlock()
|
||||
dbInfo := &databaseInfo{
|
||||
dbID: resp.GetDbID(),
|
||||
properties: resp.Properties,
|
||||
createdTimestamp: resp.GetCreatedTimestamp(),
|
||||
}
|
||||
m.dbInfo[database] = dbInfo
|
||||
|
|
|
@ -304,6 +304,87 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestMetaCacheGetCollectionWithUpdate(t *testing.T) {
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
ctx := context.Background()
|
||||
rootCoord := mocks.NewMockRootCoordClient(t)
|
||||
queryCoord := mocks.NewMockQueryCoordClient(t)
|
||||
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
|
||||
assert.NoError(t, err)
|
||||
t.Run("update with name", func(t *testing.T) {
|
||||
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
Status: merr.Success(),
|
||||
CollectionID: 1,
|
||||
Schema: &schemapb.CollectionSchema{
|
||||
Name: "bar",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1,
|
||||
Name: "p",
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "pk",
|
||||
},
|
||||
},
|
||||
},
|
||||
ShardsNum: 1,
|
||||
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
|
||||
VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"},
|
||||
}, nil).Once()
|
||||
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
|
||||
Status: merr.Success(),
|
||||
PartitionIDs: []typeutil.UniqueID{11},
|
||||
PartitionNames: []string{"p1"},
|
||||
CreatedTimestamps: []uint64{11},
|
||||
CreatedUtcTimestamps: []uint64{11},
|
||||
}, nil).Once()
|
||||
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once()
|
||||
c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "bar", 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.collID, int64(1))
|
||||
assert.Equal(t, c.schema.Name, "bar")
|
||||
})
|
||||
|
||||
t.Run("update with name", func(t *testing.T) {
|
||||
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
Status: merr.Success(),
|
||||
CollectionID: 1,
|
||||
Schema: &schemapb.CollectionSchema{
|
||||
Name: "bar",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1,
|
||||
Name: "p",
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "pk",
|
||||
},
|
||||
},
|
||||
},
|
||||
ShardsNum: 1,
|
||||
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
|
||||
VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"},
|
||||
}, nil).Once()
|
||||
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
|
||||
Status: merr.Success(),
|
||||
PartitionIDs: []typeutil.UniqueID{11},
|
||||
PartitionNames: []string{"p1"},
|
||||
CreatedTimestamps: []uint64{11},
|
||||
CreatedUtcTimestamps: []uint64{11},
|
||||
}, nil).Once()
|
||||
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once()
|
||||
c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "hoo", 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.collID, int64(1))
|
||||
assert.Equal(t, c.schema.Name, "bar")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetaCache_GetCollectionName(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockRootCoordClientInterface{}
|
||||
|
|
|
@ -40,6 +40,9 @@ func (m *mockMsgStream) ForceEnableProduce(enabled bool) {
|
|||
}
|
||||
}
|
||||
|
||||
func (m *mockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
|
||||
}
|
||||
|
||||
func newMockMsgStream() *mockMsgStream {
|
||||
return &mockMsgStream{}
|
||||
}
|
||||
|
|
|
@ -314,6 +314,9 @@ func (ms *simpleMockMsgStream) CheckTopicValid(topic string) error {
|
|||
func (ms *simpleMockMsgStream) ForceEnableProduce(enabled bool) {
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) {
|
||||
}
|
||||
|
||||
func newSimpleMockMsgStream() *simpleMockMsgStream {
|
||||
return &simpleMockMsgStream{
|
||||
msgChan: make(chan *msgstream.MsgPack, 1024),
|
||||
|
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/ctokenizer"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
|
@ -1081,6 +1082,25 @@ func (t *alterCollectionTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
_, ok := common.IsReplicateEnabled(t.Properties)
|
||||
if ok {
|
||||
return merr.WrapErrParameterInvalidMsg("can't set the replicate.id property")
|
||||
}
|
||||
endTS, ok := common.GetReplicateEndTS(t.Properties)
|
||||
if ok && collBasicInfo.replicateID != "" {
|
||||
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
|
||||
Count: 1,
|
||||
BlockTimestamp: endTS,
|
||||
})
|
||||
if err = merr.CheckRPCCall(allocResp, err); err != nil {
|
||||
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
|
||||
}
|
||||
if allocResp.GetTimestamp() <= endTS {
|
||||
return merr.WrapErrServiceInternal("alter collection: alloc timestamp failed, timestamp is not greater than endTS",
|
||||
fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -9,6 +10,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
|
@ -274,6 +276,34 @@ func (t *alterDatabaseTask) OnEnqueue() error {
|
|||
}
|
||||
|
||||
func (t *alterDatabaseTask) PreExecute(ctx context.Context) error {
|
||||
_, ok := common.GetReplicateID(t.Properties)
|
||||
if ok {
|
||||
return merr.WrapErrParameterInvalidMsg("can't set the replicate id property in alter database request")
|
||||
}
|
||||
endTS, ok := common.GetReplicateEndTS(t.Properties)
|
||||
if !ok { // not exist replicate end ts property
|
||||
return nil
|
||||
}
|
||||
cacheInfo, err := globalMetaCache.GetDatabaseInfo(ctx, t.DbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oldReplicateEnable, _ := common.IsReplicateEnabled(cacheInfo.properties)
|
||||
if !oldReplicateEnable { // old replicate enable is false
|
||||
return merr.WrapErrParameterInvalidMsg("can't set the replicate end ts property in alter database request when db replicate is disabled")
|
||||
}
|
||||
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
|
||||
Count: 1,
|
||||
BlockTimestamp: endTS,
|
||||
})
|
||||
if err = merr.CheckRPCCall(allocResp, err); err != nil {
|
||||
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
|
||||
}
|
||||
if allocResp.GetTimestamp() <= endTS {
|
||||
return merr.WrapErrServiceInternal("alter database: alloc timestamp failed, timestamp is not greater than endTS",
|
||||
fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
@ -201,6 +202,163 @@ func TestAlterDatabase(t *testing.T) {
|
|||
assert.Nil(t, err1)
|
||||
}
|
||||
|
||||
func TestAlterDatabaseTaskForReplicateProperty(t *testing.T) {
|
||||
rc := mocks.NewMockRootCoordClient(t)
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
mockCache := NewMockCache(t)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
t.Run("replicate id", func(t *testing.T) {
|
||||
task := &alterDatabaseTask{
|
||||
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
DbName: "test_alter_database",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MmapEnabledKey,
|
||||
Value: "true",
|
||||
},
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: rc,
|
||||
}
|
||||
err := task.PreExecute(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail to get database info", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
|
||||
task := &alterDatabaseTask{
|
||||
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
DbName: "test_alter_database",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "1000",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: rc,
|
||||
}
|
||||
err := task.PreExecute(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not enable replicate", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
properties: []*commonpb.KeyValuePair{},
|
||||
}, nil).Once()
|
||||
task := &alterDatabaseTask{
|
||||
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
DbName: "test_alter_database",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "1000",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: rc,
|
||||
}
|
||||
err := task.PreExecute(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail to alloc ts", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
}, nil).Once()
|
||||
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
|
||||
task := &alterDatabaseTask{
|
||||
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
DbName: "test_alter_database",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "1000",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: rc,
|
||||
}
|
||||
err := task.PreExecute(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("alloc wrong ts", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
}, nil).Once()
|
||||
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
|
||||
Status: merr.Success(),
|
||||
Timestamp: 999,
|
||||
}, nil).Once()
|
||||
task := &alterDatabaseTask{
|
||||
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
DbName: "test_alter_database",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "1000",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: rc,
|
||||
}
|
||||
err := task.PreExecute(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("alloc wrong ts", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
}, nil).Once()
|
||||
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
|
||||
Status: merr.Success(),
|
||||
Timestamp: 1001,
|
||||
}, nil).Once()
|
||||
task := &alterDatabaseTask{
|
||||
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
DbName: "test_alter_database",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "1000",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: rc,
|
||||
}
|
||||
err := task.PreExecute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDescribeDatabaseTask(t *testing.T) {
|
||||
rc := mocks.NewMockRootCoordClient(t)
|
||||
|
||||
|
|
|
@ -335,6 +335,15 @@ func (dr *deleteRunner) Init(ctx context.Context) error {
|
|||
return ErrWithLog(log, "Failed to get collection id", merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound))
|
||||
}
|
||||
|
||||
replicateID, err := GetReplicateID(ctx, dr.req.GetDbName(), collName)
|
||||
if err != nil {
|
||||
log.Warn("get replicate info failed", zap.String("collectionName", collName), zap.Error(err))
|
||||
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
|
||||
}
|
||||
if replicateID != "" {
|
||||
return merr.WrapErrCollectionReplicateMode("delete")
|
||||
}
|
||||
|
||||
dr.schema, err = globalMetaCache.GetCollectionSchema(ctx, dr.req.GetDbName(), collName)
|
||||
if err != nil {
|
||||
return ErrWithLog(log, "Failed to get collection schema", err)
|
||||
|
|
|
@ -297,6 +297,45 @@ func TestDeleteRunner_Init(t *testing.T) {
|
|||
assert.Error(t, dr.Init(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("fail to get collection info", func(t *testing.T) {
|
||||
dr := deleteRunner{req: &milvuspb.DeleteRequest{
|
||||
CollectionName: collectionName,
|
||||
DbName: dbName,
|
||||
}}
|
||||
cache := NewMockCache(t)
|
||||
cache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil)
|
||||
cache.On("GetCollectionID",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(collectionID, nil)
|
||||
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil,
|
||||
errors.New("mock get collection info"))
|
||||
|
||||
globalMetaCache = cache
|
||||
assert.Error(t, dr.Init(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("deny delete in the replicate mode", func(t *testing.T) {
|
||||
dr := deleteRunner{req: &milvuspb.DeleteRequest{
|
||||
CollectionName: collectionName,
|
||||
DbName: dbName,
|
||||
}}
|
||||
cache := NewMockCache(t)
|
||||
cache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil)
|
||||
cache.On("GetCollectionID",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(collectionID, nil)
|
||||
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
||||
replicateID: "local-mac",
|
||||
}, nil)
|
||||
|
||||
globalMetaCache = cache
|
||||
assert.Error(t, dr.Init(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("fail get collection schema", func(t *testing.T) {
|
||||
dr := deleteRunner{req: &milvuspb.DeleteRequest{
|
||||
CollectionName: collectionName,
|
||||
|
@ -309,6 +348,7 @@ func TestDeleteRunner_Init(t *testing.T) {
|
|||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(collectionID, nil)
|
||||
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
|
||||
cache.On("GetCollectionSchema",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
|
@ -332,6 +372,7 @@ func TestDeleteRunner_Init(t *testing.T) {
|
|||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(collectionID, nil)
|
||||
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
|
||||
cache.On("GetCollectionSchema",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
|
@ -376,6 +417,7 @@ func TestDeleteRunner_Init(t *testing.T) {
|
|||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(schema, nil)
|
||||
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
|
||||
|
||||
globalMetaCache = cache
|
||||
assert.Error(t, dr.Init(context.Background()))
|
||||
|
@ -402,6 +444,7 @@ func TestDeleteRunner_Init(t *testing.T) {
|
|||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(schema, nil)
|
||||
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
|
||||
cache.On("GetPartitionID",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
|
@ -431,6 +474,7 @@ func TestDeleteRunner_Init(t *testing.T) {
|
|||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(collectionID, nil)
|
||||
cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil)
|
||||
cache.On("GetCollectionSchema",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
|
|
|
@ -125,6 +125,15 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
return merr.WrapErrAsInputError(merr.WrapErrParameterTooLarge("insert request size exceeds maxInsertSize"))
|
||||
}
|
||||
|
||||
replicateID, err := GetReplicateID(it.ctx, it.insertMsg.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return merr.WrapErrAsInputError(err)
|
||||
}
|
||||
if replicateID != "" {
|
||||
return merr.WrapErrCollectionReplicateMode("insert")
|
||||
}
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, it.insertMsg.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
|
|
|
@ -41,6 +41,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
|
@ -1708,8 +1709,8 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := int32(2)
|
||||
prefix := "TestTask_all"
|
||||
dbName := ""
|
||||
prefix := "TestTask_int64pk"
|
||||
dbName := "int64PK"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
partitionName := prefix + funcutil.GenRandomStr()
|
||||
|
||||
|
@ -1726,45 +1727,43 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
}
|
||||
nb := 10
|
||||
|
||||
t.Run("create collection", func(t *testing.T) {
|
||||
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
assert.NoError(t, createColT.OnEnqueue())
|
||||
assert.NoError(t, createColT.PreExecute(ctx))
|
||||
assert.NoError(t, createColT.Execute(ctx))
|
||||
assert.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_CreatePartition,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
})
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
assert.NoError(t, createColT.OnEnqueue())
|
||||
assert.NoError(t, createColT.PreExecute(ctx))
|
||||
assert.NoError(t, createColT.Execute(ctx))
|
||||
assert.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_CreatePartition,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
})
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
|
@ -1957,7 +1956,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
|
||||
shardsNum := int32(2)
|
||||
prefix := "TestTask_all"
|
||||
dbName := ""
|
||||
dbName := "testvarchar"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
partitionName := prefix + funcutil.GenRandomStr()
|
||||
|
||||
|
@ -1975,45 +1974,43 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
}
|
||||
nb := 10
|
||||
|
||||
t.Run("create collection", func(t *testing.T) {
|
||||
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testVarCharField, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testVarCharField, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
assert.NoError(t, createColT.OnEnqueue())
|
||||
assert.NoError(t, createColT.PreExecute(ctx))
|
||||
assert.NoError(t, createColT.Execute(ctx))
|
||||
assert.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_CreatePartition,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
})
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
assert.NoError(t, createColT.OnEnqueue())
|
||||
assert.NoError(t, createColT.PreExecute(ctx))
|
||||
assert.NoError(t, createColT.Execute(ctx))
|
||||
assert.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_CreatePartition,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
})
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
|
@ -3444,30 +3441,28 @@ func TestPartitionKey(t *testing.T) {
|
|||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("create collection", func(t *testing.T) {
|
||||
createCollectionTask := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
|
||||
Timestamp: Timestamp(time.Now().UnixNano()),
|
||||
},
|
||||
DbName: "",
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
NumPartitions: common.DefaultPartitionsWithPartitionKey,
|
||||
createCollectionTask := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
|
||||
Timestamp: Timestamp(time.Now().UnixNano()),
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
err = createCollectionTask.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = createCollectionTask.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
DbName: "",
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
NumPartitions: common.DefaultPartitionsWithPartitionKey,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
err = createCollectionTask.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = createCollectionTask.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
@ -3500,7 +3495,7 @@ func TestPartitionKey(t *testing.T) {
|
|||
_ = segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames)))
|
||||
|
||||
|
@ -4269,3 +4264,136 @@ func TestTaskPartitionKeyIsolation(t *testing.T) {
|
|||
"can not alter partition key isolation mode if the collection already has a vector index. Please drop the index first")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAlterCollectionForReplicateProperty(t *testing.T) {
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
||||
replicateID: "local-mac-1",
|
||||
}, nil).Maybe()
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe()
|
||||
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil).Maybe()
|
||||
globalMetaCache = mockCache
|
||||
ctx := context.Background()
|
||||
mockRootcoord := mocks.NewMockRootCoordClient(t)
|
||||
t.Run("invalid replicate id", func(t *testing.T) {
|
||||
task := &alterCollectionTask{
|
||||
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "xxxxx",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: mockRootcoord,
|
||||
}
|
||||
|
||||
err := task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty replicate id", func(t *testing.T) {
|
||||
task := &alterCollectionTask{
|
||||
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
|
||||
CollectionName: "test",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: mockRootcoord,
|
||||
}
|
||||
|
||||
err := task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail to alloc ts", func(t *testing.T) {
|
||||
task := &alterCollectionTask{
|
||||
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
|
||||
CollectionName: "test",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "100",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: mockRootcoord,
|
||||
}
|
||||
|
||||
mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
|
||||
err := task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("alloc wrong ts", func(t *testing.T) {
|
||||
task := &alterCollectionTask{
|
||||
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
|
||||
CollectionName: "test",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "100",
|
||||
},
|
||||
},
|
||||
},
|
||||
rootCoord: mockRootcoord,
|
||||
}
|
||||
|
||||
mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
|
||||
Status: merr.Success(),
|
||||
Timestamp: 99,
|
||||
}, nil).Once()
|
||||
err := task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertForReplicate(t *testing.T) {
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
mockCache := NewMockCache(t)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
t.Run("get replicate id fail", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
|
||||
task := &insertTask{
|
||||
insertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
CollectionName: "foo",
|
||||
},
|
||||
},
|
||||
}
|
||||
err := task.PreExecute(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("insert with replicate id", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
||||
schema: &schemaInfo{
|
||||
CollectionSchema: &schemapb.CollectionSchema{
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-mac",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
replicateID: "local-mac",
|
||||
}, nil).Once()
|
||||
task := &insertTask{
|
||||
insertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
CollectionName: "foo",
|
||||
},
|
||||
},
|
||||
}
|
||||
err := task.PreExecute(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -292,6 +292,15 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
|||
Timestamp: it.EndTs(),
|
||||
}
|
||||
|
||||
replicateID, err := GetReplicateID(ctx, it.req.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("get replicate info failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
|
||||
}
|
||||
if replicateID != "" {
|
||||
return merr.WrapErrCollectionReplicateMode("upsert")
|
||||
}
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("Failed to get collection schema",
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
|
@ -325,3 +326,37 @@ func TestUpsertTask(t *testing.T) {
|
|||
assert.ElementsMatch(t, channels, ut.pChannels)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpsertTaskForReplicate(t *testing.T) {
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
mockCache := NewMockCache(t)
|
||||
globalMetaCache = mockCache
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("fail to get collection info", func(t *testing.T) {
|
||||
ut := upsertTask{
|
||||
ctx: ctx,
|
||||
req: &milvuspb.UpsertRequest{
|
||||
CollectionName: "col-0",
|
||||
},
|
||||
}
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("foo")).Once()
|
||||
err := ut.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("replicate mode", func(t *testing.T) {
|
||||
ut := upsertTask{
|
||||
ctx: ctx,
|
||||
req: &milvuspb.UpsertRequest{
|
||||
CollectionName: "col-0",
|
||||
},
|
||||
}
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
||||
replicateID: "local-mac",
|
||||
}, nil).Once()
|
||||
err := ut.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2212,3 +2212,22 @@ func GetFailedResponse(req any, err error) any {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetReplicateID(ctx context.Context, database, collectionName string) (string, error) {
|
||||
if globalMetaCache == nil {
|
||||
return "", merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
||||
}
|
||||
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, database, collectionName, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if colInfo.replicateID != "" {
|
||||
return colInfo.replicateID, nil
|
||||
}
|
||||
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, database)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
replicateID, _ := common.GetReplicateID(dbInfo.properties)
|
||||
return replicateID, nil
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ import (
|
|||
coordMocks "github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/checkers"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/dist"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
|
@ -614,6 +615,7 @@ func (suite *ServerSuite) hackServer() {
|
|||
)
|
||||
|
||||
suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Schema: &schemapb.CollectionSchema{}}, nil).Maybe()
|
||||
suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil).Maybe()
|
||||
suite.broker.EXPECT().ListIndexes(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
for _, collection := range suite.collections {
|
||||
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
|
||||
|
|
|
@ -56,9 +56,7 @@ var (
|
|||
)
|
||||
|
||||
func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
log := log.Ctx(ctx).With(zap.Int64s("collections", req.GetCollectionIDs()))
|
||||
|
||||
log.Info("show collections request received")
|
||||
log.Ctx(ctx).Debug("show collections request received", zap.Int64s("collections", req.GetCollectionIDs()))
|
||||
if err := merr.CheckHealthy(s.State()); err != nil {
|
||||
msg := "failed to show collections"
|
||||
log.Warn(msg, zap.Error(err))
|
||||
|
|
|
@ -341,18 +341,23 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
|
|||
|
||||
collectionInfo, err := ex.broker.DescribeCollection(ctx, task.CollectionID())
|
||||
if err != nil {
|
||||
log.Warn("failed to get collection info")
|
||||
log.Warn("failed to get collection info", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
loadFields := ex.meta.GetLoadFields(ctx, task.CollectionID())
|
||||
partitions, err := utils.GetPartitions(ctx, ex.targetMgr, task.CollectionID())
|
||||
if err != nil {
|
||||
log.Warn("failed to get partitions of collection")
|
||||
log.Warn("failed to get partitions of collection", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
indexInfo, err := ex.broker.ListIndexes(ctx, task.CollectionID())
|
||||
if err != nil {
|
||||
log.Warn("fail to get index meta of collection")
|
||||
log.Warn("fail to get index meta of collection", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
dbResp, err := ex.broker.DescribeDatabase(ctx, collectionInfo.GetDbName())
|
||||
if err != nil {
|
||||
log.Warn("failed to get database info", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
loadMeta := packLoadMeta(
|
||||
|
@ -363,6 +368,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
|
|||
loadFields,
|
||||
partitions...,
|
||||
)
|
||||
loadMeta.DbProperties = dbResp.GetProperties()
|
||||
|
||||
dmChannel := ex.targetMgr.GetDmChannel(ctx, task.CollectionID(), action.ChannelName(), meta.NextTarget)
|
||||
if dmChannel == nil {
|
||||
|
|
|
@ -38,6 +38,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
|
||||
. "github.com/milvus-io/milvus/internal/querycoordv2/params"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
|
@ -230,6 +231,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
|
|||
},
|
||||
}, nil
|
||||
})
|
||||
suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil)
|
||||
for channel, segment := range suite.growingSegments {
|
||||
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).
|
||||
Return([]*datapb.SegmentInfo{
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
|
@ -46,7 +47,7 @@ func TestGetPipelineJSON(t *testing.T) {
|
|||
|
||||
collectionManager := segments.NewMockCollectionManager(t)
|
||||
segmentManager := segments.NewMockSegmentManager(t)
|
||||
collectionManager.EXPECT().Get(mock.Anything).Return(&segments.Collection{})
|
||||
collectionManager.EXPECT().Get(mock.Anything).Return(segments.NewTestCollection(1, querypb.LoadType_UnKnownType, &schemapb.CollectionSchema{}))
|
||||
manager := &segments.Manager{
|
||||
Collection: collectionManager,
|
||||
Segment: segmentManager,
|
||||
|
|
|
@ -72,7 +72,7 @@ func (suite *FilterNodeSuite) TestWithLoadCollection() {
|
|||
suite.validSegmentIDs = []int64{2, 3, 4, 5, 6}
|
||||
|
||||
// mock
|
||||
collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadCollection)
|
||||
collection := segments.NewTestCollection(suite.collectionID, querypb.LoadType_LoadCollection, nil)
|
||||
for _, partitionID := range suite.partitionIDs {
|
||||
collection.AddPartition(partitionID)
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ func (suite *FilterNodeSuite) TestWithLoadPartation() {
|
|||
suite.validSegmentIDs = []int64{2, 3, 4, 5, 6}
|
||||
|
||||
// mock
|
||||
collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadPartition)
|
||||
collection := segments.NewTestCollection(suite.collectionID, querypb.LoadType_LoadPartition, nil)
|
||||
collection.AddPartition(suite.partitionIDs[0])
|
||||
|
||||
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
|
||||
|
|
|
@ -85,7 +85,7 @@ func (m *manager) Add(collectionID UniqueID, channel string) (Pipeline, error) {
|
|||
return nil, merr.WrapErrChannelNotFound(channel, "delegator not found")
|
||||
}
|
||||
|
||||
newPipeLine, err := NewPipeLine(collectionID, channel, m.dataManager, m.dispatcher, delegator)
|
||||
newPipeLine, err := NewPipeLine(collection, channel, m.dataManager, m.dispatcher, delegator)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrServiceUnavailable(err.Error(), "failed to create new pipeline")
|
||||
}
|
||||
|
|
|
@ -24,9 +24,10 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
|
@ -73,9 +74,9 @@ func (suite *PipelineManagerTestSuite) SetupTest() {
|
|||
func (suite *PipelineManagerTestSuite) TestBasic() {
|
||||
// init mock
|
||||
// mock collection manager
|
||||
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(&segments.Collection{})
|
||||
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(segments.NewTestCollection(suite.collectionID, querypb.LoadType_UnKnownType, &schemapb.CollectionSchema{}))
|
||||
// mock mq factory
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
|
||||
|
||||
// build manager
|
||||
|
|
|
@ -19,7 +19,9 @@ package pipeline
|
|||
import (
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
base "github.com/milvus-io/milvus/internal/util/pipeline"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
|
@ -45,17 +47,23 @@ func (p *pipeline) Close() {
|
|||
}
|
||||
|
||||
func NewPipeLine(
|
||||
collectionID UniqueID,
|
||||
collection *Collection,
|
||||
channel string,
|
||||
manager *DataManager,
|
||||
dispatcher msgdispatcher.Client,
|
||||
delegator delegator.ShardDelegator,
|
||||
) (Pipeline, error) {
|
||||
collectionID := collection.ID()
|
||||
replicateID, _ := common.GetReplicateID(collection.Schema().GetProperties())
|
||||
if replicateID == "" {
|
||||
replicateID, _ = common.GetReplicateID(collection.GetDBProperties())
|
||||
}
|
||||
replicateConfig := msgstream.GetReplicateConfig(replicateID, collection.GetDBName(), collection.Schema().Name)
|
||||
pipelineQueueLength := paramtable.Get().QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()
|
||||
|
||||
p := &pipeline{
|
||||
collectionID: collectionID,
|
||||
StreamPipeline: base.NewPipelineWithStream(dispatcher, nodeCtxTtInterval, enableTtChecker, channel),
|
||||
StreamPipeline: base.NewPipelineWithStream(dispatcher, nodeCtxTtInterval, enableTtChecker, channel, replicateConfig),
|
||||
}
|
||||
|
||||
filterNode := newFilterNode(collectionID, channel, manager, delegator, pipelineQueueLength)
|
||||
|
|
|
@ -24,13 +24,14 @@ import (
|
|||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
|
@ -103,11 +104,17 @@ func (suite *PipelineTestSuite) TestBasic() {
|
|||
schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
|
||||
collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
DbProperties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
})
|
||||
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection)
|
||||
|
||||
// mock mq factory
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
|
||||
|
||||
// mock delegator
|
||||
|
@ -136,16 +143,16 @@ func (suite *PipelineTestSuite) TestBasic() {
|
|||
Collection: suite.collectionManager,
|
||||
Segment: suite.segmentManager,
|
||||
}
|
||||
pipeline, err := NewPipeLine(suite.collectionID, suite.channel, manager, suite.msgDispatcher, suite.delegator)
|
||||
pipelineObj, err := NewPipeLine(collection, suite.channel, manager, suite.msgDispatcher, suite.delegator)
|
||||
suite.NoError(err)
|
||||
|
||||
// Init Consumer
|
||||
err = pipeline.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{})
|
||||
err = pipelineObj.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{})
|
||||
suite.NoError(err)
|
||||
|
||||
err = pipeline.Start()
|
||||
err = pipelineObj.Start()
|
||||
suite.NoError(err)
|
||||
defer pipeline.Close()
|
||||
defer pipelineObj.Close()
|
||||
|
||||
// build input msg
|
||||
in := suite.buildMsgPack(schema)
|
||||
|
|
|
@ -148,6 +148,7 @@ type Collection struct {
|
|||
partitions *typeutil.ConcurrentSet[int64]
|
||||
loadType querypb.LoadType
|
||||
dbName string
|
||||
dbProperties []*commonpb.KeyValuePair
|
||||
resourceGroup string
|
||||
// resource group of node may be changed if node transfer,
|
||||
// but Collection in Manager will be released before assign new replica of new resource group on these node.
|
||||
|
@ -166,6 +167,10 @@ func (c *Collection) GetDBName() string {
|
|||
return c.dbName
|
||||
}
|
||||
|
||||
func (c *Collection) GetDBProperties() []*commonpb.KeyValuePair {
|
||||
return c.dbProperties
|
||||
}
|
||||
|
||||
// GetResourceGroup returns the resource group of collection.
|
||||
func (c *Collection) GetResourceGroup() string {
|
||||
return c.resourceGroup
|
||||
|
@ -284,6 +289,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
|
|||
partitions: typeutil.NewConcurrentSet[int64](),
|
||||
loadType: loadMetaInfo.GetLoadType(),
|
||||
dbName: loadMetaInfo.GetDbName(),
|
||||
dbProperties: loadMetaInfo.GetDbProperties(),
|
||||
resourceGroup: loadMetaInfo.GetResourceGroup(),
|
||||
refCount: atomic.NewUint32(0),
|
||||
isGpuIndex: isGpuIndex,
|
||||
|
@ -297,13 +303,16 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
|
|||
return coll
|
||||
}
|
||||
|
||||
func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *Collection {
|
||||
return &Collection{
|
||||
// Only for test
|
||||
func NewTestCollection(collectionID int64, loadType querypb.LoadType, schema *schemapb.CollectionSchema) *Collection {
|
||||
col := &Collection{
|
||||
id: collectionID,
|
||||
partitions: typeutil.NewConcurrentSet[int64](),
|
||||
loadType: loadType,
|
||||
refCount: atomic.NewUint32(0),
|
||||
}
|
||||
col.schema.Store(schema)
|
||||
return col
|
||||
}
|
||||
|
||||
// new collection without segcore prepare
|
||||
|
|
|
@ -26,11 +26,13 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/proxyutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -130,6 +132,43 @@ func (a *alterCollectionTask) Execute(ctx context.Context) error {
|
|||
}))
|
||||
}
|
||||
|
||||
oldReplicateEnable, _ := common.IsReplicateEnabled(oldColl.Properties)
|
||||
replicateEnable, ok := common.IsReplicateEnabled(newColl.Properties)
|
||||
if ok && !replicateEnable && oldReplicateEnable {
|
||||
replicateID, _ := common.GetReplicateID(oldColl.Properties)
|
||||
redoTask.AddAsyncStep(NewSimpleStep("send replicate end msg for collection", func(ctx context.Context) ([]nestedStep, error) {
|
||||
msgPack := &msgstream.MsgPack{}
|
||||
msg := &msgstream.ReplicateMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: ctx,
|
||||
BeginTimestamp: ts,
|
||||
EndTimestamp: ts,
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Replicate,
|
||||
Timestamp: ts,
|
||||
ReplicateInfo: &commonpb.ReplicateInfo{
|
||||
IsReplicate: true,
|
||||
ReplicateID: replicateID,
|
||||
},
|
||||
},
|
||||
IsEnd: true,
|
||||
Database: newColl.DBName,
|
||||
Collection: newColl.Name,
|
||||
},
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
log.Info("send replicate end msg",
|
||||
zap.String("collection", newColl.Name),
|
||||
zap.String("database", newColl.DBName),
|
||||
zap.String("replicateID", replicateID),
|
||||
)
|
||||
return nil, a.core.chanTimeTick.broadcastDmlChannels(newColl.PhysicalChannelNames, msgPack)
|
||||
}))
|
||||
}
|
||||
|
||||
return redoTask.Execute(ctx)
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package rootcoord
|
|||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -29,6 +30,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
)
|
||||
|
||||
func Test_alterCollectionTask_Prepare(t *testing.T) {
|
||||
|
@ -217,14 +219,25 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("alter successfully", func(t *testing.T) {
|
||||
t.Run("alter successfully2", func(t *testing.T) {
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
meta.On("GetCollectionByName",
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(&model.Collection{CollectionID: int64(1)}, nil)
|
||||
).Return(&model.Collection{
|
||||
CollectionID: int64(1),
|
||||
Name: "cn",
|
||||
DBName: "foo",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
|
||||
}, nil)
|
||||
meta.On("AlterCollection",
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
|
@ -237,19 +250,37 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
|
|||
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
|
||||
return nil
|
||||
}
|
||||
packChan := make(chan *msgstream.MsgPack, 10)
|
||||
ticker := newChanTimeTickSync(packChan)
|
||||
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
|
||||
|
||||
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker))
|
||||
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withTtSynchronizer(ticker))
|
||||
newPros := append(properties, &commonpb.KeyValuePair{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "10000",
|
||||
})
|
||||
task := &alterCollectionTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &milvuspb.AlterCollectionRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
|
||||
CollectionName: "cn",
|
||||
Properties: properties,
|
||||
Properties: newPros,
|
||||
},
|
||||
}
|
||||
|
||||
err := task.Execute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case pack := <-packChan:
|
||||
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
|
||||
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
|
||||
assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase())
|
||||
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection())
|
||||
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
|
||||
default:
|
||||
assert.Fail(t, "no message sent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test update collection props", func(t *testing.T) {
|
||||
|
|
|
@ -25,11 +25,14 @@ import (
|
|||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/util/proxyutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -43,6 +46,19 @@ func (a *alterDatabaseTask) Prepare(ctx context.Context) error {
|
|||
return fmt.Errorf("alter database failed, database name does not exists")
|
||||
}
|
||||
|
||||
// TODO SimFG maybe it will support to alter the replica.id properties in the future when the database has no collections
|
||||
// now it can't be because the latest database properties can't be notified to the querycoord and datacoord
|
||||
replicateID, _ := common.GetReplicateID(a.Req.Properties)
|
||||
if replicateID != "" {
|
||||
colls, err := a.core.meta.ListCollections(ctx, a.Req.DbName, a.ts, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(colls) > 0 {
|
||||
return errors.New("can't set replicate id on database with collections")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -85,6 +101,18 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error {
|
|||
ts: ts,
|
||||
})
|
||||
|
||||
redoTask.AddSyncStep(&expireCacheStep{
|
||||
baseStep: baseStep{core: a.core},
|
||||
dbName: newDB.Name,
|
||||
ts: ts,
|
||||
// make sure to send the "expire cache" request
|
||||
// because it won't send this request when the length of collection names array is zero
|
||||
collectionNames: []string{""},
|
||||
opts: []proxyutil.ExpireCacheOpt{
|
||||
proxyutil.SetMsgType(commonpb.MsgType_AlterDatabase),
|
||||
},
|
||||
})
|
||||
|
||||
oldReplicaNumber, _ := common.DatabaseLevelReplicaNumber(oldDB.Properties)
|
||||
oldResourceGroups, _ := common.DatabaseLevelResourceGroups(oldDB.Properties)
|
||||
newReplicaNumber, _ := common.DatabaseLevelReplicaNumber(newDB.Properties)
|
||||
|
@ -123,6 +151,39 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error {
|
|||
}))
|
||||
}
|
||||
|
||||
oldReplicateEnable, _ := common.IsReplicateEnabled(oldDB.Properties)
|
||||
newReplicateEnable, ok := common.IsReplicateEnabled(newDB.Properties)
|
||||
if ok && !newReplicateEnable && oldReplicateEnable {
|
||||
replicateID, _ := common.GetReplicateID(oldDB.Properties)
|
||||
redoTask.AddAsyncStep(NewSimpleStep("send replicate end msg for db", func(ctx context.Context) ([]nestedStep, error) {
|
||||
msgPack := &msgstream.MsgPack{}
|
||||
msg := &msgstream.ReplicateMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: ctx,
|
||||
BeginTimestamp: ts,
|
||||
EndTimestamp: ts,
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Replicate,
|
||||
Timestamp: ts,
|
||||
ReplicateInfo: &commonpb.ReplicateInfo{
|
||||
IsReplicate: true,
|
||||
ReplicateID: replicateID,
|
||||
},
|
||||
},
|
||||
IsEnd: true,
|
||||
Database: newDB.Name,
|
||||
Collection: "",
|
||||
},
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
log.Info("send replicate end msg for db", zap.String("db", newDB.Name), zap.String("replicateID", replicateID))
|
||||
return nil, a.core.chanTimeTick.broadcastDmlChannels(a.core.chanTimeTick.listDmlChannels(), msgPack)
|
||||
}))
|
||||
}
|
||||
|
||||
return redoTask.Execute(ctx)
|
||||
}
|
||||
|
||||
|
@ -134,6 +195,14 @@ func (a *alterDatabaseTask) GetLockerKey() LockerKey {
|
|||
}
|
||||
|
||||
func MergeProperties(oldProps []*commonpb.KeyValuePair, updatedProps []*commonpb.KeyValuePair) []*commonpb.KeyValuePair {
|
||||
_, existEndTS := common.GetReplicateEndTS(updatedProps)
|
||||
if existEndTS {
|
||||
updatedProps = append(updatedProps, &commonpb.KeyValuePair{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "",
|
||||
})
|
||||
}
|
||||
|
||||
props := make(map[string]string)
|
||||
for _, prop := range oldProps {
|
||||
props[prop.Key] = prop.Value
|
||||
|
|
|
@ -19,6 +19,7 @@ package rootcoord
|
|||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -29,6 +30,8 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
)
|
||||
|
||||
func Test_alterDatabaseTask_Prepare(t *testing.T) {
|
||||
|
@ -47,6 +50,76 @@ func Test_alterDatabaseTask_Prepare(t *testing.T) {
|
|||
err := task.Prepare(context.Background())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("replicate id", func(t *testing.T) {
|
||||
{
|
||||
// no collections
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
core := newTestCore(withMeta(meta))
|
||||
meta.EXPECT().
|
||||
ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return([]*model.Collection{}, nil).
|
||||
Once()
|
||||
task := &alterDatabaseTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &rootcoordpb.AlterDatabaseRequest{
|
||||
DbName: "cn",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := task.Prepare(context.Background())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
{
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
core := newTestCore(withMeta(meta))
|
||||
meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{
|
||||
{
|
||||
Name: "foo",
|
||||
},
|
||||
}, nil).Once()
|
||||
task := &alterDatabaseTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &rootcoordpb.AlterDatabaseRequest{
|
||||
DbName: "cn",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := task.Prepare(context.Background())
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
core := newTestCore(withMeta(meta))
|
||||
meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(nil, errors.New("err")).
|
||||
Once()
|
||||
task := &alterDatabaseTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &rootcoordpb.AlterDatabaseRequest{
|
||||
DbName: "cn",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := task.Prepare(context.Background())
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_alterDatabaseTask_Execute(t *testing.T) {
|
||||
|
@ -146,25 +219,51 @@ func Test_alterDatabaseTask_Execute(t *testing.T) {
|
|||
mock.Anything,
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(&model.Database{ID: int64(1)}, nil)
|
||||
).Return(&model.Database{
|
||||
ID: int64(1),
|
||||
Name: "cn",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
meta.On("AlterDatabase",
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(nil)
|
||||
// the chan length should larger than 4, because newChanTimeTickSync will send 4 ts messages when execute the `broadcast` step
|
||||
packChan := make(chan *msgstream.MsgPack, 10)
|
||||
ticker := newChanTimeTickSync(packChan)
|
||||
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
|
||||
|
||||
core := newTestCore(withMeta(meta))
|
||||
core := newTestCore(withMeta(meta), withValidProxyManager(), withTtSynchronizer(ticker))
|
||||
newPros := append(properties,
|
||||
&commonpb.KeyValuePair{Key: common.ReplicateEndTSKey, Value: "1000"},
|
||||
)
|
||||
task := &alterDatabaseTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &rootcoordpb.AlterDatabaseRequest{
|
||||
DbName: "cn",
|
||||
Properties: properties,
|
||||
Properties: newPros,
|
||||
},
|
||||
}
|
||||
|
||||
err := task.Execute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case pack := <-packChan:
|
||||
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
|
||||
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
|
||||
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetDatabase())
|
||||
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
|
||||
default:
|
||||
assert.Fail(t, "no message sent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test update collection props", func(t *testing.T) {
|
||||
|
@ -248,3 +347,26 @@ func Test_alterDatabaseTask_Execute(t *testing.T) {
|
|||
assert.Empty(t, ret2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeProperties(t *testing.T) {
|
||||
p := MergeProperties([]*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
{
|
||||
Key: "foo",
|
||||
Value: "xxx",
|
||||
},
|
||||
}, []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateEndTSKey,
|
||||
Value: "1001",
|
||||
},
|
||||
})
|
||||
assert.Len(t, p, 3)
|
||||
m := funcutil.KeyValuePair2Map(p)
|
||||
assert.Equal(t, "", m[common.ReplicateIDKey])
|
||||
assert.Equal(t, "1001", m[common.ReplicateEndTSKey])
|
||||
assert.Equal(t, "xxx", m["foo"])
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ type watchInfo struct {
|
|||
vChannels []string
|
||||
startPositions []*commonpb.KeyDataPair
|
||||
schema *schemapb.CollectionSchema
|
||||
dbProperties []*commonpb.KeyValuePair
|
||||
}
|
||||
|
||||
// Broker communicates with other components.
|
||||
|
@ -165,6 +166,7 @@ func (b *ServerBroker) WatchChannels(ctx context.Context, info *watchInfo) error
|
|||
StartPositions: info.startPositions,
|
||||
Schema: info.schema,
|
||||
CreateTimestamp: info.ts,
|
||||
DbProperties: info.dbProperties,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -61,6 +61,7 @@ type createCollectionTask struct {
|
|||
channels collectionChannels
|
||||
dbID UniqueID
|
||||
partitionNames []string
|
||||
dbProperties []*commonpb.KeyValuePair
|
||||
}
|
||||
|
||||
func (t *createCollectionTask) validate(ctx context.Context) error {
|
||||
|
@ -424,6 +425,18 @@ func (t *createCollectionTask) Prepare(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
t.dbID = db.ID
|
||||
dbReplicateID, _ := common.GetReplicateID(db.Properties)
|
||||
if dbReplicateID != "" {
|
||||
reqProperties := make([]*commonpb.KeyValuePair, 0, len(t.Req.Properties))
|
||||
for _, prop := range t.Req.Properties {
|
||||
if prop.Key == common.ReplicateIDKey {
|
||||
continue
|
||||
}
|
||||
reqProperties = append(reqProperties, prop)
|
||||
}
|
||||
t.Req.Properties = reqProperties
|
||||
}
|
||||
t.dbProperties = db.Properties
|
||||
|
||||
if err := t.validate(ctx); err != nil {
|
||||
return err
|
||||
|
@ -565,6 +578,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
|
|||
CollectionID: collID,
|
||||
DBID: t.dbID,
|
||||
Name: t.schema.Name,
|
||||
DBName: t.Req.GetDbName(),
|
||||
Description: t.schema.Description,
|
||||
AutoID: t.schema.AutoID,
|
||||
Fields: model.UnmarshalFieldModels(t.schema.Fields),
|
||||
|
@ -644,11 +658,14 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
|
|||
startPositions: toKeyDataPairs(startPositions),
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: collInfo.Name,
|
||||
DbName: collInfo.DBName,
|
||||
Description: collInfo.Description,
|
||||
AutoID: collInfo.AutoID,
|
||||
Fields: model.MarshalFieldModels(collInfo.Fields),
|
||||
Properties: collInfo.Properties,
|
||||
Functions: model.MarshalFunctionModels(collInfo.Functions),
|
||||
},
|
||||
dbProperties: t.dbProperties,
|
||||
},
|
||||
}, &nullStep{})
|
||||
undoTask.AddStep(&changeCollectionStateStep{
|
||||
|
|
|
@ -823,6 +823,70 @@ func Test_createCollectionTask_Prepare(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestCreateCollectionTask_Prepare_WithProperty(t *testing.T) {
|
||||
paramtable.Init()
|
||||
meta := mockrootcoord.NewIMetaTable(t)
|
||||
t.Run("with db properties", func(t *testing.T) {
|
||||
meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{
|
||||
Name: "foo",
|
||||
ID: 1,
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "local-test",
|
||||
},
|
||||
},
|
||||
}, nil).Twice()
|
||||
meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{
|
||||
util.DefaultDBID: {1, 2},
|
||||
}).Once()
|
||||
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0).Once()
|
||||
|
||||
defer cleanTestEnv()
|
||||
|
||||
collectionName := funcutil.GenRandomStr()
|
||||
field1 := funcutil.GenRandomStr()
|
||||
|
||||
ticker := newRocksMqTtSynchronizer()
|
||||
core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta))
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: field1,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := createCollectionTask{
|
||||
baseTask: newBaseTask(context.Background(), core),
|
||||
Req: &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection},
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.ReplicateIDKey,
|
||||
Value: "hoo",
|
||||
},
|
||||
},
|
||||
},
|
||||
dbID: 1,
|
||||
}
|
||||
task.Req.ShardsNum = common.DefaultShardsNum
|
||||
err = task.Prepare(context.Background())
|
||||
assert.Len(t, task.dbProperties, 1)
|
||||
assert.Len(t, task.Req.Properties, 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_createCollectionTask_Execute(t *testing.T) {
|
||||
t.Run("add same collection with different parameters", func(t *testing.T) {
|
||||
defer cleanTestEnv()
|
||||
|
|
|
@ -195,6 +195,9 @@ func (mt *MetaTable) reload() error {
|
|||
return err
|
||||
}
|
||||
for _, collection := range collections {
|
||||
if collection.DBName == "" {
|
||||
collection.DBName = dbName
|
||||
}
|
||||
mt.collID2Meta[collection.CollectionID] = collection
|
||||
mt.generalCnt += len(collection.Partitions) * int(collection.ShardsNum)
|
||||
if collection.Available() {
|
||||
|
@ -559,12 +562,14 @@ func filterUnavailable(coll *model.Collection) *model.Collection {
|
|||
func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowUnavailable bool) (*model.Collection, error) {
|
||||
coll, ok := mt.collID2Meta[collectionID]
|
||||
if !ok || coll == nil {
|
||||
log.Warn("not found collection", zap.Int64("collectionID", collectionID))
|
||||
return nil, merr.WrapErrCollectionNotFound(collectionID)
|
||||
}
|
||||
if allowUnavailable {
|
||||
return coll.Clone(), nil
|
||||
}
|
||||
if !coll.Available() {
|
||||
log.Warn("collection not available", zap.Int64("collectionID", collectionID), zap.Any("state", coll.State))
|
||||
return nil, merr.WrapErrCollectionNotFound(collectionID)
|
||||
}
|
||||
return filterUnavailable(coll), nil
|
||||
|
|
|
@ -1058,6 +1058,31 @@ func newTickerWithFactory(factory msgstream.Factory) *timetickSync {
|
|||
return ticker
|
||||
}
|
||||
|
||||
func newChanTimeTickSync(packChan chan *msgstream.MsgPack) *timetickSync {
|
||||
f := msgstream.NewMockMqFactory()
|
||||
f.NewMsgStreamFunc = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
stream := msgstream.NewWastedMockMsgStream()
|
||||
stream.BroadcastFunc = func(pack *msgstream.MsgPack) error {
|
||||
log.Info("mock Broadcast")
|
||||
packChan <- pack
|
||||
return nil
|
||||
}
|
||||
stream.BroadcastMarkFunc = func(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
|
||||
log.Info("mock BroadcastMark")
|
||||
packChan <- pack
|
||||
return map[string][]msgstream.MessageID{}, nil
|
||||
}
|
||||
stream.AsProducerFunc = func(channels []string) {
|
||||
}
|
||||
stream.ChanFunc = func() <-chan *msgstream.MsgPack {
|
||||
return packChan
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
return newTickerWithFactory(f)
|
||||
}
|
||||
|
||||
type mockDdlTsLockManager struct {
|
||||
DdlTsLockManager
|
||||
GetMinDdlTsFunc func() Timestamp
|
||||
|
|
|
@ -1226,6 +1226,7 @@ func convertModelToDesc(collInfo *model.Collection, aliases []string, dbName str
|
|||
Fields: model.MarshalFieldModels(collInfo.Fields),
|
||||
Functions: model.MarshalFunctionModels(collInfo.Functions),
|
||||
EnableDynamicField: collInfo.EnableDynamicField,
|
||||
Properties: collInfo.Properties,
|
||||
}
|
||||
resp.CollectionID = collInfo.CollectionID
|
||||
resp.VirtualChannelNames = collInfo.VirtualChannelNames
|
||||
|
@ -1745,6 +1746,19 @@ func (c *Core) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestam
|
|||
}, nil
|
||||
}
|
||||
|
||||
if in.BlockTimestamp > 0 {
|
||||
blockTime, _ := tsoutil.ParseTS(in.BlockTimestamp)
|
||||
lastTime := c.tsoAllocator.GetLastSavedTime()
|
||||
deltaDuration := blockTime.Sub(lastTime)
|
||||
if deltaDuration > 0 {
|
||||
log.Info("wait for block timestamp",
|
||||
zap.Time("blockTime", blockTime),
|
||||
zap.Time("lastTime", lastTime),
|
||||
zap.Duration("delta", deltaDuration))
|
||||
time.Sleep(deltaDuration + time.Millisecond*200)
|
||||
}
|
||||
}
|
||||
|
||||
ts, err := c.tsoAllocator.GenerateTSO(in.GetCount())
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error("failed to allocate timestamp", zap.String("role", typeutil.RootCoordRole),
|
||||
|
|
|
@ -856,6 +856,32 @@ func TestRootCoord_AllocTimestamp(t *testing.T) {
|
|||
assert.Equal(t, ts-uint64(count)+1, resp.GetTimestamp())
|
||||
assert.Equal(t, count, resp.GetCount())
|
||||
})
|
||||
|
||||
t.Run("block timestamp", func(t *testing.T) {
|
||||
alloc := newMockTsoAllocator()
|
||||
count := uint32(10)
|
||||
current := time.Now()
|
||||
ts := tsoutil.ComposeTSByTime(current.Add(time.Second), 1)
|
||||
alloc.GenerateTSOF = func(count uint32) (uint64, error) {
|
||||
// end ts
|
||||
return ts, nil
|
||||
}
|
||||
alloc.GetLastSavedTimeF = func() time.Time {
|
||||
return current
|
||||
}
|
||||
ctx := context.Background()
|
||||
c := newTestCore(withHealthyCode(),
|
||||
withTsoAllocator(alloc))
|
||||
resp, err := c.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
|
||||
Count: count,
|
||||
BlockTimestamp: tsoutil.ComposeTSByTime(current.Add(time.Second), 0),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
// begin ts
|
||||
assert.Equal(t, ts-uint64(count)+1, resp.GetTimestamp())
|
||||
assert.Equal(t, count, resp.GetCount())
|
||||
})
|
||||
}
|
||||
|
||||
func TestRootCoord_AllocID(t *testing.T) {
|
||||
|
|
|
@ -18,6 +18,7 @@ package rootcoord
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
@ -173,7 +174,6 @@ func NewCollectionLockerKey(collection string, rw bool) LockerKey {
|
|||
}
|
||||
|
||||
func NewLockerKeyChain(lockerKeys ...LockerKey) LockerKey {
|
||||
log.Info("NewLockerKeyChain", zap.Any("lockerKeys", len(lockerKeys)))
|
||||
if len(lockerKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -191,3 +191,16 @@ func NewLockerKeyChain(lockerKeys ...LockerKey) LockerKey {
|
|||
}
|
||||
return lockerKeys[0]
|
||||
}
|
||||
|
||||
func GetLockerKeyString(k LockerKey) string {
|
||||
if k == nil {
|
||||
return "nil"
|
||||
}
|
||||
key := k.LockKey()
|
||||
level := k.Level()
|
||||
wLock := k.IsWLock()
|
||||
if k.Next() == nil {
|
||||
return fmt.Sprintf("%s-%d-%t", key, level, wLock)
|
||||
}
|
||||
return fmt.Sprintf("%s-%d-%t|%s", key, level, wLock, GetLockerKeyString(k.Next()))
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ package rootcoord
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
@ -72,16 +71,6 @@ func TestLockerKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func GetLockerKeyString(k LockerKey) string {
|
||||
key := k.LockKey()
|
||||
level := k.Level()
|
||||
wLock := k.IsWLock()
|
||||
if k.Next() == nil {
|
||||
return fmt.Sprintf("%s-%d-%t", key, level, wLock)
|
||||
}
|
||||
return fmt.Sprintf("%s-%d-%t|%s", key, level, wLock, GetLockerKeyString(k.Next()))
|
||||
}
|
||||
|
||||
func TestGetLockerKey(t *testing.T) {
|
||||
t.Run("alter alias task locker key", func(t *testing.T) {
|
||||
tt := &alterAliasTask{
|
||||
|
|
|
@ -116,6 +116,7 @@ func (c *channelLifetime) Run() error {
|
|||
|
||||
// Build and add pipeline.
|
||||
ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, c.f.pipelineParams,
|
||||
// TODO fubang add the db properties
|
||||
&datapb.ChannelWatchInfo{Vchan: resp.GetInfo(), Schema: resp.GetSchema()}, handler.Chan(), func(t syncmgr.Task, err error) {
|
||||
if err != nil || t == nil {
|
||||
return
|
||||
|
|
|
@ -45,12 +45,13 @@ type StreamPipeline interface {
|
|||
}
|
||||
|
||||
type streamPipeline struct {
|
||||
pipeline *pipeline
|
||||
input <-chan *msgstream.MsgPack
|
||||
scanner streaming.Scanner
|
||||
dispatcher msgdispatcher.Client
|
||||
startOnce sync.Once
|
||||
vChannel string
|
||||
pipeline *pipeline
|
||||
input <-chan *msgstream.MsgPack
|
||||
scanner streaming.Scanner
|
||||
dispatcher msgdispatcher.Client
|
||||
startOnce sync.Once
|
||||
vChannel string
|
||||
replicateConfig *msgstream.ReplicateConfig
|
||||
|
||||
closeCh chan struct{} // notify work to exit
|
||||
closeWg sync.WaitGroup
|
||||
|
@ -118,7 +119,12 @@ func (p *streamPipeline) ConsumeMsgStream(ctx context.Context, position *msgpb.M
|
|||
}
|
||||
|
||||
start := time.Now()
|
||||
p.input, err = p.dispatcher.Register(ctx, p.vChannel, position, common.SubscriptionPositionUnknown)
|
||||
p.input, err = p.dispatcher.Register(ctx, &msgdispatcher.StreamConfig{
|
||||
VChannel: p.vChannel,
|
||||
Pos: position,
|
||||
SubPos: common.SubscriptionPositionUnknown,
|
||||
ReplicateConfig: p.replicateConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("dispatcher register failed", zap.String("channel", position.ChannelName))
|
||||
return WrapErrRegDispather(err)
|
||||
|
@ -160,18 +166,24 @@ func (p *streamPipeline) Close() {
|
|||
})
|
||||
}
|
||||
|
||||
func NewPipelineWithStream(dispatcher msgdispatcher.Client, nodeTtInterval time.Duration, enableTtChecker bool, vChannel string) StreamPipeline {
|
||||
func NewPipelineWithStream(dispatcher msgdispatcher.Client,
|
||||
nodeTtInterval time.Duration,
|
||||
enableTtChecker bool,
|
||||
vChannel string,
|
||||
replicateConfig *msgstream.ReplicateConfig,
|
||||
) StreamPipeline {
|
||||
pipeline := &streamPipeline{
|
||||
pipeline: &pipeline{
|
||||
nodes: []*nodeCtx{},
|
||||
nodeTtInterval: nodeTtInterval,
|
||||
enableTtChecker: enableTtChecker,
|
||||
},
|
||||
dispatcher: dispatcher,
|
||||
vChannel: vChannel,
|
||||
closeCh: make(chan struct{}),
|
||||
closeWg: sync.WaitGroup{},
|
||||
lastAccessTime: atomic.NewTime(time.Now()),
|
||||
dispatcher: dispatcher,
|
||||
vChannel: vChannel,
|
||||
replicateConfig: replicateConfig,
|
||||
closeCh: make(chan struct{}),
|
||||
closeWg: sync.WaitGroup{},
|
||||
lastAccessTime: atomic.NewTime(time.Now()),
|
||||
}
|
||||
|
||||
return pipeline
|
||||
|
|
|
@ -25,7 +25,6 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
)
|
||||
|
@ -47,9 +46,9 @@ func (suite *StreamPipelineSuite) SetupTest() {
|
|||
suite.inChannel = make(chan *msgstream.MsgPack, 1)
|
||||
suite.outChannel = make(chan msgstream.Timestamp)
|
||||
suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T())
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.inChannel, nil)
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.inChannel, nil)
|
||||
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
|
||||
suite.pipeline = NewPipelineWithStream(suite.msgDispatcher, 0, false, suite.channel)
|
||||
suite.pipeline = NewPipelineWithStream(suite.msgDispatcher, 0, false, suite.channel, nil)
|
||||
suite.length = 4
|
||||
}
|
||||
|
||||
|
|
|
@ -191,6 +191,8 @@ const (
|
|||
PartitionKeyIsolationKey = "partitionkey.isolation"
|
||||
FieldSkipLoadKey = "field.skipLoad"
|
||||
IndexOffsetCacheEnabledKey = "indexoffsetcache.enabled"
|
||||
ReplicateIDKey = "replicate.id"
|
||||
ReplicateEndTSKey = "replicate.endTS"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -395,3 +397,31 @@ func ShouldFieldBeLoaded(kvs []*commonpb.KeyValuePair) (bool, error) {
|
|||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func IsReplicateEnabled(kvs []*commonpb.KeyValuePair) (bool, bool) {
|
||||
replicateID, ok := GetReplicateID(kvs)
|
||||
return replicateID != "", ok
|
||||
}
|
||||
|
||||
func GetReplicateID(kvs []*commonpb.KeyValuePair) (string, bool) {
|
||||
for _, kv := range kvs {
|
||||
if kv.GetKey() == ReplicateIDKey {
|
||||
return kv.GetValue(), true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func GetReplicateEndTS(kvs []*commonpb.KeyValuePair) (uint64, bool) {
|
||||
for _, kv := range kvs {
|
||||
if kv.GetKey() == ReplicateEndTSKey {
|
||||
ts, err := strconv.ParseUint(kv.GetValue(), 10, 64)
|
||||
if err != nil {
|
||||
log.Warn("parse replicate end ts failed", zap.Error(err), zap.Stack("stack"))
|
||||
return 0, false
|
||||
}
|
||||
return ts, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
|
|
@ -177,3 +177,84 @@ func TestShouldFieldBeLoaded(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateProperty(t *testing.T) {
|
||||
t.Run("ReplicateID", func(t *testing.T) {
|
||||
{
|
||||
p := []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: ReplicateIDKey,
|
||||
Value: "1001",
|
||||
},
|
||||
}
|
||||
e, ok := IsReplicateEnabled(p)
|
||||
assert.True(t, e)
|
||||
assert.True(t, ok)
|
||||
i, ok := GetReplicateID(p)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "1001", i)
|
||||
}
|
||||
|
||||
{
|
||||
p := []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: ReplicateIDKey,
|
||||
Value: "",
|
||||
},
|
||||
}
|
||||
e, ok := IsReplicateEnabled(p)
|
||||
assert.False(t, e)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
{
|
||||
p := []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "foo",
|
||||
Value: "1001",
|
||||
},
|
||||
}
|
||||
e, ok := IsReplicateEnabled(p)
|
||||
assert.False(t, e)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ReplicateTS", func(t *testing.T) {
|
||||
{
|
||||
p := []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: ReplicateEndTSKey,
|
||||
Value: "1001",
|
||||
},
|
||||
}
|
||||
ts, ok := GetReplicateEndTS(p)
|
||||
assert.True(t, ok)
|
||||
assert.EqualValues(t, 1001, ts)
|
||||
}
|
||||
|
||||
{
|
||||
p := []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: ReplicateEndTSKey,
|
||||
Value: "foo",
|
||||
},
|
||||
}
|
||||
ts, ok := GetReplicateEndTS(p)
|
||||
assert.False(t, ok)
|
||||
assert.EqualValues(t, 0, ts)
|
||||
}
|
||||
|
||||
{
|
||||
p := []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "foo",
|
||||
Value: "1001",
|
||||
},
|
||||
}
|
||||
ts, ok := GetReplicateEndTS(p)
|
||||
assert.False(t, ok)
|
||||
assert.EqualValues(t, 0, ts)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -36,8 +36,23 @@ type (
|
|||
SubPos = common.SubscriptionInitialPosition
|
||||
)
|
||||
|
||||
type StreamConfig struct {
|
||||
VChannel string
|
||||
Pos *Pos
|
||||
SubPos SubPos
|
||||
ReplicateConfig *msgstream.ReplicateConfig
|
||||
}
|
||||
|
||||
func NewStreamConfig(vchannel string, pos *Pos, subPos SubPos) *StreamConfig {
|
||||
return &StreamConfig{
|
||||
VChannel: vchannel,
|
||||
Pos: pos,
|
||||
SubPos: subPos,
|
||||
}
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
|
||||
Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error)
|
||||
Deregister(vchannel string)
|
||||
Close()
|
||||
}
|
||||
|
@ -62,7 +77,8 @@ func NewClient(factory msgstream.Factory, role string, nodeID int64) Client {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
|
||||
func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) {
|
||||
vchannel := streamConfig.VChannel
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
|
||||
pchannel := funcutil.ToPhysicalChannel(vchannel)
|
||||
|
@ -75,7 +91,7 @@ func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos
|
|||
c.managers.Insert(pchannel, manager)
|
||||
go manager.Run()
|
||||
}
|
||||
ch, err := manager.Add(ctx, vchannel, pos, subPos)
|
||||
ch, err := manager.Add(ctx, streamConfig)
|
||||
if err != nil {
|
||||
if manager.Num() == 0 {
|
||||
manager.Close()
|
||||
|
|
|
@ -34,9 +34,9 @@ import (
|
|||
func TestClient(t *testing.T) {
|
||||
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
|
||||
assert.NotNil(t, client)
|
||||
_, err := client.Register(context.Background(), "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
|
||||
_, err := client.Register(context.Background(), NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
_, err = client.Register(context.Background(), "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = client.Register(context.Background(), NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
assert.NotPanics(t, func() {
|
||||
client.Deregister("mock_vchannel_0")
|
||||
|
@ -51,7 +51,7 @@ func TestClient(t *testing.T) {
|
|||
client := NewClient(newMockFactory(), typeutil.DataNodeRole, 1)
|
||||
defer client.Close()
|
||||
assert.NotNil(t, client)
|
||||
_, err := client.Register(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
|
||||
_, err := client.Register(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func TestClient_Concurrency(t *testing.T) {
|
|||
vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int())
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := client1.Register(context.Background(), vchannel, nil, common.SubscriptionPositionUnknown)
|
||||
_, err := client1.Register(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
for j := 0; j < rand.Intn(2); j++ {
|
||||
client1.Deregister(vchannel)
|
||||
|
|
|
@ -80,10 +80,14 @@ type Dispatcher struct {
|
|||
stream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func NewDispatcher(ctx context.Context,
|
||||
factory msgstream.Factory, isMain bool,
|
||||
pchannel string, position *Pos,
|
||||
subName string, subPos SubPos,
|
||||
func NewDispatcher(
|
||||
ctx context.Context,
|
||||
factory msgstream.Factory,
|
||||
isMain bool,
|
||||
pchannel string,
|
||||
position *Pos,
|
||||
subName string,
|
||||
subPos SubPos,
|
||||
lagNotifyChan chan struct{},
|
||||
lagTargets *typeutil.ConcurrentMap[string, *target],
|
||||
includeCurrentMsg bool,
|
||||
|
@ -260,7 +264,8 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
|
|||
// init packs for all targets, even though there's no msg in pack,
|
||||
// but we still need to dispatch time ticks to the targets.
|
||||
targetPacks := make(map[string]*MsgPack)
|
||||
for vchannel := range d.targets {
|
||||
replicateConfigs := make(map[string]*msgstream.ReplicateConfig)
|
||||
for vchannel, t := range d.targets {
|
||||
targetPacks[vchannel] = &MsgPack{
|
||||
BeginTs: pack.BeginTs,
|
||||
EndTs: pack.EndTs,
|
||||
|
@ -268,6 +273,9 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
|
|||
StartPositions: pack.StartPositions,
|
||||
EndPositions: pack.EndPositions,
|
||||
}
|
||||
if t.replicateConfig != nil {
|
||||
replicateConfigs[vchannel] = t.replicateConfig
|
||||
}
|
||||
}
|
||||
// group messages by vchannel
|
||||
for _, msg := range pack.Msgs {
|
||||
|
@ -287,9 +295,16 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
|
|||
collectionID = strconv.FormatInt(msg.(*msgstream.DropPartitionMsg).GetCollectionID(), 10)
|
||||
}
|
||||
if vchannel == "" {
|
||||
// for non-dml msg, such as CreateCollection, DropCollection, ...
|
||||
// we need to dispatch it to the vchannel of this collection
|
||||
for k := range targetPacks {
|
||||
if msg.Type() == commonpb.MsgType_Replicate {
|
||||
config := replicateConfigs[k]
|
||||
if config != nil && msgstream.MatchReplicateID(msg, config.ReplicateID) {
|
||||
targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.Contains(k, collectionID) {
|
||||
continue
|
||||
}
|
||||
|
@ -303,9 +318,63 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack {
|
|||
targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, msg)
|
||||
}
|
||||
}
|
||||
replicateEndChannels := make(map[string]struct{})
|
||||
for vchannel, c := range replicateConfigs {
|
||||
if len(targetPacks[vchannel].Msgs) == 0 {
|
||||
delete(targetPacks, vchannel) // no replicate msg, can't send pack
|
||||
continue
|
||||
}
|
||||
// calculate the new pack ts
|
||||
beginTs := targetPacks[vchannel].Msgs[0].BeginTs()
|
||||
endTs := targetPacks[vchannel].Msgs[0].EndTs()
|
||||
newMsgs := make([]msgstream.TsMsg, 0)
|
||||
for _, msg := range targetPacks[vchannel].Msgs {
|
||||
if msg.BeginTs() < beginTs {
|
||||
beginTs = msg.BeginTs()
|
||||
}
|
||||
if msg.EndTs() > endTs {
|
||||
endTs = msg.EndTs()
|
||||
}
|
||||
if msg.Type() == commonpb.MsgType_Replicate {
|
||||
replicateMsg := msg.(*msgstream.ReplicateMsg)
|
||||
if c.CheckFunc(replicateMsg) {
|
||||
replicateEndChannels[vchannel] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
newMsgs = append(newMsgs, msg)
|
||||
}
|
||||
targetPacks[vchannel].Msgs = newMsgs
|
||||
d.resetMsgPackTS(targetPacks[vchannel], beginTs, endTs)
|
||||
}
|
||||
for vchannel := range replicateEndChannels {
|
||||
if t, ok := d.targets[vchannel]; ok {
|
||||
t.replicateConfig = nil
|
||||
log.Info("replicate end, set replicate config nil", zap.String("vchannel", vchannel))
|
||||
}
|
||||
}
|
||||
return targetPacks
|
||||
}
|
||||
|
||||
func (d *Dispatcher) resetMsgPackTS(pack *MsgPack, newBeginTs, newEndTs typeutil.Timestamp) {
|
||||
pack.BeginTs = newBeginTs
|
||||
pack.EndTs = newEndTs
|
||||
startPositions := make([]*msgstream.MsgPosition, 0)
|
||||
endPositions := make([]*msgstream.MsgPosition, 0)
|
||||
for _, pos := range pack.StartPositions {
|
||||
startPosition := typeutil.Clone(pos)
|
||||
startPosition.Timestamp = newBeginTs
|
||||
startPositions = append(startPositions, startPosition)
|
||||
}
|
||||
for _, pos := range pack.EndPositions {
|
||||
endPosition := typeutil.Clone(pos)
|
||||
endPosition.Timestamp = newEndTs
|
||||
endPositions = append(endPositions, endPosition)
|
||||
}
|
||||
pack.StartPositions = startPositions
|
||||
pack.EndPositions = endPositions
|
||||
}
|
||||
|
||||
func (d *Dispatcher) nonBlockingNotify() {
|
||||
select {
|
||||
case d.lagNotifyChan <- struct{}{}:
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -26,6 +28,8 @@ import (
|
|||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
)
|
||||
|
@ -73,7 +77,7 @@ func TestDispatcher(t *testing.T) {
|
|||
output := make(chan *msgstream.MsgPack, 1024)
|
||||
|
||||
getTarget := func(vchannel string, pos *Pos, ch chan *msgstream.MsgPack) *target {
|
||||
target := newTarget(vchannel, pos)
|
||||
target := newTarget(vchannel, pos, nil)
|
||||
target.ch = ch
|
||||
return target
|
||||
}
|
||||
|
@ -103,7 +107,7 @@ func TestDispatcher(t *testing.T) {
|
|||
t.Run("test concurrent send and close", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
output := make(chan *msgstream.MsgPack, 1024)
|
||||
target := newTarget("mock_vchannel_0", nil)
|
||||
target := newTarget("mock_vchannel_0", nil, nil)
|
||||
target.ch = output
|
||||
assert.Equal(t, cap(output), cap(target.ch))
|
||||
wg := &sync.WaitGroup{}
|
||||
|
@ -138,3 +142,195 @@ func BenchmarkDispatcher_handle(b *testing.B) {
|
|||
// BenchmarkDispatcher_handle-12 9568 122123 ns/op
|
||||
// PASS
|
||||
}
|
||||
|
||||
func TestGroupMessage(t *testing.T) {
|
||||
d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0"+fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest, nil, nil, false)
|
||||
assert.NoError(t, err)
|
||||
d.AddTarget(newTarget("mock_pchannel_0_1v0", nil, nil))
|
||||
d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo")))
|
||||
{
|
||||
// no replicate msg
|
||||
packs := d.groupingMsgs(&MsgPack{
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
StartPositions: []*msgstream.MsgPosition{
|
||||
{
|
||||
ChannelName: "mock_pchannel_0",
|
||||
MsgID: []byte("1"),
|
||||
Timestamp: 1,
|
||||
},
|
||||
},
|
||||
EndPositions: []*msgstream.MsgPosition{
|
||||
{
|
||||
ChannelName: "mock_pchannel_0",
|
||||
MsgID: []byte("10"),
|
||||
Timestamp: 10,
|
||||
},
|
||||
},
|
||||
Msgs: []msgstream.TsMsg{
|
||||
&msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 5,
|
||||
EndTimestamp: 5,
|
||||
},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
Timestamp: 5,
|
||||
},
|
||||
ShardName: "mock_pchannel_0_1v0",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Len(t, packs, 1)
|
||||
}
|
||||
|
||||
{
|
||||
// equal to replicateID
|
||||
packs := d.groupingMsgs(&MsgPack{
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
StartPositions: []*msgstream.MsgPosition{
|
||||
{
|
||||
ChannelName: "mock_pchannel_0",
|
||||
MsgID: []byte("1"),
|
||||
Timestamp: 1,
|
||||
},
|
||||
},
|
||||
EndPositions: []*msgstream.MsgPosition{
|
||||
{
|
||||
ChannelName: "mock_pchannel_0",
|
||||
MsgID: []byte("10"),
|
||||
Timestamp: 10,
|
||||
},
|
||||
},
|
||||
Msgs: []msgstream.TsMsg{
|
||||
&msgstream.ReplicateMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 100,
|
||||
EndTimestamp: 100,
|
||||
},
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Replicate,
|
||||
Timestamp: 100,
|
||||
ReplicateInfo: &commonpb.ReplicateInfo{
|
||||
ReplicateID: "local-test",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Len(t, packs, 2)
|
||||
{
|
||||
replicatePack := packs["mock_pchannel_0_2v0"]
|
||||
assert.EqualValues(t, 100, replicatePack.BeginTs)
|
||||
assert.EqualValues(t, 100, replicatePack.EndTs)
|
||||
assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp)
|
||||
assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp)
|
||||
assert.Len(t, replicatePack.Msgs, 0)
|
||||
}
|
||||
{
|
||||
replicatePack := packs["mock_pchannel_0_1v0"]
|
||||
assert.EqualValues(t, 1, replicatePack.BeginTs)
|
||||
assert.EqualValues(t, 10, replicatePack.EndTs)
|
||||
assert.EqualValues(t, 1, replicatePack.StartPositions[0].Timestamp)
|
||||
assert.EqualValues(t, 10, replicatePack.EndPositions[0].Timestamp)
|
||||
assert.Len(t, replicatePack.Msgs, 0)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// not equal to replicateID
|
||||
packs := d.groupingMsgs(&MsgPack{
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
StartPositions: []*msgstream.MsgPosition{
|
||||
{
|
||||
ChannelName: "mock_pchannel_0",
|
||||
MsgID: []byte("1"),
|
||||
Timestamp: 1,
|
||||
},
|
||||
},
|
||||
EndPositions: []*msgstream.MsgPosition{
|
||||
{
|
||||
ChannelName: "mock_pchannel_0",
|
||||
MsgID: []byte("10"),
|
||||
Timestamp: 10,
|
||||
},
|
||||
},
|
||||
Msgs: []msgstream.TsMsg{
|
||||
&msgstream.ReplicateMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 100,
|
||||
EndTimestamp: 100,
|
||||
},
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Replicate,
|
||||
Timestamp: 100,
|
||||
ReplicateInfo: &commonpb.ReplicateInfo{
|
||||
ReplicateID: "local-test-1", // not equal to replicateID
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Len(t, packs, 1)
|
||||
replicatePack := packs["mock_pchannel_0_2v0"]
|
||||
assert.Nil(t, replicatePack)
|
||||
}
|
||||
|
||||
{
|
||||
// replicate end
|
||||
replicateTarget := d.targets["mock_pchannel_0_2v0"]
|
||||
assert.NotNil(t, replicateTarget.replicateConfig)
|
||||
packs := d.groupingMsgs(&MsgPack{
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
StartPositions: []*msgstream.MsgPosition{
|
||||
{
|
||||
ChannelName: "mock_pchannel_0",
|
||||
MsgID: []byte("1"),
|
||||
Timestamp: 1,
|
||||
},
|
||||
},
|
||||
EndPositions: []*msgstream.MsgPosition{
|
||||
{
|
||||
ChannelName: "mock_pchannel_0",
|
||||
MsgID: []byte("10"),
|
||||
Timestamp: 10,
|
||||
},
|
||||
},
|
||||
Msgs: []msgstream.TsMsg{
|
||||
&msgstream.ReplicateMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 100,
|
||||
EndTimestamp: 100,
|
||||
},
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Replicate,
|
||||
Timestamp: 100,
|
||||
ReplicateInfo: &commonpb.ReplicateInfo{
|
||||
ReplicateID: "local-test",
|
||||
},
|
||||
},
|
||||
IsEnd: true,
|
||||
Database: "foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Len(t, packs, 2)
|
||||
replicatePack := packs["mock_pchannel_0_2v0"]
|
||||
assert.EqualValues(t, 100, replicatePack.BeginTs)
|
||||
assert.EqualValues(t, 100, replicatePack.EndTs)
|
||||
assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp)
|
||||
assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp)
|
||||
assert.Nil(t, replicateTarget.replicateConfig)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ import (
|
|||
)
|
||||
|
||||
type DispatcherManager interface {
|
||||
Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
|
||||
Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error)
|
||||
Remove(vchannel string)
|
||||
Num() int
|
||||
Run()
|
||||
|
@ -82,7 +82,8 @@ func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) strin
|
|||
return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain)
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
|
||||
func (c *dispatcherManager) Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) {
|
||||
vchannel := streamConfig.VChannel
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
|
||||
|
||||
|
@ -102,11 +103,11 @@ func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos,
|
|||
}
|
||||
|
||||
isMain := c.mainDispatcher == nil
|
||||
d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, pos, c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets, false)
|
||||
d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, streamConfig.Pos, c.constructSubName(vchannel, isMain), streamConfig.SubPos, c.lagNotifyChan, c.lagTargets, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := newTarget(vchannel, pos)
|
||||
t := newTarget(vchannel, streamConfig.Pos, streamConfig.ReplicateConfig)
|
||||
d.AddTarget(t)
|
||||
if isMain {
|
||||
c.mainDispatcher = d
|
||||
|
|
|
@ -48,7 +48,7 @@ func TestManager(t *testing.T) {
|
|||
offset++
|
||||
vchannel := fmt.Sprintf("mock-pchannel-dml_0_vchannelv%d", offset)
|
||||
t.Logf("add vchannel, %s", vchannel)
|
||||
_, err := c.Add(context.Background(), vchannel, nil, common.SubscriptionPositionUnknown)
|
||||
_, err := c.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, offset, c.Num())
|
||||
}
|
||||
|
@ -67,11 +67,11 @@ func TestManager(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
assert.NotNil(t, c)
|
||||
_, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
|
||||
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, c.Num())
|
||||
c.(*dispatcherManager).mainDispatcher.curTs.Store(1000)
|
||||
|
@ -98,11 +98,11 @@ func TestManager(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
assert.NotNil(t, c)
|
||||
_, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
|
||||
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, c.Num())
|
||||
c.(*dispatcherManager).mainDispatcher.curTs.Store(1000)
|
||||
|
@ -134,11 +134,11 @@ func TestManager(t *testing.T) {
|
|||
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
go c.Run()
|
||||
assert.NotNil(t, c)
|
||||
_, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
|
||||
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
|
||||
assert.Error(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
|
||||
assert.Error(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 0, c.Num())
|
||||
|
||||
|
@ -153,18 +153,18 @@ func TestManager(t *testing.T) {
|
|||
go c.Run()
|
||||
assert.NotNil(t, c)
|
||||
ctx := context.Background()
|
||||
_, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
|
||||
_, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
|
||||
assert.Error(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
|
||||
assert.Error(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown))
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
|
@ -325,7 +325,7 @@ func (suite *SimulationSuite) TestDispatchToVchannels() {
|
|||
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
|
||||
for i := 0; i < vchannelNum; i++ {
|
||||
vchannel := fmt.Sprintf("%s_%dv%d", suite.pchannel, collectionID, i)
|
||||
output, err := suite.manager.Add(context.Background(), vchannel, nil, common.SubscriptionPositionEarliest)
|
||||
output, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
|
||||
assert.NoError(suite.T(), err)
|
||||
suite.vchannels[vchannel] = &vchannelHelper{output: output}
|
||||
}
|
||||
|
@ -360,8 +360,10 @@ func (suite *SimulationSuite) TestMerge() {
|
|||
|
||||
for i := 0; i < vchannelNum; i++ {
|
||||
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
|
||||
output, err := suite.manager.Add(context.Background(), vchannel, positions[rand.Intn(len(positions))],
|
||||
common.SubscriptionPositionUnknown) // seek from random position
|
||||
output, err := suite.manager.Add(context.Background(), NewStreamConfig(
|
||||
vchannel, positions[rand.Intn(len(positions))],
|
||||
common.SubscriptionPositionUnknown,
|
||||
)) // seek from random position
|
||||
assert.NoError(suite.T(), err)
|
||||
suite.vchannels[vchannel] = &vchannelHelper{output: output}
|
||||
}
|
||||
|
@ -402,7 +404,7 @@ func (suite *SimulationSuite) TestSplit() {
|
|||
paramtable.Get().Save(targetBufSizeK, "10")
|
||||
}
|
||||
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
|
||||
_, err := suite.manager.Add(context.Background(), vchannel, nil, common.SubscriptionPositionEarliest)
|
||||
_, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest))
|
||||
assert.NoError(suite.T(), err)
|
||||
}
|
||||
|
||||
|
|
|
@ -5,13 +5,8 @@ package msgdispatcher
|
|||
import (
|
||||
context "context"
|
||||
|
||||
common "github.com/milvus-io/milvus/pkg/mq/common"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
|
||||
msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockClient is an autogenerated mock type for the Client type
|
||||
|
@ -92,9 +87,9 @@ func (_c *MockClient_Deregister_Call) RunAndReturn(run func(string)) *MockClient
|
|||
return _c
|
||||
}
|
||||
|
||||
// Register provides a mock function with given fields: ctx, vchannel, pos, subPos
|
||||
func (_m *MockClient) Register(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) {
|
||||
ret := _m.Called(ctx, vchannel, pos, subPos)
|
||||
// Register provides a mock function with given fields: ctx, streamConfig
|
||||
func (_m *MockClient) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *msgstream.MsgPack, error) {
|
||||
ret := _m.Called(ctx, streamConfig)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Register")
|
||||
|
@ -102,19 +97,19 @@ func (_m *MockClient) Register(ctx context.Context, vchannel string, pos *msgpb.
|
|||
|
||||
var r0 <-chan *msgstream.MsgPack
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)); ok {
|
||||
return rf(ctx, vchannel, pos, subPos)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *StreamConfig) (<-chan *msgstream.MsgPack, error)); ok {
|
||||
return rf(ctx, streamConfig)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok {
|
||||
r0 = rf(ctx, vchannel, pos, subPos)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *StreamConfig) <-chan *msgstream.MsgPack); ok {
|
||||
r0 = rf(ctx, streamConfig)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(<-chan *msgstream.MsgPack)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) error); ok {
|
||||
r1 = rf(ctx, vchannel, pos, subPos)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *StreamConfig) error); ok {
|
||||
r1 = rf(ctx, streamConfig)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
@ -129,16 +124,14 @@ type MockClient_Register_Call struct {
|
|||
|
||||
// Register is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - vchannel string
|
||||
// - pos *msgpb.MsgPosition
|
||||
// - subPos common.SubscriptionInitialPosition
|
||||
func (_e *MockClient_Expecter) Register(ctx interface{}, vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call {
|
||||
return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, vchannel, pos, subPos)}
|
||||
// - streamConfig *StreamConfig
|
||||
func (_e *MockClient_Expecter) Register(ctx interface{}, streamConfig interface{}) *MockClient_Register_Call {
|
||||
return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, streamConfig)}
|
||||
}
|
||||
|
||||
func (_c *MockClient_Register_Call) Run(run func(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos common.SubscriptionInitialPosition)) *MockClient_Register_Call {
|
||||
func (_c *MockClient_Register_Call) Run(run func(ctx context.Context, streamConfig *StreamConfig)) *MockClient_Register_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.MsgPosition), args[3].(common.SubscriptionInitialPosition))
|
||||
run(args[0].(context.Context), args[1].(*StreamConfig))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
@ -148,7 +141,7 @@ func (_c *MockClient_Register_Call) Return(_a0 <-chan *msgstream.MsgPack, _a1 er
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockClient_Register_Call) RunAndReturn(run func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call {
|
||||
func (_c *MockClient_Register_Call) RunAndReturn(run func(context.Context, *StreamConfig) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/lifetime"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
@ -33,26 +34,33 @@ type target struct {
|
|||
ch chan *MsgPack
|
||||
pos *Pos
|
||||
|
||||
closeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
closed bool
|
||||
maxLag time.Duration
|
||||
timer *time.Timer
|
||||
closeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
closed bool
|
||||
maxLag time.Duration
|
||||
timer *time.Timer
|
||||
replicateConfig *msgstream.ReplicateConfig
|
||||
|
||||
cancelCh lifetime.SafeChan
|
||||
}
|
||||
|
||||
func newTarget(vchannel string, pos *Pos) *target {
|
||||
func newTarget(vchannel string, pos *Pos, replicateConfig *msgstream.ReplicateConfig) *target {
|
||||
maxTolerantLag := paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second)
|
||||
t := &target{
|
||||
vchannel: vchannel,
|
||||
ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()),
|
||||
pos: pos,
|
||||
cancelCh: lifetime.NewSafeChan(),
|
||||
maxLag: maxTolerantLag,
|
||||
timer: time.NewTimer(maxTolerantLag),
|
||||
vchannel: vchannel,
|
||||
ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()),
|
||||
pos: pos,
|
||||
cancelCh: lifetime.NewSafeChan(),
|
||||
maxLag: maxTolerantLag,
|
||||
timer: time.NewTimer(maxTolerantLag),
|
||||
replicateConfig: replicateConfig,
|
||||
}
|
||||
t.closed = false
|
||||
if replicateConfig != nil {
|
||||
log.Info("have replicate config",
|
||||
zap.String("vchannel", vchannel),
|
||||
zap.String("replicateID", replicateConfig.ReplicateID))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestSendTimeout(t *testing.T) {
|
||||
target := newTarget("test1", &msgpb.MsgPosition{})
|
||||
target := newTarget("test1", &msgpb.MsgPosition{}, nil)
|
||||
|
||||
time.Sleep(paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second))
|
||||
|
||||
|
|
|
@ -72,6 +72,9 @@ type mqMsgStream struct {
|
|||
ttMsgEnable atomic.Value
|
||||
forceEnableProduce atomic.Value
|
||||
configEvent config.EventHandler
|
||||
|
||||
replicateID string
|
||||
checkFunc CheckReplicateMsgFunc
|
||||
}
|
||||
|
||||
// NewMqMsgStream is used to generate a new mqMsgStream object
|
||||
|
@ -276,6 +279,23 @@ func (ms *mqMsgStream) isEnabledProduce() bool {
|
|||
return ms.forceEnableProduce.Load().(bool) || ms.ttMsgEnable.Load().(bool)
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) isSkipSystemTT() bool {
|
||||
return ms.replicateID != ""
|
||||
}
|
||||
|
||||
// checkReplicateID check the replicate id of the message, return values: isMatch, isReplicate
|
||||
func (ms *mqMsgStream) checkReplicateID(msg TsMsg) (bool, bool) {
|
||||
if !ms.isSkipSystemTT() {
|
||||
return true, false
|
||||
}
|
||||
msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase })
|
||||
if !ok {
|
||||
log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type()))
|
||||
return false, false
|
||||
}
|
||||
return msgBase.GetBase().GetReplicateInfo().GetReplicateID() == ms.replicateID, true
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
|
||||
if !ms.isEnabledProduce() {
|
||||
log.Ctx(ms.ctx).Warn("can't produce the msg in the backup instance", zap.Stack("stack"))
|
||||
|
@ -688,9 +708,9 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
|
|||
startBufTime := time.Now()
|
||||
var endTs uint64
|
||||
var size uint64
|
||||
var containsDropCollectionMsg bool
|
||||
var containsEndBufferMsg bool
|
||||
|
||||
for ms.continueBuffering(endTs, size, startBufTime) && !containsDropCollectionMsg {
|
||||
for ms.continueBuffering(endTs, size, startBufTime) && !containsEndBufferMsg {
|
||||
ms.consumerLock.Lock()
|
||||
// wait all channels get ttMsg
|
||||
for _, consumer := range ms.consumers {
|
||||
|
@ -726,15 +746,16 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() {
|
|||
timeTickMsg = v
|
||||
continue
|
||||
}
|
||||
if v.EndTs() <= currTs {
|
||||
if v.EndTs() <= currTs ||
|
||||
GetReplicateID(v) != "" {
|
||||
size += uint64(v.Size())
|
||||
timeTickBuf = append(timeTickBuf, v)
|
||||
} else {
|
||||
tempBuffer = append(tempBuffer, v)
|
||||
}
|
||||
// when drop collection, force to exit the buffer loop
|
||||
if v.Type() == commonpb.MsgType_DropCollection {
|
||||
containsDropCollectionMsg = true
|
||||
if v.Type() == commonpb.MsgType_DropCollection || v.Type() == commonpb.MsgType_Replicate {
|
||||
containsEndBufferMsg = true
|
||||
}
|
||||
}
|
||||
ms.chanMsgBuf[consumer] = tempBuffer
|
||||
|
@ -860,7 +881,7 @@ func (ms *MqTtMsgStream) allChanReachSameTtMsg(chanTtMsgSync map[mqwrapper.Consu
|
|||
}
|
||||
for consumer := range ms.chanTtMsgTime {
|
||||
ms.chanTtMsgTimeMutex.RLock()
|
||||
chanTtMsgSync[consumer] = (ms.chanTtMsgTime[consumer] == maxTime)
|
||||
chanTtMsgSync[consumer] = ms.chanTtMsgTime[consumer] == maxTime
|
||||
ms.chanTtMsgTimeMutex.RUnlock()
|
||||
}
|
||||
|
||||
|
@ -960,6 +981,10 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition,
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error())
|
||||
}
|
||||
// skip the replicate msg because it must have been consumed
|
||||
if GetReplicateID(tsMsg) != "" {
|
||||
continue
|
||||
}
|
||||
if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp {
|
||||
runLoop = false
|
||||
if time.Since(loopStarTime) > 30*time.Second {
|
||||
|
|
|
@ -708,6 +708,21 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
|
|||
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1))
|
||||
msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3))
|
||||
|
||||
replicatePack := MsgPack{}
|
||||
replicatePack.Msgs = append(replicatePack.Msgs, &ReplicateMsg{
|
||||
BaseMsg: BaseMsg{
|
||||
BeginTimestamp: 0,
|
||||
EndTimestamp: 0,
|
||||
HashValues: []uint32{100},
|
||||
},
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Replicate,
|
||||
Timestamp: 100,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgPack2 := MsgPack{}
|
||||
msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5))
|
||||
|
||||
|
@ -721,6 +736,9 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) {
|
|||
err = inputStream.Produce(ctx, &msgPack1)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
err = inputStream.Produce(ctx, &replicatePack)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
|
||||
|
||||
_, err = inputStream.Broadcast(ctx, &msgPack2)
|
||||
require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err))
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package msgstream
|
||||
|
||||
import (
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
)
|
||||
|
||||
type ReplicateMsg struct {
|
||||
BaseMsg
|
||||
*msgpb.ReplicateMsg
|
||||
}
|
||||
|
||||
var _ TsMsg = (*ReplicateMsg)(nil)
|
||||
|
||||
func (r *ReplicateMsg) ID() UniqueID {
|
||||
return r.Base.MsgID
|
||||
}
|
||||
|
||||
func (r *ReplicateMsg) SetID(id UniqueID) {
|
||||
r.Base.MsgID = id
|
||||
}
|
||||
|
||||
func (r *ReplicateMsg) Type() MsgType {
|
||||
return r.Base.MsgType
|
||||
}
|
||||
|
||||
func (r *ReplicateMsg) SourceID() int64 {
|
||||
return r.Base.SourceID
|
||||
}
|
||||
|
||||
func (r *ReplicateMsg) Marshal(input TsMsg) (MarshalType, error) {
|
||||
replicateMsg := input.(*ReplicateMsg)
|
||||
mb, err := proto.Marshal(replicateMsg.ReplicateMsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
func (r *ReplicateMsg) Unmarshal(input MarshalType) (TsMsg, error) {
|
||||
replicateMsg := &msgpb.ReplicateMsg{}
|
||||
in, err := convertToByteArray(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = proto.Unmarshal(in, replicateMsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rr := &ReplicateMsg{ReplicateMsg: replicateMsg}
|
||||
rr.BeginTimestamp = replicateMsg.GetBase().GetTimestamp()
|
||||
rr.EndTimestamp = replicateMsg.GetBase().GetTimestamp()
|
||||
|
||||
return rr, nil
|
||||
}
|
||||
|
||||
func (r *ReplicateMsg) Size() int {
|
||||
return proto.Size(r.ReplicateMsg)
|
||||
}
|
|
@ -19,7 +19,11 @@ package msgstream
|
|||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
@ -73,6 +77,50 @@ type MsgStream interface {
|
|||
ForceEnableProduce(can bool)
|
||||
}
|
||||
|
||||
type ReplicateConfig struct {
|
||||
ReplicateID string
|
||||
CheckFunc CheckReplicateMsgFunc
|
||||
}
|
||||
|
||||
type CheckReplicateMsgFunc func(*ReplicateMsg) bool
|
||||
|
||||
func GetReplicateConfig(replicateID, dbName, colName string) *ReplicateConfig {
|
||||
if replicateID == "" {
|
||||
return nil
|
||||
}
|
||||
replicateConfig := &ReplicateConfig{
|
||||
ReplicateID: replicateID,
|
||||
CheckFunc: func(msg *ReplicateMsg) bool {
|
||||
if !msg.GetIsEnd() {
|
||||
return false
|
||||
}
|
||||
log.Info("check replicate msg",
|
||||
zap.String("replicateID", replicateID),
|
||||
zap.String("dbName", dbName),
|
||||
zap.String("colName", colName),
|
||||
zap.Any("msg", msg))
|
||||
if msg.GetIsCluster() {
|
||||
return true
|
||||
}
|
||||
return msg.GetDatabase() == dbName && (msg.GetCollection() == colName || msg.GetCollection() == "")
|
||||
},
|
||||
}
|
||||
return replicateConfig
|
||||
}
|
||||
|
||||
func GetReplicateID(msg TsMsg) string {
|
||||
msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase })
|
||||
if !ok {
|
||||
log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type()))
|
||||
return ""
|
||||
}
|
||||
return msgBase.GetBase().GetReplicateInfo().GetReplicateID()
|
||||
}
|
||||
|
||||
func MatchReplicateID(msg TsMsg, replicateID string) bool {
|
||||
return GetReplicateID(msg) == replicateID
|
||||
}
|
||||
|
||||
type Factory interface {
|
||||
NewMsgStream(ctx context.Context) (MsgStream, error)
|
||||
NewTtMsgStream(ctx context.Context) (MsgStream, error)
|
||||
|
|
|
@ -24,6 +24,8 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/pkg/mq/common"
|
||||
)
|
||||
|
||||
|
@ -80,3 +82,90 @@ func TestGetLatestMsgID(t *testing.T) {
|
|||
assert.Equal(t, []byte("mock"), id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateConfig(t *testing.T) {
|
||||
t.Run("get replicate id", func(t *testing.T) {
|
||||
{
|
||||
msg := &InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
ReplicateInfo: &commonpb.ReplicateInfo{
|
||||
ReplicateID: "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "local", GetReplicateID(msg))
|
||||
assert.True(t, MatchReplicateID(msg, "local"))
|
||||
}
|
||||
{
|
||||
msg := &InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "", GetReplicateID(msg))
|
||||
assert.False(t, MatchReplicateID(msg, "local"))
|
||||
}
|
||||
{
|
||||
msg := &MarshalFailTsMsg{}
|
||||
assert.Equal(t, "", GetReplicateID(msg))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("get replicate config", func(t *testing.T) {
|
||||
{
|
||||
assert.Nil(t, GetReplicateConfig("", "", ""))
|
||||
}
|
||||
{
|
||||
rc := GetReplicateConfig("local", "db", "col")
|
||||
assert.Equal(t, "local", rc.ReplicateID)
|
||||
checkFunc := rc.CheckFunc
|
||||
assert.False(t, checkFunc(&ReplicateMsg{
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{},
|
||||
}))
|
||||
assert.True(t, checkFunc(&ReplicateMsg{
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
IsEnd: true,
|
||||
IsCluster: true,
|
||||
},
|
||||
}))
|
||||
assert.False(t, checkFunc(&ReplicateMsg{
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
IsEnd: true,
|
||||
Database: "db1",
|
||||
},
|
||||
}))
|
||||
assert.True(t, checkFunc(&ReplicateMsg{
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
IsEnd: true,
|
||||
Database: "db",
|
||||
},
|
||||
}))
|
||||
assert.False(t, checkFunc(&ReplicateMsg{
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
IsEnd: true,
|
||||
Database: "db",
|
||||
Collection: "col1",
|
||||
},
|
||||
}))
|
||||
}
|
||||
{
|
||||
rc := GetReplicateConfig("local", "db", "col")
|
||||
checkFunc := rc.CheckFunc
|
||||
assert.True(t, checkFunc(&ReplicateMsg{
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
IsEnd: true,
|
||||
Database: "db",
|
||||
},
|
||||
}))
|
||||
assert.False(t, checkFunc(&ReplicateMsg{
|
||||
ReplicateMsg: &msgpb.ReplicateMsg{
|
||||
IsEnd: true,
|
||||
Database: "db1",
|
||||
Collection: "col1",
|
||||
},
|
||||
}))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -84,6 +84,7 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher {
|
|||
dropRoleMsg := DropRoleMsg{}
|
||||
operateUserRoleMsg := OperateUserRoleMsg{}
|
||||
operatePrivilegeMsg := OperatePrivilegeMsg{}
|
||||
replicateMsg := ReplicateMsg{}
|
||||
|
||||
p := &ProtoUnmarshalDispatcher{}
|
||||
p.TempMap = make(map[commonpb.MsgType]UnmarshalFunc)
|
||||
|
@ -113,6 +114,7 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher {
|
|||
p.TempMap[commonpb.MsgType_DropRole] = dropRoleMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_OperateUserRole] = operateUserRoleMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_OperatePrivilege] = operatePrivilegeMsg.Unmarshal
|
||||
p.TempMap[commonpb.MsgType_Replicate] = replicateMsg.Unmarshal
|
||||
|
||||
return p
|
||||
}
|
||||
|
|
|
@ -70,6 +70,7 @@ var (
|
|||
ErrCollectionIllegalSchema = newMilvusError("illegal collection schema", 105, false)
|
||||
ErrCollectionOnRecovering = newMilvusError("collection on recovering", 106, true)
|
||||
ErrCollectionVectorClusteringKeyNotAllowed = newMilvusError("vector clustering key not allowed", 107, false)
|
||||
ErrCollectionReplicateMode = newMilvusError("can't operate on the collection under standby mode", 108, false)
|
||||
|
||||
// Partition related
|
||||
ErrPartitionNotFound = newMilvusError("partition not found", 200, false)
|
||||
|
|
|
@ -330,6 +330,10 @@ func WrapErrAsInputErrorWhen(err error, targets ...milvusError) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func WrapErrCollectionReplicateMode(operation string) error {
|
||||
return wrapFields(ErrCollectionReplicateMode, value("operation", operation))
|
||||
}
|
||||
|
||||
func GetErrorType(err error) ErrorType {
|
||||
if merr, ok := err.(milvusError); ok {
|
||||
return merr.errType
|
||||
|
|
|
@ -268,6 +268,7 @@ type commonConfig struct {
|
|||
MaxBloomFalsePositive ParamItem `refreshable:"true"`
|
||||
BloomFilterApplyBatchSize ParamItem `refreshable:"true"`
|
||||
PanicWhenPluginFail ParamItem `refreshable:"false"`
|
||||
CollectionReplicateEnable ParamItem `refreshable:"true"`
|
||||
|
||||
UsePartitionKeyAsClusteringKey ParamItem `refreshable:"true"`
|
||||
UseVectorAsClusteringKey ParamItem `refreshable:"true"`
|
||||
|
@ -784,6 +785,15 @@ This helps Milvus-CDC synchronize incremental data`,
|
|||
}
|
||||
p.TTMsgEnabled.Init(base.mgr)
|
||||
|
||||
p.CollectionReplicateEnable = ParamItem{
|
||||
Key: "common.collectionReplicateEnable",
|
||||
Version: "2.4.16",
|
||||
DefaultValue: "false",
|
||||
Doc: `Whether to enable collection replication.`,
|
||||
Export: true,
|
||||
}
|
||||
p.CollectionReplicateEnable.Init(base.mgr)
|
||||
|
||||
p.TraceLogMode = ParamItem{
|
||||
Key: "common.traceLogMode",
|
||||
Version: "2.3.4",
|
||||
|
|
Loading…
Reference in New Issue