Add memory message stream for search

Signed-off-by: groot <yihua.mo@zilliz.com>
pull/4973/head^2
groot 2021-03-19 20:16:04 +08:00 committed by yefu.chen
parent 0823382876
commit 2280791128
9 changed files with 570 additions and 4 deletions

View File

@ -0,0 +1,168 @@
package memms
import (
"errors"
"sync"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
)
var Mmq *MemMQ
var once sync.Once
type Consumer struct {
GroupName string
ChannelName string
MsgChan chan *msgstream.MsgPack
}
type MemMQ struct {
consumers map[string][]*Consumer
consumerMu sync.Mutex
}
func (mmq *MemMQ) CreateChannel(channelName string) error {
mmq.consumerMu.Lock()
defer mmq.consumerMu.Unlock()
if _, ok := mmq.consumers[channelName]; !ok {
consumers := make([]*Consumer, 0)
mmq.consumers[channelName] = consumers
}
return nil
}
func (mmq *MemMQ) DestroyChannel(channelName string) error {
mmq.consumerMu.Lock()
defer mmq.consumerMu.Unlock()
consumers, ok := mmq.consumers[channelName]
if ok {
// send nil to consumer so that client can close it self
for _, consumer := range consumers {
consumer.MsgChan <- nil
}
}
delete(mmq.consumers, channelName)
return nil
}
func (mmq *MemMQ) CreateConsumerGroup(groupName string, channelName string) (*Consumer, error) {
mmq.consumerMu.Lock()
defer mmq.consumerMu.Unlock()
consumers, ok := mmq.consumers[channelName]
if !ok {
consumers = make([]*Consumer, 0)
mmq.consumers[channelName] = consumers
}
// exist?
for _, consumer := range consumers {
if consumer.GroupName == groupName {
return consumer, nil
}
}
// append new
consumer := Consumer{
GroupName: groupName,
ChannelName: channelName,
MsgChan: make(chan *msgstream.MsgPack, 1024),
}
mmq.consumers[channelName] = append(mmq.consumers[channelName], &consumer)
return &consumer, nil
}
func (mmq *MemMQ) DestroyConsumerGroup(groupName string, channelName string) error {
mmq.consumerMu.Lock()
defer mmq.consumerMu.Unlock()
consumers, ok := mmq.consumers[channelName]
if !ok {
return nil
}
tempConsumers := make([]*Consumer, 0)
for _, consumer := range consumers {
if consumer.GroupName == groupName {
// send nil to consumer so that client can close it self
consumer.MsgChan <- nil
} else {
tempConsumers = append(tempConsumers, consumer)
}
}
mmq.consumers[channelName] = tempConsumers
return nil
}
func (mmq *MemMQ) Produce(channelName string, msgPack *MsgPack) error {
if msgPack == nil {
return nil
}
mmq.consumerMu.Lock()
defer mmq.consumerMu.Unlock()
consumers := mmq.consumers[channelName]
if consumers == nil {
return errors.New("Channel " + channelName + " doesn't exist")
}
for _, consumer := range consumers {
consumer.MsgChan <- msgPack
}
return nil
}
func (mmq *MemMQ) Broadcast(msgPack *MsgPack) error {
if msgPack == nil {
return nil
}
mmq.consumerMu.Lock()
defer mmq.consumerMu.Unlock()
for _, consumers := range mmq.consumers {
for _, consumer := range consumers {
consumer.MsgChan <- msgPack
}
}
return nil
}
func (mmq *MemMQ) Consume(groupName string, channelName string) (*MsgPack, error) {
var consumer *Consumer = nil
mmq.consumerMu.Lock()
consumers := mmq.consumers[channelName]
for _, c := range consumers {
if c.GroupName == groupName {
consumer = c
break
}
}
mmq.consumerMu.Unlock()
msg, ok := <-consumer.MsgChan
if !ok {
return nil, nil
}
return msg, nil
}
func InitMmq() error {
var err error
once.Do(func() {
Mmq = &MemMQ{
consumerMu: sync.Mutex{},
}
Mmq.consumers = make(map[string][]*Consumer)
})
return err
}

View File

@ -0,0 +1,206 @@
package memms
import (
"context"
"errors"
"log"
"strconv"
"sync"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/msgstream/util"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
)
type TsMsg = msgstream.TsMsg
type MsgPack = msgstream.MsgPack
type MsgType = msgstream.MsgType
type UniqueID = msgstream.UniqueID
type BaseMsg = msgstream.BaseMsg
type Timestamp = msgstream.Timestamp
type IntPrimaryKey = msgstream.IntPrimaryKey
type TimeTickMsg = msgstream.TimeTickMsg
type QueryNodeStatsMsg = msgstream.QueryNodeStatsMsg
type RepackFunc = msgstream.RepackFunc
type MemMsgStream struct {
ctx context.Context
streamCancel func()
repackFunc msgstream.RepackFunc
consumers []*Consumer
producers []string
receiveBuf chan *msgstream.MsgPack
wait sync.WaitGroup
}
func NewMemMsgStream(ctx context.Context, receiveBufSize int64) (*MemMsgStream, error) {
streamCtx, streamCancel := context.WithCancel(ctx)
receiveBuf := make(chan *msgstream.MsgPack, receiveBufSize)
channels := make([]string, 0)
consumers := make([]*Consumer, 0)
stream := &MemMsgStream{
ctx: streamCtx,
streamCancel: streamCancel,
receiveBuf: receiveBuf,
consumers: consumers,
producers: channels,
}
return stream, nil
}
func (mms *MemMsgStream) Start() {
}
func (mms *MemMsgStream) Close() {
for _, consumer := range mms.consumers {
Mmq.DestroyConsumerGroup(consumer.GroupName, consumer.ChannelName)
}
mms.streamCancel()
mms.wait.Wait()
}
func (mms *MemMsgStream) SetRepackFunc(repackFunc RepackFunc) {
mms.repackFunc = repackFunc
}
func (mms *MemMsgStream) AsProducer(channels []string) {
for _, channel := range channels {
err := Mmq.CreateChannel(channel)
if err == nil {
mms.producers = append(mms.producers, channel)
} else {
errMsg := "Failed to create producer " + channel + ", error = " + err.Error()
panic(errMsg)
}
}
}
func (mms *MemMsgStream) AsConsumer(channels []string, groupName string) {
for _, channelName := range channels {
consumer, err := Mmq.CreateConsumerGroup(groupName, channelName)
if err == nil {
mms.consumers = append(mms.consumers, consumer)
mms.wait.Add(1)
go mms.receiveMsg(*consumer)
}
}
}
func (mms *MemMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
tsMsgs := pack.Msgs
if len(tsMsgs) <= 0 {
log.Printf("Warning: Receive empty msgPack")
return nil
}
if len(mms.producers) <= 0 {
return errors.New("nil producer in msg stream")
}
reBucketValues := make([][]int32, len(tsMsgs))
for channelID, tsMsg := range tsMsgs {
hashValues := tsMsg.HashKeys()
bucketValues := make([]int32, len(hashValues))
for index, hashValue := range hashValues {
if tsMsg.Type() == commonpb.MsgType_SearchResult {
searchResult := tsMsg.(*msgstream.SearchResultMsg)
channelID := searchResult.ResultChannelID
channelIDInt, _ := strconv.ParseInt(channelID, 10, 64)
if channelIDInt >= int64(len(mms.producers)) {
return errors.New("Failed to produce rmq msg to unKnow channel")
}
bucketValues[index] = int32(channelIDInt)
continue
}
bucketValues[index] = int32(hashValue % uint32(len(mms.producers)))
}
reBucketValues[channelID] = bucketValues
}
var result map[int32]*msgstream.MsgPack
var err error
if mms.repackFunc != nil {
result, err = mms.repackFunc(tsMsgs, reBucketValues)
} else {
msgType := (tsMsgs[0]).Type()
switch msgType {
case commonpb.MsgType_Insert:
result, err = util.InsertRepackFunc(tsMsgs, reBucketValues)
case commonpb.MsgType_Delete:
result, err = util.DeleteRepackFunc(tsMsgs, reBucketValues)
default:
result, err = util.DefaultRepackFunc(tsMsgs, reBucketValues)
}
}
if err != nil {
return err
}
for k, v := range result {
err := Mmq.Produce(mms.producers[k], v)
if err != nil {
return err
}
}
return nil
}
func (mms *MemMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error {
for _, channelName := range mms.producers {
err := Mmq.Produce(channelName, msgPack)
if err != nil {
return err
}
}
return nil
}
func (mms *MemMsgStream) Consume() (*msgstream.MsgPack, context.Context) {
for {
select {
case cm, ok := <-mms.receiveBuf:
if !ok {
log.Println("buf chan closed")
return nil, nil
}
return cm, nil
case <-mms.ctx.Done():
log.Printf("context closed")
return nil, nil
}
}
}
/**
receiveMsg func is used to solve search timeout problem
which is caused by selectcase
*/
func (mms *MemMsgStream) receiveMsg(consumer Consumer) {
defer mms.wait.Done()
for {
select {
case <-mms.ctx.Done():
return
case msg := <-consumer.MsgChan:
if msg == nil {
return
}
mms.receiveBuf <- msg
}
}
}
func (mms *MemMsgStream) Chan() <-chan *msgstream.MsgPack {
return mms.receiveBuf
}
func (mms *MemMsgStream) Seek(offset *msgstream.MsgPosition) error {
return errors.New("MemMsgStream seek not implemented")
}

View File

@ -0,0 +1,181 @@
package memms
import (
"context"
"log"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)
func getTsMsg(msgType MsgType, reqID UniqueID, hashValue uint32) TsMsg {
baseMsg := BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{hashValue},
}
switch msgType {
case commonpb.MsgType_Search:
searchRequest := internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: reqID,
Timestamp: 11,
SourceID: reqID,
},
Query: nil,
ResultChannelID: "0",
}
searchMsg := &msgstream.SearchMsg{
BaseMsg: baseMsg,
SearchRequest: searchRequest,
}
return searchMsg
case commonpb.MsgType_SearchResult:
searchResult := internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: reqID,
Timestamp: 1,
SourceID: reqID,
},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
ResultChannelID: "0",
}
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: baseMsg,
SearchResults: searchResult,
}
return searchResultMsg
}
return nil
}
func createProducer(channels []string) *MemMsgStream {
InitMmq()
produceStream, err := NewMemMsgStream(context.Background(), 1024)
if err != nil {
log.Fatalf("new msgstream error = %v", err)
}
produceStream.AsProducer(channels)
produceStream.Start()
return produceStream
}
func createCondumers(channels []string) []*MemMsgStream {
consumerStreams := make([]*MemMsgStream, 0)
for _, channel := range channels {
consumeStream, err := NewMemMsgStream(context.Background(), 1024)
if err != nil {
log.Fatalf("new msgstream error = %v", err)
}
thisChannel := []string{channel}
consumeStream.AsConsumer(thisChannel, channel+"_consumer")
consumerStreams = append(consumerStreams, consumeStream)
}
return consumerStreams
}
func TestStream_GlobalMmq_Func(t *testing.T) {
channels := []string{"red", "blue", "black", "green"}
produceStream := createProducer(channels)
defer produceStream.Close()
consumerStreams := createCondumers(channels)
// validate channel and consumer count
assert.Equal(t, len(Mmq.consumers), len(channels), "global mmq channel error")
for _, consumers := range Mmq.consumers {
assert.Equal(t, len(consumers), 1, "global mmq consumer error")
}
// validate msg produce/consume
msg := msgstream.MsgPack{}
err := Mmq.Produce(channels[0], &msg)
if err != nil {
log.Fatalf("global mmq produce error = %v", err)
}
cm, _ := consumerStreams[0].Consume()
assert.Equal(t, cm, &msg, "global mmq consume error")
err = Mmq.Broadcast(&msg)
if err != nil {
log.Fatalf("global mmq broadcast error = %v", err)
}
for _, cs := range consumerStreams {
cm, _ := cs.Consume()
assert.Equal(t, cm, &msg, "global mmq consume error")
}
// validate consumer close
for _, cs := range consumerStreams {
cs.Close()
}
assert.Equal(t, len(Mmq.consumers), len(channels), "global mmq channel error")
for _, consumers := range Mmq.consumers {
assert.Equal(t, len(consumers), 0, "global mmq consumer error")
}
// validate channel destroy
for _, channel := range channels {
Mmq.DestroyChannel(channel)
}
assert.Equal(t, len(Mmq.consumers), 0, "global mmq channel error")
}
func TestStream_MemMsgStream_Produce(t *testing.T) {
channels := []string{"red", "blue", "black", "green"}
produceStream := createProducer(channels)
defer produceStream.Close()
consumerStreams := createCondumers(channels)
for _, cs := range consumerStreams {
defer cs.Close()
}
msgPack := msgstream.MsgPack{}
var hashValue uint32 = 2
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1, hashValue))
err := produceStream.Produce(context.Background(), &msgPack)
if err != nil {
log.Fatalf("new msgstream error = %v", err)
}
msg, _ := consumerStreams[hashValue].Consume()
if msg == nil {
log.Fatalf("msgstream consume error")
}
produceStream.Close()
}
func TestStream_MemMsgStream_BroadCast(t *testing.T) {
channels := []string{"red", "blue", "black", "green"}
produceStream := createProducer(channels)
defer produceStream.Close()
consumerStreams := createCondumers(channels)
for _, cs := range consumerStreams {
defer cs.Close()
}
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1, 100))
err := produceStream.Broadcast(context.Background(), &msgPack)
if err != nil {
log.Fatalf("new msgstream error = %v", err)
}
for _, consumer := range consumerStreams {
msg, _ := consumer.Consume()
if msg == nil {
log.Fatalf("msgstream consume error")
}
}
}

View File

@ -40,4 +40,5 @@ type Factory interface {
SetParams(params map[string]interface{}) error
NewMsgStream(ctx context.Context) (MsgStream, error)
NewTtMsgStream(ctx context.Context) (MsgStream, error)
NewQueryMsgStream(ctx context.Context) (MsgStream, error)
}

View File

@ -31,6 +31,10 @@ func (f *Factory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, erro
return newPulsarTtMsgStream(ctx, f.PulsarAddress, f.ReceiveBufSize, f.PulsarBufSize, f.dispatcherFactory.NewUnmarshalDispatcher())
}
func (f *Factory) NewQueryMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
return f.NewMsgStream(ctx)
}
func NewFactory() msgstream.Factory {
f := &Factory{
dispatcherFactory: msgstream.ProtoUDFactory{},

View File

@ -6,6 +6,7 @@ import (
"github.com/mitchellh/mapstructure"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/msgstream/memms"
"github.com/zilliztech/milvus-distributed/internal/util/rocksmq/server/rocksmq"
)
@ -32,6 +33,11 @@ func (f *Factory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, erro
return newRmqTtMsgStream(ctx, f.ReceiveBufSize, f.RmqBufSize, f.dispatcherFactory.NewUnmarshalDispatcher())
}
func (f *Factory) NewQueryMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
memms.InitMmq()
return memms.NewMemMsgStream(ctx, f.ReceiveBufSize)
}
func NewFactory() msgstream.Factory {
f := &Factory{
dispatcherFactory: msgstream.ProtoUDFactory{},

View File

@ -152,7 +152,7 @@ func (node *ProxyNode) Init() error {
return err
}
node.queryMsgStream, _ = node.msFactory.NewMsgStream(node.ctx)
node.queryMsgStream, _ = node.msFactory.NewQueryMsgStream(node.ctx)
node.queryMsgStream.AsProducer(Params.SearchChannelNames)
// FIXME(wxyu): use log.Debug instead
log.Debug("proxynode", zap.Strings("proxynode AsProducer:", Params.SearchChannelNames))

View File

@ -387,7 +387,7 @@ func (sched *TaskScheduler) queryLoop() {
func (sched *TaskScheduler) queryResultLoop() {
defer sched.wg.Done()
queryResultMsgStream, _ := sched.msFactory.NewMsgStream(sched.ctx)
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
queryResultMsgStream.AsConsumer(Params.SearchResultChannelNames, Params.ProxySubName)
log.Debug("proxynode", zap.Strings("search result channel names", Params.SearchResultChannelNames))
log.Debug("proxynode", zap.String("proxySubName", Params.ProxySubName))

View File

@ -41,8 +41,8 @@ type ResultEntityIds []UniqueID
func newSearchService(ctx context.Context, replica ReplicaInterface, factory msgstream.Factory) *searchService {
receiveBufSize := Params.SearchReceiveBufSize
searchStream, _ := factory.NewMsgStream(ctx)
searchResultStream, _ := factory.NewMsgStream(ctx)
searchStream, _ := factory.NewQueryMsgStream(ctx)
searchResultStream, _ := factory.NewQueryMsgStream(ctx)
// query node doesn't need to consumer any search or search result channel actively.
consumeChannels := Params.SearchChannelNames