mirror of https://github.com/milvus-io/milvus.git
Refactor repack logic for insertion (#5399)
Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>pull/5417/head
parent
1c49ddc86a
commit
6766169878
1
go.mod
1
go.mod
|
@ -26,6 +26,7 @@ require (
|
|||
github.com/pierrec/lz4 v2.5.2+incompatible // indirect
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/prometheus/client_golang v1.7.1
|
||||
github.com/quasilyte/go-ruleguard v0.2.1 // indirect
|
||||
github.com/sirupsen/logrus v1.6.0 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0
|
||||
github.com/spf13/cast v1.3.0
|
||||
|
|
6
go.sum
6
go.sum
|
@ -320,6 +320,8 @@ github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFB
|
|||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||
github.com/protocolbuffers/protobuf v3.17.0+incompatible h1:MYhKKlaNOl8FB3F4u6oM2AlpcyLtT+p8Ec1w/9YeHss=
|
||||
github.com/quasilyte/go-ruleguard v0.2.1 h1:56eRm0daAyny9UhJnmtJW/UyLZQusukBAB8oT8AHKHo=
|
||||
github.com/quasilyte/go-ruleguard v0.2.1/go.mod h1:hN2rVc/uS4bQhQKTio2XaSJSafJwqBUWWwtssT3cQmc=
|
||||
github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
|
@ -380,6 +382,7 @@ github.com/yahoo/athenz v1.8.55/go.mod h1:G7LLFUH7Z/r4QAB7FfudfuA7Am/eCzO1GlzBhD
|
|||
github.com/yahoo/athenz v1.9.16 h1:2s8KtIxwAbcJIYySsfrT/t/WO0Ss5O7BPGUN/q8x2bg=
|
||||
github.com/yahoo/athenz v1.9.16/go.mod h1:guj+0Ut6F33wj+OcSRlw69O0itsR7tVocv15F2wJnIo=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
|
||||
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
|
@ -453,6 +456,7 @@ golang.org/x/net v0.0.0-20190921015927-1a5e07d1ff72/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb h1:eBmm0M9fYhWpKZLjQUUKka/LtIxf46G4fxeEz5KJr9U=
|
||||
|
@ -468,6 +472,7 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ
|
|||
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
@ -531,6 +536,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
|||
golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20200812195022-5ae4c3c160a0/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a h1:CB3a9Nez8M13wwlr/E2YtwoU+qYHKfC+JrDa45RXXoQ=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
|
|
@ -67,6 +67,10 @@ func (mms *MemMsgStream) SetRepackFunc(repackFunc RepackFunc) {
|
|||
mms.repackFunc = repackFunc
|
||||
}
|
||||
|
||||
func (mms *MemMsgStream) GetProduceChannels() []string {
|
||||
return mms.producers
|
||||
}
|
||||
|
||||
func (mms *MemMsgStream) AsProducer(channels []string) {
|
||||
for _, channel := range channels {
|
||||
err := Mmq.CreateChannel(channel)
|
||||
|
|
|
@ -170,6 +170,7 @@ func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 {
|
|||
}
|
||||
reBucketValues := make([][]int32, len(tsMsgs))
|
||||
channelNum := uint32(len(ms.producerChannels))
|
||||
|
||||
if channelNum == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -184,6 +185,10 @@ func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 {
|
|||
return reBucketValues
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) GetProduceChannels() []string {
|
||||
return ms.producerChannels
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) Produce(msgPack *MsgPack) error {
|
||||
tsMsgs := msgPack.Msgs
|
||||
if len(tsMsgs) <= 0 {
|
||||
|
|
|
@ -41,6 +41,7 @@ type MsgStream interface {
|
|||
AsConsumer(channels []string, subName string)
|
||||
SetRepackFunc(repackFunc RepackFunc)
|
||||
ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32
|
||||
GetProduceChannels() []string
|
||||
Produce(*MsgPack) error
|
||||
Broadcast(*MsgPack) error
|
||||
Consume() *MsgPack
|
||||
|
|
|
@ -76,6 +76,10 @@ func (ms *SimpleMsgStream) Broadcast(pack *MsgPack) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ms *SimpleMsgStream) GetProduceChannels() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *SimpleMsgStream) Consume() *MsgPack {
|
||||
if ms.getMsgCount() <= 0 {
|
||||
return nil
|
||||
|
|
|
@ -1058,6 +1058,7 @@ func (node *ProxyNode) Insert(ctx context.Context, request *milvuspb.InsertReque
|
|||
},
|
||||
},
|
||||
rowIDAllocator: node.idAllocator,
|
||||
segIDAssigner: node.segAssigner,
|
||||
}
|
||||
if len(it.PartitionName) <= 0 {
|
||||
it.PartitionName = Params.DefaultPartitionName
|
||||
|
|
|
@ -61,10 +61,7 @@ func (m *insertChannelsMap) CreateInsertMsgStream(collID UniqueID, channels []st
|
|||
stream, _ := m.msFactory.NewMsgStream(context.Background())
|
||||
stream.AsProducer(channels)
|
||||
log.Debug("proxynode", zap.Strings("proxynode AsProducer: ", channels))
|
||||
repack := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
return insertRepackFunc(tsMsgs, hashKeys, m.nodeInstance.segAssigner, true)
|
||||
}
|
||||
stream.SetRepackFunc(repack)
|
||||
stream.SetRepackFunc(insertRepackFunc)
|
||||
stream.Start()
|
||||
m.insertMsgStreams = append(m.insertMsgStreams, stream)
|
||||
m.droppedBitMap = append(m.droppedBitMap, 0)
|
||||
|
|
|
@ -12,290 +12,23 @@
|
|||
package proxynode
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"unsafe"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
||||
hashKeys [][]int32,
|
||||
segIDAssigner *SegIDAssigner,
|
||||
together bool) (map[int32]*msgstream.MsgPack, error) {
|
||||
hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
|
||||
result := make(map[int32]*msgstream.MsgPack)
|
||||
|
||||
channelCountMap := make(map[UniqueID]map[int32]uint32) // reqID --> channelID to count
|
||||
channelMaxTSMap := make(map[UniqueID]map[int32]Timestamp) // reqID --> channelID to max Timestamp
|
||||
reqSchemaMap := make(map[UniqueID][]UniqueID) // reqID --> channelID [2]UniqueID {CollectionID, PartitionID}
|
||||
channelNamesMap := make(map[UniqueID][]string) // collectionID --> channelNames
|
||||
|
||||
for i, request := range tsMsgs {
|
||||
if request.Type() != commonpb.MsgType_Insert {
|
||||
return nil, errors.New("msg's must be Insert")
|
||||
}
|
||||
insertRequest, ok := request.(*msgstream.InsertMsg)
|
||||
if !ok {
|
||||
return nil, errors.New("msg's must be Insert")
|
||||
}
|
||||
|
||||
keys := hashKeys[i]
|
||||
timestampLen := len(insertRequest.Timestamps)
|
||||
rowIDLen := len(insertRequest.RowIDs)
|
||||
rowDataLen := len(insertRequest.RowData)
|
||||
keysLen := len(keys)
|
||||
|
||||
if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen {
|
||||
return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal")
|
||||
}
|
||||
|
||||
reqID := insertRequest.Base.MsgID
|
||||
if _, ok := channelCountMap[reqID]; !ok {
|
||||
channelCountMap[reqID] = make(map[int32]uint32)
|
||||
}
|
||||
|
||||
if _, ok := channelMaxTSMap[reqID]; !ok {
|
||||
channelMaxTSMap[reqID] = make(map[int32]Timestamp)
|
||||
}
|
||||
|
||||
if _, ok := reqSchemaMap[reqID]; !ok {
|
||||
reqSchemaMap[reqID] = []UniqueID{insertRequest.CollectionID, insertRequest.PartitionID}
|
||||
}
|
||||
|
||||
for idx, channelID := range keys {
|
||||
channelCountMap[reqID][channelID]++
|
||||
if _, ok := channelMaxTSMap[reqID][channelID]; !ok {
|
||||
channelMaxTSMap[reqID][channelID] = typeutil.ZeroTimestamp
|
||||
}
|
||||
ts := insertRequest.Timestamps[idx]
|
||||
if channelMaxTSMap[reqID][channelID] < ts {
|
||||
channelMaxTSMap[reqID][channelID] = ts
|
||||
}
|
||||
}
|
||||
|
||||
collID := insertRequest.CollectionID
|
||||
if _, ok := channelNamesMap[collID]; !ok {
|
||||
channelNames, err := globalInsertChannelsMap.GetInsertChannels(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channelNamesMap[collID] = channelNames
|
||||
}
|
||||
}
|
||||
|
||||
var getChannelName = func(collID UniqueID, channelID int32) string {
|
||||
if _, ok := channelNamesMap[collID]; !ok {
|
||||
return ""
|
||||
}
|
||||
names := channelNamesMap[collID]
|
||||
return names[channelID]
|
||||
}
|
||||
|
||||
reqSegCountMap := make(map[UniqueID]map[int32]map[UniqueID]uint32)
|
||||
|
||||
for reqID, countInfo := range channelCountMap {
|
||||
if _, ok := reqSegCountMap[reqID]; !ok {
|
||||
reqSegCountMap[reqID] = make(map[int32]map[UniqueID]uint32)
|
||||
}
|
||||
schema := reqSchemaMap[reqID]
|
||||
collID, partitionID := schema[0], schema[1]
|
||||
for channelID, count := range countInfo {
|
||||
ts, ok := channelMaxTSMap[reqID][channelID]
|
||||
if !ok {
|
||||
ts = typeutil.ZeroTimestamp
|
||||
log.Debug("Warning: did not get max Timstamp!")
|
||||
}
|
||||
channelName := getChannelName(collID, channelID)
|
||||
if channelName == "" {
|
||||
return nil, errors.New("ProxyNode, repack_func, can not found channelName")
|
||||
}
|
||||
mapInfo, err := segIDAssigner.GetSegmentID(collID, partitionID, channelName, count, ts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqSegCountMap[reqID][channelID] = make(map[UniqueID]uint32)
|
||||
reqSegCountMap[reqID][channelID] = mapInfo
|
||||
log.Debug("proxynode", zap.Int64("repackFunc, reqSegCountMap, reqID", reqID), zap.Any("mapinfo", mapInfo))
|
||||
}
|
||||
}
|
||||
|
||||
reqSegAccumulateCountMap := make(map[UniqueID]map[int32][]uint32)
|
||||
reqSegIDMap := make(map[UniqueID]map[int32][]UniqueID)
|
||||
reqSegAllocateCounter := make(map[UniqueID]map[int32]uint32)
|
||||
|
||||
for reqID, channelInfo := range reqSegCountMap {
|
||||
if _, ok := reqSegAccumulateCountMap[reqID]; !ok {
|
||||
reqSegAccumulateCountMap[reqID] = make(map[int32][]uint32)
|
||||
}
|
||||
if _, ok := reqSegIDMap[reqID]; !ok {
|
||||
reqSegIDMap[reqID] = make(map[int32][]UniqueID)
|
||||
}
|
||||
if _, ok := reqSegAllocateCounter[reqID]; !ok {
|
||||
reqSegAllocateCounter[reqID] = make(map[int32]uint32)
|
||||
}
|
||||
for channelID, segInfo := range channelInfo {
|
||||
reqSegAllocateCounter[reqID][channelID] = 0
|
||||
keys := make([]UniqueID, len(segInfo))
|
||||
i := 0
|
||||
for key := range segInfo {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
|
||||
accumulate := uint32(0)
|
||||
for _, key := range keys {
|
||||
accumulate += segInfo[key]
|
||||
if _, ok := reqSegAccumulateCountMap[reqID][channelID]; !ok {
|
||||
reqSegAccumulateCountMap[reqID][channelID] = make([]uint32, 0)
|
||||
}
|
||||
reqSegAccumulateCountMap[reqID][channelID] = append(
|
||||
reqSegAccumulateCountMap[reqID][channelID],
|
||||
accumulate,
|
||||
)
|
||||
if _, ok := reqSegIDMap[reqID][channelID]; !ok {
|
||||
reqSegIDMap[reqID][channelID] = make([]UniqueID, 0)
|
||||
}
|
||||
reqSegIDMap[reqID][channelID] = append(
|
||||
reqSegIDMap[reqID][channelID],
|
||||
key,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var getSegmentID = func(reqID UniqueID, channelID int32) UniqueID {
|
||||
reqSegAllocateCounter[reqID][channelID]++
|
||||
cur := reqSegAllocateCounter[reqID][channelID]
|
||||
accumulateSlice := reqSegAccumulateCountMap[reqID][channelID]
|
||||
segIDSlice := reqSegIDMap[reqID][channelID]
|
||||
for index, count := range accumulateSlice {
|
||||
if cur <= count {
|
||||
return segIDSlice[index]
|
||||
}
|
||||
}
|
||||
log.Warn("Can't Found SegmentID")
|
||||
return 0
|
||||
}
|
||||
|
||||
factor := 10
|
||||
threshold := Params.PulsarMaxMessageSize / factor
|
||||
log.Debug("proxynode", zap.Int("threshold of message size: ", threshold))
|
||||
// not accurate
|
||||
getSizeOfInsertMsg := func(msg *msgstream.InsertMsg) int {
|
||||
// if real struct, call unsafe.Sizeof directly,
|
||||
// if reference, dereference and then call unsafe.Sizeof,
|
||||
// if slice, todo: a common function to calculate size of slice,
|
||||
// if map, a little complicated
|
||||
size := 0
|
||||
size += int(unsafe.Sizeof(msg.Ctx))
|
||||
size += int(unsafe.Sizeof(msg.BeginTimestamp))
|
||||
size += int(unsafe.Sizeof(msg.EndTimestamp))
|
||||
size += int(unsafe.Sizeof(msg.HashValues))
|
||||
size += len(msg.HashValues) * 4
|
||||
size += int(unsafe.Sizeof(*msg.MsgPosition))
|
||||
size += int(unsafe.Sizeof(*msg.Base))
|
||||
size += int(unsafe.Sizeof(msg.DbName))
|
||||
size += int(unsafe.Sizeof(msg.CollectionName))
|
||||
size += int(unsafe.Sizeof(msg.PartitionName))
|
||||
size += int(unsafe.Sizeof(msg.DbID))
|
||||
size += int(unsafe.Sizeof(msg.CollectionID))
|
||||
size += int(unsafe.Sizeof(msg.PartitionID))
|
||||
size += int(unsafe.Sizeof(msg.SegmentID))
|
||||
size += int(unsafe.Sizeof(msg.ChannelID))
|
||||
size += int(unsafe.Sizeof(msg.Timestamps))
|
||||
size += int(unsafe.Sizeof(msg.RowIDs))
|
||||
size += len(msg.RowIDs) * 8
|
||||
for _, blob := range msg.RowData {
|
||||
size += int(unsafe.Sizeof(blob.Value))
|
||||
size += len(blob.Value)
|
||||
}
|
||||
|
||||
//log.Debug("proxynode", zap.Int("insert message size", size))
|
||||
return size
|
||||
}
|
||||
// not accurate
|
||||
// getSizeOfMsgPack := func(mp *msgstream.MsgPack) int {
|
||||
// size := 0
|
||||
// for _, msg := range mp.Msgs {
|
||||
// insertMsg, ok := msg.(*msgstream.InsertMsg)
|
||||
// if !ok {
|
||||
// log.Panic("only insert message is supported!")
|
||||
// }
|
||||
// size += getSizeOfInsertMsg(insertMsg)
|
||||
// }
|
||||
// return size
|
||||
// }
|
||||
|
||||
for i, request := range tsMsgs {
|
||||
insertRequest := request.(*msgstream.InsertMsg)
|
||||
keys := hashKeys[i]
|
||||
reqID := insertRequest.Base.MsgID
|
||||
collectionName := insertRequest.CollectionName
|
||||
collectionID := insertRequest.CollectionID
|
||||
partitionID := insertRequest.PartitionID
|
||||
partitionName := insertRequest.PartitionName
|
||||
proxyID := insertRequest.Base.SourceID
|
||||
channelNames := channelNamesMap[collectionID]
|
||||
for index, key := range keys {
|
||||
ts := insertRequest.Timestamps[index]
|
||||
rowID := insertRequest.RowIDs[index]
|
||||
row := insertRequest.RowData[index]
|
||||
if len(keys) > 0 {
|
||||
key := keys[0]
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
msgPack := msgstream.MsgPack{}
|
||||
result[key] = &msgPack
|
||||
}
|
||||
segmentID := getSegmentID(reqID, key)
|
||||
channelID := channelNames[int(key)%len(channelNames)]
|
||||
sliceRequest := internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: reqID,
|
||||
Timestamp: ts,
|
||||
SourceID: proxyID,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
SegmentID: segmentID,
|
||||
// todo rename to ChannelName
|
||||
// ChannelID: strconv.FormatInt(int64(key), 10),
|
||||
ChannelID: channelID,
|
||||
Timestamps: []uint64{ts},
|
||||
RowIDs: []int64{rowID},
|
||||
RowData: []*commonpb.Blob{row},
|
||||
}
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: request.TraceCtx(),
|
||||
},
|
||||
InsertRequest: sliceRequest,
|
||||
}
|
||||
if together { // all rows with same hash value are accumulated to only one message
|
||||
msgNums := len(result[key].Msgs)
|
||||
if len(result[key].Msgs) <= 0 {
|
||||
result[key].Msgs = append(result[key].Msgs, insertMsg)
|
||||
} else if getSizeOfInsertMsg(result[key].Msgs[msgNums-1].(*msgstream.InsertMsg)) >= threshold {
|
||||
result[key].Msgs = append(result[key].Msgs, insertMsg)
|
||||
} else {
|
||||
accMsgs, _ := result[key].Msgs[msgNums-1].(*msgstream.InsertMsg)
|
||||
accMsgs.Timestamps = append(accMsgs.Timestamps, ts)
|
||||
accMsgs.RowIDs = append(accMsgs.RowIDs, rowID)
|
||||
accMsgs.RowData = append(accMsgs.RowData, row)
|
||||
}
|
||||
} else { // every row is a message
|
||||
result[key].Msgs = append(result[key].Msgs, insertMsg)
|
||||
result[key] = &msgstream.MsgPack{}
|
||||
}
|
||||
result[key].Msgs = append(result[key].Msgs, request)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
@ -18,8 +18,10 @@ import (
|
|||
"math"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
|
||||
|
@ -99,6 +101,7 @@ type InsertTask struct {
|
|||
dataService types.DataService
|
||||
result *milvuspb.InsertResponse
|
||||
rowIDAllocator *allocator.IDAllocator
|
||||
segIDAssigner *SegIDAssigner
|
||||
}
|
||||
|
||||
func (it *InsertTask) TraceCtx() context.Context {
|
||||
|
@ -160,6 +163,211 @@ func (it *InsertTask) PreExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (it *InsertTask) _assignSegmentID(stream msgstream.MsgStream, pack *msgstream.MsgPack) (*msgstream.MsgPack, error) {
|
||||
newPack := &msgstream.MsgPack{
|
||||
BeginTs: pack.BeginTs,
|
||||
EndTs: pack.EndTs,
|
||||
StartPositions: pack.StartPositions,
|
||||
EndPositions: pack.EndPositions,
|
||||
Msgs: nil,
|
||||
}
|
||||
tsMsgs := pack.Msgs
|
||||
hashKeys := stream.ComputeProduceChannelIndexes(tsMsgs)
|
||||
reqID := it.Base.MsgID
|
||||
channelCountMap := make(map[int32]uint32) // channelID to count
|
||||
channelMaxTSMap := make(map[int32]Timestamp) // channelID to max Timestamp
|
||||
channelNames := stream.GetProduceChannels()
|
||||
log.Debug("_assignSemgentID, produceChannels:", zap.Any("Channels", channelNames))
|
||||
|
||||
for i, request := range tsMsgs {
|
||||
if request.Type() != commonpb.MsgType_Insert {
|
||||
return nil, fmt.Errorf("msg's must be Insert")
|
||||
}
|
||||
insertRequest, ok := request.(*msgstream.InsertMsg)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("msg's must be Insert")
|
||||
}
|
||||
|
||||
keys := hashKeys[i]
|
||||
timestampLen := len(insertRequest.Timestamps)
|
||||
rowIDLen := len(insertRequest.RowIDs)
|
||||
rowDataLen := len(insertRequest.RowData)
|
||||
keysLen := len(keys)
|
||||
|
||||
if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen {
|
||||
return nil, fmt.Errorf("the length of hashValue, timestamps, rowIDs, RowData are not equal")
|
||||
}
|
||||
|
||||
for idx, channelID := range keys {
|
||||
channelCountMap[channelID]++
|
||||
if _, ok := channelMaxTSMap[channelID]; !ok {
|
||||
channelMaxTSMap[channelID] = typeutil.ZeroTimestamp
|
||||
}
|
||||
ts := insertRequest.Timestamps[idx]
|
||||
if channelMaxTSMap[channelID] < ts {
|
||||
channelMaxTSMap[channelID] = ts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reqSegCountMap := make(map[int32]map[UniqueID]uint32)
|
||||
|
||||
for channelID, count := range channelCountMap {
|
||||
ts, ok := channelMaxTSMap[channelID]
|
||||
if !ok {
|
||||
ts = typeutil.ZeroTimestamp
|
||||
log.Debug("Warning: did not get max Timestamp!")
|
||||
}
|
||||
channelName := channelNames[channelID]
|
||||
if channelName == "" {
|
||||
return nil, fmt.Errorf("ProxyNode, repack_func, can not found channelName")
|
||||
}
|
||||
mapInfo, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, count, ts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqSegCountMap[channelID] = make(map[UniqueID]uint32)
|
||||
reqSegCountMap[channelID] = mapInfo
|
||||
log.Debug("ProxyNode", zap.Int64("repackFunc, reqSegCountMap, reqID", reqID), zap.Any("mapinfo", mapInfo))
|
||||
}
|
||||
|
||||
reqSegAccumulateCountMap := make(map[int32][]uint32)
|
||||
reqSegIDMap := make(map[int32][]UniqueID)
|
||||
reqSegAllocateCounter := make(map[int32]uint32)
|
||||
|
||||
for channelID, segInfo := range reqSegCountMap {
|
||||
reqSegAllocateCounter[channelID] = 0
|
||||
keys := make([]UniqueID, len(segInfo))
|
||||
i := 0
|
||||
for key := range segInfo {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
|
||||
accumulate := uint32(0)
|
||||
for _, key := range keys {
|
||||
accumulate += segInfo[key]
|
||||
if _, ok := reqSegAccumulateCountMap[channelID]; !ok {
|
||||
reqSegAccumulateCountMap[channelID] = make([]uint32, 0)
|
||||
}
|
||||
reqSegAccumulateCountMap[channelID] = append(
|
||||
reqSegAccumulateCountMap[channelID],
|
||||
accumulate,
|
||||
)
|
||||
if _, ok := reqSegIDMap[channelID]; !ok {
|
||||
reqSegIDMap[channelID] = make([]UniqueID, 0)
|
||||
}
|
||||
reqSegIDMap[channelID] = append(
|
||||
reqSegIDMap[channelID],
|
||||
key,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
var getSegmentID = func(channelID int32) UniqueID {
|
||||
reqSegAllocateCounter[channelID]++
|
||||
cur := reqSegAllocateCounter[channelID]
|
||||
accumulateSlice := reqSegAccumulateCountMap[channelID]
|
||||
segIDSlice := reqSegIDMap[channelID]
|
||||
for index, count := range accumulateSlice {
|
||||
if cur <= count {
|
||||
return segIDSlice[index]
|
||||
}
|
||||
}
|
||||
log.Warn("Can't Found SegmentID")
|
||||
return 0
|
||||
}
|
||||
|
||||
factor := 10
|
||||
threshold := Params.PulsarMaxMessageSize / factor
|
||||
log.Debug("ProxyNode", zap.Int("threshold of message size: ", threshold))
|
||||
// not accurate
|
||||
getFixedSizeOfInsertMsg := func(msg *msgstream.InsertMsg) int {
|
||||
size := 0
|
||||
|
||||
size += int(unsafe.Sizeof(*msg.Base))
|
||||
size += int(unsafe.Sizeof(msg.DbName))
|
||||
size += int(unsafe.Sizeof(msg.CollectionName))
|
||||
size += int(unsafe.Sizeof(msg.PartitionName))
|
||||
size += int(unsafe.Sizeof(msg.DbID))
|
||||
size += int(unsafe.Sizeof(msg.CollectionID))
|
||||
size += int(unsafe.Sizeof(msg.PartitionID))
|
||||
size += int(unsafe.Sizeof(msg.SegmentID))
|
||||
size += int(unsafe.Sizeof(msg.ChannelID))
|
||||
size += int(unsafe.Sizeof(msg.Timestamps))
|
||||
size += int(unsafe.Sizeof(msg.RowIDs))
|
||||
return size
|
||||
}
|
||||
|
||||
result := make(map[int32]msgstream.TsMsg)
|
||||
curMsgSizeMap := make(map[int32]int)
|
||||
|
||||
for i, request := range tsMsgs {
|
||||
insertRequest := request.(*msgstream.InsertMsg)
|
||||
keys := hashKeys[i]
|
||||
collectionName := insertRequest.CollectionName
|
||||
collectionID := insertRequest.CollectionID
|
||||
partitionID := insertRequest.PartitionID
|
||||
partitionName := insertRequest.PartitionName
|
||||
proxyID := insertRequest.Base.SourceID
|
||||
for index, key := range keys {
|
||||
ts := insertRequest.Timestamps[index]
|
||||
rowID := insertRequest.RowIDs[index]
|
||||
row := insertRequest.RowData[index]
|
||||
segmentID := getSegmentID(key)
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
sliceRequest := internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: reqID,
|
||||
Timestamp: ts,
|
||||
SourceID: proxyID,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
SegmentID: segmentID,
|
||||
// todo rename to ChannelName
|
||||
ChannelID: channelNames[key],
|
||||
}
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: request.TraceCtx(),
|
||||
},
|
||||
InsertRequest: sliceRequest,
|
||||
}
|
||||
result[key] = insertMsg
|
||||
curMsgSizeMap[key] = getFixedSizeOfInsertMsg(insertMsg)
|
||||
}
|
||||
curMsg := result[key].(*msgstream.InsertMsg)
|
||||
curMsgSize := curMsgSizeMap[key]
|
||||
curMsg.HashValues = append(curMsg.HashValues, insertRequest.HashValues[index])
|
||||
curMsg.Timestamps = append(curMsg.Timestamps, ts)
|
||||
curMsg.RowIDs = append(curMsg.RowIDs, rowID)
|
||||
curMsg.RowData = append(curMsg.RowData, row)
|
||||
curMsgSize += 4 + 8 + int(unsafe.Sizeof(row.Value))
|
||||
curMsgSize += len(row.Value)
|
||||
|
||||
if curMsgSize >= threshold {
|
||||
newPack.Msgs = append(newPack.Msgs, curMsg)
|
||||
delete(result, key)
|
||||
curMsgSize = 0
|
||||
}
|
||||
|
||||
curMsgSizeMap[key] = curMsgSize
|
||||
}
|
||||
}
|
||||
for _, msg := range result {
|
||||
if msg != nil {
|
||||
newPack.Msgs = append(newPack.Msgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
return newPack, nil
|
||||
}
|
||||
|
||||
func (it *InsertTask) Execute(ctx context.Context) error {
|
||||
collectionName := it.BaseInsertTask.CollectionName
|
||||
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
||||
|
@ -254,7 +462,14 @@ func (it *InsertTask) Execute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
err = stream.Produce(&msgPack)
|
||||
// Assign SegmentID
|
||||
var pack *msgstream.MsgPack
|
||||
pack, err = it._assignSegmentID(stream, &msgPack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = stream.Produce(pack)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
|
|
Loading…
Reference in New Issue