From 1d48d0aeb28a002fa4c76230c5a457103cd889c2 Mon Sep 17 00:00:00 2001
From: SimFG <bang.fu@zilliz.com>
Date: Tue, 14 May 2024 14:59:33 +0800
Subject: [PATCH] enhance: use different value to get related data size
 according to segment type (#33017)

issue: #30436

Signed-off-by: SimFG <bang.fu@zilliz.com>
---
 internal/proto/query_coord.proto            |  2 +-
 internal/querycoordv2/utils/types.go        | 43 ----------------
 internal/querynodev2/segments/retrieve.go   |  2 +-
 internal/querynodev2/segments/utils.go      | 38 ++++++++++++++
 internal/querynodev2/segments/utils_test.go | 57 +++++++++++++++++++++
 internal/querynodev2/tasks/query_task.go    |  2 +-
 internal/querynodev2/tasks/search_task.go   | 12 ++++-
 7 files changed, 108 insertions(+), 48 deletions(-)

diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto
index 5c926df966..2f5facab9d 100644
--- a/internal/proto/query_coord.proto
+++ b/internal/proto/query_coord.proto
@@ -352,7 +352,7 @@ message SegmentLoadInfo {
     repeated data.FieldBinlog deltalogs = 9;
     repeated int64 compactionFrom = 10;  // segmentIDs compacted from
     repeated FieldIndexInfo index_infos = 11;
-    int64 segment_size = 12;
+    int64 segment_size = 12 [deprecated = true];
     string insert_channel = 13;
     msg.MsgPosition start_position = 14;
     msg.MsgPosition delta_position = 15;
diff --git a/internal/querycoordv2/utils/types.go b/internal/querycoordv2/utils/types.go
index aa4481cd34..acd58b7709 100644
--- a/internal/querycoordv2/utils/types.go
+++ b/internal/querycoordv2/utils/types.go
@@ -86,52 +86,9 @@ func PackSegmentLoadInfo(segment *datapb.SegmentInfo, channelCheckpoint *msgpb.M
 		Level:          segment.GetLevel(),
 		StorageVersion: segment.GetStorageVersion(),
 	}
-	loadInfo.SegmentSize = calculateSegmentSize(loadInfo)
 	return loadInfo
 }
 
-func calculateSegmentSize(segmentLoadInfo *querypb.SegmentLoadInfo) int64 {
-	segmentSize := int64(0)
-
-	fieldIndex := make(map[int64]*querypb.FieldIndexInfo)
-	for _, index := range segmentLoadInfo.IndexInfos {
-		if index.EnableIndex {
-			fieldID := index.FieldID
-			fieldIndex[fieldID] = index
-		}
-	}
-
-	for _, fieldBinlog := range segmentLoadInfo.BinlogPaths {
-		fieldID := fieldBinlog.FieldID
-		if index, ok := fieldIndex[fieldID]; ok {
-			segmentSize += index.IndexSize
-		} else {
-			segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
-		}
-	}
-
-	// Get size of state data
-	for _, fieldBinlog := range segmentLoadInfo.Statslogs {
-		segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
-	}
-
-	// Get size of delete data
-	for _, fieldBinlog := range segmentLoadInfo.Deltalogs {
-		segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
-	}
-
-	return segmentSize
-}
-
-func getFieldSizeFromFieldBinlog(fieldBinlog *datapb.FieldBinlog) int64 {
-	fieldSize := int64(0)
-	for _, binlog := range fieldBinlog.Binlogs {
-		fieldSize += binlog.LogSize
-	}
-
-	return fieldSize
-}
-
 func MergeDmChannelInfo(infos []*datapb.VchannelInfo) *meta.DmChannel {
 	var dmChannel *meta.DmChannel
 
diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go
index a29590839f..9d6eecc2fe 100644
--- a/internal/querynodev2/segments/retrieve.go
+++ b/internal/querynodev2/segments/retrieve.go
@@ -146,7 +146,7 @@ func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segTy
 					Ids:        result.GetIds(),
 					FieldsData: result.GetFieldsData(),
 					CostAggregation: &internalpb.CostAggregation{
-						TotalRelatedDataSize: segment.MemSize(),
+						TotalRelatedDataSize: GetSegmentRelatedDataSize(segment),
 					},
 					AllRetrieveCount: result.GetAllRetrieveCount(),
 				}); err != nil {
diff --git a/internal/querynodev2/segments/utils.go b/internal/querynodev2/segments/utils.go
index 2a9ff5ac5c..c247bdc443 100644
--- a/internal/querynodev2/segments/utils.go
+++ b/internal/querynodev2/segments/utils.go
@@ -23,7 +23,9 @@ import (
 	"go.uber.org/zap"
 
 	"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
+	"github.com/milvus-io/milvus/internal/proto/datapb"
 	"github.com/milvus-io/milvus/internal/proto/internalpb"
+	"github.com/milvus-io/milvus/internal/proto/querypb"
 	"github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil"
 	"github.com/milvus-io/milvus/internal/storage"
 	"github.com/milvus-io/milvus/pkg/common"
@@ -206,3 +208,39 @@ func withLazyLoadTimeoutContext(ctx context.Context) (context.Context, context.C
 	// TODO: use context.WithTimeoutCause instead of contextutil.WithTimeoutCause in go1.21
 	return contextutil.WithTimeoutCause(ctx, lazyLoadTimeout, errLazyLoadTimeout)
 }
+
+func GetSegmentRelatedDataSize(segment Segment) int64 {
+	if segment.Type() == SegmentTypeSealed {
+		return calculateSegmentLogSize(segment.LoadInfo())
+	}
+	return segment.MemSize()
+}
+
+func calculateSegmentLogSize(segmentLoadInfo *querypb.SegmentLoadInfo) int64 {
+	segmentSize := int64(0)
+
+	for _, fieldBinlog := range segmentLoadInfo.BinlogPaths {
+		segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
+	}
+
+	// Get size of state data
+	for _, fieldBinlog := range segmentLoadInfo.Statslogs {
+		segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
+	}
+
+	// Get size of delete data
+	for _, fieldBinlog := range segmentLoadInfo.Deltalogs {
+		segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
+	}
+
+	return segmentSize
+}
+
+func getFieldSizeFromFieldBinlog(fieldBinlog *datapb.FieldBinlog) int64 {
+	fieldSize := int64(0)
+	for _, binlog := range fieldBinlog.Binlogs {
+		fieldSize += binlog.LogSize
+	}
+
+	return fieldSize
+}
diff --git a/internal/querynodev2/segments/utils_test.go b/internal/querynodev2/segments/utils_test.go
index 95d81e4575..6ad5c92291 100644
--- a/internal/querynodev2/segments/utils_test.go
+++ b/internal/querynodev2/segments/utils_test.go
@@ -4,6 +4,9 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+
+	"github.com/milvus-io/milvus/internal/proto/datapb"
+	"github.com/milvus-io/milvus/internal/proto/querypb"
 )
 
 func TestFilterZeroValuesFromSlice(t *testing.T) {
@@ -18,3 +21,57 @@ func TestFilterZeroValuesFromSlice(t *testing.T) {
 	assert.Equal(t, 3, len(filteredInts))
 	assert.EqualValues(t, []int64{10, 5, 13}, filteredInts)
 }
+
+func TestGetSegmentRelatedDataSize(t *testing.T) {
+	t.Run("seal segment", func(t *testing.T) {
+		segment := NewMockSegment(t)
+		segment.EXPECT().Type().Return(SegmentTypeSealed)
+		segment.EXPECT().LoadInfo().Return(&querypb.SegmentLoadInfo{
+			BinlogPaths: []*datapb.FieldBinlog{
+				{
+					Binlogs: []*datapb.Binlog{
+						{
+							LogSize: 10,
+						},
+						{
+							LogSize: 20,
+						},
+					},
+				},
+				{
+					Binlogs: []*datapb.Binlog{
+						{
+							LogSize: 30,
+						},
+					},
+				},
+			},
+			Deltalogs: []*datapb.FieldBinlog{
+				{
+					Binlogs: []*datapb.Binlog{
+						{
+							LogSize: 30,
+						},
+					},
+				},
+			},
+			Statslogs: []*datapb.FieldBinlog{
+				{
+					Binlogs: []*datapb.Binlog{
+						{
+							LogSize: 10,
+						},
+					},
+				},
+			},
+		})
+		assert.EqualValues(t, 100, GetSegmentRelatedDataSize(segment))
+	})
+
+	t.Run("growing segment", func(t *testing.T) {
+		segment := NewMockSegment(t)
+		segment.EXPECT().Type().Return(SegmentTypeGrowing)
+		segment.EXPECT().MemSize().Return(int64(100))
+		assert.EqualValues(t, 100, GetSegmentRelatedDataSize(segment))
+	})
+}
diff --git a/internal/querynodev2/tasks/query_task.go b/internal/querynodev2/tasks/query_task.go
index 668489a5d1..5c330c681a 100644
--- a/internal/querynodev2/tasks/query_task.go
+++ b/internal/querynodev2/tasks/query_task.go
@@ -144,7 +144,7 @@ func (t *QueryTask) Execute() error {
 	}
 
 	relatedDataSize := lo.Reduce(querySegments, func(acc int64, seg segments.Segment, _ int) int64 {
-		return acc + seg.MemSize()
+		return acc + segments.GetSegmentRelatedDataSize(seg)
 	}, 0)
 
 	t.result = &internalpb.RetrieveResults{
diff --git a/internal/querynodev2/tasks/search_task.go b/internal/querynodev2/tasks/search_task.go
index 3fc6c3c287..6b1f4e1f67 100644
--- a/internal/querynodev2/tasks/search_task.go
+++ b/internal/querynodev2/tasks/search_task.go
@@ -215,7 +215,7 @@ func (t *SearchTask) Execute() error {
 	}
 
 	relatedDataSize := lo.Reduce(searchedSegments, func(acc int64, seg segments.Segment, _ int) int64 {
-		return acc + seg.MemSize()
+		return acc + segments.GetSegmentRelatedDataSize(seg)
 	}, 0)
 
 	tr.RecordSpan()
@@ -445,6 +445,7 @@ func (t *StreamingSearchTask) Execute() error {
 
 	// 1. search&&reduce or streaming-search&&streaming-reduce
 	metricType := searchReq.Plan().GetMetricType()
+	var relatedDataSize int64
 	if req.GetScope() == querypb.DataScope_Historical {
 		streamReduceFunc := func(result *segments.SearchResult) error {
 			reduceErr := t.streamReduce(t.ctx, searchReq.Plan(), result, t.originNqs, t.originTopks)
@@ -470,6 +471,9 @@ func (t *StreamingSearchTask) Execute() error {
 			log.Error("Failed to get stream-reduced search result")
 			return err
 		}
+		relatedDataSize = lo.Reduce(pinnedSegments, func(acc int64, seg segments.Segment, _ int) int64 {
+			return acc + segments.GetSegmentRelatedDataSize(seg)
+		}, 0)
 	} else if req.GetScope() == querypb.DataScope_Streaming {
 		results, pinnedSegments, err := segments.SearchStreaming(
 			t.ctx,
@@ -507,6 +511,9 @@ func (t *StreamingSearchTask) Execute() error {
 			metrics.ReduceSegments,
 			metrics.BatchReduce).
 			Observe(float64(tr.RecordSpan().Milliseconds()))
+		relatedDataSize = lo.Reduce(pinnedSegments, func(acc int64, seg segments.Segment, _ int) int64 {
+			return acc + segments.GetSegmentRelatedDataSize(seg)
+		}, 0)
 	}
 
 	// 2. reorganize blobs to original search request
@@ -539,7 +546,8 @@ func (t *StreamingSearchTask) Execute() error {
 			SlicedOffset:   1,
 			SlicedNumCount: 1,
 			CostAggregation: &internalpb.CostAggregation{
-				ServiceTime: tr.ElapseSpan().Milliseconds(),
+				ServiceTime:          tr.ElapseSpan().Milliseconds(),
+				TotalRelatedDataSize: relatedDataSize,
 			},
 		}
 	}