enhance: reduce stats task cost by skipping ser/de (#39568)

See #37234

---------

Signed-off-by: Ted Xu <ted.xu@zilliz.com>
pull/39665/head
Ted Xu 2025-02-06 17:14:45 +08:00 committed by GitHub
parent 2b4caba76e
commit 427b6a4c94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 901 additions and 596 deletions

View File

@ -1,20 +1,21 @@
package compaction
import (
"container/heap"
"context"
"fmt"
sio "io"
"math"
"time"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/proto/datapb"
@ -22,6 +23,24 @@ import (
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type segmentWriterWrapper struct {
*MultiSegmentWriter
}
var _ storage.RecordWriter = (*segmentWriterWrapper)(nil)
func (w *segmentWriterWrapper) GetWrittenUncompressed() uint64 {
return 0
}
func (w *segmentWriterWrapper) Write(record storage.Record) error {
return w.MultiSegmentWriter.WriteRecord(record)
}
func (w *segmentWriterWrapper) Close() error {
return nil
}
func mergeSortMultipleSegments(ctx context.Context,
plan *datapb.CompactionPlan,
collectionID, partitionID, maxRows int64,
@ -43,6 +62,7 @@ func mergeSortMultipleSegments(ctx context.Context,
logIDAlloc := allocator.NewLocalAllocator(plan.GetBeginLogID(), math.MaxInt64)
compAlloc := NewCompactionAllocator(segIDAlloc, logIDAlloc)
mWriter := NewMultiSegmentWriter(binlogIO, compAlloc, plan, maxRows, partitionID, collectionID, bm25FieldIds)
writer := &segmentWriterWrapper{MultiSegmentWriter: mWriter}
pkField, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
if err != nil {
@ -50,8 +70,7 @@ func mergeSortMultipleSegments(ctx context.Context,
return nil, err
}
// SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID())
segmentReaders := make([]*SegmentDeserializeReader, len(binlogs))
segmentReaders := make([]storage.RecordReader, len(binlogs))
segmentFilters := make([]*EntityFilter, len(binlogs))
for i, s := range binlogs {
var binlogBatchCount int
@ -75,7 +94,7 @@ func mergeSortMultipleSegments(ctx context.Context,
}
binlogPaths[idx] = batchPaths
}
segmentReaders[i] = NewSegmentDeserializeReader(ctx, binlogPaths, binlogIO, pkField.GetFieldID(), bm25FieldIds)
segmentReaders[i] = NewSegmentRecordReader(ctx, binlogPaths, binlogIO)
deltalogPaths := make([]string, 0)
for _, d := range s.GetDeltalogs() {
for _, l := range d.GetBinlogs() {
@ -89,57 +108,26 @@ func mergeSortMultipleSegments(ctx context.Context,
segmentFilters[i] = newEntityFilter(delta, collectionTtl, currentTime)
}
advanceRow := func(i int) (*storage.Value, error) {
for {
v, err := segmentReaders[i].Next()
if err != nil {
return nil, err
}
if segmentFilters[i].Filtered(v.PK.GetValue(), uint64(v.Timestamp)) {
continue
}
return v, nil
var predicate func(r storage.Record, ri, i int) bool
switch pkField.DataType {
case schemapb.DataType_Int64:
predicate = func(r storage.Record, ri, i int) bool {
pk := r.Column(pkField.FieldID).(*array.Int64).Value(i)
ts := r.Column(common.TimeStampField).(*array.Int64).Value(i)
return !segmentFilters[ri].Filtered(pk, uint64(ts))
}
case schemapb.DataType_VarChar:
predicate = func(r storage.Record, ri, i int) bool {
pk := r.Column(pkField.FieldID).(*array.String).Value(i)
ts := r.Column(common.TimeStampField).(*array.Int64).Value(i)
return !segmentFilters[ri].Filtered(pk, uint64(ts))
}
default:
log.Warn("compaction only support int64 and varchar pk field")
}
pq := make(PriorityQueue, 0)
heap.Init(&pq)
for i := range segmentReaders {
v, err := advanceRow(i)
if err != nil {
log.Warn("compact wrong, failed to advance row", zap.Error(err))
return nil, err
}
heap.Push(&pq, &PQItem{
Value: v,
Index: i,
})
}
for pq.Len() > 0 {
smallest := heap.Pop(&pq).(*PQItem)
v := smallest.Value
err := mWriter.Write(v)
if err != nil {
log.Warn("compact wrong, failed to writer row", zap.Error(err))
return nil, err
}
iv, err := advanceRow(smallest.Index)
if err != nil && err != sio.EOF {
return nil, err
}
if err == nil {
next := &PQItem{
Value: iv,
Index: smallest.Index,
}
heap.Push(&pq, next)
}
if _, err = storage.MergeSort(segmentReaders, pkField.FieldID, writer, predicate); err != nil {
return nil, err
}
res, err := mWriter.Finish()

View File

@ -1,86 +0,0 @@
package compaction
import (
"context"
"io"
"github.com/samber/lo"
"go.uber.org/zap"
binlogIO "github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/log"
)
type SegmentDeserializeReader struct {
ctx context.Context
binlogIO binlogIO.BinlogIO
reader *storage.DeserializeReader[*storage.Value]
pos int
PKFieldID int64
binlogPaths [][]string
binlogPathPos int
bm25FieldIDs []int64
}
func NewSegmentDeserializeReader(ctx context.Context, binlogPaths [][]string, binlogIO binlogIO.BinlogIO, PKFieldID int64, bm25FieldIDs []int64) *SegmentDeserializeReader {
return &SegmentDeserializeReader{
ctx: ctx,
binlogIO: binlogIO,
reader: nil,
pos: 0,
PKFieldID: PKFieldID,
binlogPaths: binlogPaths,
binlogPathPos: 0,
bm25FieldIDs: bm25FieldIDs,
}
}
func (r *SegmentDeserializeReader) initDeserializeReader() error {
if r.binlogPathPos >= len(r.binlogPaths) {
return io.EOF
}
allValues, err := r.binlogIO.Download(r.ctx, r.binlogPaths[r.binlogPathPos])
if err != nil {
log.Warn("compact wrong, fail to download insertLogs", zap.Error(err))
return err
}
blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob {
return &storage.Blob{Key: r.binlogPaths[r.binlogPathPos][i], Value: v}
})
r.reader, err = storage.NewBinlogDeserializeReader(blobs, r.PKFieldID)
if err != nil {
log.Warn("compact wrong, failed to new insert binlogs reader", zap.Error(err))
return err
}
r.binlogPathPos++
return nil
}
func (r *SegmentDeserializeReader) Next() (*storage.Value, error) {
if r.reader == nil {
if err := r.initDeserializeReader(); err != nil {
return nil, err
}
}
if err := r.reader.Next(); err != nil {
if err == io.EOF {
r.reader.Close()
if err := r.initDeserializeReader(); err != nil {
return nil, err
}
err = r.reader.Next()
return r.reader.Value(), err
}
return nil, err
}
return r.reader.Value(), nil
}
func (r *SegmentDeserializeReader) Close() {
r.reader.Close()
}

View File

@ -0,0 +1,31 @@
package compaction
import (
"context"
"io"
"github.com/samber/lo"
binlogIO "github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/storage"
)
func NewSegmentRecordReader(ctx context.Context, binlogPaths [][]string, binlogIO binlogIO.BinlogIO) storage.RecordReader {
pos := 0
return &storage.CompositeBinlogRecordReader{
BlobsReader: func() ([]*storage.Blob, error) {
if pos >= len(binlogPaths) {
return nil, io.EOF
}
bytesArr, err := binlogIO.Download(ctx, binlogPaths[pos])
if err != nil {
return nil, err
}
blobs := lo.Map(bytesArr, func(v []byte, i int) *storage.Blob {
return &storage.Blob{Key: binlogPaths[pos][i], Value: v}
})
pos++
return blobs, nil
},
}
}

View File

@ -19,22 +19,23 @@ package indexnode
import (
"context"
"fmt"
sio "io"
"sort"
"strconv"
"time"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datanode/compaction"
iter "github.com/milvus-io/milvus/internal/datanode/iterators"
"github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/metastore/kv/binlog"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/indexcgowrapper"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/proto/datapb"
"github.com/milvus-io/milvus/pkg/proto/indexcgopb"
@ -155,139 +156,309 @@ func (st *statsTask) PreExecute(ctx context.Context) error {
return nil
}
func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, error) {
// segmentRecordWriter is a wrapper of SegmentWriter to implement RecordWriter interface
type segmentRecordWriter struct {
sw *compaction.SegmentWriter
binlogMaxSize uint64
rootPath string
logID int64
maxLogID int64
binlogIO io.BinlogIO
ctx context.Context
numRows int64
bm25FieldIds []int64
lastUploads []*conc.Future[any]
binlogs map[typeutil.UniqueID]*datapb.FieldBinlog
statslog *datapb.FieldBinlog
bm25statslog []*datapb.FieldBinlog
}
var _ storage.RecordWriter = (*segmentRecordWriter)(nil)
func (srw *segmentRecordWriter) Close() error {
if !srw.sw.FlushAndIsEmpty() {
if err := srw.upload(); err != nil {
return err
}
if err := srw.waitLastUpload(); err != nil {
return err
}
}
statslog, err := srw.statSerializeWrite()
if err != nil {
log.Ctx(srw.ctx).Warn("stats wrong, failed to serialize write segment stats",
zap.Int64("remaining row count", srw.numRows), zap.Error(err))
return err
}
srw.statslog = statslog
srw.logID++
if len(srw.bm25FieldIds) > 0 {
binlogNums, bm25StatsLogs, err := srw.bm25SerializeWrite()
if err != nil {
log.Ctx(srw.ctx).Warn("compact wrong, failed to serialize write segment bm25 stats", zap.Error(err))
return err
}
srw.logID += binlogNums
srw.bm25statslog = bm25StatsLogs
}
return nil
}
func (srw *segmentRecordWriter) GetWrittenUncompressed() uint64 {
return srw.sw.WrittenMemorySize()
}
func (srw *segmentRecordWriter) Write(r storage.Record) error {
err := srw.sw.WriteRecord(r)
if err != nil {
return err
}
if srw.sw.IsFullWithBinlogMaxSize(srw.binlogMaxSize) {
return srw.upload()
}
return nil
}
func (srw *segmentRecordWriter) upload() error {
if err := srw.waitLastUpload(); err != nil {
return err
}
binlogNum, kvs, partialBinlogs, err := serializeWrite(srw.ctx, srw.rootPath, srw.logID, srw.sw)
if err != nil {
return err
}
srw.lastUploads = srw.binlogIO.AsyncUpload(srw.ctx, kvs)
if srw.binlogs == nil {
srw.binlogs = make(map[typeutil.UniqueID]*datapb.FieldBinlog)
}
mergeFieldBinlogs(srw.binlogs, partialBinlogs)
srw.logID += binlogNum
if srw.logID > srw.maxLogID {
return fmt.Errorf("log id exausted")
}
return nil
}
func (srw *segmentRecordWriter) waitLastUpload() error {
if len(srw.lastUploads) > 0 {
for _, future := range srw.lastUploads {
if _, err := future.Await(); err != nil {
return err
}
}
}
return nil
}
func (srw *segmentRecordWriter) statSerializeWrite() (*datapb.FieldBinlog, error) {
ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(srw.ctx, "statslog serializeWrite")
defer span.End()
sblob, err := srw.sw.Finish()
if err != nil {
return nil, err
}
key, _ := binlog.BuildLogPathWithRootPath(srw.rootPath, storage.StatsBinlog,
srw.sw.GetCollectionID(), srw.sw.GetPartitionID(), srw.sw.GetSegmentID(), srw.sw.GetPkID(), srw.logID)
kvs := map[string][]byte{key: sblob.GetValue()}
statFieldLog := &datapb.FieldBinlog{
FieldID: srw.sw.GetPkID(),
Binlogs: []*datapb.Binlog{
{
LogSize: int64(len(sblob.GetValue())),
MemorySize: int64(len(sblob.GetValue())),
LogPath: key,
EntriesNum: srw.numRows,
},
},
}
if err := srw.binlogIO.Upload(ctx, kvs); err != nil {
log.Ctx(ctx).Warn("failed to upload insert log", zap.Error(err))
return nil, err
}
return statFieldLog, nil
}
func (srw *segmentRecordWriter) bm25SerializeWrite() (int64, []*datapb.FieldBinlog, error) {
ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(srw.ctx, "bm25log serializeWrite")
defer span.End()
writer := srw.sw
stats, err := writer.GetBm25StatsBlob()
if err != nil {
return 0, nil, err
}
kvs := make(map[string][]byte)
binlogs := []*datapb.FieldBinlog{}
cnt := int64(0)
for fieldID, blob := range stats {
key, _ := binlog.BuildLogPathWithRootPath(srw.rootPath, storage.BM25Binlog,
writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), fieldID, srw.logID)
kvs[key] = blob.GetValue()
fieldLog := &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []*datapb.Binlog{
{
LogSize: int64(len(blob.GetValue())),
MemorySize: int64(len(blob.GetValue())),
LogPath: key,
EntriesNum: srw.numRows,
},
},
}
binlogs = append(binlogs, fieldLog)
srw.logID++
cnt++
}
if err := srw.binlogIO.Upload(ctx, kvs); err != nil {
log.Ctx(ctx).Warn("failed to upload bm25 log", zap.Error(err))
return 0, nil, err
}
return cnt, binlogs, nil
}
func (st *statsTask) sort(ctx context.Context) ([]*datapb.FieldBinlog, error) {
numRows := st.req.GetNumRows()
bm25FieldIds := compaction.GetBM25FieldIDs(st.req.GetSchema())
writer, err := compaction.NewSegmentWriter(st.req.GetSchema(), numRows, statsBatchSize, st.req.GetTargetSegmentID(), st.req.GetPartitionID(), st.req.GetCollectionID(), bm25FieldIds)
pkField, err := typeutil.GetPrimaryFieldSchema(st.req.GetSchema())
if err != nil {
return nil, err
}
pkFieldID := pkField.FieldID
writer, err := compaction.NewSegmentWriter(st.req.GetSchema(), numRows, statsBatchSize,
st.req.GetTargetSegmentID(), st.req.GetPartitionID(), st.req.GetCollectionID(), bm25FieldIds)
if err != nil {
log.Ctx(ctx).Warn("sort segment wrong, unable to init segment writer",
zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
return nil, err
}
srw := &segmentRecordWriter{
sw: writer,
binlogMaxSize: st.req.GetBinlogMaxSize(),
rootPath: st.req.GetStorageConfig().GetRootPath(),
logID: st.req.StartLogID,
maxLogID: st.req.EndLogID,
binlogIO: st.binlogIO,
ctx: ctx,
numRows: st.req.NumRows,
bm25FieldIds: bm25FieldIds,
}
var (
flushBatchCount int // binlog batch count
allBinlogs = make(map[typeutil.UniqueID]*datapb.FieldBinlog) // All binlog meta of a segment
uploadFutures = make([]*conc.Future[any], 0)
downloadCost time.Duration
serWriteTimeCost time.Duration
sortTimeCost time.Duration
log := log.Ctx(ctx).With(
zap.String("clusterID", st.req.GetClusterID()),
zap.Int64("taskID", st.req.GetTaskID()),
zap.Int64("collectionID", st.req.GetCollectionID()),
zap.Int64("partitionID", st.req.GetPartitionID()),
zap.Int64("segmentID", st.req.GetSegmentID()),
zap.Int64s("bm25Fields", bm25FieldIds),
)
downloadStart := time.Now()
values, err := st.downloadData(ctx, numRows, writer.GetPkID(), bm25FieldIds)
deletePKs, err := st.loadDeltalogs(ctx, st.deltaLogs)
if err != nil {
log.Ctx(ctx).Warn("download data failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
log.Warn("load deletePKs failed", zap.Error(err))
return nil, err
}
downloadCost = time.Since(downloadStart)
sortStart := time.Now()
sort.Slice(values, func(i, j int) bool {
return values[i].PK.LT(values[j].PK)
})
sortTimeCost += time.Since(sortStart)
for i, v := range values {
err := writer.Write(v)
if err != nil {
log.Ctx(ctx).Warn("write value wrong, failed to writer row", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
return nil, err
}
if (i+1)%statsBatchSize == 0 && writer.IsFullWithBinlogMaxSize(st.req.GetBinlogMaxSize()) {
serWriteStart := time.Now()
binlogNum, kvs, partialBinlogs, err := serializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.req.GetStartLogID()+st.logIDOffset, writer)
if err != nil {
log.Ctx(ctx).Warn("stats wrong, failed to serialize writer", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
return nil, err
var isValueValid func(r storage.Record, ri, i int) bool
switch pkField.DataType {
case schemapb.DataType_Int64:
isValueValid = func(r storage.Record, ri, i int) bool {
v := r.Column(pkFieldID).(*array.Int64).Value(i)
deleteTs, ok := deletePKs[v]
ts := uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i))
if ok && ts < deleteTs {
return false
}
serWriteTimeCost += time.Since(serWriteStart)
uploadFutures = append(uploadFutures, st.binlogIO.AsyncUpload(ctx, kvs)...)
mergeFieldBinlogs(allBinlogs, partialBinlogs)
flushBatchCount++
st.logIDOffset += binlogNum
if st.req.GetStartLogID()+st.logIDOffset >= st.req.GetEndLogID() {
log.Ctx(ctx).Warn("binlog files too much, log is not enough", zap.Int64("taskID", st.req.GetTaskID()),
zap.Int64("binlog num", binlogNum), zap.Int64("startLogID", st.req.GetStartLogID()),
zap.Int64("endLogID", st.req.GetEndLogID()), zap.Int64("logIDOffset", st.logIDOffset))
return nil, fmt.Errorf("binlog files too much, log is not enough")
return !st.isExpiredEntity(ts)
}
case schemapb.DataType_VarChar:
isValueValid = func(r storage.Record, ri, i int) bool {
v := r.Column(pkFieldID).(*array.String).Value(i)
deleteTs, ok := deletePKs[v]
ts := uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i))
if ok && ts < deleteTs {
return false
}
return !st.isExpiredEntity(ts)
}
}
if !writer.FlushAndIsEmpty() {
serWriteStart := time.Now()
binlogNum, kvs, partialBinlogs, err := serializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.req.GetStartLogID()+st.logIDOffset, writer)
downloadTimeCost := time.Duration(0)
rrs := make([]storage.RecordReader, len(st.insertLogs))
for i, paths := range st.insertLogs {
log := log.With(zap.Strings("paths", paths))
downloadStart := time.Now()
allValues, err := st.binlogIO.Download(ctx, paths)
if err != nil {
log.Ctx(ctx).Warn("stats wrong, failed to serialize writer", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
log.Warn("download wrong, fail to download insertLogs", zap.Error(err))
return nil, err
}
serWriteTimeCost += time.Since(serWriteStart)
st.logIDOffset += binlogNum
downloadTimeCost += time.Since(downloadStart)
uploadFutures = append(uploadFutures, st.binlogIO.AsyncUpload(ctx, kvs)...)
mergeFieldBinlogs(allBinlogs, partialBinlogs)
flushBatchCount++
blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob {
return &storage.Blob{Key: paths[i], Value: v}
})
rr, err := storage.NewCompositeBinlogRecordReader(blobs)
if err != nil {
log.Warn("downloadData wrong, failed to new insert binlogs reader", zap.Error(err))
return nil, err
}
rrs[i] = rr
}
err = conc.AwaitAll(uploadFutures...)
log.Info("download data success",
zap.Int64("numRows", numRows),
zap.Duration("download binlogs elapse", downloadTimeCost),
)
numValidRows, err := storage.Sort(rrs, writer.GetPkID(), srw, isValueValid)
if err != nil {
log.Ctx(ctx).Warn("stats wrong, failed to upload kvs", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
log.Warn("sort failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
return nil, err
}
if err := srw.Close(); err != nil {
return nil, err
}
serWriteStart := time.Now()
binlogNums, sPath, err := statSerializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.binlogIO, st.req.GetStartLogID()+st.logIDOffset, writer, numRows)
if err != nil {
log.Ctx(ctx).Warn("stats wrong, failed to serialize write segment stats", zap.Int64("taskID", st.req.GetTaskID()),
zap.Int64("remaining row count", numRows), zap.Error(err))
return nil, err
}
serWriteTimeCost += time.Since(serWriteStart)
st.logIDOffset += binlogNums
var bm25StatsLogs []*datapb.FieldBinlog
if len(bm25FieldIds) > 0 {
binlogNums, bm25StatsLogs, err = bm25SerializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.binlogIO, st.req.GetStartLogID()+st.logIDOffset, writer, numRows)
if err != nil {
log.Ctx(ctx).Warn("compact wrong, failed to serialize write segment bm25 stats", zap.Error(err))
return nil, err
}
st.logIDOffset += binlogNums
if err := binlog.CompressFieldBinlogs(bm25StatsLogs); err != nil {
return nil, err
}
}
totalElapse := st.tr.RecordSpan()
insertLogs := lo.Values(allBinlogs)
insertLogs := lo.Values(srw.binlogs)
if err := binlog.CompressFieldBinlogs(insertLogs); err != nil {
return nil, err
}
statsLogs := []*datapb.FieldBinlog{sPath}
statsLogs := []*datapb.FieldBinlog{srw.statslog}
if err := binlog.CompressFieldBinlogs(statsLogs); err != nil {
return nil, err
}
bm25StatsLogs := srw.bm25statslog
if err := binlog.CompressFieldBinlogs(bm25StatsLogs); err != nil {
return nil, err
}
st.node.storePKSortStatsResult(st.req.GetClusterID(),
st.req.GetTaskID(),
st.req.GetCollectionID(),
st.req.GetPartitionID(),
st.req.GetTargetSegmentID(),
st.req.GetInsertChannel(),
int64(len(values)), insertLogs, statsLogs, bm25StatsLogs)
int64(numValidRows), insertLogs, statsLogs, bm25StatsLogs)
log.Ctx(ctx).Info("sort segment end",
log.Info("sort segment end",
zap.String("clusterID", st.req.GetClusterID()),
zap.Int64("taskID", st.req.GetTaskID()),
zap.Int64("collectionID", st.req.GetCollectionID()),
@ -296,12 +467,7 @@ func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, er
zap.String("subTaskType", st.req.GetSubJobType().String()),
zap.Int64("target segmentID", st.req.GetTargetSegmentID()),
zap.Int64("old rows", numRows),
zap.Int("valid rows", len(values)),
zap.Int("binlog batch count", flushBatchCount),
zap.Duration("download elapse", downloadCost),
zap.Duration("sort elapse", sortTimeCost),
zap.Duration("serWrite elapse", serWriteTimeCost),
zap.Duration("total elapse", totalElapse))
zap.Int("valid rows", numValidRows))
return insertLogs, nil
}
@ -313,7 +479,7 @@ func (st *statsTask) Execute(ctx context.Context) error {
insertLogs := st.req.GetInsertLogs()
var err error
if st.req.GetSubJobType() == indexpb.StatsSubJob_Sort {
insertLogs, err = st.sortSegment(ctx)
insertLogs, err = st.sort(ctx)
if err != nil {
return err
}
@ -350,99 +516,6 @@ func (st *statsTask) Reset() {
st.node = nil
}
func (st *statsTask) downloadData(ctx context.Context, numRows int64, PKFieldID int64, bm25FieldIds []int64) ([]*storage.Value, error) {
log := log.Ctx(ctx).With(
zap.String("clusterID", st.req.GetClusterID()),
zap.Int64("taskID", st.req.GetTaskID()),
zap.Int64("collectionID", st.req.GetCollectionID()),
zap.Int64("partitionID", st.req.GetPartitionID()),
zap.Int64("segmentID", st.req.GetSegmentID()),
zap.Int64s("bm25Fields", bm25FieldIds),
)
deletePKs, err := st.loadDeltalogs(ctx, st.deltaLogs)
if err != nil {
log.Warn("load deletePKs failed", zap.Error(err))
return nil, err
}
var (
remainingRowCount int64 // the number of remaining entities
expiredRowCount int64 // the number of expired entities
)
isValueDeleted := func(v *storage.Value) bool {
ts, ok := deletePKs[v.PK.GetValue()]
// insert task and delete task has the same ts when upsert
// here should be < instead of <=
// to avoid the upsert data to be deleted after compact
if ok && uint64(v.Timestamp) < ts {
return true
}
return false
}
downloadTimeCost := time.Duration(0)
values := make([]*storage.Value, 0, numRows)
for _, paths := range st.insertLogs {
log := log.With(zap.Strings("paths", paths))
downloadStart := time.Now()
allValues, err := st.binlogIO.Download(ctx, paths)
if err != nil {
log.Warn("download wrong, fail to download insertLogs", zap.Error(err))
return nil, err
}
downloadTimeCost += time.Since(downloadStart)
blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob {
return &storage.Blob{Key: paths[i], Value: v}
})
iter, err := storage.NewBinlogDeserializeReader(blobs, PKFieldID)
if err != nil {
log.Warn("downloadData wrong, failed to new insert binlogs reader", zap.Error(err))
return nil, err
}
for {
err := iter.Next()
if err != nil {
if err == sio.EOF {
break
} else {
log.Warn("downloadData wrong, failed to iter through data", zap.Error(err))
iter.Close()
return nil, err
}
}
v := iter.Value()
if isValueDeleted(v) {
continue
}
// Filtering expired entity
if st.isExpiredEntity(typeutil.Timestamp(v.Timestamp)) {
expiredRowCount++
continue
}
values = append(values, iter.Value())
remainingRowCount++
}
iter.Close()
}
log.Info("download data success",
zap.Int64("old rows", numRows),
zap.Int64("remainingRowCount", remainingRowCount),
zap.Int64("expiredRowCount", expiredRowCount),
zap.Duration("download binlogs elapse", downloadTimeCost),
)
return values, nil
}
func (st *statsTask) loadDeltalogs(ctx context.Context, dpaths []string) (map[interface{}]typeutil.Timestamp, error) {
st.tr.RecordSpan()
ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "loadDeltalogs")
@ -545,74 +618,6 @@ func serializeWrite(ctx context.Context, rootPath string, startID int64, writer
return
}
func statSerializeWrite(ctx context.Context, rootPath string, io io.BinlogIO, startID int64, writer *compaction.SegmentWriter, finalRowCount int64) (int64, *datapb.FieldBinlog, error) {
ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "statslog serializeWrite")
defer span.End()
sblob, err := writer.Finish()
if err != nil {
return 0, nil, err
}
binlogNum := int64(1)
key, _ := binlog.BuildLogPathWithRootPath(rootPath, storage.StatsBinlog, writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), writer.GetPkID(), startID)
kvs := map[string][]byte{key: sblob.GetValue()}
statFieldLog := &datapb.FieldBinlog{
FieldID: writer.GetPkID(),
Binlogs: []*datapb.Binlog{
{
LogSize: int64(len(sblob.GetValue())),
MemorySize: int64(len(sblob.GetValue())),
LogPath: key,
EntriesNum: finalRowCount,
},
},
}
if err := io.Upload(ctx, kvs); err != nil {
log.Ctx(ctx).Warn("failed to upload insert log", zap.Error(err))
return binlogNum, nil, err
}
return binlogNum, statFieldLog, nil
}
func bm25SerializeWrite(ctx context.Context, rootPath string, io io.BinlogIO, startID int64, writer *compaction.SegmentWriter, finalRowCount int64) (int64, []*datapb.FieldBinlog, error) {
ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "bm25log serializeWrite")
defer span.End()
stats, err := writer.GetBm25StatsBlob()
if err != nil {
return 0, nil, err
}
kvs := make(map[string][]byte)
binlogs := []*datapb.FieldBinlog{}
cnt := int64(0)
for fieldID, blob := range stats {
key, _ := binlog.BuildLogPathWithRootPath(rootPath, storage.BM25Binlog, writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), fieldID, startID+cnt)
kvs[key] = blob.GetValue()
fieldLog := &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []*datapb.Binlog{
{
LogSize: int64(len(blob.GetValue())),
MemorySize: int64(len(blob.GetValue())),
LogPath: key,
EntriesNum: finalRowCount,
},
},
}
binlogs = append(binlogs, fieldLog)
cnt++
}
if err := io.Upload(ctx, kvs); err != nil {
log.Ctx(ctx).Warn("failed to upload bm25 log", zap.Error(err))
return 0, nil, err
}
return cnt, binlogs, nil
}
func ParseStorageConfig(s *indexpb.StorageConfig) (*indexcgopb.StorageConfig, error) {
bs, err := proto.Marshal(s)
if err != nil {

View File

@ -81,26 +81,6 @@ func (s *TaskStatsSuite) GenSegmentWriterWithBM25(magic int64) {
s.segWriter = segWriter
}
func (s *TaskStatsSuite) Testbm25SerializeWriteError() {
s.Run("normal case", func() {
s.schema = genCollectionSchemaWithBM25()
s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil).Once()
s.GenSegmentWriterWithBM25(0)
cnt, binlogs, err := bm25SerializeWrite(context.Background(), "root_path", s.mockBinlogIO, 0, s.segWriter, 1)
s.Require().NoError(err)
s.Equal(int64(1), cnt)
s.Equal(1, len(binlogs))
})
s.Run("upload failed", func() {
s.schema = genCollectionSchemaWithBM25()
s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.GenSegmentWriterWithBM25(0)
_, _, err := bm25SerializeWrite(context.Background(), "root_path", s.mockBinlogIO, 0, s.segWriter, 1)
s.Error(err)
})
}
func (s *TaskStatsSuite) TestSortSegmentWithBM25() {
s.Run("normal case", func() {
s.schema = genCollectionSchemaWithBM25()
@ -130,10 +110,13 @@ func (s *TaskStatsSuite) TestSortSegmentWithBM25() {
InsertLogs: lo.Values(fBinlogs),
Schema: s.schema,
NumRows: 1,
StartLogID: 0,
EndLogID: 5,
BinlogMaxSize: 64 * 1024 * 1024,
}, node, s.mockBinlogIO)
err = task.PreExecute(ctx)
s.Require().NoError(err)
binlog, err := task.sortSegment(ctx)
binlog, err := task.sort(ctx)
s.Require().NoError(err)
s.Equal(5, len(binlog))
@ -174,10 +157,13 @@ func (s *TaskStatsSuite) TestSortSegmentWithBM25() {
InsertLogs: lo.Values(fBinlogs),
Schema: s.schema,
NumRows: 1,
StartLogID: 0,
EndLogID: 5,
BinlogMaxSize: 64 * 1024 * 1024,
}, node, s.mockBinlogIO)
err = task.PreExecute(ctx)
s.Require().NoError(err)
_, err = task.sortSegment(ctx)
_, err = task.sort(ctx)
s.Error(err)
})
}

View File

@ -66,6 +66,10 @@ func generateTestSchema() *schemapb.CollectionSchema {
}
func generateTestData(num int) ([]*Blob, error) {
return generateTestDataWithSeed(1, num)
}
func generateTestDataWithSeed(seed, num int) ([]*Blob, error) {
insertCodec := NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ID: 1, Schema: generateTestSchema()})
var (
@ -92,7 +96,7 @@ func generateTestData(num int) ([]*Blob, error) {
field106 [][]byte
)
for i := 1; i <= num; i++ {
for i := seed; i < seed+num; i++ {
field0 = append(field0, int64(i))
field1 = append(field1, int64(i))
field10 = append(field10, true)

View File

@ -106,6 +106,9 @@ func ReadDescriptorEvent(buffer io.Reader) (*descriptorEvent, error) {
// Close closes the BinlogReader object.
// It mainly calls the Close method of the internal events, reclaims resources, and marks itself as closed.
func (reader *BinlogReader) Close() {
if reader == nil {
return
}
if reader.isClose {
return
}

View File

@ -18,7 +18,6 @@ package storage
import (
"encoding/binary"
"fmt"
"io"
"math"
"sync"
@ -36,24 +35,23 @@ import (
)
type Record interface {
Schema() map[FieldID]schemapb.DataType
ArrowSchema() *arrow.Schema
Column(i FieldID) arrow.Array
Len() int
Release()
Retain()
Slice(start, end int) Record
}
type RecordReader interface {
Next() error
Record() Record
Close()
Close() error
}
type RecordWriter interface {
Write(r Record) error
GetWrittenUncompressed() uint64
Close()
Close() error
}
type (
@ -63,19 +61,19 @@ type (
// compositeRecord is a record being composed of multiple records, in which each only have 1 column
type compositeRecord struct {
recs map[FieldID]arrow.Record
schema map[FieldID]schemapb.DataType
index map[FieldID]int16
recs []arrow.Array
}
var _ Record = (*compositeRecord)(nil)
func (r *compositeRecord) Column(i FieldID) arrow.Array {
return r.recs[i].Column(0)
return r.recs[r.index[i]]
}
func (r *compositeRecord) Len() int {
for _, rec := range r.recs {
return rec.Column(0).Len()
return rec.Len()
}
return 0
}
@ -86,26 +84,21 @@ func (r *compositeRecord) Release() {
}
}
func (r *compositeRecord) Schema() map[FieldID]schemapb.DataType {
return r.schema
}
func (r *compositeRecord) ArrowSchema() *arrow.Schema {
var fields []arrow.Field
func (r *compositeRecord) Retain() {
for _, rec := range r.recs {
fields = append(fields, rec.Schema().Field(0))
rec.Retain()
}
return arrow.NewSchema(fields, nil)
}
func (r *compositeRecord) Slice(start, end int) Record {
slices := make(map[FieldID]arrow.Record)
slices := make([]arrow.Array, len(r.index))
for i, rec := range r.recs {
slices[i] = rec.NewSlice(int64(start), int64(end))
d := array.NewSliceData(rec.Data(), int64(start), int64(end))
slices[i] = array.MakeFromData(d)
}
return &compositeRecord{
recs: slices,
schema: r.schema,
index: r.index,
recs: slices,
}
}
@ -550,28 +543,20 @@ func (deser *DeserializeReader[T]) Next() error {
return nil
}
func (deser *DeserializeReader[T]) NextRecord() (Record, error) {
if len(deser.values) != 0 {
return nil, errors.New("deserialize result is not empty")
}
if err := deser.rr.Next(); err != nil {
return nil, err
}
return deser.rr.Record(), nil
}
func (deser *DeserializeReader[T]) Value() T {
return deser.values[deser.pos]
}
func (deser *DeserializeReader[T]) Close() {
func (deser *DeserializeReader[T]) Close() error {
if deser.rec != nil {
deser.rec.Release()
}
if deser.rr != nil {
deser.rr.Close()
if err := deser.rr.Close(); err != nil {
return err
}
}
return nil
}
func NewDeserializeReader[T any](rr RecordReader, deserializer Deserializer[T]) *DeserializeReader[T] {
@ -585,22 +570,12 @@ var _ Record = (*selectiveRecord)(nil)
// selectiveRecord is a Record that only contains a single field, reusing existing Record.
type selectiveRecord struct {
r Record
selectedFieldId FieldID
schema map[FieldID]schemapb.DataType
}
func (r *selectiveRecord) Schema() map[FieldID]schemapb.DataType {
return r.schema
}
func (r *selectiveRecord) ArrowSchema() *arrow.Schema {
return r.r.ArrowSchema()
r Record
fieldId FieldID
}
func (r *selectiveRecord) Column(i FieldID) arrow.Array {
if i == r.selectedFieldId {
if i == r.fieldId {
return r.r.Column(i)
}
return nil
@ -614,6 +589,10 @@ func (r *selectiveRecord) Release() {
// do nothing.
}
func (r *selectiveRecord) Retain() {
// do nothing
}
func (r *selectiveRecord) Slice(start, end int) Record {
panic("not implemented")
}
@ -664,17 +643,10 @@ func calculateArraySize(a arrow.Array) int {
return totalSize
}
func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord {
dt, ok := r.Schema()[selectedFieldId]
if !ok {
return nil
}
schema := make(map[FieldID]schemapb.DataType, 1)
schema[selectedFieldId] = dt
func newSelectiveRecord(r Record, selectedFieldId FieldID) Record {
return &selectiveRecord{
r: r,
selectedFieldId: selectedFieldId,
schema: schema,
r: r,
fieldId: selectedFieldId,
}
}
@ -682,26 +654,19 @@ var _ RecordWriter = (*CompositeRecordWriter)(nil)
type CompositeRecordWriter struct {
writers map[FieldID]RecordWriter
writtenUncompressed uint64
}
func (crw *CompositeRecordWriter) GetWrittenUncompressed() uint64 {
return crw.writtenUncompressed
s := uint64(0)
for _, w := range crw.writers {
s += w.GetWrittenUncompressed()
}
return s
}
func (crw *CompositeRecordWriter) Write(r Record) error {
if len(r.Schema()) != len(crw.writers) {
return fmt.Errorf("schema length mismatch %d, expected %d", len(r.Schema()), len(crw.writers))
}
var bytes uint64
for fid := range r.Schema() {
arr := r.Column(fid)
bytes += uint64(calculateArraySize(arr))
}
crw.writtenUncompressed += bytes
for fieldId, w := range crw.writers {
// TODO: if field is not exist, write
sr := newSelectiveRecord(r, fieldId)
if err := w.Write(sr); err != nil {
return err
@ -710,14 +675,17 @@ func (crw *CompositeRecordWriter) Write(r Record) error {
return nil
}
func (crw *CompositeRecordWriter) Close() {
func (crw *CompositeRecordWriter) Close() error {
if crw != nil {
for _, w := range crw.writers {
if w != nil {
w.Close()
if err := w.Close(); err != nil {
return err
}
}
}
}
return nil
}
func NewCompositeRecordWriter(writers map[FieldID]RecordWriter) *CompositeRecordWriter {
@ -760,8 +728,8 @@ func (sfw *singleFieldRecordWriter) GetWrittenUncompressed() uint64 {
return sfw.writtenUncompressed
}
func (sfw *singleFieldRecordWriter) Close() {
sfw.fw.Close()
func (sfw *singleFieldRecordWriter) Close() error {
return sfw.fw.Close()
}
func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer, opts ...RecordWriterOptions) (*singleFieldRecordWriter, error) {
@ -811,8 +779,8 @@ func (mfw *multiFieldRecordWriter) GetWrittenUncompressed() uint64 {
return mfw.writtenUncompressed
}
func (mfw *multiFieldRecordWriter) Close() {
mfw.fw.Close()
func (mfw *multiFieldRecordWriter) Close() error {
return mfw.fw.Close()
}
func newMultiFieldRecordWriter(fieldIds []FieldID, fields []arrow.Field, writer io.Writer) (*multiFieldRecordWriter, error) {
@ -910,18 +878,13 @@ func NewSerializeRecordWriter[T any](rw RecordWriter, serializer Serializer[T],
}
type simpleArrowRecord struct {
r arrow.Record
schema map[FieldID]schemapb.DataType
r arrow.Record
field2Col map[FieldID]int
}
var _ Record = (*simpleArrowRecord)(nil)
func (sr *simpleArrowRecord) Schema() map[FieldID]schemapb.DataType {
return sr.schema
}
func (sr *simpleArrowRecord) Column(i FieldID) arrow.Array {
colIdx, ok := sr.field2Col[i]
if !ok {
@ -938,19 +901,22 @@ func (sr *simpleArrowRecord) Release() {
sr.r.Release()
}
func (sr *simpleArrowRecord) Retain() {
sr.r.Retain()
}
func (sr *simpleArrowRecord) ArrowSchema() *arrow.Schema {
return sr.r.Schema()
}
func (sr *simpleArrowRecord) Slice(start, end int) Record {
s := sr.r.NewSlice(int64(start), int64(end))
return newSimpleArrowRecord(s, sr.schema, sr.field2Col)
return newSimpleArrowRecord(s, sr.field2Col)
}
func newSimpleArrowRecord(r arrow.Record, schema map[FieldID]schemapb.DataType, field2Col map[FieldID]int) *simpleArrowRecord {
func newSimpleArrowRecord(r arrow.Record, field2Col map[FieldID]int) *simpleArrowRecord {
return &simpleArrowRecord{
r: r,
schema: schema,
field2Col: field2Col,
}
}

View File

@ -39,39 +39,50 @@ import (
var _ RecordReader = (*CompositeBinlogRecordReader)(nil)
// ChunkedBlobsReader returns a chunk composed of blobs, or io.EOF if no more data
type ChunkedBlobsReader func() ([]*Blob, error)
type CompositeBinlogRecordReader struct {
blobs [][]*Blob
BlobsReader ChunkedBlobsReader
blobPos int
rrs []array.RecordReader
closers []func()
fields []FieldID
brs []*BinlogReader
rrs []array.RecordReader
r compositeRecord
schema map[FieldID]schemapb.DataType
index map[FieldID]int16
r *compositeRecord
}
func (crr *CompositeBinlogRecordReader) iterateNextBatch() error {
if crr.closers != nil {
for _, close := range crr.closers {
if close != nil {
close()
}
if crr.brs != nil {
for _, er := range crr.brs {
er.Close()
}
for _, rr := range crr.rrs {
rr.Release()
}
}
crr.blobPos++
if crr.blobPos >= len(crr.blobs[0]) {
return io.EOF
blobs, err := crr.BlobsReader()
if err != nil {
return err
}
for i, b := range crr.blobs {
reader, err := NewBinlogReader(b[crr.blobPos].Value)
if crr.rrs == nil {
crr.rrs = make([]array.RecordReader, len(blobs))
crr.brs = make([]*BinlogReader, len(blobs))
crr.schema = make(map[FieldID]schemapb.DataType)
crr.index = make(map[FieldID]int16, len(blobs))
}
for i, b := range blobs {
reader, err := NewBinlogReader(b.Value)
if err != nil {
return err
}
crr.fields[i] = reader.FieldID
// TODO: assert schema being the same in every blobs
crr.r.schema[reader.FieldID] = reader.PayloadDataType
crr.schema[reader.FieldID] = reader.PayloadDataType
er, err := reader.NextEventReader()
if err != nil {
return err
@ -81,40 +92,30 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error {
return err
}
crr.rrs[i] = rr
crr.closers[i] = func() {
rr.Release()
er.Close()
reader.Close()
}
crr.index[reader.FieldID] = int16(i)
crr.brs[i] = reader
}
return nil
}
func (crr *CompositeBinlogRecordReader) Next() error {
if crr.rrs == nil {
if crr.blobs == nil || len(crr.blobs) == 0 {
return io.EOF
}
crr.rrs = make([]array.RecordReader, len(crr.blobs))
crr.closers = make([]func(), len(crr.blobs))
crr.blobPos = -1
crr.fields = make([]FieldID, len(crr.rrs))
crr.r = compositeRecord{
recs: make(map[FieldID]arrow.Record, len(crr.rrs)),
schema: make(map[FieldID]schemapb.DataType, len(crr.rrs)),
}
if err := crr.iterateNextBatch(); err != nil {
return err
}
}
composeRecord := func() bool {
recs := make([]arrow.Array, len(crr.rrs))
for i, rr := range crr.rrs {
if ok := rr.Next(); !ok {
return false
}
// compose record
crr.r.recs[crr.fields[i]] = rr.Record()
recs[i] = rr.Record().Column(0)
}
crr.r = &compositeRecord{
index: crr.index,
recs: recs,
}
return true
}
@ -135,15 +136,26 @@ func (crr *CompositeBinlogRecordReader) Next() error {
}
func (crr *CompositeBinlogRecordReader) Record() Record {
return &crr.r
return crr.r
}
func (crr *CompositeBinlogRecordReader) Close() {
for _, close := range crr.closers {
if close != nil {
close()
func (crr *CompositeBinlogRecordReader) Close() error {
if crr.brs != nil {
for _, er := range crr.brs {
if er != nil {
er.Close()
}
}
}
if crr.rrs != nil {
for _, rr := range crr.rrs {
if rr != nil {
rr.Release()
}
}
}
crr.r = nil
return nil
}
func parseBlobKey(blobKey string) (colId FieldID, logId UniqueID) {
@ -177,8 +189,19 @@ func NewCompositeBinlogRecordReader(blobs []*Blob) (*CompositeBinlogRecordReader
})
sortedBlobs = append(sortedBlobs, blobsForField)
}
chunkPos := 0
return &CompositeBinlogRecordReader{
blobs: sortedBlobs,
BlobsReader: func() ([]*Blob, error) {
if len(sortedBlobs) == 0 || chunkPos >= len(sortedBlobs[0]) {
return nil, io.EOF
}
blobs := make([]*Blob, len(sortedBlobs))
for fieldPos := range blobs {
blobs[fieldPos] = sortedBlobs[fieldPos][chunkPos]
}
chunkPos++
return blobs, nil
},
}, nil
}
@ -189,17 +212,18 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize
}
return NewDeserializeReader(reader, func(r Record, v []*Value) error {
schema := reader.schema
// Note: the return value `Value` is reused.
for i := 0; i < r.Len(); i++ {
value := v[i]
if value == nil {
value = &Value{}
value.Value = make(map[FieldID]interface{}, len(r.Schema()))
value.Value = make(map[FieldID]interface{}, len(schema))
v[i] = value
}
m := value.Value.(map[FieldID]interface{})
for j, dt := range r.Schema() {
for j, dt := range schema {
if r.Column(j).IsNull(i) {
m[j] = nil
} else {
@ -219,7 +243,7 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize
value.ID = rowID
value.Timestamp = m[common.TimeStampField].(int64)
pk, err := GenPrimaryKeyByRawData(m[PKfieldID], r.Schema()[PKfieldID])
pk, err := GenPrimaryKeyByRawData(m[PKfieldID], schema[PKfieldID])
if err != nil {
return err
}
@ -239,7 +263,7 @@ func newDeltalogOneFieldReader(blobs []*Blob) (*DeserializeReader[*DeleteLog], e
}
return NewDeserializeReader(reader, func(r Record, v []*DeleteLog) error {
var fid FieldID // The only fid from delete file
for k := range r.Schema() {
for k := range reader.schema {
fid = k
break
}
@ -391,7 +415,7 @@ func ValueSerializer(v []*Value, fieldSchema []*schemapb.FieldSchema) (Record, e
field2Col[fid] = i
i++
}
return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), types, field2Col), nil
return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), field2Col), nil
}
func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, segmentID UniqueID,
@ -529,10 +553,7 @@ func newDeltalogSerializeWriter(eventWriter *DeltalogStreamWriter, batchSize int
field2Col := map[FieldID]int{
0: 0,
}
schema := map[FieldID]schemapb.DataType{
0: schemapb.DataType_String,
}
return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(field, nil), arr, int64(len(v))), schema, field2Col), nil
return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(field, nil), arr, int64(len(v))), field2Col), nil
}, batchSize), nil
}
@ -583,12 +604,11 @@ func (crr *simpleArrowRecordReader) iterateNextBatch() error {
func (crr *simpleArrowRecordReader) Next() error {
if crr.rr == nil {
if crr.blobs == nil || len(crr.blobs) == 0 {
if len(crr.blobs) == 0 {
return io.EOF
}
crr.blobPos = -1
crr.r = simpleArrowRecord{
schema: make(map[FieldID]schemapb.DataType),
field2Col: make(map[FieldID]int),
}
if err := crr.iterateNextBatch(); err != nil {
@ -623,10 +643,11 @@ func (crr *simpleArrowRecordReader) Record() Record {
return &crr.r
}
func (crr *simpleArrowRecordReader) Close() {
func (crr *simpleArrowRecordReader) Close() error {
if crr.closer != nil {
crr.closer()
}
return nil
}
func newSimpleArrowRecordReader(blobs []*Blob) (*simpleArrowRecordReader, error) {
@ -781,11 +802,7 @@ func newDeltalogMultiFieldWriter(eventWriter *MultiFieldDeltalogStreamWriter, ba
common.RowIDField: 0,
common.TimeStampField: 1,
}
schema := map[FieldID]schemapb.DataType{
common.RowIDField: pkType,
common.TimeStampField: schemapb.DataType_Int64,
}
return newSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), schema, field2Col), nil
return newSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), field2Col), nil
}, batchSize), nil
}

View File

@ -92,7 +92,7 @@ func TestBinlogStreamWriter(t *testing.T) {
[]arrow.Array{arr},
int64(size),
)
r := newSimpleArrowRecord(ar, map[FieldID]schemapb.DataType{1: schemapb.DataType_Bool}, map[FieldID]int{1: 0})
r := newSimpleArrowRecord(ar, map[FieldID]int{1: 0})
defer r.Release()
err = rw.Write(r)
assert.NoError(t, err)

View File

@ -30,6 +30,25 @@ import (
"github.com/milvus-io/milvus/pkg/common"
)
type MockRecordWriter struct {
writefn func(Record) error
closefn func() error
}
var _ RecordWriter = (*MockRecordWriter)(nil)
func (w *MockRecordWriter) Write(record Record) error {
return w.writefn(record)
}
func (w *MockRecordWriter) Close() error {
return w.closefn()
}
func (w *MockRecordWriter) GetWrittenUncompressed() uint64 {
return 0
}
func TestSerDe(t *testing.T) {
type args struct {
dt schemapb.DataType
@ -101,37 +120,6 @@ func TestSerDe(t *testing.T) {
}
}
func TestArrowSchema(t *testing.T) {
fields := []arrow.Field{{Name: "1", Type: arrow.BinaryTypes.String, Nullable: true}}
builder := array.NewBuilder(memory.DefaultAllocator, arrow.BinaryTypes.String)
builder.AppendValueFromString("1")
record := array.NewRecord(arrow.NewSchema(fields, nil), []arrow.Array{builder.NewArray()}, 1)
t.Run("test composite record", func(t *testing.T) {
cr := &compositeRecord{
recs: make(map[FieldID]arrow.Record, 1),
schema: make(map[FieldID]schemapb.DataType, 1),
}
cr.recs[0] = record
cr.schema[0] = schemapb.DataType_String
expected := arrow.NewSchema(fields, nil)
assert.Equal(t, expected, cr.ArrowSchema())
})
t.Run("test simple arrow record", func(t *testing.T) {
cr := &simpleArrowRecord{
r: record,
schema: make(map[FieldID]schemapb.DataType, 1),
field2Col: make(map[FieldID]int, 1),
}
cr.schema[0] = schemapb.DataType_String
expected := arrow.NewSchema(fields, nil)
assert.Equal(t, expected, cr.ArrowSchema())
sr := newSelectiveRecord(cr, 0)
assert.Equal(t, expected, sr.ArrowSchema())
})
}
func BenchmarkDeserializeReader(b *testing.B) {
len := 1000000
blobs, err := generateTestData(len)

241
internal/storage/sort.go Normal file
View File

@ -0,0 +1,241 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"container/heap"
"io"
"sort"
"github.com/apache/arrow/go/v12/arrow/array"
)
func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r Record, ri, i int) bool) (numRows int, err error) {
records := make([]Record, 0)
type index struct {
ri int
i int
}
indices := make([]*index, 0)
defer func() {
for _, r := range records {
r.Release()
}
}()
for _, r := range rr {
for {
err := r.Next()
if err == nil {
rec := r.Record()
rec.Retain()
ri := len(records)
records = append(records, rec)
for i := 0; i < rec.Len(); i++ {
if predicate(rec, ri, i) {
indices = append(indices, &index{ri, i})
}
}
} else if err == io.EOF {
break
} else {
return 0, err
}
}
}
if len(records) == 0 {
return 0, nil
}
switch records[0].Column(pkField).(type) {
case *array.Int64:
sort.Slice(indices, func(i, j int) bool {
pki := records[indices[i].ri].Column(pkField).(*array.Int64).Value(indices[i].i)
pkj := records[indices[j].ri].Column(pkField).(*array.Int64).Value(indices[j].i)
return pki < pkj
})
case *array.String:
sort.Slice(indices, func(i, j int) bool {
pki := records[indices[i].ri].Column(pkField).(*array.String).Value(indices[i].i)
pkj := records[indices[j].ri].Column(pkField).(*array.String).Value(indices[j].i)
return pki < pkj
})
}
writeOne := func(i *index) error {
rec := records[i.ri].Slice(i.i, i.i+1)
defer rec.Release()
return rw.Write(rec)
}
for _, i := range indices {
numRows++
writeOne(i)
}
return numRows, nil
}
// A PriorityQueue implements heap.Interface and holds Items.
type PriorityQueue[T any] struct {
items []*T
less func(x, y *T) bool
}
var _ heap.Interface = (*PriorityQueue[any])(nil)
func (pq PriorityQueue[T]) Len() int { return len(pq.items) }
func (pq PriorityQueue[T]) Less(i, j int) bool {
return pq.less(pq.items[i], pq.items[j])
}
func (pq PriorityQueue[T]) Swap(i, j int) {
pq.items[i], pq.items[j] = pq.items[j], pq.items[i]
}
func (pq *PriorityQueue[T]) Push(x any) {
pq.items = append(pq.items, x.(*T))
}
func (pq *PriorityQueue[T]) Pop() any {
old := pq.items
n := len(old)
x := old[n-1]
old[n-1] = nil
pq.items = old[0 : n-1]
return x
}
func (pq *PriorityQueue[T]) Enqueue(x *T) {
heap.Push(pq, x)
}
func (pq *PriorityQueue[T]) Dequeue() *T {
return heap.Pop(pq).(*T)
}
func NewPriorityQueue[T any](less func(x, y *T) bool) *PriorityQueue[T] {
pq := PriorityQueue[T]{
items: make([]*T, 0),
less: less,
}
heap.Init(&pq)
return &pq
}
func MergeSort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r Record, ri, i int) bool) (numRows int, err error) {
type index struct {
ri int
i int
}
advanceRecord := func(r RecordReader) (Record, error) {
err := r.Next()
if err != nil {
return nil, err
}
return r.Record(), nil
}
recs := make([]Record, len(rr))
for i, r := range rr {
rec, err := advanceRecord(r)
if err == io.EOF {
recs[i] = nil
continue
}
if err != nil {
return 0, err
}
recs[i] = rec
}
var pq *PriorityQueue[index]
switch recs[0].Column(pkField).(type) {
case *array.Int64:
pq = NewPriorityQueue[index](func(x, y *index) bool {
return rr[x.ri].Record().Column(pkField).(*array.Int64).Value(x.i) < rr[y.ri].Record().Column(pkField).(*array.Int64).Value(y.i)
})
case *array.String:
pq = NewPriorityQueue[index](func(x, y *index) bool {
return rr[x.ri].Record().Column(pkField).(*array.String).Value(x.i) < rr[y.ri].Record().Column(pkField).(*array.String).Value(y.i)
})
}
enqueueAll := func(ri int, r Record) {
for j := 0; j < r.Len(); j++ {
if predicate(r, ri, j) {
pq.Enqueue(&index{
ri: ri,
i: j,
})
numRows++
}
}
}
for i, v := range recs {
if v != nil {
enqueueAll(i, v)
}
}
ri, istart, iend := -1, -1, -1
for pq.Len() > 0 {
idx := pq.Dequeue()
if ri == idx.ri {
// record end of cache, do nothing
iend = idx.i + 1
} else {
if ri != -1 {
// record changed, write old one and reset
sr := rr[ri].Record().Slice(istart, iend)
err := rw.Write(sr)
sr.Release()
if err != nil {
return 0, err
}
}
ri = idx.ri
istart = idx.i
iend = idx.i + 1
}
// If poped idx reaches end of segment, invalidate cache and advance to next segment
if idx.i == rr[idx.ri].Record().Len()-1 {
sr := rr[ri].Record().Slice(istart, iend)
err := rw.Write(sr)
sr.Release()
if err != nil {
return 0, err
}
ri, istart, iend = -1, -1, -1
rec, err := advanceRecord(rr[idx.ri])
if err == io.EOF {
continue
}
if err != nil {
return 0, err
}
enqueueAll(idx.ri, rec)
}
}
return numRows, nil
}

View File

@ -0,0 +1,162 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"testing"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/common"
)
func TestSort(t *testing.T) {
getReaders := func() []RecordReader {
blobs, err := generateTestDataWithSeed(10, 3)
assert.NoError(t, err)
reader10, err := NewCompositeBinlogRecordReader(blobs)
assert.NoError(t, err)
blobs, err = generateTestDataWithSeed(20, 3)
assert.NoError(t, err)
reader20, err := NewCompositeBinlogRecordReader(blobs)
assert.NoError(t, err)
rr := []RecordReader{reader20, reader10}
return rr
}
lastPK := int64(-1)
rw := &MockRecordWriter{
writefn: func(r Record) error {
pk := r.Column(common.RowIDField).(*array.Int64).Value(0)
assert.Greater(t, pk, lastPK)
lastPK = pk
return nil
},
closefn: func() error {
lastPK = int64(-1)
return nil
},
}
t.Run("sort", func(t *testing.T) {
gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool {
return true
})
assert.NoError(t, err)
assert.Equal(t, 6, gotNumRows)
err = rw.Close()
assert.NoError(t, err)
})
t.Run("sort with predicate", func(t *testing.T) {
gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool {
pk := r.Column(common.RowIDField).(*array.Int64).Value(i)
return pk >= 20
})
assert.NoError(t, err)
assert.Equal(t, 3, gotNumRows)
err = rw.Close()
assert.NoError(t, err)
})
}
func TestMergeSort(t *testing.T) {
getReaders := func() []RecordReader {
blobs, err := generateTestDataWithSeed(10, 3)
assert.NoError(t, err)
reader10, err := NewCompositeBinlogRecordReader(blobs)
assert.NoError(t, err)
blobs, err = generateTestDataWithSeed(20, 3)
assert.NoError(t, err)
reader20, err := NewCompositeBinlogRecordReader(blobs)
assert.NoError(t, err)
rr := []RecordReader{reader20, reader10}
return rr
}
lastPK := int64(-1)
rw := &MockRecordWriter{
writefn: func(r Record) error {
pk := r.Column(common.RowIDField).(*array.Int64).Value(0)
assert.Greater(t, pk, lastPK)
lastPK = pk
return nil
},
closefn: func() error {
lastPK = int64(-1)
return nil
},
}
t.Run("merge sort", func(t *testing.T) {
gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool {
return true
})
assert.NoError(t, err)
assert.Equal(t, 6, gotNumRows)
err = rw.Close()
assert.NoError(t, err)
})
t.Run("merge sort with predicate", func(t *testing.T) {
gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool {
pk := r.Column(common.RowIDField).(*array.Int64).Value(i)
return pk >= 20
})
assert.NoError(t, err)
assert.Equal(t, 3, gotNumRows)
err = rw.Close()
assert.NoError(t, err)
})
}
// Benchmark sort
func BenchmarkSort(b *testing.B) {
batch := 500000
blobs, err := generateTestDataWithSeed(batch, batch)
assert.NoError(b, err)
reader10, err := NewCompositeBinlogRecordReader(blobs)
assert.NoError(b, err)
blobs, err = generateTestDataWithSeed(batch*2+1, batch)
assert.NoError(b, err)
reader20, err := NewCompositeBinlogRecordReader(blobs)
assert.NoError(b, err)
rr := []RecordReader{reader20, reader10}
rw := &MockRecordWriter{
writefn: func(r Record) error {
return nil
},
closefn: func() error {
return nil
},
}
b.ResetTimer()
b.Run("sort", func(b *testing.B) {
for i := 0; i < b.N; i++ {
Sort(rr, common.RowIDField, rw, func(r Record, ri, i int) bool {
return true
})
}
})
}