mirror of https://github.com/milvus-io/milvus.git
Refactor insert channels used in Proxy
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/4973/head^2
parent
3a67dda06c
commit
996111bf8c
|
@ -414,13 +414,10 @@ func (t *ShowPartitionReqTask) IgnoreTimeStamp() bool {
|
|||
}
|
||||
|
||||
func (t *ShowPartitionReqTask) Execute() error {
|
||||
coll, err := t.core.MetaTable.GetCollectionByID(t.Req.CollectionID)
|
||||
coll, err := t.core.MetaTable.GetCollectionByName(t.Req.CollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if coll.Schema.Name != t.Req.CollectionName {
|
||||
return errors.Errorf("collection %s not exist", t.Req.CollectionName)
|
||||
}
|
||||
for _, partID := range coll.PartitionIDs {
|
||||
partMeta, err := t.core.MetaTable.GetPartitionByID(partID)
|
||||
if err != nil {
|
||||
|
|
|
@ -40,6 +40,7 @@ func (node *NodeImpl) CreateCollection(request *milvuspb.CreateCollectionRequest
|
|||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: request,
|
||||
masterClient: node.masterClient,
|
||||
dataServiceClient: node.dataServiceClient,
|
||||
}
|
||||
var cancel func()
|
||||
cct.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval)
|
||||
|
@ -79,6 +80,7 @@ func (node *NodeImpl) DropCollection(request *milvuspb.DropCollectionRequest) (*
|
|||
Condition: NewTaskCondition(ctx),
|
||||
DropCollectionRequest: request,
|
||||
masterClient: node.masterClient,
|
||||
dataServiceClient: node.dataServiceClient,
|
||||
}
|
||||
var cancel func()
|
||||
dct.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval)
|
||||
|
@ -569,8 +571,9 @@ func (node *NodeImpl) Insert(request *milvuspb.InsertRequest) (*milvuspb.InsertR
|
|||
span.SetTag("partition tag", request.PartitionName)
|
||||
log.Println("insert into: ", request.CollectionName)
|
||||
it := &InsertTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
dataServiceClient: node.dataServiceClient,
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: request.HashKeys,
|
||||
|
@ -585,8 +588,7 @@ func (node *NodeImpl) Insert(request *milvuspb.InsertRequest) (*milvuspb.InsertR
|
|||
RowData: request.RowData,
|
||||
},
|
||||
},
|
||||
manipulationMsgStream: node.manipulationMsgStream,
|
||||
rowIDAllocator: node.idAllocator,
|
||||
rowIDAllocator: node.idAllocator,
|
||||
}
|
||||
if len(it.PartitionName) <= 0 {
|
||||
it.PartitionName = Params.DefaultPartitionTag
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
package proxynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
)
|
||||
|
||||
func SliceContain(s interface{}, item interface{}) bool {
|
||||
ss := reflect.ValueOf(s)
|
||||
if ss.Kind() != reflect.Slice {
|
||||
panic("SliceContain expect a slice")
|
||||
}
|
||||
|
||||
for i := 0; i < ss.Len(); i++ {
|
||||
if ss.Index(i).Interface() == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func SliceSetEqual(s1 interface{}, s2 interface{}) bool {
|
||||
ss1 := reflect.ValueOf(s1)
|
||||
ss2 := reflect.ValueOf(s2)
|
||||
if ss1.Kind() != reflect.Slice {
|
||||
panic("expect a slice")
|
||||
}
|
||||
if ss2.Kind() != reflect.Slice {
|
||||
panic("expect a slice")
|
||||
}
|
||||
if ss1.Len() != ss2.Len() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < ss1.Len(); i++ {
|
||||
if !SliceContain(s2, ss1.Index(i).Interface()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func SortedSliceEqual(s1 interface{}, s2 interface{}) bool {
|
||||
ss1 := reflect.ValueOf(s1)
|
||||
ss2 := reflect.ValueOf(s2)
|
||||
if ss1.Kind() != reflect.Slice {
|
||||
panic("expect a slice")
|
||||
}
|
||||
if ss2.Kind() != reflect.Slice {
|
||||
panic("expect a slice")
|
||||
}
|
||||
if ss1.Len() != ss2.Len() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < ss1.Len(); i++ {
|
||||
if ss2.Index(i).Interface() != ss1.Index(i).Interface() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type InsertChannelsMap struct {
|
||||
collectionID2InsertChannels map[UniqueID]int // the value of map is the location of insertChannels & insertMsgStreams
|
||||
insertChannels [][]string // it's a little confusing to use []string as the key of map
|
||||
insertMsgStreams []msgstream.MsgStream // maybe there's a better way to implement Set, just agilely now
|
||||
droppedBitMap []int // 0 -> normal, 1 -> dropped
|
||||
mtx sync.RWMutex
|
||||
nodeInstance *NodeImpl
|
||||
}
|
||||
|
||||
func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []string) error {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
_, ok := m.collectionID2InsertChannels[collID]
|
||||
if ok {
|
||||
return errors.New("impossible and forbidden to create message stream twice")
|
||||
}
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i] <= channels[j]
|
||||
})
|
||||
for loc, existedChannels := range m.insertChannels {
|
||||
if m.droppedBitMap[loc] == 0 && SortedSliceEqual(existedChannels, channels) {
|
||||
m.collectionID2InsertChannels[collID] = loc
|
||||
return nil
|
||||
}
|
||||
}
|
||||
m.insertChannels = append(m.insertChannels, channels)
|
||||
m.collectionID2InsertChannels[collID] = len(m.insertChannels) - 1
|
||||
stream := pulsarms.NewPulsarMsgStream(context.Background(), Params.MsgStreamInsertBufSize)
|
||||
stream.SetPulsarClient(Params.PulsarAddress)
|
||||
stream.CreatePulsarProducers(channels)
|
||||
repack := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
return insertRepackFunc(tsMsgs, hashKeys, m.nodeInstance.segAssigner, true)
|
||||
}
|
||||
stream.SetRepackFunc(repack)
|
||||
stream.Start()
|
||||
m.insertMsgStreams = append(m.insertMsgStreams, stream)
|
||||
m.droppedBitMap = append(m.droppedBitMap, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InsertChannelsMap) closeInsertMsgStream(collID UniqueID) error {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
loc, ok := m.collectionID2InsertChannels[collID]
|
||||
if !ok {
|
||||
return errors.New("cannot find collection with id: " + strconv.Itoa(int(collID)))
|
||||
}
|
||||
if m.droppedBitMap[loc] != 0 {
|
||||
return errors.New("insert message stream already closed")
|
||||
}
|
||||
m.insertMsgStreams[loc].Close()
|
||||
log.Print("close insert message stream ...")
|
||||
|
||||
m.droppedBitMap[loc] = 1
|
||||
delete(m.collectionID2InsertChannels, collID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InsertChannelsMap) getInsertMsgStream(collID UniqueID) (msgstream.MsgStream, error) {
|
||||
m.mtx.RLock()
|
||||
defer m.mtx.RUnlock()
|
||||
|
||||
loc, ok := m.collectionID2InsertChannels[collID]
|
||||
if !ok {
|
||||
return nil, errors.New("cannot find collection with id: " + strconv.Itoa(int(collID)))
|
||||
}
|
||||
|
||||
if m.droppedBitMap[loc] != 0 {
|
||||
return nil, errors.New("insert message stream already closed")
|
||||
}
|
||||
|
||||
return m.insertMsgStreams[loc], nil
|
||||
}
|
||||
|
||||
func newInsertChannelsMap(node *NodeImpl) *InsertChannelsMap {
|
||||
return &InsertChannelsMap{
|
||||
collectionID2InsertChannels: make(map[UniqueID]int),
|
||||
insertChannels: make([][]string, 0),
|
||||
insertMsgStreams: make([]msgstream.MsgStream, 0),
|
||||
nodeInstance: node,
|
||||
}
|
||||
}
|
||||
|
||||
var globalInsertChannelsMap *InsertChannelsMap
|
||||
|
||||
func initGlobalInsertChannelsMap(node *NodeImpl) {
|
||||
globalInsertChannelsMap = newInsertChannelsMap(node)
|
||||
}
|
|
@ -244,6 +244,9 @@ func (node *NodeImpl) Init() error {
|
|||
}
|
||||
|
||||
func (node *NodeImpl) Start() error {
|
||||
initGlobalInsertChannelsMap(node)
|
||||
log.Println("init global insert channels map ...")
|
||||
|
||||
initGlobalMetaCache(node.ctx, node)
|
||||
log.Println("init global meta cache ...")
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"math"
|
||||
"strconv"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
oplog "github.com/opentracing/opentracing-go/log"
|
||||
|
||||
|
@ -41,10 +43,10 @@ type BaseInsertTask = msgstream.InsertMsg
|
|||
type InsertTask struct {
|
||||
BaseInsertTask
|
||||
Condition
|
||||
result *milvuspb.InsertResponse
|
||||
manipulationMsgStream *pulsarms.PulsarMsgStream
|
||||
ctx context.Context
|
||||
rowIDAllocator *allocator.IDAllocator
|
||||
dataServiceClient DataServiceClient
|
||||
result *milvuspb.InsertResponse
|
||||
ctx context.Context
|
||||
rowIDAllocator *allocator.IDAllocator
|
||||
}
|
||||
|
||||
func (it *InsertTask) OnEnqueue() error {
|
||||
|
@ -161,8 +163,6 @@ func (it *InsertTask) Execute() error {
|
|||
}
|
||||
tsMsg.SetMsgContext(ctx)
|
||||
span.LogFields(oplog.String("send msg", "send msg"))
|
||||
msgPack.Msgs[0] = tsMsg
|
||||
err = it.manipulationMsgStream.Produce(msgPack)
|
||||
|
||||
it.result = &milvuspb.InsertResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -171,11 +171,45 @@ func (it *InsertTask) Execute() error {
|
|||
RowIDBegin: rowIDBegin,
|
||||
RowIDEnd: rowIDEnd,
|
||||
}
|
||||
|
||||
msgPack.Msgs[0] = tsMsg
|
||||
|
||||
stream, err := globalInsertChannelsMap.getInsertMsgStream(description.CollectionID)
|
||||
if err != nil {
|
||||
collectionInsertChannels, err := it.dataServiceClient.GetInsertChannels(&datapb.InsertChannelRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_kInsert, // todo
|
||||
MsgID: it.Base.MsgID, // todo
|
||||
Timestamp: 0, // todo
|
||||
SourceID: Params.ProxyID,
|
||||
},
|
||||
DbID: 0, // todo
|
||||
CollectionID: description.CollectionID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = globalInsertChannelsMap.createInsertMsgStream(description.CollectionID, collectionInsertChannels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
stream, err = globalInsertChannelsMap.getInsertMsgStream(description.CollectionID)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
|
||||
it.result.Status.Reason = err.Error()
|
||||
span.LogFields(oplog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
err = stream.Produce(msgPack)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
|
||||
it.result.Status.Reason = err.Error()
|
||||
span.LogFields(oplog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -188,10 +222,11 @@ func (it *InsertTask) PostExecute() error {
|
|||
type CreateCollectionTask struct {
|
||||
Condition
|
||||
*milvuspb.CreateCollectionRequest
|
||||
masterClient MasterClient
|
||||
result *commonpb.Status
|
||||
ctx context.Context
|
||||
schema *schemapb.CollectionSchema
|
||||
masterClient MasterClient
|
||||
dataServiceClient DataServiceClient
|
||||
result *commonpb.Status
|
||||
ctx context.Context
|
||||
schema *schemapb.CollectionSchema
|
||||
}
|
||||
|
||||
func (cct *CreateCollectionTask) OnEnqueue() error {
|
||||
|
@ -293,7 +328,37 @@ func (cct *CreateCollectionTask) PreExecute() error {
|
|||
func (cct *CreateCollectionTask) Execute() error {
|
||||
var err error
|
||||
cct.result, err = cct.masterClient.CreateCollection(cct.CreateCollectionRequest)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cct.result.ErrorCode == commonpb.ErrorCode_SUCCESS {
|
||||
err = globalMetaCache.Sync(cct.CollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
desc, err := globalMetaCache.Get(cct.CollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
collectionInsertChannels, err := cct.dataServiceClient.GetInsertChannels(&datapb.InsertChannelRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_kInsert, // todo
|
||||
MsgID: cct.Base.MsgID, // todo
|
||||
Timestamp: 0, // todo
|
||||
SourceID: Params.ProxyID,
|
||||
},
|
||||
DbID: 0, // todo
|
||||
CollectionID: desc.CollectionID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = globalInsertChannelsMap.createInsertMsgStream(desc.CollectionID, collectionInsertChannels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cct *CreateCollectionTask) PostExecute() error {
|
||||
|
@ -303,9 +368,10 @@ func (cct *CreateCollectionTask) PostExecute() error {
|
|||
type DropCollectionTask struct {
|
||||
Condition
|
||||
*milvuspb.DropCollectionRequest
|
||||
masterClient MasterClient
|
||||
result *commonpb.Status
|
||||
ctx context.Context
|
||||
masterClient MasterClient
|
||||
dataServiceClient DataServiceClient
|
||||
result *commonpb.Status
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (dct *DropCollectionTask) OnEnqueue() error {
|
||||
|
@ -350,6 +416,11 @@ func (dct *DropCollectionTask) PreExecute() error {
|
|||
func (dct *DropCollectionTask) Execute() error {
|
||||
var err error
|
||||
dct.result, err = dct.masterClient.DropCollection(dct.DropCollectionRequest)
|
||||
if dct.result.ErrorCode == commonpb.ErrorCode_SUCCESS {
|
||||
_ = globalMetaCache.Sync(dct.CollectionName)
|
||||
desc, _ := globalMetaCache.Get(dct.CollectionName)
|
||||
_ = globalInsertChannelsMap.closeInsertMsgStream(desc.CollectionID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue