diff --git a/cmd/distributed/roles/roles.go b/cmd/distributed/roles/roles.go index 3369acb45e..590d73b07f 100644 --- a/cmd/distributed/roles/roles.go +++ b/cmd/distributed/roles/roles.go @@ -15,7 +15,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms" "github.com/zilliztech/milvus-distributed/internal/msgstream/rmqms" - "github.com/zilliztech/milvus-distributed/internal/util/rocksmq" + "github.com/zilliztech/milvus-distributed/internal/util/rocksmq/server/rocksmq" ) func newMsgFactory(localMsg bool) msgstream.Factory { diff --git a/internal/msgstream/rmqms/factory.go b/internal/msgstream/rmqms/factory.go index 462978b95a..1a7b44d9f5 100644 --- a/internal/msgstream/rmqms/factory.go +++ b/internal/msgstream/rmqms/factory.go @@ -6,7 +6,7 @@ import ( "github.com/mitchellh/mapstructure" "github.com/zilliztech/milvus-distributed/internal/msgstream" - "github.com/zilliztech/milvus-distributed/internal/util/rocksmq" + "github.com/zilliztech/milvus-distributed/internal/util/rocksmq/server/rocksmq" ) type Factory struct { diff --git a/internal/msgstream/rmqms/rmq_msgstream.go b/internal/msgstream/rmqms/rmq_msgstream.go index d8b3cc5f73..6600b24536 100644 --- a/internal/msgstream/rmqms/rmq_msgstream.go +++ b/internal/msgstream/rmqms/rmq_msgstream.go @@ -12,7 +12,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/msgstream/util" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - rocksmq "github.com/zilliztech/milvus-distributed/internal/util/rocksmq" + "github.com/zilliztech/milvus-distributed/internal/util/rocksmq/server/rocksmq" "github.com/zilliztech/milvus-distributed/internal/msgstream" ) @@ -71,18 +71,18 @@ func newRmqMsgStream(ctx context.Context, receiveBufSize int64, rmqBufSize int64 return stream, nil } -func (ms *RmqMsgStream) Start() { +func (rms *RmqMsgStream) Start() { } -func (ms *RmqMsgStream) Close() { - ms.streamCancel() +func (rms *RmqMsgStream) Close() { + rms.streamCancel() - for _, producer := range ms.producers { + for _, producer := range rms.producers { if producer != "" { _ = rocksmq.Rmq.DestroyChannel(producer) } } - for _, consumer := range ms.consumers { + for _, consumer := range rms.consumers { _ = rocksmq.Rmq.DestroyConsumerGroup(consumer.GroupName, consumer.ChannelName) close(consumer.MsgMutex) } @@ -92,15 +92,15 @@ type propertiesReaderWriter struct { ppMap map[string]string } -func (ms *RmqMsgStream) SetRepackFunc(repackFunc RepackFunc) { - ms.repackFunc = repackFunc +func (rms *RmqMsgStream) SetRepackFunc(repackFunc RepackFunc) { + rms.repackFunc = repackFunc } -func (ms *RmqMsgStream) AsProducer(channels []string) { +func (rms *RmqMsgStream) AsProducer(channels []string) { for _, channel := range channels { err := rocksmq.Rmq.CreateChannel(channel) if err == nil { - ms.producers = append(ms.producers, channel) + rms.producers = append(rms.producers, channel) } else { errMsg := "Failed to create producer " + channel + ", error = " + err.Error() panic(errMsg) @@ -108,31 +108,31 @@ func (ms *RmqMsgStream) AsProducer(channels []string) { } } -func (ms *RmqMsgStream) AsConsumer(channels []string, groupName string) { +func (rms *RmqMsgStream) AsConsumer(channels []string, groupName string) { for _, channelName := range channels { consumer, err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName) if err == nil { - consumer.MsgMutex = make(chan struct{}, ms.rmqBufSize) + consumer.MsgMutex = make(chan struct{}, rms.rmqBufSize) //consumer.MsgMutex <- struct{}{} - ms.consumers = append(ms.consumers, *consumer) - ms.consumerChannels = append(ms.consumerChannels, channelName) - ms.consumerReflects = append(ms.consumerReflects, reflect.SelectCase{ + rms.consumers = append(rms.consumers, *consumer) + rms.consumerChannels = append(rms.consumerChannels, channelName) + rms.consumerReflects = append(rms.consumerReflects, reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(consumer.MsgMutex), }) - ms.wait.Add(1) - go ms.receiveMsg(*consumer) + rms.wait.Add(1) + go rms.receiveMsg(*consumer) } } } -func (ms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error { +func (rms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error { tsMsgs := pack.Msgs if len(tsMsgs) <= 0 { log.Printf("Warning: Receive empty msgPack") return nil } - if len(ms.producers) <= 0 { + if len(rms.producers) <= 0 { return errors.New("nil producer in msg stream") } reBucketValues := make([][]int32, len(tsMsgs)) @@ -144,21 +144,21 @@ func (ms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) er searchResult := tsMsg.(*msgstream.SearchResultMsg) channelID := searchResult.ResultChannelID channelIDInt, _ := strconv.ParseInt(channelID, 10, 64) - if channelIDInt >= int64(len(ms.producers)) { + if channelIDInt >= int64(len(rms.producers)) { return errors.New("Failed to produce rmq msg to unKnow channel") } bucketValues[index] = int32(channelIDInt) continue } - bucketValues[index] = int32(hashValue % uint32(len(ms.producers))) + bucketValues[index] = int32(hashValue % uint32(len(rms.producers))) } reBucketValues[channelID] = bucketValues } var result map[int32]*msgstream.MsgPack var err error - if ms.repackFunc != nil { - result, err = ms.repackFunc(tsMsgs, reBucketValues) + if rms.repackFunc != nil { + result, err = rms.repackFunc(tsMsgs, reBucketValues) } else { msgType := (tsMsgs[0]).Type() switch msgType { @@ -187,7 +187,7 @@ func (ms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) er msg := make([]rocksmq.ProducerMessage, 0) msg = append(msg, *rocksmq.NewProducerMessage(m)) - if err := rocksmq.Rmq.Produce(ms.producers[k], msg); err != nil { + if err := rocksmq.Rmq.Produce(rms.producers[k], msg); err != nil { return err } } @@ -195,8 +195,8 @@ func (ms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) er return nil } -func (ms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error { - producerLen := len(ms.producers) +func (rms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error { + producerLen := len(rms.producers) for _, v := range msgPack.Msgs { mb, err := v.Marshal(v) if err != nil { @@ -210,7 +210,7 @@ func (ms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error { msg = append(msg, *rocksmq.NewProducerMessage(m)) for i := 0; i < producerLen; i++ { - if err := rocksmq.Rmq.Produce(ms.producers[i], msg); err != nil { + if err := rocksmq.Rmq.Produce(rms.producers[i], msg); err != nil { return err } } @@ -218,16 +218,16 @@ func (ms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error { return nil } -func (ms *RmqMsgStream) Consume() (*msgstream.MsgPack, context.Context) { +func (rms *RmqMsgStream) Consume() (*msgstream.MsgPack, context.Context) { for { select { - case cm, ok := <-ms.receiveBuf: + case cm, ok := <-rms.receiveBuf: if !ok { log.Println("buf chan closed") return nil, nil } return cm, nil - case <-ms.ctx.Done(): + case <-rms.ctx.Done(): log.Printf("context closed") return nil, nil } @@ -238,12 +238,12 @@ func (ms *RmqMsgStream) Consume() (*msgstream.MsgPack, context.Context) { receiveMsg func is used to solve search timeout problem which is caused by selectcase */ -func (ms *RmqMsgStream) receiveMsg(consumer rocksmq.Consumer) { - defer ms.wait.Done() +func (rms *RmqMsgStream) receiveMsg(consumer rocksmq.Consumer) { + defer rms.wait.Done() for { select { - case <-ms.ctx.Done(): + case <-rms.ctx.Done(): return case _, ok := <-consumer.MsgMutex: if !ok { @@ -266,7 +266,7 @@ func (ms *RmqMsgStream) receiveMsg(consumer rocksmq.Consumer) { log.Printf("Failed to unmar`shal message header, error = %v", err) continue } - tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg.Payload, headerMsg.Base.MsgType) + tsMsg, err := rms.unmarshal.Unmarshal(rmqMsg.Payload, headerMsg.Base.MsgType) if err != nil { log.Printf("Failed to unmarshal tsMsg, error = %v", err) continue @@ -276,24 +276,24 @@ func (ms *RmqMsgStream) receiveMsg(consumer rocksmq.Consumer) { if len(tsMsgList) > 0 { msgPack := util.MsgPack{Msgs: tsMsgList} - ms.receiveBuf <- &msgPack + rms.receiveBuf <- &msgPack } } } } -func (ms *RmqMsgStream) Chan() <-chan *msgstream.MsgPack { - return ms.receiveBuf +func (rms *RmqMsgStream) Chan() <-chan *msgstream.MsgPack { + return rms.receiveBuf } -func (ms *RmqMsgStream) Seek(offset *msgstream.MsgPosition) error { - for i := 0; i < len(ms.consumers); i++ { - if ms.consumers[i].ChannelName == offset.ChannelName { +func (rms *RmqMsgStream) Seek(offset *msgstream.MsgPosition) error { + for i := 0; i < len(rms.consumers); i++ { + if rms.consumers[i].ChannelName == offset.ChannelName { messageID, err := strconv.ParseInt(offset.MsgID, 10, 64) if err != nil { return err } - err = rocksmq.Rmq.Seek(ms.consumers[i].GroupName, ms.consumers[i].ChannelName, messageID) + err = rocksmq.Rmq.Seek(rms.consumers[i].GroupName, rms.consumers[i].ChannelName, messageID) if err != nil { return err } @@ -325,64 +325,64 @@ func newRmqTtMsgStream(ctx context.Context, receiveBufSize int64, rmqBufSize int }, nil } -func (ms *RmqTtMsgStream) AsConsumer(channels []string, +func (rtms *RmqTtMsgStream) AsConsumer(channels []string, groupName string) { for _, channelName := range channels { consumer, err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName) if err != nil { panic(err.Error()) } - consumer.MsgMutex = make(chan struct{}, ms.rmqBufSize) + consumer.MsgMutex = make(chan struct{}, rtms.rmqBufSize) //consumer.MsgMutex <- struct{}{} - ms.consumers = append(ms.consumers, *consumer) - ms.consumerChannels = append(ms.consumerChannels, consumer.ChannelName) - ms.consumerReflects = append(ms.consumerReflects, reflect.SelectCase{ + rtms.consumers = append(rtms.consumers, *consumer) + rtms.consumerChannels = append(rtms.consumerChannels, consumer.ChannelName) + rtms.consumerReflects = append(rtms.consumerReflects, reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(consumer.MsgMutex), }) } } -func (ms *RmqTtMsgStream) Start() { - ms.wait = &sync.WaitGroup{} - if ms.consumers != nil { - ms.wait.Add(1) - go ms.bufMsgPackToChannel() +func (rtms *RmqTtMsgStream) Start() { + rtms.wait = &sync.WaitGroup{} + if rtms.consumers != nil { + rtms.wait.Add(1) + go rtms.bufMsgPackToChannel() } } -func (ms *RmqTtMsgStream) bufMsgPackToChannel() { - defer ms.wait.Done() - ms.unsolvedBuf = make(map[rocksmq.Consumer][]TsMsg) +func (rtms *RmqTtMsgStream) bufMsgPackToChannel() { + defer rtms.wait.Done() + rtms.unsolvedBuf = make(map[rocksmq.Consumer][]TsMsg) isChannelReady := make(map[rocksmq.Consumer]bool) eofMsgTimeStamp := make(map[rocksmq.Consumer]Timestamp) for { select { - case <-ms.ctx.Done(): + case <-rtms.ctx.Done(): return default: wg := sync.WaitGroup{} findMapMutex := sync.RWMutex{} - ms.consumerLock.Lock() - for _, consumer := range ms.consumers { + rtms.consumerLock.Lock() + for _, consumer := range rtms.consumers { if isChannelReady[consumer] { continue } wg.Add(1) - go ms.findTimeTick(consumer, eofMsgTimeStamp, &wg, &findMapMutex) + go rtms.findTimeTick(consumer, eofMsgTimeStamp, &wg, &findMapMutex) } wg.Wait() timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady, &findMapMutex) - ms.consumerLock.Unlock() - if !ok || timeStamp <= ms.lastTimeStamp { + rtms.consumerLock.Unlock() + if !ok || timeStamp <= rtms.lastTimeStamp { //log.Printf("All timeTick's timestamps are inconsistent") continue } timeTickBuf := make([]TsMsg, 0) msgPositions := make([]*msgstream.MsgPosition, 0) - ms.unsolvedMutex.Lock() - for consumer, msgs := range ms.unsolvedBuf { + rtms.unsolvedMutex.Lock() + for consumer, msgs := range rtms.unsolvedBuf { if len(msgs) == 0 { continue } @@ -399,7 +399,7 @@ func (ms *RmqTtMsgStream) bufMsgPackToChannel() { tempBuffer = append(tempBuffer, v) } } - ms.unsolvedBuf[consumer] = tempBuffer + rtms.unsolvedBuf[consumer] = tempBuffer if len(tempBuffer) > 0 { msgPositions = append(msgPositions, &msgstream.MsgPosition{ @@ -415,29 +415,29 @@ func (ms *RmqTtMsgStream) bufMsgPackToChannel() { }) } } - ms.unsolvedMutex.Unlock() + rtms.unsolvedMutex.Unlock() msgPack := MsgPack{ - BeginTs: ms.lastTimeStamp, + BeginTs: rtms.lastTimeStamp, EndTs: timeStamp, Msgs: timeTickBuf, StartPositions: msgPositions, } - ms.receiveBuf <- &msgPack - ms.lastTimeStamp = timeStamp + rtms.receiveBuf <- &msgPack + rtms.lastTimeStamp = timeStamp } } } -func (ms *RmqTtMsgStream) findTimeTick(consumer rocksmq.Consumer, +func (rtms *RmqTtMsgStream) findTimeTick(consumer rocksmq.Consumer, eofMsgMap map[rocksmq.Consumer]Timestamp, wg *sync.WaitGroup, findMapMutex *sync.RWMutex) { defer wg.Done() for { select { - case <-ms.ctx.Done(): + case <-rtms.ctx.Done(): return case _, ok := <-consumer.MsgMutex: if !ok { @@ -460,7 +460,7 @@ func (ms *RmqTtMsgStream) findTimeTick(consumer rocksmq.Consumer, log.Printf("Failed to unmarshal message header, error = %v", err) continue } - tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg.Payload, headerMsg.Base.MsgType) + tsMsg, err := rtms.unmarshal.Unmarshal(rmqMsg.Payload, headerMsg.Base.MsgType) if err != nil { log.Printf("Failed to unmarshal tsMsg, error = %v", err) continue @@ -471,9 +471,9 @@ func (ms *RmqTtMsgStream) findTimeTick(consumer rocksmq.Consumer, MsgID: strconv.Itoa(int(rmqMsg.MsgID)), }) - ms.unsolvedMutex.Lock() - ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) - ms.unsolvedMutex.Unlock() + rtms.unsolvedMutex.Lock() + rtms.unsolvedBuf[consumer] = append(rtms.unsolvedBuf[consumer], tsMsg) + rtms.unsolvedMutex.Unlock() if headerMsg.Base.MsgType == commonpb.MsgType_TimeTick { findMapMutex.Lock() @@ -487,12 +487,12 @@ func (ms *RmqTtMsgStream) findTimeTick(consumer rocksmq.Consumer, } } -func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { +func (rtms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { var consumer rocksmq.Consumer var msgID UniqueID - for index, channel := range ms.consumerChannels { + for index, channel := range rtms.consumerChannels { if filepath.Base(channel) == filepath.Base(mp.ChannelName) { - consumer = ms.consumers[index] + consumer = rtms.consumers[index] if len(mp.MsgID) == 0 { msgID = -1 break @@ -512,8 +512,8 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { if msgID == -1 { return nil } - ms.unsolvedMutex.Lock() - ms.unsolvedBuf[consumer] = make([]TsMsg, 0) + rtms.unsolvedMutex.Lock() + rtms.unsolvedBuf[consumer] = make([]TsMsg, 0) // When rmq seek is called, msgMutex can't be used before current msgs all consumed, because // new msgMutex is not generated. So just try to consume msgs @@ -531,7 +531,7 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { log.Printf("Failed to unmarshal message header, error = %v", err) return err } - tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg[0].Payload, headerMsg.Base.MsgType) + tsMsg, err := rtms.unmarshal.Unmarshal(rmqMsg[0].Payload, headerMsg.Base.MsgType) if err != nil { log.Printf("Failed to unmarshal tsMsg, error = %v", err) return err @@ -539,7 +539,7 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { if headerMsg.Base.MsgType == commonpb.MsgType_TimeTick { if tsMsg.BeginTs() >= mp.Timestamp { - ms.unsolvedMutex.Unlock() + rtms.unsolvedMutex.Unlock() return nil } continue @@ -549,7 +549,7 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { ChannelName: filepath.Base(consumer.ChannelName), MsgID: strconv.Itoa(int(rmqMsg[0].MsgID)), }) - ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) + rtms.unsolvedBuf[consumer] = append(rtms.unsolvedBuf[consumer], tsMsg) } } } @@ -557,7 +557,7 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { //for { // select { - // case <-ms.ctx.Done(): + // case <-rtms.ctx.Done(): // return nil // case num, ok := <-consumer.MsgNum: // if !ok { @@ -575,14 +575,14 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { // if err != nil { // log.Printf("Failed to unmarshal message header, error = %v", err) // } - // tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg[j].Payload, headerMsg.Base.MsgType) + // tsMsg, err := rtms.unmarshal.Unmarshal(rmqMsg[j].Payload, headerMsg.Base.MsgType) // if err != nil { // log.Printf("Failed to unmarshal tsMsg, error = %v", err) // } // // if headerMsg.Base.MsgType == commonpb.MsgType_kTimeTick { // if tsMsg.BeginTs() >= mp.Timestamp { - // ms.unsolvedMutex.Unlock() + // rtms.unsolvedMutex.Unlock() // return nil // } // continue @@ -592,7 +592,7 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { // ChannelName: filepath.Base(consumer.ChannelName), // MsgID: strconv.Itoa(int(rmqMsg[j].MsgID)), // }) - // ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) + // rtms.unsolvedBuf[consumer] = append(rtms.unsolvedBuf[consumer], tsMsg) // } // } // } diff --git a/internal/msgstream/rmqms/rmq_msgstream_test.go b/internal/msgstream/rmqms/rmq_msgstream_test.go index 4b335af1d3..81d511922a 100644 --- a/internal/msgstream/rmqms/rmq_msgstream_test.go +++ b/internal/msgstream/rmqms/rmq_msgstream_test.go @@ -10,7 +10,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/allocator" etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd" - "github.com/zilliztech/milvus-distributed/internal/util/rocksmq" + "github.com/zilliztech/milvus-distributed/internal/util/rocksmq/server/rocksmq" "go.etcd.io/etcd/clientv3" "github.com/zilliztech/milvus-distributed/internal/msgstream" diff --git a/internal/util/rocksmq/client/rocksmq/client.go b/internal/util/rocksmq/client/rocksmq/client.go new file mode 100644 index 0000000000..0220499de6 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/client.go @@ -0,0 +1,26 @@ +package rocksmq + +import ( + server "github.com/zilliztech/milvus-distributed/internal/util/rocksmq/server/rocksmq" +) + +type RocksMQ = server.RocksMQ + +func NewClient(options ClientOptions) (Client, error) { + return newClient(options) +} + +type ClientOptions struct { + server *RocksMQ +} + +type Client interface { + // Create a producer instance + CreateProducer(options ProducerOptions) (Producer, error) + + // Create a consumer instance and subscribe a topic + Subscribe(options ConsumerOptions) (Consumer, error) + + // Close the client and free associated resources + Close() +} diff --git a/internal/util/rocksmq/client/rocksmq/client_impl.go b/internal/util/rocksmq/client/rocksmq/client_impl.go new file mode 100644 index 0000000000..58035d2684 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/client_impl.go @@ -0,0 +1,52 @@ +package rocksmq + +type client struct { + server *RocksMQ +} + +func newClient(options ClientOptions) (*client, error) { + if options.server == nil { + return nil, newError(InvalidConfiguration, "Server is nil") + } + + c := &client{ + server: options.server, + } + return c, nil +} + +func (c *client) CreateProducer(options ProducerOptions) (Producer, error) { + // Create a producer + producer, err := newProducer(c, options) + if err != nil { + return nil, err + } + + // Create a topic in rocksmq, ignore if topic exists + err = c.server.CreateChannel(options.Topic) + if err != nil { + return nil, err + } + + return producer, nil +} + +func (c *client) Subscribe(options ConsumerOptions) (Consumer, error) { + // Create a consumer + consumer, err := newConsumer(c, options) + if err != nil { + return nil, err + } + + // Create a consumergroup in rocksmq, raise error if consumergroup exists + _, err = c.server.CreateConsumerGroup(options.SubscriptionName, options.Topic) + if err != nil { + return nil, err + } + + return consumer, nil +} + +func (c *client) Close() { + // TODO: free resources +} diff --git a/internal/util/rocksmq/client/rocksmq/client_impl_test.go b/internal/util/rocksmq/client/rocksmq/client_impl_test.go new file mode 100644 index 0000000000..89994af529 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/client_impl_test.go @@ -0,0 +1,45 @@ +package rocksmq + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClient(t *testing.T) { + client, err := NewClient(ClientOptions{}) + assert.Nil(t, client) + assert.NotNil(t, err) + assert.Equal(t, InvalidConfiguration, err.(*Error).Result()) +} + +func TestCreateProducer(t *testing.T) { + client, err := NewClient(ClientOptions{ + server: newMockRocksMQ(), + }) + assert.NoError(t, err) + + producer, err := client.CreateProducer(ProducerOptions{ + Topic: newTopicName(), + }) + assert.NoError(t, err) + assert.NotNil(t, producer) + + client.Close() +} + +func TestSubscribe(t *testing.T) { + client, err := NewClient(ClientOptions{ + server: newMockRocksMQ(), + }) + assert.NoError(t, err) + + consumer, err := client.Subscribe(ConsumerOptions{ + Topic: newTopicName(), + SubscriptionName: newConsumerName(), + }) + assert.NoError(t, err) + assert.NotNil(t, consumer) + + client.Close() +} diff --git a/internal/util/rocksmq/client/rocksmq/consumer.go b/internal/util/rocksmq/client/rocksmq/consumer.go new file mode 100644 index 0000000000..1de9234797 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/consumer.go @@ -0,0 +1,43 @@ +package rocksmq + +import ( + "context" +) + +type SubscriptionInitialPosition int + +const ( + SubscriptionPositionLatest SubscriptionInitialPosition = iota + SubscriptionPositionEarliest +) + +type ConsumerOptions struct { + // The topic that this consumer will subscribe on + Topic string + + // The subscription name for this consumer + SubscriptionName string + + // InitialPosition at which the cursor will be set when subscribe + // Default is `Latest` + SubscriptionInitialPosition + + // Message for this consumer + // When a message is received, it will be pushed to this channel for consumption + MessageChannel chan ConsumerMessage +} + +type ConsumerMessage struct { + Payload []byte +} + +type Consumer interface { + // returns the substription for the consumer + Subscription() string + + // Receive a single message + Receive(ctx context.Context) (ConsumerMessage, error) + + // TODO: Chan returns a channel to consume messages from + // Chan() <-chan ConsumerMessage +} diff --git a/internal/util/rocksmq/client/rocksmq/consumer_impl.go b/internal/util/rocksmq/client/rocksmq/consumer_impl.go new file mode 100644 index 0000000000..05fdce4ce3 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/consumer_impl.go @@ -0,0 +1,60 @@ +package rocksmq + +import ( + "context" +) + +type consumer struct { + topic string + client *client + consumerName string + options ConsumerOptions + + messageCh chan ConsumerMessage +} + +func newConsumer(c *client, options ConsumerOptions) (*consumer, error) { + if c == nil { + return nil, newError(InvalidConfiguration, "client is nil") + } + + if options.Topic == "" { + return nil, newError(InvalidConfiguration, "Topic is empty") + } + + if options.SubscriptionName == "" { + return nil, newError(InvalidConfiguration, "SubscriptionName is empty") + } + + messageCh := options.MessageChannel + if options.MessageChannel == nil { + messageCh = make(chan ConsumerMessage, 10) + } + + return &consumer{ + topic: options.Topic, + client: c, + consumerName: options.SubscriptionName, + options: options, + messageCh: messageCh, + }, nil +} + +func (c *consumer) Subscription() string { + return c.consumerName +} + +func (c *consumer) Receive(ctx context.Context) (ConsumerMessage, error) { + msgs, err := c.client.server.Consume(c.consumerName, c.topic, 1) + if err != nil { + return ConsumerMessage{}, err + } + + if len(msgs) == 0 { + return ConsumerMessage{}, nil + } + + return ConsumerMessage{ + Payload: msgs[0].Payload, + }, nil +} diff --git a/internal/util/rocksmq/client/rocksmq/consumer_impl_test.go b/internal/util/rocksmq/client/rocksmq/consumer_impl_test.go new file mode 100644 index 0000000000..fc17daf04b --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/consumer_impl_test.go @@ -0,0 +1,42 @@ +package rocksmq + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConsumer(t *testing.T) { + consumer, err := newConsumer(nil, ConsumerOptions{ + Topic: newTopicName(), + SubscriptionName: newConsumerName(), + SubscriptionInitialPosition: SubscriptionPositionLatest, + }) + assert.Nil(t, consumer) + assert.NotNil(t, err) + assert.Equal(t, InvalidConfiguration, err.(*Error).Result()) + + consumer, err = newConsumer(newMockClient(), ConsumerOptions{}) + assert.Nil(t, consumer) + assert.NotNil(t, err) + assert.Equal(t, InvalidConfiguration, err.(*Error).Result()) + + consumer, err = newConsumer(newMockClient(), ConsumerOptions{ + Topic: newTopicName(), + }) + assert.Nil(t, consumer) + assert.NotNil(t, err) + assert.Equal(t, InvalidConfiguration, err.(*Error).Result()) +} + +func TestSubscription(t *testing.T) { + topicName := newTopicName() + consumerName := newConsumerName() + consumer, err := newConsumer(newMockClient(), ConsumerOptions{ + Topic: topicName, + SubscriptionName: consumerName, + }) + assert.NotNil(t, consumer) + assert.Nil(t, err) + assert.Equal(t, consumerName, consumer.Subscription()) +} diff --git a/internal/util/rocksmq/client/rocksmq/error.go b/internal/util/rocksmq/client/rocksmq/error.go new file mode 100644 index 0000000000..a3801d1511 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/error.go @@ -0,0 +1,44 @@ +package rocksmq + +import "fmt" + +type Result int + +const ( + Ok Result = iota + UnknownError + InvalidConfiguration +) + +type Error struct { + msg string + result Result +} + +func (e *Error) Result() Result { + return e.result +} + +func (e *Error) Error() string { + return e.msg +} + +func newError(result Result, msg string) error { + return &Error{ + msg: fmt.Sprintf("%s: %s", msg, getResultStr(result)), + result: result, + } +} + +func getResultStr(r Result) string { + switch r { + case Ok: + return "OK" + case UnknownError: + return "UnknownError" + case InvalidConfiguration: + return "InvalidConfiguration" + default: + return fmt.Sprintf("Result(%d)", r) + } +} diff --git a/internal/util/rocksmq/client/rocksmq/producer.go b/internal/util/rocksmq/client/rocksmq/producer.go new file mode 100644 index 0000000000..a927b9dafb --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/producer.go @@ -0,0 +1,17 @@ +package rocksmq + +type ProducerOptions struct { + Topic string +} + +type ProducerMessage struct { + Payload []byte +} + +type Producer interface { + // return the topic which producer is publishing to + Topic() string + + // publish a message + Send(message *ProducerMessage) error +} diff --git a/internal/util/rocksmq/client/rocksmq/producer_impl.go b/internal/util/rocksmq/client/rocksmq/producer_impl.go new file mode 100644 index 0000000000..072c2d4cf9 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/producer_impl.go @@ -0,0 +1,38 @@ +package rocksmq + +import ( + server "github.com/zilliztech/milvus-distributed/internal/util/rocksmq/server/rocksmq" +) + +type producer struct { + // client which the producer belong to + c *client + topic string +} + +func newProducer(c *client, options ProducerOptions) (*producer, error) { + if c == nil { + return nil, newError(InvalidConfiguration, "client is nil") + } + + if options.Topic == "" { + return nil, newError(InvalidConfiguration, "Topic is empty") + } + + return &producer{ + c: c, + topic: options.Topic, + }, nil +} + +func (p *producer) Topic() string { + return p.topic +} + +func (p *producer) Send(message *ProducerMessage) error { + return p.c.server.Produce(p.topic, []server.ProducerMessage{ + { + Payload: message.Payload, + }, + }) +} diff --git a/internal/util/rocksmq/client/rocksmq/producer_impl_test.go b/internal/util/rocksmq/client/rocksmq/producer_impl_test.go new file mode 100644 index 0000000000..2cb36ffac0 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/producer_impl_test.go @@ -0,0 +1,33 @@ +package rocksmq + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProducer(t *testing.T) { + // invalid client + producer, err := newProducer(nil, ProducerOptions{ + Topic: newTopicName(), + }) + assert.Nil(t, producer) + assert.NotNil(t, err) + assert.Equal(t, InvalidConfiguration, err.(*Error).Result()) + + // invalid produceroptions + producer, err = newProducer(newMockClient(), ProducerOptions{}) + assert.Nil(t, producer) + assert.NotNil(t, err) + assert.Equal(t, InvalidConfiguration, err.(*Error).Result()) +} + +func TestProducerTopic(t *testing.T) { + topicName := newTopicName() + producer, err := newProducer(newMockClient(), ProducerOptions{ + Topic: topicName, + }) + assert.NotNil(t, producer) + assert.Nil(t, err) + assert.Equal(t, topicName, producer.Topic()) +} diff --git a/internal/util/rocksmq/client/rocksmq/test_helper.go b/internal/util/rocksmq/client/rocksmq/test_helper.go new file mode 100644 index 0000000000..3eaab78692 --- /dev/null +++ b/internal/util/rocksmq/client/rocksmq/test_helper.go @@ -0,0 +1,27 @@ +package rocksmq + +import ( + "fmt" + "time" + + server "github.com/zilliztech/milvus-distributed/internal/util/rocksmq/server/rocksmq" +) + +func newTopicName() string { + return fmt.Sprintf("my-topic-%v", time.Now().Nanosecond()) +} + +func newConsumerName() string { + return fmt.Sprintf("my-consumer-%v", time.Now().Nanosecond()) +} + +func newMockRocksMQ() *RocksMQ { + return &server.RocksMQ{} +} + +func newMockClient() *client { + client, _ := newClient(ClientOptions{ + server: newMockRocksMQ(), + }) + return client +} diff --git a/internal/util/rocksmq/global_rmq.go b/internal/util/rocksmq/server/rocksmq/global_rmq.go similarity index 100% rename from internal/util/rocksmq/global_rmq.go rename to internal/util/rocksmq/server/rocksmq/global_rmq.go diff --git a/internal/util/rocksmq/rocksmq.go b/internal/util/rocksmq/server/rocksmq/rocksmq.go similarity index 99% rename from internal/util/rocksmq/rocksmq.go rename to internal/util/rocksmq/server/rocksmq/rocksmq.go index 4854d14a73..3f22245015 100644 --- a/internal/util/rocksmq/rocksmq.go +++ b/internal/util/rocksmq/server/rocksmq/rocksmq.go @@ -55,7 +55,7 @@ func combKey(channelName string, id UniqueID) (string, error) { } type ProducerMessage struct { - payload []byte + Payload []byte } type ConsumerMessage struct { @@ -111,7 +111,7 @@ func NewRocksMQ(name string, idAllocator allocator.GIDAllocator) (*RocksMQ, erro func NewProducerMessage(data []byte) *ProducerMessage { return &ProducerMessage{ - payload: data, + Payload: data, } } @@ -229,7 +229,7 @@ func (rmq *RocksMQ) Produce(channelName string, messages []ProducerMessage) erro return err } - batch.Put([]byte(key), messages[i].payload) + batch.Put([]byte(key), messages[i].Payload) } err = rmq.store.Write(gorocksdb.NewDefaultWriteOptions(), batch) diff --git a/internal/util/rocksmq/rocksmq_test.go b/internal/util/rocksmq/server/rocksmq/rocksmq_test.go similarity index 94% rename from internal/util/rocksmq/rocksmq_test.go rename to internal/util/rocksmq/server/rocksmq/rocksmq_test.go index d69df08bd9..053bc11e10 100644 --- a/internal/util/rocksmq/rocksmq_test.go +++ b/internal/util/rocksmq/server/rocksmq/rocksmq_test.go @@ -47,15 +47,15 @@ func TestRocksMQ(t *testing.T) { msgA := "a_message" pMsgs := make([]ProducerMessage, 1) - pMsgA := ProducerMessage{payload: []byte(msgA)} + pMsgA := ProducerMessage{Payload: []byte(msgA)} pMsgs[0] = pMsgA _ = idAllocator.UpdateID() err = rmq.Produce(channelName, pMsgs) assert.Nil(t, err) - pMsgB := ProducerMessage{payload: []byte("b_message")} - pMsgC := ProducerMessage{payload: []byte("c_message")} + pMsgB := ProducerMessage{Payload: []byte("b_message")} + pMsgC := ProducerMessage{Payload: []byte("c_message")} pMsgs[0] = pMsgB pMsgs = append(pMsgs, pMsgC) @@ -106,7 +106,7 @@ func TestRocksMQ_Loop(t *testing.T) { // Produce one message once for i := 0; i < loopNum; i++ { msg := "message_" + strconv.Itoa(i) - pMsg := ProducerMessage{payload: []byte(msg)} + pMsg := ProducerMessage{Payload: []byte(msg)} pMsgs := make([]ProducerMessage, 1) pMsgs[0] = pMsg err := rmq.Produce(channelName, pMsgs) @@ -117,7 +117,7 @@ func TestRocksMQ_Loop(t *testing.T) { pMsgs := make([]ProducerMessage, loopNum) for i := 0; i < loopNum; i++ { msg := "message_" + strconv.Itoa(i+loopNum) - pMsg := ProducerMessage{payload: []byte(msg)} + pMsg := ProducerMessage{Payload: []byte(msg)} pMsgs[i] = pMsg } err = rmq.Produce(channelName, pMsgs) @@ -178,8 +178,8 @@ func TestRocksMQ_Goroutines(t *testing.T) { group.Add(2) msg0 := "message_" + strconv.Itoa(i) msg1 := "message_" + strconv.Itoa(i+1) - pMsg0 := ProducerMessage{payload: []byte(msg0)} - pMsg1 := ProducerMessage{payload: []byte(msg1)} + pMsg0 := ProducerMessage{Payload: []byte(msg0)} + pMsg1 := ProducerMessage{Payload: []byte(msg1)} pMsgs := make([]ProducerMessage, 2) pMsgs[0] = pMsg0 pMsgs[1] = pMsg1 @@ -245,7 +245,7 @@ func TestRocksMQ_Throughout(t *testing.T) { pt0 := time.Now().UnixNano() / int64(time.Millisecond) for i := 0; i < entityNum; i++ { msg := "message_" + strconv.Itoa(i) - pMsg := ProducerMessage{payload: []byte(msg)} + pMsg := ProducerMessage{Payload: []byte(msg)} assert.Nil(t, idAllocator.UpdateID()) err := rmq.Produce(channelName, []ProducerMessage{pMsg}) assert.Nil(t, err) @@ -303,8 +303,8 @@ func TestRocksMQ_MultiChan(t *testing.T) { for i := 0; i < loopNum; i++ { msg0 := "for_chann0_" + strconv.Itoa(i) msg1 := "for_chann1_" + strconv.Itoa(i) - pMsg0 := ProducerMessage{payload: []byte(msg0)} - pMsg1 := ProducerMessage{payload: []byte(msg1)} + pMsg0 := ProducerMessage{Payload: []byte(msg0)} + pMsg1 := ProducerMessage{Payload: []byte(msg1)} err = rmq.Produce(channelName0, []ProducerMessage{pMsg0}) assert.Nil(t, err) err = rmq.Produce(channelName1, []ProducerMessage{pMsg1})