enhance: Apply node-indexing and cache optimization for channel dist (#32595)

See also #32165

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/32673/head
congqixia 2024-04-28 16:19:24 +08:00 committed by GitHub
parent c7807afe71
commit a239e9110e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 132 additions and 38 deletions

View File

@ -20,31 +20,96 @@ import (
"sync" "sync"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
. "github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
type ChannelDistFilter = func(ch *DmChannel) bool type channelDistCriterion struct {
nodeIDs typeutil.Set[int64]
collectionID int64
channelName string
hasOtherFilter bool
}
type ChannelDistFilter interface {
Match(ch *DmChannel) bool
AddFilter(*channelDistCriterion)
}
type collChannelFilter int64
func (f collChannelFilter) Match(ch *DmChannel) bool {
return ch.GetCollectionID() == int64(f)
}
func (f collChannelFilter) AddFilter(criterion *channelDistCriterion) {
criterion.collectionID = int64(f)
}
func WithCollectionID2Channel(collectionID int64) ChannelDistFilter { func WithCollectionID2Channel(collectionID int64) ChannelDistFilter {
return func(ch *DmChannel) bool { return collChannelFilter(collectionID)
return ch.GetCollectionID() == collectionID }
type nodeChannelFilter int64
func (f nodeChannelFilter) Match(ch *DmChannel) bool {
return ch.Node == int64(f)
}
func (f nodeChannelFilter) AddFilter(criterion *channelDistCriterion) {
set := typeutil.NewSet(int64(f))
if criterion.nodeIDs == nil {
criterion.nodeIDs = set
} else {
criterion.nodeIDs = criterion.nodeIDs.Intersection(set)
} }
} }
func WithNodeID2Channel(nodeID int64) ChannelDistFilter { func WithNodeID2Channel(nodeID int64) ChannelDistFilter {
return func(ch *DmChannel) bool { return nodeChannelFilter(nodeID)
return ch.Node == nodeID }
type replicaChannelFilter struct {
*Replica
}
func (f replicaChannelFilter) Match(ch *DmChannel) bool {
return ch.GetCollectionID() == f.GetCollectionID() && f.Contains(ch.Node)
}
func (f replicaChannelFilter) AddFilter(criterion *channelDistCriterion) {
criterion.collectionID = f.GetCollectionID()
set := typeutil.NewSet(f.GetNodes()...)
if criterion.nodeIDs == nil {
criterion.nodeIDs = set
} else {
criterion.nodeIDs = criterion.nodeIDs.Intersection(set)
} }
} }
func WithReplica2Channel(replica *Replica) ChannelDistFilter { func WithReplica2Channel(replica *Replica) ChannelDistFilter {
return func(ch *DmChannel) bool { return &replicaChannelFilter{
return ch.GetCollectionID() == replica.GetCollectionID() && replica.Contains(ch.Node) Replica: replica,
} }
} }
type nameChannelFilter string
func (f nameChannelFilter) Match(ch *DmChannel) bool {
return ch.GetChannelName() == string(f)
}
func (f nameChannelFilter) AddFilter(criterion *channelDistCriterion) {
criterion.channelName = string(f)
}
func WithChannelName2Channel(channelName string) ChannelDistFilter {
return nameChannelFilter(channelName)
}
type DmChannel struct { type DmChannel struct {
*datapb.VchannelInfo *datapb.VchannelInfo
Node int64 Node int64
@ -65,11 +130,43 @@ func (channel *DmChannel) Clone() *DmChannel {
} }
} }
type nodeChannels struct {
channels []*DmChannel
// collection id => channels
collChannels map[int64][]*DmChannel
// channel name => DmChannel
nameChannel map[string]*DmChannel
}
func (c nodeChannels) Filter(critertion *channelDistCriterion) []*DmChannel {
var channels []*DmChannel
switch {
case critertion.channelName != "":
if ch, ok := c.nameChannel[critertion.channelName]; ok {
channels = []*DmChannel{ch}
}
case critertion.collectionID != 0:
channels = c.collChannels[critertion.collectionID]
default:
channels = c.channels
}
return channels // lo.Filter(channels, func(ch *DmChannel, _ int) bool { return mergedFilters(ch) })
}
func composeNodeChannels(channels ...*DmChannel) nodeChannels {
return nodeChannels{
channels: channels,
collChannels: lo.GroupBy(channels, func(ch *DmChannel) int64 { return ch.GetCollectionID() }),
nameChannel: lo.SliceToMap(channels, func(ch *DmChannel) (string, *DmChannel) { return ch.GetChannelName(), ch }),
}
}
type ChannelDistManager struct { type ChannelDistManager struct {
rwmutex sync.RWMutex rwmutex sync.RWMutex
// NodeID -> Channels // NodeID -> Channels
channels map[UniqueID][]*DmChannel channels map[typeutil.UniqueID]nodeChannels
// CollectionID -> Channels // CollectionID -> Channels
collectionIndex map[int64][]*DmChannel collectionIndex map[int64][]*DmChannel
@ -77,7 +174,7 @@ type ChannelDistManager struct {
func NewChannelDistManager() *ChannelDistManager { func NewChannelDistManager() *ChannelDistManager {
return &ChannelDistManager{ return &ChannelDistManager{
channels: make(map[UniqueID][]*DmChannel), channels: make(map[typeutil.UniqueID]nodeChannels),
collectionIndex: make(map[int64][]*DmChannel), collectionIndex: make(map[int64][]*DmChannel),
} }
} }
@ -91,10 +188,9 @@ func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int
for _, node := range replica.GetNodes() { for _, node := range replica.GetNodes() {
channels := m.channels[node] channels := m.channels[node]
for _, dmc := range channels { _, ok := channels.nameChannel[shard]
if dmc.ChannelName == shard { if ok {
return node, true return node, true
}
} }
} }
@ -109,10 +205,8 @@ func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[stri
ret := make(map[string]int64) ret := make(map[string]int64)
for _, node := range replica.GetNodes() { for _, node := range replica.GetNodes() {
channels := m.channels[node] channels := m.channels[node]
for _, dmc := range channels { for _, dmc := range channels.collChannels[replica.GetCollectionID()] {
if dmc.GetCollectionID() == replica.GetCollectionID() { ret[dmc.GetChannelName()] = node
ret[dmc.GetChannelName()] = node
}
} }
} }
return ret return ret
@ -123,23 +217,23 @@ func (m *ChannelDistManager) GetByFilter(filters ...ChannelDistFilter) []*DmChan
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
mergedFilters := func(ch *DmChannel) bool { criterion := &channelDistCriterion{}
for _, fn := range filters { for _, filter := range filters {
if fn != nil && !fn(ch) { filter.AddFilter(criterion)
return false
}
}
return true
} }
ret := make([]*DmChannel, 0) var candidates []nodeChannels
for _, channels := range m.channels { if criterion.nodeIDs != nil {
for _, channel := range channels { candidates = lo.Map(criterion.nodeIDs.Collect(), func(nodeID int64, _ int) nodeChannels {
if mergedFilters(channel) { return m.channels[nodeID]
ret = append(ret, channel) })
} } else {
} candidates = lo.Values(m.channels)
}
var ret []*DmChannel
for _, candidate := range candidates {
ret = append(ret, candidate.Filter(criterion)...)
} }
return ret return ret
} }
@ -150,7 +244,7 @@ func (m *ChannelDistManager) GetByCollectionAndFilter(collectionID int64, filter
mergedFilters := func(ch *DmChannel) bool { mergedFilters := func(ch *DmChannel) bool {
for _, fn := range filters { for _, fn := range filters {
if fn != nil && !fn(ch) { if fn != nil && !fn.Match(ch) {
return false return false
} }
} }
@ -169,7 +263,7 @@ func (m *ChannelDistManager) GetByCollectionAndFilter(collectionID int64, filter
return ret return ret
} }
func (m *ChannelDistManager) Update(nodeID UniqueID, channels ...*DmChannel) { func (m *ChannelDistManager) Update(nodeID typeutil.UniqueID, channels ...*DmChannel) {
m.rwmutex.Lock() m.rwmutex.Lock()
defer m.rwmutex.Unlock() defer m.rwmutex.Unlock()
@ -177,7 +271,7 @@ func (m *ChannelDistManager) Update(nodeID UniqueID, channels ...*DmChannel) {
channel.Node = nodeID channel.Node = nodeID
} }
m.channels[nodeID] = channels m.channels[nodeID] = composeNodeChannels(channels...)
m.updateCollectionIndex() m.updateCollectionIndex()
} }
@ -186,7 +280,7 @@ func (m *ChannelDistManager) Update(nodeID UniqueID, channels ...*DmChannel) {
func (m *ChannelDistManager) updateCollectionIndex() { func (m *ChannelDistManager) updateCollectionIndex() {
m.collectionIndex = make(map[int64][]*DmChannel) m.collectionIndex = make(map[int64][]*DmChannel)
for _, nodeChannels := range m.channels { for _, nodeChannels := range m.channels {
for _, channel := range nodeChannels { for _, channel := range nodeChannels.channels {
collectionID := channel.GetCollectionID() collectionID := channel.GetCollectionID()
if channels, ok := m.collectionIndex[collectionID]; !ok { if channels, ok := m.collectionIndex[collectionID]; !ok {
m.collectionIndex[collectionID] = []*DmChannel{channel} m.collectionIndex[collectionID] = []*DmChannel{channel}

View File

@ -66,7 +66,7 @@ func (suite *ChannelDistManagerSuite) TestGetBy() {
dist := suite.dist dist := suite.dist
// Test GetAll // Test GetAll
channels := dist.GetByFilter(nil) channels := dist.GetByFilter()
suite.Len(channels, 4) suite.Len(channels, 4)
// Test GetByNode // Test GetByNode