fix: Fix data race for cluerting compaction (#36440)

issue: #36438

Signed-off-by: Cai Zhang <cai.zhang@zilliz.com>
pull/36372/head
cai.zhang 2024-09-28 17:19:21 +08:00 committed by GitHub
parent 31353ae406
commit 2adca8b754
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 22 additions and 21 deletions

View File

@ -106,7 +106,7 @@ type clusteringCompactionTask struct {
type ClusterBuffer struct {
id int
writer *SegmentWriter
writer atomic.Value
flushLock lock.RWMutex
bufferMemorySize atomic.Int64
@ -464,7 +464,7 @@ func (t *clusteringCompactionTask) getBufferTotalUsedMemorySize() int64 {
var totalBufferSize int64 = 0
for _, buffer := range t.clusterBuffers {
t.clusterBufferLocks.Lock(buffer.id)
totalBufferSize = totalBufferSize + int64(buffer.writer.WrittenMemorySize()) + buffer.bufferMemorySize.Load()
totalBufferSize = totalBufferSize + int64(buffer.writer.Load().(*SegmentWriter).WrittenMemorySize()) + buffer.bufferMemorySize.Load()
t.clusterBufferLocks.Unlock(buffer.id)
}
return totalBufferSize
@ -599,14 +599,14 @@ func (t *clusteringCompactionTask) mappingSegment(
if (remained+1)%100 == 0 {
currentBufferTotalMemorySize := t.getBufferTotalUsedMemorySize()
if clusterBuffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() || clusterBuffer.writer.IsFull() {
if clusterBuffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() || clusterBuffer.writer.Load().(*SegmentWriter).IsFull() {
// reach segment/binlog max size
flushWriterFunc := func() {
t.clusterBufferLocks.Lock(clusterBuffer.id)
currentSegmentNumRows := clusterBuffer.currentSegmentRowNum.Load()
// double-check the condition is still met
if currentSegmentNumRows > t.plan.GetMaxSegmentRows() || clusterBuffer.writer.IsFull() {
writer := clusterBuffer.writer
writer := clusterBuffer.writer.Load().(*SegmentWriter)
if currentSegmentNumRows > t.plan.GetMaxSegmentRows() || writer.IsFull() {
pack, _ := t.refreshBufferWriterWithPack(clusterBuffer)
log.Debug("buffer need to flush", zap.Int("bufferID", clusterBuffer.id),
zap.Bool("pack", pack),
@ -677,11 +677,12 @@ func (t *clusteringCompactionTask) writeToBuffer(ctx context.Context, clusterBuf
t.clusterBufferLocks.Lock(clusterBuffer.id)
defer t.clusterBufferLocks.Unlock(clusterBuffer.id)
// prepare
if clusterBuffer.writer == nil {
writer := clusterBuffer.writer.Load()
if writer == nil || writer.(*SegmentWriter) == nil {
log.Warn("unexpected behavior, please check", zap.Int("buffer id", clusterBuffer.id))
return fmt.Errorf("unexpected behavior, please check buffer id: %d", clusterBuffer.id)
}
err := clusterBuffer.writer.Write(value)
err := writer.(*SegmentWriter).Write(value)
if err != nil {
return err
}
@ -764,7 +765,7 @@ func (t *clusteringCompactionTask) flushLargestBuffers(ctx context.Context) erro
for _, buffer := range t.clusterBuffers {
bufferIDs = append(bufferIDs, buffer.id)
t.clusterBufferLocks.RLock(buffer.id)
bufferRowNums = append(bufferRowNums, buffer.writer.GetRowNum())
bufferRowNums = append(bufferRowNums, buffer.writer.Load().(*SegmentWriter).GetRowNum())
t.clusterBufferLocks.RUnlock(buffer.id)
}
sort.Slice(bufferIDs, func(i, j int) bool {
@ -777,7 +778,7 @@ func (t *clusteringCompactionTask) flushLargestBuffers(ctx context.Context) erro
t.clusterBufferLocks.Lock(bufferId)
buffer := t.clusterBuffers[bufferId]
writer := buffer.writer
currentMemorySize -= int64(writer.WrittenMemorySize())
currentMemorySize -= int64(writer.Load().(*SegmentWriter).WrittenMemorySize())
if err := t.refreshBufferWriter(buffer); err != nil {
t.clusterBufferLocks.Unlock(bufferId)
return err
@ -787,10 +788,10 @@ func (t *clusteringCompactionTask) flushLargestBuffers(ctx context.Context) erro
log.Info("currentMemorySize after flush buffer binlog",
zap.Int64("currentMemorySize", currentMemorySize),
zap.Int("bufferID", bufferId),
zap.Uint64("WrittenMemorySize()", writer.WrittenMemorySize()),
zap.Int64("RowNum", writer.GetRowNum()))
zap.Uint64("WrittenMemorySize()", writer.Load().(*SegmentWriter).WrittenMemorySize()),
zap.Int64("RowNum", writer.Load().(*SegmentWriter).GetRowNum()))
future := t.flushPool.Submit(func() (any, error) {
err := t.flushBinlog(ctx, buffer, writer, false)
err := t.flushBinlog(ctx, buffer, writer.Load().(*SegmentWriter), false)
if err != nil {
return nil, err
}
@ -819,7 +820,7 @@ func (t *clusteringCompactionTask) flushAll(ctx context.Context) error {
for _, buffer := range t.clusterBuffers {
buffer := buffer
future := t.flushPool.Submit(func() (any, error) {
err := t.flushBinlog(ctx, buffer, buffer.writer, true)
err := t.flushBinlog(ctx, buffer, buffer.writer.Load().(*SegmentWriter), true)
if err != nil {
return nil, err
}
@ -1201,11 +1202,11 @@ func (t *clusteringCompactionTask) refreshBufferWriterWithPack(buffer *ClusterBu
var segmentID int64
var err error
var pack bool
if buffer.writer != nil {
segmentID = buffer.writer.GetSegmentID()
buffer.bufferMemorySize.Add(int64(buffer.writer.WrittenMemorySize()))
if buffer.writer.Load() != nil && buffer.writer.Load().(*SegmentWriter) != nil {
segmentID = buffer.writer.Load().(*SegmentWriter).GetSegmentID()
buffer.bufferMemorySize.Add(int64(buffer.writer.Load().(*SegmentWriter).WrittenMemorySize()))
}
if buffer.writer == nil || buffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() {
if buffer.writer.Load() == nil || buffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() {
pack = true
segmentID, err = t.segIDAlloc.AllocOne()
if err != nil {
@ -1219,22 +1220,22 @@ func (t *clusteringCompactionTask) refreshBufferWriterWithPack(buffer *ClusterBu
return pack, err
}
buffer.writer = writer
buffer.writer.Store(writer)
return pack, nil
}
func (t *clusteringCompactionTask) refreshBufferWriter(buffer *ClusterBuffer) error {
var segmentID int64
var err error
segmentID = buffer.writer.GetSegmentID()
buffer.bufferMemorySize.Add(int64(buffer.writer.WrittenMemorySize()))
segmentID = buffer.writer.Load().(*SegmentWriter).GetSegmentID()
buffer.bufferMemorySize.Add(int64(buffer.writer.Load().(*SegmentWriter).WrittenMemorySize()))
writer, err := NewSegmentWriter(t.plan.GetSchema(), t.plan.MaxSegmentRows, segmentID, t.partitionID, t.collectionID)
if err != nil {
return err
}
buffer.writer = writer
buffer.writer.Store(writer)
return nil
}