Fix flushManager.isFull is too slow (#28141)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/28163/head
Jiquan Long 2023-11-03 14:42:17 +08:00 committed by GitHub
parent 755a592b08
commit a21042dde7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 137 additions and 14 deletions

View File

@ -59,7 +59,7 @@ type Channel interface {
getCollectionID() UniqueID
getCollectionSchema(collectionID UniqueID, ts Timestamp) (*schemapb.CollectionSchema, error)
getCollectionAndPartitionID(segID UniqueID) (collID, partitionID UniqueID, err error)
getChannelName(segID UniqueID) string
getChannelName() string
addSegment(ctx context.Context, req addSegmentReq) error
getSegment(segID UniqueID) *Segment
@ -204,7 +204,7 @@ func (c *ChannelMeta) getCollectionAndPartitionID(segID UniqueID) (collID, parti
return 0, 0, fmt.Errorf("cannot find segment, id = %d", segID)
}
func (c *ChannelMeta) getChannelName(segID UniqueID) string {
func (c *ChannelMeta) getChannelName() string {
return c.channelName
}

View File

@ -775,7 +775,7 @@ func TestGetChannelWithTickler(t *testing.T) {
channel, err := getChannelWithTickler(context.TODO(), node, info, newTickler(), unflushed, flushed)
assert.NoError(t, err)
assert.NotNil(t, channel)
assert.Equal(t, channelName, channel.getChannelName(100))
assert.Equal(t, channelName, channel.getChannelName())
assert.Equal(t, int64(1), channel.getCollectionID())
assert.True(t, channel.hasSegment(100, true))
assert.True(t, channel.hasSegment(101, true))

View File

@ -95,6 +95,7 @@ var _ flushManager = (*rendezvousFlushManager)(nil)
type orderFlushQueue struct {
sync.Once
segmentID UniqueID
channel string
injectCh chan *taskInjection
// MsgID => flushTask
@ -110,9 +111,10 @@ type orderFlushQueue struct {
}
// newOrderFlushQueue creates an orderFlushQueue
func newOrderFlushQueue(segID UniqueID, f notifyMetaFunc) *orderFlushQueue {
func newOrderFlushQueue(segID UniqueID, channel string, f notifyMetaFunc) *orderFlushQueue {
q := &orderFlushQueue{
segmentID: segID,
channel: channel,
notifyFunc: f,
injectCh: make(chan *taskInjection, 100),
working: typeutil.NewConcurrentMap[string, *flushTaskRunner](),
@ -133,6 +135,7 @@ func (q *orderFlushQueue) getFlushTaskRunner(pos *msgpb.MsgPosition) *flushTaskR
t, loaded := q.working.GetOrInsert(getSyncTaskID(pos), newFlushTaskRunner(q.segmentID, q.injectCh))
// not loaded means the task runner is new, do initializtion
if !loaded {
getOrCreateFlushTaskCounter().increase(q.channel)
// take over injection if task queue is handling it
q.injectMut.Lock()
q.runningTasks++
@ -154,6 +157,7 @@ func (q *orderFlushQueue) getFlushTaskRunner(pos *msgpb.MsgPosition) *flushTaskR
func (q *orderFlushQueue) postTask(pack *segmentFlushPack, postInjection postInjectionFunc) {
// delete task from working map
q.working.GetAndRemove(getSyncTaskID(pack.pos))
getOrCreateFlushTaskCounter().decrease(q.channel)
// after descreasing working count, check whether flush queue is empty
q.injectMut.Lock()
q.runningTasks--
@ -281,7 +285,7 @@ type rendezvousFlushManager struct {
// getFlushQueue gets or creates an orderFlushQueue for segment id if not found
func (m *rendezvousFlushManager) getFlushQueue(segmentID UniqueID) *orderFlushQueue {
newQueue := newOrderFlushQueue(segmentID, m.notifyFunc)
newQueue := newOrderFlushQueue(segmentID, m.getChannelName(), m.notifyFunc)
queue, _ := m.dispatcher.GetOrInsert(segmentID, newQueue)
queue.init()
return queue
@ -420,12 +424,8 @@ func (m *rendezvousFlushManager) serializePkStatsLog(segmentID int64, flushed bo
// isFull return true if the task pool is full
func (m *rendezvousFlushManager) isFull() bool {
var num int
m.dispatcher.Range(func(_ int64, queue *orderFlushQueue) bool {
num += queue.working.Len()
return true
})
return num >= Params.DataNodeCfg.MaxParallelSyncTaskNum.GetAsInt()
return getOrCreateFlushTaskCounter().getOrZero(m.getChannelName()) >=
int32(Params.DataNodeCfg.MaxParallelSyncTaskNum.GetAsInt())
}
// flushBufferData notifies flush manager insert buffer data.

View File

@ -69,7 +69,7 @@ func TestOrderFlushQueue_Execute(t *testing.T) {
size := 1000
finish.Add(size)
q := newOrderFlushQueue(1, func(*segmentFlushPack) {
q := newOrderFlushQueue(1, "", func(*segmentFlushPack) {
counter.Inc()
finish.Done()
})
@ -111,7 +111,7 @@ func TestOrderFlushQueue_Order(t *testing.T) {
size := 1000
finish.Add(size)
resultList := make([][]byte, 0, size)
q := newOrderFlushQueue(1, func(pack *segmentFlushPack) {
q := newOrderFlushQueue(1, "", func(pack *segmentFlushPack) {
counter.Inc()
resultList = append(resultList, pack.pos.MsgID)
finish.Done()

View File

@ -0,0 +1,79 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datanode
import (
"sync"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type flushTaskCounter struct {
inner *typeutil.ConcurrentMap[string, *atomic.Int32] // channel -> counter
}
func (c *flushTaskCounter) getOrZero(channel string) int32 {
counter, exist := c.inner.Get(channel)
if !exist {
return 0
}
return counter.Load()
}
func (c *flushTaskCounter) increaseImpl(channel string, delta int32) {
counter, _ := c.inner.GetOrInsert(channel, atomic.NewInt32(0))
counter.Add(delta)
}
func (c *flushTaskCounter) increase(channel string) {
c.increaseImpl(channel, 1)
}
func (c *flushTaskCounter) decrease(channel string) {
c.increaseImpl(channel, -1)
}
func (c *flushTaskCounter) close() {
allChannels := make([]string, 0, c.inner.Len())
c.inner.Range(func(channel string, _ *atomic.Int32) bool {
allChannels = append(allChannels, channel)
return false
})
for _, channel := range allChannels {
c.inner.Remove(channel)
}
}
func newFlushTaskCounter() *flushTaskCounter {
return &flushTaskCounter{
inner: typeutil.NewConcurrentMap[string, *atomic.Int32](),
}
}
var (
globalFlushTaskCounter *flushTaskCounter
flushTaskCounterOnce sync.Once
)
func getOrCreateFlushTaskCounter() *flushTaskCounter {
flushTaskCounterOnce.Do(func() {
globalFlushTaskCounter = newFlushTaskCounter()
})
return globalFlushTaskCounter
}

View File

@ -0,0 +1,44 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datanode
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_flushTaskCounter_getOrZero(t *testing.T) {
c := newFlushTaskCounter()
defer c.close()
assert.Zero(t, c.getOrZero("non-exist"))
n := 10
channel := "channel"
assert.Zero(t, c.getOrZero(channel))
for i := 0; i < n; i++ {
c.increase(channel)
}
assert.Equal(t, int32(n), c.getOrZero(channel))
for i := 0; i < n; i++ {
c.decrease(channel)
}
assert.Zero(t, c.getOrZero(channel))
}

View File

@ -361,7 +361,7 @@ func (node *DataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegments
log.Ctx(ctx).Warn("fail to get the channel", zap.Int64("segment", fromSegment), zap.Error(err))
continue
}
ds, ok = node.flowgraphManager.getFlowgraphService(channel.getChannelName(fromSegment))
ds, ok = node.flowgraphManager.getFlowgraphService(channel.getChannelName())
if !ok {
log.Ctx(ctx).Warn("fail to find flow graph service", zap.Int64("segment", fromSegment))
continue