milvus/internal/rootcoord/dml_channels.go

310 lines
7.6 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"container/heap"
"context"
"fmt"
"sync"
"github.com/milvus-io/milvus/internal/metrics"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
)
type dmlMsgStream struct {
ms msgstream.MsgStream
mutex sync.RWMutex
refcnt int64 // current in use count
used int64 // total used counter in current run, not stored in meta so meant to be inaccurate
idx int64 // idx for name
pos int // position in the heap slice
}
// RefCnt returns refcnt with mutex protection.
func (dms *dmlMsgStream) RefCnt() int64 {
dms.mutex.RLock()
defer dms.mutex.RUnlock()
return dms.refcnt
}
// RefCnt returns refcnt with mutex protection.
func (dms *dmlMsgStream) Used() int64 {
dms.mutex.RLock()
defer dms.mutex.RUnlock()
return dms.used
}
// IncRefcnt increases refcnt.
func (dms *dmlMsgStream) IncRefcnt() {
dms.mutex.Lock()
defer dms.mutex.Unlock()
dms.refcnt++
}
// BookUsage increases used, acting like reservation usage.
func (dms *dmlMsgStream) BookUsage() {
dms.mutex.Lock()
defer dms.mutex.Unlock()
dms.used++
}
// DecRefCnt decreases refcnt only.
func (dms *dmlMsgStream) DecRefCnt() {
dms.mutex.Lock()
defer dms.mutex.Unlock()
if dms.refcnt > 0 {
dms.refcnt--
} else {
log.Warn("Try to remove channel with no ref count", zap.Int64("idx", dms.idx))
}
}
// channelsHeap implements heap.Interface to performs like an priority queue.
type channelsHeap []*dmlMsgStream
// Len is the number of elements in the collection.
func (h channelsHeap) Len() int {
return len(h)
}
// Less reports whether the element with index i
// must sort before the element with index j.
func (h channelsHeap) Less(i int, j int) bool {
ei, ej := h[i], h[j]
// use less refcnt first
rci, rcj := ei.RefCnt(), ej.RefCnt()
if rci != rcj {
return rci < rcj
}
// used not used channel first
ui, uj := ei.Used(), ej.Used()
if ui != uj {
return ui < uj
}
// all number same, used alphabetic smaller one
return ei.idx < ej.idx
}
// Swap swaps the elements with indexes i and j.
func (h channelsHeap) Swap(i int, j int) {
h[i], h[j] = h[j], h[i]
h[i].pos, h[j].pos = i, j
}
// Push adds a new element to the heap.
func (h *channelsHeap) Push(x interface{}) {
item := x.(*dmlMsgStream)
*h = append(*h, item)
}
// Pop implements heap.Interface, pop the last value.
func (h *channelsHeap) Pop() interface{} {
old := *h
n := len(old)
item := old[n-1]
old[n-1] = nil
*h = old[0 : n-1]
return item
}
type dmlChannels struct {
ctx context.Context
factory msgstream.Factory
namePrefix string
capacity int64
// pool maintains channelName => dmlMsgStream mapping, stable
pool sync.Map
// mut protects channlsHeap only
mut sync.Mutex
// channelsHeap is the heap to pop next dms for use
channelsHeap channelsHeap
}
func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePrefix string, chanNum int64) *dmlChannels {
d := &dmlChannels{
ctx: ctx,
factory: factory,
namePrefix: chanNamePrefix,
capacity: chanNum,
channelsHeap: make([]*dmlMsgStream, 0, chanNum),
}
for i := int64(0); i < chanNum; i++ {
name := genChannelName(d.namePrefix, i)
ms, err := factory.NewMsgStream(ctx)
if err != nil {
log.Error("Failed to add msgstream", zap.String("name", name), zap.Error(err))
panic("Failed to add msgstream")
}
ms.AsProducer([]string{name})
dms := &dmlMsgStream{
ms: ms,
refcnt: 0,
used: 0,
idx: i,
pos: int(i),
}
d.pool.Store(name, dms)
d.channelsHeap = append(d.channelsHeap, dms)
}
heap.Init(&d.channelsHeap)
log.Debug("init dml channels", zap.Int64("num", chanNum))
metrics.RootCoordNumOfDMLChannel.Add(float64(chanNum))
metrics.RootCoordNumOfMsgStream.Add(float64(chanNum))
return d
}
func (d *dmlChannels) getChannelNames(count int) []string {
d.mut.Lock()
defer d.mut.Unlock()
if count > len(d.channelsHeap) {
return nil
}
// get next count items from heap
items := make([]*dmlMsgStream, 0, count)
result := make([]string, 0, count)
for i := 0; i < count; i++ {
item := heap.Pop(&d.channelsHeap).(*dmlMsgStream)
item.BookUsage()
items = append(items, item)
result = append(result, genChannelName(d.namePrefix, item.idx))
}
for _, item := range items {
heap.Push(&d.channelsHeap, item)
}
return result
}
func (d *dmlChannels) listChannels() []string {
var chanNames []string
d.pool.Range(
func(k, v interface{}) bool {
dms := v.(*dmlMsgStream)
if dms.RefCnt() > 0 {
chanNames = append(chanNames, genChannelName(d.namePrefix, dms.idx))
}
return true
})
return chanNames
}
func (d *dmlChannels) getChannelNum() int {
return len(d.listChannels())
}
func (d *dmlChannels) broadcast(chanNames []string, pack *msgstream.MsgPack) error {
for _, chanName := range chanNames {
v, ok := d.pool.Load(chanName)
if !ok {
log.Error("invalid channel name", zap.String("chanName", chanName))
panic("invalid channel name: " + chanName)
}
dms := v.(*dmlMsgStream)
dms.mutex.RLock()
if dms.refcnt > 0 {
if err := dms.ms.Broadcast(pack); err != nil {
log.Error("Broadcast failed", zap.Error(err), zap.String("chanName", chanName))
dms.mutex.RUnlock()
return err
}
}
dms.mutex.RUnlock()
}
return nil
}
func (d *dmlChannels) broadcastMark(chanNames []string, pack *msgstream.MsgPack) (map[string][]byte, error) {
result := make(map[string][]byte)
for _, chanName := range chanNames {
v, ok := d.pool.Load(chanName)
if !ok {
log.Error("invalid channel name", zap.String("chanName", chanName))
panic("invalid channel name: " + chanName)
}
dms := v.(*dmlMsgStream)
dms.mutex.RLock()
if dms.refcnt > 0 {
ids, err := dms.ms.BroadcastMark(pack)
if err != nil {
log.Error("BroadcastMark failed", zap.Error(err), zap.String("chanName", chanName))
dms.mutex.RUnlock()
return result, err
}
for cn, idList := range ids {
// idList should have length 1, just flat by iteration
for _, id := range idList {
result[cn] = id.Serialize()
}
}
}
dms.mutex.RUnlock()
}
return result, nil
}
func (d *dmlChannels) addChannels(names ...string) {
for _, name := range names {
v, ok := d.pool.Load(name)
if !ok {
log.Error("invalid channel name", zap.String("chanName", name))
panic("invalid channel name: " + name)
}
dms := v.(*dmlMsgStream)
d.mut.Lock()
dms.IncRefcnt()
heap.Fix(&d.channelsHeap, dms.pos)
d.mut.Unlock()
}
}
func (d *dmlChannels) removeChannels(names ...string) {
for _, name := range names {
v, ok := d.pool.Load(name)
if !ok {
log.Error("invalid channel name", zap.String("chanName", name))
panic("invalid channel name: " + name)
}
dms := v.(*dmlMsgStream)
d.mut.Lock()
dms.DecRefCnt()
heap.Fix(&d.channelsHeap, dms.pos)
d.mut.Unlock()
}
}
func genChannelName(prefix string, idx int64) string {
return fmt.Sprintf("%s_%d", prefix, idx)
}