feat: support bm25 logs mixcompaction (#36072)

relate: https://github.com/milvus-io/milvus/issues/35853

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
pull/36848/head
aoiasd 2024-10-14 16:57:22 +08:00 committed by GitHub
parent c96bbe19ba
commit 5ec4163d0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 938 additions and 32 deletions

View File

@ -17,11 +17,11 @@
package datacoord
import (
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)

View File

@ -1532,6 +1532,7 @@ func (m *meta) completeMixCompactionMutation(t *datapb.CompactionTask, result *d
Binlogs: compactToSegment.GetInsertLogs(),
Statslogs: compactToSegment.GetField2StatslogPaths(),
Deltalogs: compactToSegment.GetDeltalogs(),
Bm25Statslogs: compactToSegment.GetBm25Logs(),
CreatedByCompaction: true,
CompactionFrom: compactFromSegIDs,
@ -1985,6 +1986,7 @@ func (m *meta) SaveStatsResultSegment(oldSegmentID int64, result *workerpb.Stats
MaxRowNum: cloned.GetMaxRowNum(),
Binlogs: result.GetInsertLogs(),
Statslogs: result.GetStatsLogs(),
Bm25Statslogs: result.GetBm25Logs(),
TextStatsLogs: result.GetTextStatsLogs(),
CreatedByCompaction: true,
CompactionFrom: []int64{oldSegmentID},

View File

@ -178,7 +178,9 @@ func (st *statsTask) PreCheck(ctx context.Context, dependency *taskScheduler) bo
return false
}
start, end, err := dependency.allocator.AllocN(segment.getSegmentSize() / Params.DataNodeCfg.BinLogMaxSize.GetAsInt64() * int64(len(collInfo.Schema.GetFields())) * 2)
binlogNum := (segment.getSegmentSize()/Params.DataNodeCfg.BinLogMaxSize.GetAsInt64() + 1) * int64(len(collInfo.Schema.GetFields())) * 100
// binlogNum + BM25logNum + statslogNum
start, end, err := dependency.allocator.AllocN(binlogNum + int64(len(collInfo.Schema.GetFunctions())) + 1)
if err != nil {
log.Warn("stats task alloc logID failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err))
st.SetState(indexpb.JobState_JobStateInit, err.Error())

View File

@ -40,6 +40,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/metastore/kv/binlog"
"github.com/milvus-io/milvus/internal/proto/clusteringpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
@ -101,6 +102,8 @@ type clusteringCompactionTask struct {
// vector
segmentIDOffsetMapping map[int64]string
offsetToBufferFunc func(int64, []uint32) *ClusterBuffer
// bm25
bm25FieldIds []int64
}
type ClusterBuffer struct {
@ -115,6 +118,8 @@ type ClusterBuffer struct {
currentSegmentRowNum atomic.Int64
// segID -> fieldID -> binlogs
flushedBinlogs map[typeutil.UniqueID]map[typeutil.UniqueID]*datapb.FieldBinlog
// segID -> fieldID -> binlogs
flushedBM25stats map[typeutil.UniqueID]map[int64]*storage.BM25Stats
uploadedSegments []*datapb.CompactionSegment
uploadedSegmentStats map[typeutil.UniqueID]storage.SegmentStats
@ -205,6 +210,13 @@ func (t *clusteringCompactionTask) init() error {
t.clusteringKeyField = field
}
}
for _, function := range t.plan.Schema.Functions {
if function.GetType() == schemapb.FunctionType_BM25 {
t.bm25FieldIds = append(t.bm25FieldIds, function.GetOutputFieldIds()[0])
}
}
t.primaryKeyField = pkField
t.isVectorClusteringKey = typeutil.IsVectorType(t.clusteringKeyField.DataType)
t.currentTs = tsoutil.GetCurrentTime()
@ -310,6 +322,7 @@ func (t *clusteringCompactionTask) getScalarAnalyzeResult(ctx context.Context) e
id: id,
flushedRowNum: map[typeutil.UniqueID]atomic.Int64{},
flushedBinlogs: make(map[typeutil.UniqueID]map[typeutil.UniqueID]*datapb.FieldBinlog, 0),
flushedBM25stats: make(map[int64]map[int64]*storage.BM25Stats, 0),
uploadedSegments: make([]*datapb.CompactionSegment, 0),
uploadedSegmentStats: make(map[typeutil.UniqueID]storage.SegmentStats, 0),
clusteringKeyFieldStats: fieldStats,
@ -461,6 +474,7 @@ func (t *clusteringCompactionTask) mapping(ctx context.Context,
Field2StatslogPaths: seg.GetField2StatslogPaths(),
Deltalogs: seg.GetDeltalogs(),
Channel: seg.GetChannel(),
Bm25Logs: seg.GetBm25Logs(),
}
log.Debug("put segment into final compaction result", zap.String("segment", se.String()))
resultSegments = append(resultSegments, se)
@ -566,6 +580,7 @@ func (t *clusteringCompactionTask) mappingSegment(
blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob {
return &storage.Blob{Key: paths[i], Value: v}
})
pkIter, err := storage.NewBinlogDeserializeReader(blobs, t.primaryKeyField.GetFieldID())
if err != nil {
log.Warn("new insert binlogs Itr wrong", zap.Strings("paths", paths), zap.Error(err))
@ -892,6 +907,15 @@ func (t *clusteringCompactionTask) packBufferToSegment(ctx context.Context, buff
Field2StatslogPaths: []*datapb.FieldBinlog{statsLogs},
Channel: t.plan.GetChannel(),
}
if len(t.bm25FieldIds) > 0 {
bm25Logs, err := t.generateBM25Stats(ctx, segmentID, buffer.flushedBM25stats[segmentID])
if err != nil {
return err
}
seg.Bm25Logs = bm25Logs
}
buffer.uploadedSegments = append(buffer.uploadedSegments, seg)
segmentStats := storage.SegmentStats{
FieldStats: []storage.FieldStats{buffer.clusteringKeyFieldStats.Clone()},
@ -914,6 +938,10 @@ func (t *clusteringCompactionTask) packBufferToSegment(ctx context.Context, buff
// clear segment binlogs cache
delete(buffer.flushedBinlogs, segmentID)
if len(t.bm25FieldIds) > 0 {
delete(buffer.flushedBM25stats, segmentID)
}
return nil
}
@ -966,6 +994,23 @@ func (t *clusteringCompactionTask) flushBinlog(ctx context.Context, buffer *Clus
buffer.flushedBinlogs[segmentID] = make(map[typeutil.UniqueID]*datapb.FieldBinlog)
}
// if has bm25 failed, cache bm25 stats
if len(t.bm25FieldIds) > 0 {
statsMap, ok := buffer.flushedBM25stats[segmentID]
if !ok || statsMap == nil {
buffer.flushedBM25stats[segmentID] = make(map[int64]*storage.BM25Stats)
statsMap = buffer.flushedBM25stats[segmentID]
}
for fieldID, newstats := range writer.GetBm25Stats() {
if stats, ok := statsMap[fieldID]; ok {
stats.Merge(newstats)
} else {
statsMap[fieldID] = newstats
}
}
}
for fID, path := range partialBinlogs {
tmpBinlog, ok := buffer.flushedBinlogs[segmentID][fID]
if !ok {
@ -1230,7 +1275,7 @@ func (t *clusteringCompactionTask) refreshBufferWriterWithPack(buffer *ClusterBu
buffer.currentSegmentRowNum.Store(0)
}
writer, err := NewSegmentWriter(t.plan.GetSchema(), t.plan.MaxSegmentRows, segmentID, t.partitionID, t.collectionID)
writer, err := NewSegmentWriter(t.plan.GetSchema(), t.plan.MaxSegmentRows, segmentID, t.partitionID, t.collectionID, t.bm25FieldIds)
if err != nil {
return pack, err
}
@ -1245,7 +1290,7 @@ func (t *clusteringCompactionTask) refreshBufferWriter(buffer *ClusterBuffer) er
segmentID = buffer.writer.Load().(*SegmentWriter).GetSegmentID()
buffer.bufferMemorySize.Add(int64(buffer.writer.Load().(*SegmentWriter).WrittenMemorySize()))
writer, err := NewSegmentWriter(t.plan.GetSchema(), t.plan.MaxSegmentRows, segmentID, t.partitionID, t.collectionID)
writer, err := NewSegmentWriter(t.plan.GetSchema(), t.plan.MaxSegmentRows, segmentID, t.partitionID, t.collectionID, t.bm25FieldIds)
if err != nil {
return err
}
@ -1270,6 +1315,48 @@ func (t *clusteringCompactionTask) checkBuffersAfterCompaction() error {
return nil
}
func (t *clusteringCompactionTask) generateBM25Stats(ctx context.Context, segmentID int64, statsMap map[int64]*storage.BM25Stats) ([]*datapb.FieldBinlog, error) {
binlogs := []*datapb.FieldBinlog{}
kvs := map[string][]byte{}
logID, _, err := t.logIDAlloc.Alloc(uint32(len(statsMap)))
if err != nil {
return nil, err
}
for fieldID, stats := range statsMap {
key, _ := binlog.BuildLogPath(storage.BM25Binlog, t.collectionID, t.partitionID, segmentID, fieldID, logID)
bytes, err := stats.Serialize()
if err != nil {
log.Warn("failed to seralize bm25 stats", zap.Int64("collection", t.collectionID),
zap.Int64("partition", t.partitionID), zap.Int64("segment", segmentID), zap.Error(err))
return nil, err
}
kvs[key] = bytes
binlogs = append(binlogs, &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []*datapb.Binlog{{
LogSize: int64(len(bytes)),
MemorySize: int64(len(bytes)),
LogPath: key,
EntriesNum: stats.NumRow(),
}},
})
logID++
}
if err := t.binlogIO.Upload(ctx, kvs); err != nil {
log.Warn("failed to upload bm25 stats log",
zap.Int64("collection", t.collectionID),
zap.Int64("partition", t.partitionID),
zap.Int64("segment", segmentID),
zap.Error(err))
return nil, err
}
return binlogs, nil
}
func (t *clusteringCompactionTask) generatePkStats(ctx context.Context, segmentID int64,
numRows int64, binlogPaths [][]string,
) (*datapb.FieldBinlog, error) {

View File

@ -51,7 +51,6 @@ type ClusteringCompactionTaskSuite struct {
mockBinlogIO *io.MockBinlogIO
mockAlloc *allocator.MockAllocator
mockID atomic.Int64
segWriter *SegmentWriter
task *clusteringCompactionTask
@ -172,7 +171,7 @@ func (s *ClusteringCompactionTaskSuite) TestCompactionInit() {
func (s *ClusteringCompactionTaskSuite) TestScalarCompactionNormal() {
schema := genCollectionSchema()
var segmentID int64 = 1001
segWriter, err := NewSegmentWriter(schema, 1000, segmentID, PartitionID, CollectionID)
segWriter, err := NewSegmentWriter(schema, 1000, segmentID, PartitionID, CollectionID, []int64{})
s.Require().NoError(err)
for i := 0; i < 10240; i++ {
v := storage.Value{
@ -186,6 +185,7 @@ func (s *ClusteringCompactionTaskSuite) TestScalarCompactionNormal() {
segWriter.FlushAndIsFull()
kvs, fBinlogs, err := serializeWrite(context.TODO(), s.mockAlloc, segWriter)
s.NoError(err)
s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.Anything).Return(lo.Values(kvs), nil)
s.plan.SegmentBinlogs = []*datapb.CompactionSegmentBinlogs{
@ -237,6 +237,89 @@ func (s *ClusteringCompactionTaskSuite) TestScalarCompactionNormal() {
s.Equal(totalRowNum, statsRowNum)
}
func (s *ClusteringCompactionTaskSuite) TestCompactionWithBM25Function() {
schema := genCollectionSchemaWithBM25()
var segmentID int64 = 1001
segWriter, err := NewSegmentWriter(schema, 1000, segmentID, PartitionID, CollectionID, []int64{102})
s.Require().NoError(err)
for i := 0; i < 10240; i++ {
v := storage.Value{
PK: storage.NewInt64PrimaryKey(int64(i)),
Timestamp: int64(tsoutil.ComposeTSByTime(getMilvusBirthday(), 0)),
Value: genRowWithBM25(int64(i)),
}
err = segWriter.Write(&v)
s.Require().NoError(err)
}
segWriter.FlushAndIsFull()
kvs, fBinlogs, err := serializeWrite(context.TODO(), s.mockAlloc, segWriter)
s.NoError(err)
s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.Anything).Return(lo.Values(kvs), nil)
s.plan.SegmentBinlogs = []*datapb.CompactionSegmentBinlogs{
{
SegmentID: segmentID,
FieldBinlogs: lo.Values(fBinlogs),
},
}
s.task.bm25FieldIds = []int64{102}
s.task.plan.Schema = schema
s.task.plan.ClusteringKeyField = 100
s.task.plan.PreferSegmentRows = 2048
s.task.plan.MaxSegmentRows = 2048
s.task.plan.PreAllocatedSegmentIDs = &datapb.IDRange{
Begin: time.Now().UnixMilli(),
End: time.Now().UnixMilli() + 1000,
}
// 8 + 8 + 8 + 7 + 8 = 39
// 39*1024 = 39936
// writer will automatically flush after 1024 rows.
paramtable.Get().Save(paramtable.Get().DataNodeCfg.BinLogMaxSize.Key, "39935")
defer paramtable.Get().Reset(paramtable.Get().DataNodeCfg.BinLogMaxSize.Key)
compactionResult, err := s.task.Compact()
s.Require().NoError(err)
s.Equal(5, len(s.task.clusterBuffers))
s.Equal(5, len(compactionResult.GetSegments()))
totalBinlogNum := 0
totalRowNum := int64(0)
for _, fb := range compactionResult.GetSegments()[0].GetInsertLogs() {
for _, b := range fb.GetBinlogs() {
totalBinlogNum++
if fb.GetFieldID() == 100 {
totalRowNum += b.GetEntriesNum()
}
}
}
statsBinlogNum := 0
statsRowNum := int64(0)
for _, sb := range compactionResult.GetSegments()[0].GetField2StatslogPaths() {
for _, b := range sb.GetBinlogs() {
statsBinlogNum++
statsRowNum += b.GetEntriesNum()
}
}
s.Equal(2, totalBinlogNum/len(schema.GetFields()))
s.Equal(1, statsBinlogNum)
s.Equal(totalRowNum, statsRowNum)
bm25BinlogNum := 0
bm25RowNum := int64(0)
for _, bmb := range compactionResult.GetSegments()[0].GetBm25Logs() {
for _, b := range bmb.GetBinlogs() {
bm25BinlogNum++
bm25RowNum += b.GetEntriesNum()
}
}
s.Equal(1, bm25BinlogNum)
s.Equal(totalRowNum, bm25RowNum)
}
func (s *ClusteringCompactionTaskSuite) TestCheckBuffersAfterCompaction() {
s.Run("no leak", func() {
task := &clusteringCompactionTask{clusterBuffers: []*ClusterBuffer{{}}}
@ -263,6 +346,71 @@ func (s *ClusteringCompactionTaskSuite) TestCheckBuffersAfterCompaction() {
})
}
func (s *ClusteringCompactionTaskSuite) TestGenerateBM25Stats() {
s.Run("normal case", func() {
segmentID := int64(1)
task := &clusteringCompactionTask{
collectionID: 111,
partitionID: 222,
bm25FieldIds: []int64{102},
logIDAlloc: s.mockAlloc,
binlogIO: s.mockBinlogIO,
}
statsMap := make(map[int64]*storage.BM25Stats)
statsMap[102] = storage.NewBM25Stats()
statsMap[102].Append(map[uint32]float32{1: 1})
binlogs, err := task.generateBM25Stats(context.Background(), segmentID, statsMap)
s.NoError(err)
s.Equal(1, len(binlogs))
s.Equal(1, len(binlogs[0].Binlogs))
s.Equal(int64(102), binlogs[0].FieldID)
s.Equal(int64(1), binlogs[0].Binlogs[0].GetEntriesNum())
})
s.Run("alloc ID failed", func() {
segmentID := int64(1)
mockAlloc := allocator.NewMockAllocator(s.T())
mockAlloc.EXPECT().Alloc(mock.Anything).Return(0, 0, fmt.Errorf("mock error")).Once()
task := &clusteringCompactionTask{
collectionID: 111,
partitionID: 222,
bm25FieldIds: []int64{102},
logIDAlloc: mockAlloc,
}
statsMap := make(map[int64]*storage.BM25Stats)
statsMap[102] = storage.NewBM25Stats()
statsMap[102].Append(map[uint32]float32{1: 1})
_, err := task.generateBM25Stats(context.Background(), segmentID, statsMap)
s.Error(err)
})
s.Run("upload failed", func() {
segmentID := int64(1)
mockBinlogIO := io.NewMockBinlogIO(s.T())
mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
task := &clusteringCompactionTask{
collectionID: 111,
partitionID: 222,
bm25FieldIds: []int64{102},
logIDAlloc: s.mockAlloc,
binlogIO: mockBinlogIO,
}
statsMap := make(map[int64]*storage.BM25Stats)
statsMap[102] = storage.NewBM25Stats()
statsMap[102].Append(map[uint32]float32{1: 1})
_, err := task.generateBM25Stats(context.Background(), segmentID, statsMap)
s.Error(err)
})
}
func (s *ClusteringCompactionTaskSuite) TestGeneratePkStats() {
pkField := &schemapb.FieldSchema{
FieldID: 100,
@ -304,7 +452,7 @@ func (s *ClusteringCompactionTaskSuite) TestGeneratePkStats() {
s.Run("upload failed", func() {
schema := genCollectionSchema()
segWriter, err := NewSegmentWriter(schema, 1000, SegmentID, PartitionID, CollectionID)
segWriter, err := NewSegmentWriter(schema, 1000, SegmentID, PartitionID, CollectionID, []int64{})
s.Require().NoError(err)
for i := 0; i < 2000; i++ {
v := storage.Value{
@ -403,3 +551,64 @@ func genCollectionSchema() *schemapb.CollectionSchema {
},
}
}
func genCollectionSchemaWithBM25() *schemapb.CollectionSchema {
return &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: "row_id",
DataType: schemapb.DataType_Int64,
},
{
FieldID: common.TimeStampField,
Name: "Timestamp",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 100,
Name: "pk",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: "text",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: "8",
},
},
},
{
FieldID: 102,
Name: "sparse",
DataType: schemapb.DataType_SparseFloatVector,
},
},
Functions: []*schemapb.FunctionSchema{{
Name: "BM25",
Id: 100,
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"text"},
InputFieldIds: []int64{101},
OutputFieldNames: []string{"sparse"},
OutputFieldIds: []int64{102},
}},
}
}
func genRowWithBM25(magic int64) map[int64]interface{} {
ts := tsoutil.ComposeTSByTime(getMilvusBirthday(), 0)
return map[int64]interface{}{
common.RowIDField: magic,
common.TimeStampField: int64(ts),
100: magic,
101: "varchar",
102: typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{1: 1}),
}
}

View File

@ -214,7 +214,7 @@ func uploadStatsBlobs(ctx context.Context, collectionID, partitionID, segmentID,
},
}
if err := io.Upload(ctx, kvs); err != nil {
log.Warn("failed to upload insert log", zap.Error(err))
log.Warn("failed to upload stats log", zap.Error(err))
return nil, err
}
@ -229,3 +229,45 @@ func mergeFieldBinlogs(base, paths map[typeutil.UniqueID]*datapb.FieldBinlog) {
base[fID].Binlogs = append(base[fID].Binlogs, fpath.GetBinlogs()...)
}
}
func bm25SerializeWrite(ctx context.Context, io io.BinlogIO, allocator allocator.Interface, writer *SegmentWriter) ([]*datapb.FieldBinlog, error) {
ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "bm25 stats log serializeWrite")
defer span.End()
stats, err := writer.GetBm25StatsBlob()
if err != nil {
return nil, err
}
logID, _, err := allocator.Alloc(uint32(len(stats)))
if err != nil {
return nil, err
}
kvs := make(map[string][]byte)
binlogs := []*datapb.FieldBinlog{}
for fieldID, blob := range stats {
key, _ := binlog.BuildLogPath(storage.BM25Binlog, writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), fieldID, logID)
kvs[key] = blob.GetValue()
fieldLog := &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []*datapb.Binlog{
{
LogSize: int64(len(blob.GetValue())),
MemorySize: int64(len(blob.GetValue())),
LogPath: key,
EntriesNum: writer.GetRowNum(),
},
},
}
binlogs = append(binlogs, fieldLog)
}
if err := io.Upload(ctx, kvs); err != nil {
log.Warn("failed to upload bm25 log", zap.Error(err))
return nil, err
}
return binlogs, nil
}

View File

@ -28,6 +28,7 @@ func mergeSortMultipleSegments(ctx context.Context,
tr *timerecord.TimeRecorder,
currentTs typeutil.Timestamp,
collectionTtl int64,
bm25FieldIds []int64,
) ([]*datapb.CompactionSegment, error) {
_ = tr.RecordSpan()
@ -39,7 +40,7 @@ func mergeSortMultipleSegments(ctx context.Context,
segIDAlloc := allocator.NewLocalAllocator(plan.GetPreAllocatedSegmentIDs().GetBegin(), plan.GetPreAllocatedSegmentIDs().GetEnd())
logIDAlloc := allocator.NewLocalAllocator(plan.GetBeginLogID(), math.MaxInt64)
compAlloc := NewCompactionAllocator(segIDAlloc, logIDAlloc)
mWriter := NewMultiSegmentWriter(binlogIO, compAlloc, plan, maxRows, partitionID, collectionID)
mWriter := NewMultiSegmentWriter(binlogIO, compAlloc, plan, maxRows, partitionID, collectionID, bm25FieldIds)
var (
expiredRowCount int64 // the number of expired entities
@ -87,7 +88,7 @@ func mergeSortMultipleSegments(ctx context.Context,
}
binlogPaths[idx] = batchPaths
}
segmentReaders[i] = NewSegmentDeserializeReader(ctx, binlogPaths, binlogIO, pkField.GetFieldID())
segmentReaders[i] = NewSegmentDeserializeReader(ctx, binlogPaths, binlogIO, pkField.GetFieldID(), bm25FieldIds)
}
pq := make(PriorityQueue, 0)

View File

@ -57,6 +57,8 @@ type mixCompactionTask struct {
maxRows int64
pkID int64
bm25FieldIDs []int64
done chan struct{}
tr *timerecord.TimeRecorder
}
@ -97,6 +99,7 @@ func (t *mixCompactionTask) preCompact() error {
t.collectionID = t.plan.GetSegmentBinlogs()[0].GetCollectionID()
t.partitionID = t.plan.GetSegmentBinlogs()[0].GetPartitionID()
t.targetSize = t.plan.GetMaxSize()
t.bm25FieldIDs = GetBM25FieldIDs(t.plan.GetSchema())
currSize := int64(0)
for _, segmentBinlog := range t.plan.GetSegmentBinlogs() {
@ -140,7 +143,7 @@ func (t *mixCompactionTask) mergeSplit(
segIDAlloc := allocator.NewLocalAllocator(t.plan.GetPreAllocatedSegmentIDs().GetBegin(), t.plan.GetPreAllocatedSegmentIDs().GetEnd())
logIDAlloc := allocator.NewLocalAllocator(t.plan.GetBeginLogID(), math.MaxInt64)
compAlloc := NewCompactionAllocator(segIDAlloc, logIDAlloc)
mWriter := NewMultiSegmentWriter(t.binlogIO, compAlloc, t.plan, t.maxRows, t.partitionID, t.collectionID)
mWriter := NewMultiSegmentWriter(t.binlogIO, compAlloc, t.plan, t.maxRows, t.partitionID, t.collectionID, t.bm25FieldIDs)
deletedRowCount := int64(0)
expiredRowCount := int64(0)
@ -285,7 +288,7 @@ func (t *mixCompactionTask) Compact() (*datapb.CompactionPlanResult, error) {
if allSorted && len(t.plan.GetSegmentBinlogs()) > 1 {
log.Info("all segments are sorted, use merge sort")
res, err = mergeSortMultipleSegments(ctxTimeout, t.plan, t.collectionID, t.partitionID, t.maxRows, t.binlogIO,
t.plan.GetSegmentBinlogs(), deltaPk2Ts, t.tr, t.currentTs, t.plan.GetCollectionTtl())
t.plan.GetSegmentBinlogs(), deltaPk2Ts, t.tr, t.currentTs, t.plan.GetCollectionTtl(), t.bm25FieldIDs)
if err != nil {
log.Warn("compact wrong, fail to merge sort segments", zap.Error(err))
return nil, err
@ -341,3 +344,12 @@ func (t *mixCompactionTask) GetCollection() typeutil.UniqueID {
func (t *mixCompactionTask) GetSlotUsage() int64 {
return t.plan.GetSlotUsage()
}
func GetBM25FieldIDs(coll *schemapb.CollectionSchema) []int64 {
return lo.FilterMap(coll.GetFunctions(), func(function *schemapb.FunctionSchema, _ int) (int64, bool) {
if function.GetType() == schemapb.FunctionType_BM25 {
return function.GetOutputFieldIds()[0], true
}
return 0, false
})
}

View File

@ -91,6 +91,30 @@ func (s *MixCompactionTaskSuite) SetupTest() {
s.task.plan = s.plan
}
func (s *MixCompactionTaskSuite) SetupBM25() {
s.mockBinlogIO = io.NewMockBinlogIO(s.T())
s.meta = genTestCollectionMetaWithBM25()
s.plan = &datapb.CompactionPlan{
PlanID: 999,
SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{{
SegmentID: 100,
FieldBinlogs: nil,
Field2StatslogPaths: nil,
Deltalogs: nil,
}},
TimeoutInSeconds: 10,
Type: datapb.CompactionType_MixCompaction,
Schema: s.meta.GetSchema(),
BeginLogID: 19530,
PreAllocatedSegmentIDs: &datapb.IDRange{Begin: 19531, End: math.MaxInt64},
MaxSize: 64 * 1024 * 1024,
}
s.task = NewMixCompactionTask(context.Background(), s.mockBinlogIO, s.plan)
s.task.plan = s.plan
}
func (s *MixCompactionTaskSuite) SetupSubTest() {
s.SetupTest()
}
@ -210,6 +234,55 @@ func (s *MixCompactionTaskSuite) TestCompactTwoToOne() {
s.Empty(segment.Deltalogs)
}
func (s *MixCompactionTaskSuite) TestCompactTwoToOneWithBM25() {
s.SetupBM25()
segments := []int64{5, 6, 7}
alloc := allocator.NewLocalAllocator(7777777, math.MaxInt64)
s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil)
s.task.plan.SegmentBinlogs = make([]*datapb.CompactionSegmentBinlogs, 0)
for _, segID := range segments {
s.initSegBufferWithBM25(segID)
kvs, fBinlogs, err := serializeWrite(context.TODO(), alloc, s.segWriter)
s.Require().NoError(err)
s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.MatchedBy(func(keys []string) bool {
left, right := lo.Difference(keys, lo.Keys(kvs))
return len(left) == 0 && len(right) == 0
})).Return(lo.Values(kvs), nil).Once()
s.plan.SegmentBinlogs = append(s.plan.SegmentBinlogs, &datapb.CompactionSegmentBinlogs{
SegmentID: segID,
FieldBinlogs: lo.Values(fBinlogs),
})
}
// append an empty segment
seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{
CollectionID: CollectionID,
PartitionID: PartitionID,
ID: 99999,
NumOfRows: 0,
}, pkoracle.NewBloomFilterSet(), nil)
s.plan.SegmentBinlogs = append(s.plan.SegmentBinlogs, &datapb.CompactionSegmentBinlogs{
SegmentID: seg.SegmentID(),
})
result, err := s.task.Compact()
s.NoError(err)
s.NotNil(result)
s.Equal(s.task.plan.GetPlanID(), result.GetPlanID())
s.Equal(1, len(result.GetSegments()))
segment := result.GetSegments()[0]
s.EqualValues(19531, segment.GetSegmentID())
s.EqualValues(3, segment.GetNumOfRows())
s.NotEmpty(segment.InsertLogs)
s.NotEmpty(segment.Bm25Logs)
s.NotEmpty(segment.Field2StatslogPaths)
s.Empty(segment.Deltalogs)
}
func (s *MixCompactionTaskSuite) TestCompactSortedSegment() {
segments := []int64{1001, 1002, 1003}
alloc := allocator.NewLocalAllocator(100, math.MaxInt64)
@ -316,6 +389,16 @@ func (s *MixCompactionTaskSuite) TestMergeNoExpiration() {
}
}
func (s *MixCompactionTaskSuite) TestGetBM25FieldIDs() {
fieldIDs := GetBM25FieldIDs(&schemapb.CollectionSchema{
Functions: []*schemapb.FunctionSchema{{}},
})
s.Equal(0, len(fieldIDs))
fieldIDs = GetBM25FieldIDs(genCollectionSchemaWithBM25())
s.Equal(1, len(fieldIDs))
}
func (s *MixCompactionTaskSuite) TestMergeDeltalogsMultiSegment() {
tests := []struct {
segIDA int64
@ -540,7 +623,7 @@ func getRow(magic int64) map[int64]interface{} {
}
func (s *MixCompactionTaskSuite) initMultiRowsSegBuffer(magic, numRows, step int64) {
segWriter, err := NewSegmentWriter(s.meta.GetSchema(), 65535, magic, PartitionID, CollectionID)
segWriter, err := NewSegmentWriter(s.meta.GetSchema(), 65535, magic, PartitionID, CollectionID, []int64{})
s.Require().NoError(err)
for i := int64(0); i < numRows; i++ {
@ -558,8 +641,24 @@ func (s *MixCompactionTaskSuite) initMultiRowsSegBuffer(magic, numRows, step int
s.segWriter = segWriter
}
func (s *MixCompactionTaskSuite) initSegBufferWithBM25(magic int64) {
segWriter, err := NewSegmentWriter(s.meta.GetSchema(), 100, magic, PartitionID, CollectionID, []int64{102})
s.Require().NoError(err)
v := storage.Value{
PK: storage.NewInt64PrimaryKey(magic),
Timestamp: int64(tsoutil.ComposeTSByTime(getMilvusBirthday(), 0)),
Value: genRowWithBM25(magic),
}
err = segWriter.Write(&v)
s.Require().NoError(err)
segWriter.FlushAndIsFull()
s.segWriter = segWriter
}
func (s *MixCompactionTaskSuite) initSegBuffer(magic int64) {
segWriter, err := NewSegmentWriter(s.meta.GetSchema(), 100, magic, PartitionID, CollectionID)
segWriter, err := NewSegmentWriter(s.meta.GetSchema(), 100, magic, PartitionID, CollectionID, []int64{})
s.Require().NoError(err)
v := storage.Value{
@ -608,6 +707,14 @@ func getInt64DeltaBlobs(segID int64, pks []int64, tss []uint64) (*storage.Blob,
return blob, err
}
func genTestCollectionMetaWithBM25() *etcdpb.CollectionMeta {
return &etcdpb.CollectionMeta{
ID: CollectionID,
PartitionTags: []string{"partition_0", "partition_1"},
Schema: genCollectionSchemaWithBM25(),
}
}
func genTestCollectionMeta() *etcdpb.CollectionMeta {
return &etcdpb.CollectionMeta{
ID: CollectionID,

View File

@ -21,9 +21,11 @@ type SegmentDeserializeReader struct {
PKFieldID int64
binlogPaths [][]string
binlogPathPos int
bm25FieldIDs []int64
}
func NewSegmentDeserializeReader(ctx context.Context, binlogPaths [][]string, binlogIO binlogIO.BinlogIO, PKFieldID int64) *SegmentDeserializeReader {
func NewSegmentDeserializeReader(ctx context.Context, binlogPaths [][]string, binlogIO binlogIO.BinlogIO, PKFieldID int64, bm25FieldIDs []int64) *SegmentDeserializeReader {
return &SegmentDeserializeReader{
ctx: ctx,
binlogIO: binlogIO,
@ -32,6 +34,7 @@ func NewSegmentDeserializeReader(ctx context.Context, binlogPaths [][]string, bi
PKFieldID: PKFieldID,
binlogPaths: binlogPaths,
binlogPathPos: 0,
bm25FieldIDs: bm25FieldIDs,
}
}

View File

@ -6,6 +6,7 @@ package compaction
import (
"context"
"fmt"
"math"
"github.com/samber/lo"
@ -48,6 +49,7 @@ type MultiSegmentWriter struct {
res []*datapb.CompactionSegment
// DONOT leave it empty of all segments are deleted, just return a segment with zero meta for datacoord
bm25Fields []int64
}
type compactionAlloactor struct {
@ -70,7 +72,7 @@ func (alloc *compactionAlloactor) getLogIDAllocator() allocator.Interface {
return alloc.logIDAlloc
}
func NewMultiSegmentWriter(binlogIO io.BinlogIO, allocator *compactionAlloactor, plan *datapb.CompactionPlan, maxRows int64, partitionID, collectionID int64) *MultiSegmentWriter {
func NewMultiSegmentWriter(binlogIO io.BinlogIO, allocator *compactionAlloactor, plan *datapb.CompactionPlan, maxRows int64, partitionID, collectionID int64, bm25Fields []int64) *MultiSegmentWriter {
return &MultiSegmentWriter{
binlogIO: binlogIO,
allocator: allocator,
@ -88,6 +90,7 @@ func NewMultiSegmentWriter(binlogIO io.BinlogIO, allocator *compactionAlloactor,
cachedMeta: make(map[typeutil.UniqueID]map[typeutil.UniqueID]*datapb.FieldBinlog),
res: make([]*datapb.CompactionSegment, 0),
bm25Fields: bm25Fields,
}
}
@ -116,13 +119,24 @@ func (w *MultiSegmentWriter) finishCurrent() error {
return err
}
w.res = append(w.res, &datapb.CompactionSegment{
result := &datapb.CompactionSegment{
SegmentID: writer.GetSegmentID(),
InsertLogs: lo.Values(allBinlogs),
Field2StatslogPaths: []*datapb.FieldBinlog{sPath},
NumOfRows: writer.GetRowNum(),
Channel: w.channel,
})
}
if len(w.bm25Fields) > 0 {
bmBinlogs, err := bm25SerializeWrite(context.TODO(), w.binlogIO, w.allocator.getLogIDAllocator(), writer)
if err != nil {
log.Warn("compact wrong, failed to serialize write segment bm25 stats", zap.Error(err))
return err
}
result.Bm25Logs = bmBinlogs
}
w.res = append(w.res, result)
log.Info("Segment writer flushed a segment",
zap.Int64("segmentID", writer.GetSegmentID()),
@ -139,7 +153,7 @@ func (w *MultiSegmentWriter) addNewWriter() error {
if err != nil {
return err
}
writer, err := NewSegmentWriter(w.schema, w.maxRows, newSegmentID, w.partitionID, w.collectionID)
writer, err := NewSegmentWriter(w.schema, w.maxRows, newSegmentID, w.partitionID, w.collectionID, w.bm25Fields)
if err != nil {
return err
}
@ -307,7 +321,9 @@ type SegmentWriter struct {
tsFrom typeutil.Timestamp
tsTo typeutil.Timestamp
pkstats *storage.PrimaryKeyStats
pkstats *storage.PrimaryKeyStats
bm25Stats map[int64]*storage.BM25Stats
segmentID int64
partitionID int64
collectionID int64
@ -350,6 +366,19 @@ func (w *SegmentWriter) Write(v *storage.Value) error {
}
w.pkstats.Update(v.PK)
for fieldID, stats := range w.bm25Stats {
data, ok := v.Value.(map[storage.FieldID]interface{})[fieldID]
if !ok {
return fmt.Errorf("bm25 field value not found")
}
bytes, ok := data.([]byte)
if !ok {
return fmt.Errorf("bm25 field value not sparse bytes")
}
stats.AppendBytes(bytes)
}
w.rowCount.Inc()
return w.writer.Write(v)
}
@ -360,6 +389,28 @@ func (w *SegmentWriter) Finish() (*storage.Blob, error) {
return codec.SerializePkStats(w.pkstats, w.GetRowNum())
}
func (w *SegmentWriter) GetBm25Stats() map[int64]*storage.BM25Stats {
return w.bm25Stats
}
func (w *SegmentWriter) GetBm25StatsBlob() (map[int64]*storage.Blob, error) {
result := make(map[int64]*storage.Blob)
for fieldID, stats := range w.bm25Stats {
bytes, err := stats.Serialize()
if err != nil {
return nil, err
}
result[fieldID] = &storage.Blob{
Key: fmt.Sprintf("%d", fieldID),
Value: bytes,
RowNum: stats.NumRow(),
MemorySize: int64(len(bytes)),
}
}
return result, nil
}
func (w *SegmentWriter) IsFull() bool {
return w.writer.WrittenMemorySize() > paramtable.Get().DataNodeCfg.BinLogMaxSize.GetAsUint64()
}
@ -420,7 +471,7 @@ func (w *SegmentWriter) clear() {
w.tsTo = 0
}
func NewSegmentWriter(sch *schemapb.CollectionSchema, maxCount int64, segID, partID, collID int64) (*SegmentWriter, error) {
func NewSegmentWriter(sch *schemapb.CollectionSchema, maxCount int64, segID, partID, collID int64, Bm25Fields []int64) (*SegmentWriter, error) {
writer, closers, err := newBinlogWriter(collID, partID, segID, sch)
if err != nil {
return nil, err
@ -444,6 +495,7 @@ func NewSegmentWriter(sch *schemapb.CollectionSchema, maxCount int64, segID, par
tsTo: 0,
pkstats: stats,
bm25Stats: make(map[int64]*storage.BM25Stats),
sch: sch,
segmentID: segID,
partitionID: partID,
@ -452,6 +504,9 @@ func NewSegmentWriter(sch *schemapb.CollectionSchema, maxCount int64, segID, par
syncedSize: atomic.NewInt64(0),
}
for _, fieldID := range Bm25Fields {
segWriter.bm25Stats[fieldID] = storage.NewBM25Stats()
}
return &segWriter, nil
}

View File

@ -0,0 +1,73 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package compaction
import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
)
func TestSegmentWriterSuite(t *testing.T) {
suite.Run(t, new(SegmentWriteSuite))
}
type SegmentWriteSuite struct {
suite.Suite
collectionID int64
parititonID int64
}
func (s *SegmentWriteSuite) SetupSuite() {
s.collectionID = 100
s.parititonID = 101
}
func (s *SegmentWriteSuite) TestWriteFailed() {
s.Run("get bm25 field failed", func() {
schema := genCollectionSchemaWithBM25()
// init segment writer with invalid bm25 fieldID
writer, err := NewSegmentWriter(schema, 1024, 1, s.parititonID, s.collectionID, []int64{1000})
s.Require().NoError(err)
v := storage.Value{
PK: storage.NewInt64PrimaryKey(int64(0)),
Timestamp: int64(tsoutil.ComposeTSByTime(getMilvusBirthday(), 0)),
Value: genRowWithBM25(int64(0)),
}
err = writer.Write(&v)
s.Error(err)
})
s.Run("parse bm25 field data failed", func() {
schema := genCollectionSchemaWithBM25()
// init segment writer with wrong field as bm25 sparse field
writer, err := NewSegmentWriter(schema, 1024, 1, s.parititonID, s.collectionID, []int64{101})
s.Require().NoError(err)
v := storage.Value{
PK: storage.NewInt64PrimaryKey(int64(0)),
Timestamp: int64(tsoutil.ComposeTSByTime(getMilvusBirthday(), 0)),
Value: genRowWithBM25(int64(0)),
}
err = writer.Write(&v)
s.Error(err)
})
}

View File

@ -514,6 +514,7 @@ func (i *IndexNode) QueryJobsV2(ctx context.Context, req *workerpb.QueryJobsV2Re
InsertLogs: info.insertLogs,
StatsLogs: info.statsLogs,
TextStatsLogs: info.textStatsLogs,
Bm25Logs: info.bm25Logs,
NumRows: info.numRows,
})
}

View File

@ -155,7 +155,9 @@ func (st *statsTask) PreExecute(ctx context.Context) error {
func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, error) {
numRows := st.req.GetNumRows()
writer, err := compaction.NewSegmentWriter(st.req.GetSchema(), numRows, st.req.GetTargetSegmentID(), st.req.GetPartitionID(), st.req.GetCollectionID())
bm25FieldIds := compaction.GetBM25FieldIDs(st.req.GetSchema())
writer, err := compaction.NewSegmentWriter(st.req.GetSchema(), numRows, st.req.GetTargetSegmentID(), st.req.GetPartitionID(), st.req.GetCollectionID(), bm25FieldIds)
if err != nil {
log.Warn("sort segment wrong, unable to init segment writer",
zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
@ -174,7 +176,7 @@ func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, er
uploadTimeCost := time.Duration(0)
sortTimeCost := time.Duration(0)
values, err := st.downloadData(ctx, numRows, writer.GetPkID())
values, err := st.downloadData(ctx, numRows, writer.GetPkID(), bm25FieldIds)
if err != nil {
log.Warn("download data failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
return nil, err
@ -255,6 +257,20 @@ func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, er
st.logIDOffset += binlogNums
var bm25StatsLogs []*datapb.FieldBinlog
if len(bm25FieldIds) > 0 {
binlogNums, bm25StatsLogs, err = bm25SerializeWrite(ctx, st.binlogIO, st.req.GetStartLogID()+st.logIDOffset, writer, numRows)
if err != nil {
log.Warn("compact wrong, failed to serialize write segment bm25 stats", zap.Error(err))
return nil, err
}
st.logIDOffset += binlogNums
if err := binlog.CompressFieldBinlogs(bm25StatsLogs); err != nil {
return nil, err
}
}
totalElapse := st.tr.RecordSpan()
insertLogs := lo.Values(allBinlogs)
@ -273,7 +289,7 @@ func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, er
st.req.GetPartitionID(),
st.req.GetTargetSegmentID(),
st.req.GetInsertChannel(),
int64(len(values)), insertLogs, statsLogs)
int64(len(values)), insertLogs, statsLogs, bm25StatsLogs)
log.Info("sort segment end",
zap.String("clusterID", st.req.GetClusterID()),
@ -322,8 +338,6 @@ func (st *statsTask) Execute(ctx context.Context) error {
}
}
// TODO support bm25
return nil
}
@ -340,13 +354,14 @@ func (st *statsTask) Reset() {
st.node = nil
}
func (st *statsTask) downloadData(ctx context.Context, numRows int64, PKFieldID int64) ([]*storage.Value, error) {
func (st *statsTask) downloadData(ctx context.Context, numRows int64, PKFieldID int64, bm25FieldIds []int64) ([]*storage.Value, error) {
log := log.Ctx(ctx).With(
zap.String("clusterID", st.req.GetClusterID()),
zap.Int64("taskID", st.req.GetTaskID()),
zap.Int64("collectionID", st.req.GetCollectionID()),
zap.Int64("partitionID", st.req.GetPartitionID()),
zap.Int64("segmentID", st.req.GetSegmentID()),
zap.Int64s("bm25Fields", bm25FieldIds),
)
deletePKs, err := st.loadDeltalogs(ctx, st.deltaLogs)
@ -564,6 +579,44 @@ func statSerializeWrite(ctx context.Context, io io.BinlogIO, startID int64, writ
return binlogNum, statFieldLog, nil
}
func bm25SerializeWrite(ctx context.Context, io io.BinlogIO, startID int64, writer *compaction.SegmentWriter, finalRowCount int64) (int64, []*datapb.FieldBinlog, error) {
ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "bm25log serializeWrite")
defer span.End()
stats, err := writer.GetBm25StatsBlob()
if err != nil {
return 0, nil, err
}
kvs := make(map[string][]byte)
binlogs := []*datapb.FieldBinlog{}
cnt := int64(0)
for fieldID, blob := range stats {
key, _ := binlog.BuildLogPath(storage.BM25Binlog, writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), fieldID, startID+cnt)
kvs[key] = blob.GetValue()
fieldLog := &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []*datapb.Binlog{
{
LogSize: int64(len(blob.GetValue())),
MemorySize: int64(len(blob.GetValue())),
LogPath: key,
EntriesNum: finalRowCount,
},
},
}
binlogs = append(binlogs, fieldLog)
cnt++
}
if err := io.Upload(ctx, kvs); err != nil {
log.Warn("failed to upload bm25 log", zap.Error(err))
return 0, nil, err
}
return cnt, binlogs, nil
}
func buildTextLogPrefix(rootPath string, collID, partID, segID, fieldID, version int64) string {
return fmt.Sprintf("%s/%s/%d/%d/%d/%d/%d", rootPath, common.TextIndexPath, collID, partID, segID, fieldID, version)
}

View File

@ -0,0 +1,247 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package indexnode
import (
"context"
"fmt"
"testing"
"time"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datanode/compaction"
"github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/proto/workerpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestTaskStatsSuite(t *testing.T) {
suite.Run(t, new(TaskStatsSuite))
}
type TaskStatsSuite struct {
suite.Suite
collectionID int64
partitionID int64
clusterID string
schema *schemapb.CollectionSchema
mockBinlogIO *io.MockBinlogIO
segWriter *compaction.SegmentWriter
}
func (s *TaskStatsSuite) SetupSuite() {
s.collectionID = 100
s.partitionID = 101
s.clusterID = "102"
}
func (s *TaskStatsSuite) SetupSubTest() {
paramtable.Init()
s.mockBinlogIO = io.NewMockBinlogIO(s.T())
}
func (s *TaskStatsSuite) GenSegmentWriterWithBM25(magic int64) {
segWriter, err := compaction.NewSegmentWriter(s.schema, 100, magic, s.partitionID, s.collectionID, []int64{102})
s.Require().NoError(err)
v := storage.Value{
PK: storage.NewInt64PrimaryKey(magic),
Timestamp: int64(tsoutil.ComposeTSByTime(getMilvusBirthday(), 0)),
Value: genRowWithBM25(magic),
}
err = segWriter.Write(&v)
s.Require().NoError(err)
segWriter.FlushAndIsFull()
s.segWriter = segWriter
}
func (s *TaskStatsSuite) Testbm25SerializeWriteError() {
s.Run("normal case", func() {
s.schema = genCollectionSchemaWithBM25()
s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil).Once()
s.GenSegmentWriterWithBM25(0)
cnt, binlogs, err := bm25SerializeWrite(context.Background(), s.mockBinlogIO, 0, s.segWriter, 1)
s.Require().NoError(err)
s.Equal(int64(1), cnt)
s.Equal(1, len(binlogs))
})
s.Run("upload failed", func() {
s.schema = genCollectionSchemaWithBM25()
s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.GenSegmentWriterWithBM25(0)
_, _, err := bm25SerializeWrite(context.Background(), s.mockBinlogIO, 0, s.segWriter, 1)
s.Error(err)
})
}
func (s *TaskStatsSuite) TestSortSegmentWithBM25() {
s.Run("normal case", func() {
s.schema = genCollectionSchemaWithBM25()
s.GenSegmentWriterWithBM25(0)
_, kvs, fBinlogs, err := serializeWrite(context.TODO(), 0, s.segWriter)
s.NoError(err)
s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, paths []string) ([][]byte, error) {
result := make([][]byte, len(paths))
for i, path := range paths {
result[i] = kvs[path]
}
return result, nil
})
s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil)
ctx, cancel := context.WithCancel(context.Background())
testTaskKey := taskKey{ClusterID: s.clusterID, TaskID: 100}
node := &IndexNode{statsTasks: map[taskKey]*statsTaskInfo{testTaskKey: {segID: 1}}}
task := newStatsTask(ctx, cancel, &workerpb.CreateStatsRequest{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
ClusterID: s.clusterID,
TaskID: testTaskKey.TaskID,
TargetSegmentID: 1,
InsertLogs: lo.Values(fBinlogs),
Schema: s.schema,
NumRows: 1,
}, node, s.mockBinlogIO)
err = task.PreExecute(ctx)
s.Require().NoError(err)
binlog, err := task.sortSegment(ctx)
s.Require().NoError(err)
s.Equal(5, len(binlog))
// check bm25 log
s.Equal(1, len(node.statsTasks))
for key, task := range node.statsTasks {
s.Equal(testTaskKey.ClusterID, key.ClusterID)
s.Equal(testTaskKey.TaskID, key.TaskID)
s.Equal(1, len(task.bm25Logs))
}
})
s.Run("upload bm25 binlog failed", func() {
s.schema = genCollectionSchemaWithBM25()
s.GenSegmentWriterWithBM25(0)
_, kvs, fBinlogs, err := serializeWrite(context.TODO(), 0, s.segWriter)
s.NoError(err)
s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, paths []string) ([][]byte, error) {
result := make([][]byte, len(paths))
for i, path := range paths {
result[i] = kvs[path]
}
return result, nil
})
s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil).Times(2)
s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
ctx, cancel := context.WithCancel(context.Background())
testTaskKey := taskKey{ClusterID: s.clusterID, TaskID: 100}
node := &IndexNode{statsTasks: map[taskKey]*statsTaskInfo{testTaskKey: {segID: 1}}}
task := newStatsTask(ctx, cancel, &workerpb.CreateStatsRequest{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
ClusterID: s.clusterID,
TaskID: testTaskKey.TaskID,
TargetSegmentID: 1,
InsertLogs: lo.Values(fBinlogs),
Schema: s.schema,
NumRows: 1,
}, node, s.mockBinlogIO)
err = task.PreExecute(ctx)
s.Require().NoError(err)
_, err = task.sortSegment(ctx)
s.Error(err)
})
}
func genCollectionSchemaWithBM25() *schemapb.CollectionSchema {
return &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: "row_id",
DataType: schemapb.DataType_Int64,
},
{
FieldID: common.TimeStampField,
Name: "Timestamp",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 100,
Name: "pk",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: "text",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: "8",
},
},
},
{
FieldID: 102,
Name: "sparse",
DataType: schemapb.DataType_SparseFloatVector,
},
},
Functions: []*schemapb.FunctionSchema{{
Name: "BM25",
Id: 100,
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"text"},
InputFieldIds: []int64{101},
OutputFieldNames: []string{"sparse"},
OutputFieldIds: []int64{102},
}},
}
}
func genRowWithBM25(magic int64) map[int64]interface{} {
ts := tsoutil.ComposeTSByTime(getMilvusBirthday(), 0)
return map[int64]interface{}{
common.RowIDField: magic,
common.TimeStampField: int64(ts),
100: magic,
101: "varchar",
102: typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{1: 1}),
}
}
func getMilvusBirthday() time.Time {
return time.Date(2019, time.Month(5), 30, 0, 0, 0, 0, time.UTC)
}

View File

@ -323,6 +323,7 @@ type statsTaskInfo struct {
insertLogs []*datapb.FieldBinlog
statsLogs []*datapb.FieldBinlog
textStatsLogs map[int64]*datapb.TextIndexStats
bm25Logs []*datapb.FieldBinlog
}
func (i *IndexNode) loadOrStoreStatsTask(clusterID string, taskID UniqueID, info *statsTaskInfo) *statsTaskInfo {
@ -370,6 +371,7 @@ func (i *IndexNode) storePKSortStatsResult(
numRows int64,
insertLogs []*datapb.FieldBinlog,
statsLogs []*datapb.FieldBinlog,
bm25Logs []*datapb.FieldBinlog,
) {
key := taskKey{ClusterID: ClusterID, TaskID: taskID}
i.stateLock.Lock()
@ -382,6 +384,7 @@ func (i *IndexNode) storePKSortStatsResult(
info.numRows = numRows
info.insertLogs = insertLogs
info.statsLogs = statsLogs
info.bm25Logs = bm25Logs
return
}
}
@ -424,6 +427,7 @@ func (i *IndexNode) getStatsTaskInfo(clusterID string, taskID UniqueID) *statsTa
insertLogs: info.insertLogs,
statsLogs: info.statsLogs,
textStatsLogs: info.textStatsLogs,
bm25Logs: info.bm25Logs,
}
}
return nil

View File

@ -79,6 +79,7 @@ func (s *statsTaskInfoSuite) Test_Methods() {
s.node.storePKSortStatsResult(s.cluster, s.taskID, 1, 2, 3, "ch1", 65535,
[]*datapb.FieldBinlog{{FieldID: 100, Binlogs: []*datapb.Binlog{{LogID: 1}}}},
[]*datapb.FieldBinlog{{FieldID: 100, Binlogs: []*datapb.Binlog{{LogID: 2}}}},
[]*datapb.FieldBinlog{},
)
})

View File

@ -63,6 +63,10 @@ func CompressCompactionBinlogs(binlogs []*datapb.CompactionSegment) error {
if err != nil {
return err
}
err = CompressFieldBinlogs(binlog.GetBm25Logs())
if err != nil {
return err
}
}
return nil
}

View File

@ -194,6 +194,7 @@ func buildBinlogKvs(collectionID, partitionID, segmentID typeutil.UniqueID, binl
kv[key] = string(binlogBytes)
}
// bm25log
for _, bm25log := range bm25logs {
if err := checkLogID(bm25log); err != nil {
return nil, err

View File

@ -198,6 +198,7 @@ message StatsResult {
repeated data.FieldBinlog stats_logs = 9;
map<int64, data.TextIndexStats> text_stats_logs = 10;
int64 num_rows = 11;
repeated data.FieldBinlog bm25_logs = 12;
}
message StatsResults {

View File

@ -54,7 +54,8 @@ func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
if stats, ok := s.stats[fieldID]; ok {
stats.Merge(newstats)
} else {
log.Panic("merge failed, BM25 stats not exist", zap.Int64("fieldID", fieldID))
s.stats[fieldID] = storage.NewBM25Stats()
s.stats[fieldID].Merge(newstats)
}
}
}

View File

@ -178,15 +178,15 @@ func (suite *IDFOracleSuite) TestStats() {
OutputFieldIds: []int64{102},
}})
suite.Panics(func() {
suite.NotPanics(func() {
stats.Merge(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()})
})
suite.Panics(func() {
stats.Minus(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()})
stats.Minus(map[int64]*storage.BM25Stats{104: storage.NewBM25Stats()})
})
_, err := stats.GetStats(103)
_, err := stats.GetStats(104)
suite.Error(err)
_, err = stats.GetStats(102)