fix: split stream query result to avoid grpc response too large error (#36090)

relate: https://github.com/milvus-io/milvus/issues/36089

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
pull/36250/head
aoiasd 2024-09-13 15:07:09 +08:00 committed by GitHub
parent f0f2fb4cf0
commit c22a2cebb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 188 additions and 22 deletions

View File

@ -459,7 +459,8 @@ queryNode:
maxQueueLength: 16 # The maximum size of task queue cache in flow graph in query node.
maxParallelism: 1024 # Maximum number of tasks executed in parallel in the flowgraph
enableSegmentPrune: false # use partition stats to prune data in search/query on shard delegator
queryStreamBatchSize: 4194304 # return batch size of stream query
queryStreamBatchSize: 4194304 # return min batch size of stream query
queryStreamMaxBatchSize: 134217728 # return max batch size of stream query
bloomFilterApplyParallelFactor: 4 # parallel factor when to apply pk to bloom filter, default to 4*CPU_CORE_NUM
workerPooling:
size: 10 # the size for worker querynode client pool

View File

@ -38,6 +38,7 @@ import (
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
@ -318,7 +319,10 @@ func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.Que
}
// Send task to scheduler and wait until it finished.
task := tasks.NewQueryStreamTask(ctx, collection, node.manager, req, srv, node.streamBatchSzie)
task := tasks.NewQueryStreamTask(ctx, collection, node.manager, req, srv,
paramtable.Get().QueryNodeCfg.QueryStreamBatchSize.GetAsInt(),
paramtable.Get().QueryNodeCfg.QueryStreamMaxBatchSize.GetAsInt())
if err := node.scheduler.Add(task); err != nil {
log.Warn("failed to add query task into scheduler", zap.Error(err))
return err

View File

@ -109,8 +109,7 @@ type QueryNode struct {
loader segments.Loader
// Search/Query
scheduler tasks.Scheduler
streamBatchSzie int
scheduler tasks.Scheduler
// etcd client
etcdCli *clientv3.Client
@ -316,9 +315,8 @@ func (node *QueryNode) Init() error {
node.scheduler = tasks.NewScheduler(
schedulePolicy,
)
node.streamBatchSzie = paramtable.Get().QueryNodeCfg.QueryStreamBatchSize.GetAsInt()
log.Info("queryNode init scheduler", zap.String("policy", schedulePolicy))
log.Info("queryNode init scheduler", zap.String("policy", schedulePolicy))
node.clusterManager = cluster.NewWorkerManager(func(ctx context.Context, nodeID int64) (cluster.Worker, error) {
if nodeID == node.GetNodeID() {
return NewLocalWorker(node), nil

View File

@ -16,7 +16,8 @@ func NewQueryStreamTask(ctx context.Context,
manager *segments.Manager,
req *querypb.QueryRequest,
srv streamrpc.QueryStreamServer,
streamBatchSize int,
minMsgSize int,
maxMsgSize int,
) *QueryStreamTask {
return &QueryStreamTask{
ctx: ctx,
@ -24,7 +25,8 @@ func NewQueryStreamTask(ctx context.Context,
segmentManager: manager,
req: req,
srv: srv,
batchSize: streamBatchSize,
minMsgSize: minMsgSize,
maxMsgSize: maxMsgSize,
notifier: make(chan error, 1),
}
}
@ -35,7 +37,8 @@ type QueryStreamTask struct {
segmentManager *segments.Manager
req *querypb.QueryRequest
srv streamrpc.QueryStreamServer
batchSize int
minMsgSize int
maxMsgSize int
notifier chan error
}
@ -67,7 +70,7 @@ func (t *QueryStreamTask) Execute() error {
}
defer retrievePlan.Delete()
srv := streamrpc.NewResultCacheServer(t.srv, t.batchSize)
srv := streamrpc.NewResultCacheServer(t.srv, t.minMsgSize, t.maxMsgSize)
defer srv.Flush()
segments, err := segments.RetrieveStream(t.ctx, t.segmentManager, retrievePlan, t.req, srv)

View File

@ -10,6 +10,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type QueryStreamServer interface {
@ -102,25 +103,67 @@ func mergeCostAggregation(a *internalpb.CostAggregation, b *internalpb.CostAggre
return &internalpb.CostAggregation{
ResponseTime: a.GetResponseTime() + b.GetResponseTime(),
ServiceTime: a.GetServiceTime() + b.GetServiceTime(),
TotalNQ: a.GetTotalNQ() + b.GetTotalNQ(),
TotalNQ: a.GetTotalNQ(),
TotalRelatedDataSize: a.GetTotalRelatedDataSize() + b.GetTotalRelatedDataSize(),
}
}
// Merge result by size and time.
type ResultCacheServer struct {
srv QueryStreamServer
cache *RetrieveResultCache
mu sync.Mutex
mu sync.Mutex
srv QueryStreamServer
cache *RetrieveResultCache
maxMsgSize int
}
func NewResultCacheServer(srv QueryStreamServer, cap int) *ResultCacheServer {
func NewResultCacheServer(srv QueryStreamServer, cap int, maxMsgSize int) *ResultCacheServer {
return &ResultCacheServer{
srv: srv,
cache: &RetrieveResultCache{cap: cap},
srv: srv,
cache: &RetrieveResultCache{cap: cap},
maxMsgSize: maxMsgSize,
}
}
func (s *ResultCacheServer) splitMsgToMaxSize(result *internalpb.RetrieveResults) []*internalpb.RetrieveResults {
newpks := make([]*schemapb.IDs, 0)
switch result.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
pks := result.GetIds().GetIntId().Data
batch := s.maxMsgSize / 8
print(batch)
for start := 0; start < len(pks); start += batch {
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: pks[start:min(start+batch, len(pks))]}}})
}
case *schemapb.IDs_StrId:
pks := result.GetIds().GetStrId().Data
start := 0
size := 0
for i, pk := range pks {
if size+len(pk) > s.maxMsgSize {
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: pks[start:i]}}})
start = i
size = 0
}
size += len(pk)
}
if size > 0 {
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: pks[start:]}}})
}
}
results := make([]*internalpb.RetrieveResults, len(newpks))
for i, pks := range newpks {
results[i] = &internalpb.RetrieveResults{
Status: merr.Status(nil),
Ids: pks,
}
}
results[len(results)-1].AllRetrieveCount = result.AllRetrieveCount
results[len(results)-1].CostAggregation = result.CostAggregation
return results
}
func (s *ResultCacheServer) Send(result *internalpb.RetrieveResults) error {
s.mu.Lock()
defer s.mu.Unlock()
@ -133,11 +176,23 @@ func (s *ResultCacheServer) Send(result *internalpb.RetrieveResults) error {
}
s.cache.Put(result)
if s.cache.IsFull() {
if s.cache.IsFull() && s.cache.size <= s.maxMsgSize {
result := s.cache.Flush()
if err := s.srv.Send(result); err != nil {
return err
}
} else if s.cache.IsFull() && s.cache.size > s.maxMsgSize {
results := s.splitMsgToMaxSize(s.cache.Flush())
if proto.Size(results[len(results)-1]) < s.cache.cap {
s.cache.Put(results[len(results)-1])
results = results[:len(results)-1]
}
for _, result := range results {
if err := s.srv.Send(result); err != nil {
return err
}
}
}
return nil
}

View File

@ -18,6 +18,9 @@ package streamrpc
import (
"context"
"io"
"math"
"strconv"
"testing"
"github.com/stretchr/testify/suite"
@ -36,7 +39,7 @@ func (s *ResultCacheServerSuite) TestSend() {
client := NewLocalQueryClient(ctx)
srv := client.CreateServer()
cacheSrv := NewResultCacheServer(srv, 1024)
cacheSrv := NewResultCacheServer(srv, 1024, math.MaxInt)
err := cacheSrv.Send(&internalpb.RetrieveResults{
Ids: &schemapb.IDs{
@ -63,6 +66,98 @@ func (s *ResultCacheServerSuite) TestSend() {
s.Equal(6, len(msg.GetIds().GetIntId().GetData()))
}
func generateIntIds(num int) *schemapb.IDs {
data := make([]int64, num)
for i := 0; i < num; i++ {
data[i] = int64(i)
}
return &schemapb.IDs{
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: data}},
}
}
func generateStrIds(num int) *schemapb.IDs {
data := make([]string, num)
for i := 0; i < num; i++ {
data[i] = strconv.FormatInt(int64(i), 10)
}
return &schemapb.IDs{
IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: data}},
}
}
func (s *ResultCacheServerSuite) TestSplit() {
s.Run("split int64 message", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client := NewLocalQueryClient(ctx)
srv := client.CreateServer()
cacheSrv := NewResultCacheServer(srv, 1024, 1024)
err := cacheSrv.Send(&internalpb.RetrieveResults{
Ids: generateIntIds(1024),
})
s.NoError(err)
err = cacheSrv.Flush()
s.NoError(err)
srv.FinishSend(nil)
rev := 0
for {
result, err := client.Recv()
if err != nil {
s.Equal(err, io.EOF)
break
}
cnt := len(result.Ids.GetIntId().GetData())
rev += cnt
s.LessOrEqual(4*cnt, 1024)
}
s.Equal(1024, rev)
})
s.Run("split string message", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client := NewLocalQueryClient(ctx)
srv := client.CreateServer()
cacheSrv := NewResultCacheServer(srv, 1024, 1024)
err := cacheSrv.Send(&internalpb.RetrieveResults{
Ids: generateStrIds(2048),
})
s.NoError(err)
err = cacheSrv.Flush()
s.NoError(err)
srv.FinishSend(nil)
rev := 0
for {
result, err := client.Recv()
if err != nil {
s.Equal(err, io.EOF)
break
}
rev += len(result.Ids.GetStrId().GetData())
size := 0
for _, str := range result.Ids.GetStrId().GetData() {
size += len(str)
}
s.LessOrEqual(size, 1024)
}
s.Equal(rev, 2048)
})
}
func (s *ResultCacheServerSuite) TestMerge() {
s.Nil(mergeCostAggregation(nil, nil))
@ -70,12 +165,12 @@ func (s *ResultCacheServerSuite) TestMerge() {
s.Equal(cost, mergeCostAggregation(nil, cost))
s.Equal(cost, mergeCostAggregation(cost, nil))
a := &internalpb.CostAggregation{ResponseTime: 1, ServiceTime: 1, TotalNQ: 1, TotalRelatedDataSize: 1}
a := &internalpb.CostAggregation{ResponseTime: 1, ServiceTime: 1, TotalNQ: 2, TotalRelatedDataSize: 1}
b := &internalpb.CostAggregation{ResponseTime: 2, ServiceTime: 2, TotalNQ: 2, TotalRelatedDataSize: 2}
c := mergeCostAggregation(a, b)
s.Equal(int64(3), c.ResponseTime)
s.Equal(int64(3), c.ServiceTime)
s.Equal(int64(3), c.TotalNQ)
s.Equal(int64(2), c.TotalNQ)
s.Equal(int64(3), c.TotalRelatedDataSize)
}

View File

@ -2404,6 +2404,7 @@ type queryNodeConfig struct {
DefaultSegmentFilterRatio ParamItem `refreshable:"false"`
UseStreamComputing ParamItem `refreshable:"false"`
QueryStreamBatchSize ParamItem `refreshable:"false"`
QueryStreamMaxBatchSize ParamItem `refreshable:"false"`
BloomFilterApplyParallelFactor ParamItem `refreshable:"true"`
// worker
@ -3108,11 +3109,20 @@ user-task-polling:
Key: "queryNode.queryStreamBatchSize",
Version: "2.4.1",
DefaultValue: "4194304",
Doc: "return batch size of stream query",
Doc: "return min batch size of stream query",
Export: true,
}
p.QueryStreamBatchSize.Init(base.mgr)
p.QueryStreamMaxBatchSize = ParamItem{
Key: "queryNode.queryStreamMaxBatchSize",
Version: "2.4.10",
DefaultValue: "134217728",
Doc: "return max batch size of stream query",
Export: true,
}
p.QueryStreamMaxBatchSize.Init(base.mgr)
p.BloomFilterApplyParallelFactor = ParamItem{
Key: "queryNode.bloomFilterApplyParallelFactor",
FallbackKeys: []string{"queryNode.bloomFilterApplyBatchSize"},