From 2adca8b754ffe0a50fda46d791b110112cff0521 Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Sat, 28 Sep 2024 17:19:21 +0800 Subject: [PATCH] fix: Fix data race for cluerting compaction (#36440) issue: #36438 Signed-off-by: Cai Zhang --- .../compaction/clustering_compactor.go | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/internal/datanode/compaction/clustering_compactor.go b/internal/datanode/compaction/clustering_compactor.go index 248e08be9a..7391f4ac4b 100644 --- a/internal/datanode/compaction/clustering_compactor.go +++ b/internal/datanode/compaction/clustering_compactor.go @@ -106,7 +106,7 @@ type clusteringCompactionTask struct { type ClusterBuffer struct { id int - writer *SegmentWriter + writer atomic.Value flushLock lock.RWMutex bufferMemorySize atomic.Int64 @@ -464,7 +464,7 @@ func (t *clusteringCompactionTask) getBufferTotalUsedMemorySize() int64 { var totalBufferSize int64 = 0 for _, buffer := range t.clusterBuffers { t.clusterBufferLocks.Lock(buffer.id) - totalBufferSize = totalBufferSize + int64(buffer.writer.WrittenMemorySize()) + buffer.bufferMemorySize.Load() + totalBufferSize = totalBufferSize + int64(buffer.writer.Load().(*SegmentWriter).WrittenMemorySize()) + buffer.bufferMemorySize.Load() t.clusterBufferLocks.Unlock(buffer.id) } return totalBufferSize @@ -599,14 +599,14 @@ func (t *clusteringCompactionTask) mappingSegment( if (remained+1)%100 == 0 { currentBufferTotalMemorySize := t.getBufferTotalUsedMemorySize() - if clusterBuffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() || clusterBuffer.writer.IsFull() { + if clusterBuffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() || clusterBuffer.writer.Load().(*SegmentWriter).IsFull() { // reach segment/binlog max size flushWriterFunc := func() { t.clusterBufferLocks.Lock(clusterBuffer.id) currentSegmentNumRows := clusterBuffer.currentSegmentRowNum.Load() // double-check the condition is still met - if currentSegmentNumRows > t.plan.GetMaxSegmentRows() || clusterBuffer.writer.IsFull() { - writer := clusterBuffer.writer + writer := clusterBuffer.writer.Load().(*SegmentWriter) + if currentSegmentNumRows > t.plan.GetMaxSegmentRows() || writer.IsFull() { pack, _ := t.refreshBufferWriterWithPack(clusterBuffer) log.Debug("buffer need to flush", zap.Int("bufferID", clusterBuffer.id), zap.Bool("pack", pack), @@ -677,11 +677,12 @@ func (t *clusteringCompactionTask) writeToBuffer(ctx context.Context, clusterBuf t.clusterBufferLocks.Lock(clusterBuffer.id) defer t.clusterBufferLocks.Unlock(clusterBuffer.id) // prepare - if clusterBuffer.writer == nil { + writer := clusterBuffer.writer.Load() + if writer == nil || writer.(*SegmentWriter) == nil { log.Warn("unexpected behavior, please check", zap.Int("buffer id", clusterBuffer.id)) return fmt.Errorf("unexpected behavior, please check buffer id: %d", clusterBuffer.id) } - err := clusterBuffer.writer.Write(value) + err := writer.(*SegmentWriter).Write(value) if err != nil { return err } @@ -764,7 +765,7 @@ func (t *clusteringCompactionTask) flushLargestBuffers(ctx context.Context) erro for _, buffer := range t.clusterBuffers { bufferIDs = append(bufferIDs, buffer.id) t.clusterBufferLocks.RLock(buffer.id) - bufferRowNums = append(bufferRowNums, buffer.writer.GetRowNum()) + bufferRowNums = append(bufferRowNums, buffer.writer.Load().(*SegmentWriter).GetRowNum()) t.clusterBufferLocks.RUnlock(buffer.id) } sort.Slice(bufferIDs, func(i, j int) bool { @@ -777,7 +778,7 @@ func (t *clusteringCompactionTask) flushLargestBuffers(ctx context.Context) erro t.clusterBufferLocks.Lock(bufferId) buffer := t.clusterBuffers[bufferId] writer := buffer.writer - currentMemorySize -= int64(writer.WrittenMemorySize()) + currentMemorySize -= int64(writer.Load().(*SegmentWriter).WrittenMemorySize()) if err := t.refreshBufferWriter(buffer); err != nil { t.clusterBufferLocks.Unlock(bufferId) return err @@ -787,10 +788,10 @@ func (t *clusteringCompactionTask) flushLargestBuffers(ctx context.Context) erro log.Info("currentMemorySize after flush buffer binlog", zap.Int64("currentMemorySize", currentMemorySize), zap.Int("bufferID", bufferId), - zap.Uint64("WrittenMemorySize()", writer.WrittenMemorySize()), - zap.Int64("RowNum", writer.GetRowNum())) + zap.Uint64("WrittenMemorySize()", writer.Load().(*SegmentWriter).WrittenMemorySize()), + zap.Int64("RowNum", writer.Load().(*SegmentWriter).GetRowNum())) future := t.flushPool.Submit(func() (any, error) { - err := t.flushBinlog(ctx, buffer, writer, false) + err := t.flushBinlog(ctx, buffer, writer.Load().(*SegmentWriter), false) if err != nil { return nil, err } @@ -819,7 +820,7 @@ func (t *clusteringCompactionTask) flushAll(ctx context.Context) error { for _, buffer := range t.clusterBuffers { buffer := buffer future := t.flushPool.Submit(func() (any, error) { - err := t.flushBinlog(ctx, buffer, buffer.writer, true) + err := t.flushBinlog(ctx, buffer, buffer.writer.Load().(*SegmentWriter), true) if err != nil { return nil, err } @@ -1201,11 +1202,11 @@ func (t *clusteringCompactionTask) refreshBufferWriterWithPack(buffer *ClusterBu var segmentID int64 var err error var pack bool - if buffer.writer != nil { - segmentID = buffer.writer.GetSegmentID() - buffer.bufferMemorySize.Add(int64(buffer.writer.WrittenMemorySize())) + if buffer.writer.Load() != nil && buffer.writer.Load().(*SegmentWriter) != nil { + segmentID = buffer.writer.Load().(*SegmentWriter).GetSegmentID() + buffer.bufferMemorySize.Add(int64(buffer.writer.Load().(*SegmentWriter).WrittenMemorySize())) } - if buffer.writer == nil || buffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() { + if buffer.writer.Load() == nil || buffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() { pack = true segmentID, err = t.segIDAlloc.AllocOne() if err != nil { @@ -1219,22 +1220,22 @@ func (t *clusteringCompactionTask) refreshBufferWriterWithPack(buffer *ClusterBu return pack, err } - buffer.writer = writer + buffer.writer.Store(writer) return pack, nil } func (t *clusteringCompactionTask) refreshBufferWriter(buffer *ClusterBuffer) error { var segmentID int64 var err error - segmentID = buffer.writer.GetSegmentID() - buffer.bufferMemorySize.Add(int64(buffer.writer.WrittenMemorySize())) + segmentID = buffer.writer.Load().(*SegmentWriter).GetSegmentID() + buffer.bufferMemorySize.Add(int64(buffer.writer.Load().(*SegmentWriter).WrittenMemorySize())) writer, err := NewSegmentWriter(t.plan.GetSchema(), t.plan.MaxSegmentRows, segmentID, t.partitionID, t.collectionID) if err != nil { return err } - buffer.writer = writer + buffer.writer.Store(writer) return nil }