diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 91cfcc554a..5d630ebc0e 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -377,6 +377,7 @@ queryNode: maxQueueLength: 16 # Maximum length of task queue in flowgraph maxParallelism: 1024 # Maximum number of tasks executed in parallel in the flowgraph enableSegmentPrune: false # use partition prune function on shard delegator + queryStreamBatchSize: 4194304 # return batch size of stream query ip: # if not specified, use the first unicastable address port: 21123 grpc: diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 1b25fd7c25..170af4c39e 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -317,7 +317,7 @@ func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.Que } // Send task to scheduler and wait until it finished. - task := tasks.NewQueryStreamTask(ctx, collection, node.manager, req, srv) + task := tasks.NewQueryStreamTask(ctx, collection, node.manager, req, srv, node.streamBatchSzie) if err := node.scheduler.Add(task); err != nil { log.Warn("failed to add query task into scheduler", zap.Error(err)) return err diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index d142d72f2b..c9a3d5cf42 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -111,7 +111,8 @@ type QueryNode struct { loader segments.Loader // Search/Query - scheduler tasks.Scheduler + scheduler tasks.Scheduler + streamBatchSzie int // etcd client etcdCli *clientv3.Client @@ -328,6 +329,7 @@ func (node *QueryNode) Init() error { node.scheduler = tasks.NewScheduler( schedulePolicy, ) + node.streamBatchSzie = paramtable.Get().QueryNodeCfg.QueryStreamBatchSize.GetAsInt() log.Info("queryNode init scheduler", zap.String("policy", schedulePolicy)) node.clusterManager = cluster.NewWorkerManager(func(ctx context.Context, nodeID int64) (cluster.Worker, error) { diff --git a/internal/querynodev2/tasks/query_stream_task.go b/internal/querynodev2/tasks/query_stream_task.go index 5840efa6c1..6c85535bbe 100644 --- a/internal/querynodev2/tasks/query_stream_task.go +++ b/internal/querynodev2/tasks/query_stream_task.go @@ -16,6 +16,7 @@ func NewQueryStreamTask(ctx context.Context, manager *segments.Manager, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer, + streamBatchSize int, ) *QueryStreamTask { return &QueryStreamTask{ ctx: ctx, @@ -23,6 +24,7 @@ func NewQueryStreamTask(ctx context.Context, segmentManager: manager, req: req, srv: srv, + batchSize: streamBatchSize, notifier: make(chan error, 1), } } @@ -33,6 +35,7 @@ type QueryStreamTask struct { segmentManager *segments.Manager req *querypb.QueryRequest srv streamrpc.QueryStreamServer + batchSize int notifier chan error } @@ -64,7 +67,10 @@ func (t *QueryStreamTask) Execute() error { } defer retrievePlan.Delete() - segments, err := segments.RetrieveStream(t.ctx, t.segmentManager, retrievePlan, t.req, t.srv) + srv := streamrpc.NewResultCacheServer(t.srv, t.batchSize) + defer srv.Flush() + + segments, err := segments.RetrieveStream(t.ctx, t.segmentManager, retrievePlan, t.req, srv) defer t.segmentManager.Segment.Unpin(segments) if err != nil { return err diff --git a/internal/util/streamrpc/streamer.go b/internal/util/streamrpc/streamer.go index 53571672ee..79f47c8bc3 100644 --- a/internal/util/streamrpc/streamer.go +++ b/internal/util/streamrpc/streamer.go @@ -5,8 +5,10 @@ import ( "io" "sync" + "github.com/golang/protobuf/proto" "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" ) @@ -42,6 +44,123 @@ func NewConcurrentQueryStreamServer(srv QueryStreamServer) *ConcurrentQueryStrea } } +type RetrieveResultCache struct { + result *internalpb.RetrieveResults + size int + cap int +} + +func (c *RetrieveResultCache) Put(result *internalpb.RetrieveResults) { + if c.result == nil { + c.result = result + c.size = proto.Size(result) + return + } + + c.merge(result) +} + +func (c *RetrieveResultCache) Flush() *internalpb.RetrieveResults { + result := c.result + c.result = nil + c.size = 0 + return result +} + +func (c *RetrieveResultCache) Alloc(result *internalpb.RetrieveResults) bool { + return proto.Size(result)+c.size <= c.cap +} + +func (c *RetrieveResultCache) IsFull() bool { + return c.size > c.cap +} + +func (c *RetrieveResultCache) IsEmpty() bool { + return c.size == 0 +} + +func (c *RetrieveResultCache) merge(result *internalpb.RetrieveResults) { + switch result.GetIds().GetIdField().(type) { + case *schemapb.IDs_IntId: + c.result.GetIds().GetIntId().Data = append(c.result.GetIds().GetIntId().GetData(), result.GetIds().GetIntId().GetData()...) + case *schemapb.IDs_StrId: + c.result.GetIds().GetStrId().Data = append(c.result.GetIds().GetStrId().GetData(), result.GetIds().GetStrId().GetData()...) + } + c.result.AllRetrieveCount = c.result.AllRetrieveCount + result.AllRetrieveCount + c.result.CostAggregation = mergeCostAggregation(c.result.GetCostAggregation(), result.GetCostAggregation()) + c.size = proto.Size(c.result) +} + +func mergeCostAggregation(a *internalpb.CostAggregation, b *internalpb.CostAggregation) *internalpb.CostAggregation { + if a == nil { + return b + } + if b == nil { + return a + } + + return &internalpb.CostAggregation{ + ResponseTime: a.GetResponseTime() + b.GetResponseTime(), + ServiceTime: a.GetServiceTime() + b.GetServiceTime(), + TotalNQ: a.GetTotalNQ() + b.GetTotalNQ(), + TotalRelatedDataSize: a.GetTotalRelatedDataSize() + b.GetTotalRelatedDataSize(), + } +} + +// Merge result by size and time. +type ResultCacheServer struct { + srv QueryStreamServer + cache *RetrieveResultCache + mu sync.Mutex +} + +func NewResultCacheServer(srv QueryStreamServer, cap int) *ResultCacheServer { + return &ResultCacheServer{ + srv: srv, + cache: &RetrieveResultCache{cap: cap}, + } +} + +func (s *ResultCacheServer) Send(result *internalpb.RetrieveResults) error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.cache.Alloc(result) && !s.cache.IsEmpty() { + result := s.cache.Flush() + if err := s.srv.Send(result); err != nil { + return err + } + } + + s.cache.Put(result) + if s.cache.IsFull() { + result := s.cache.Flush() + if err := s.srv.Send(result); err != nil { + return err + } + } + return nil +} + +func (s *ResultCacheServer) Flush() error { + s.mu.Lock() + defer s.mu.Unlock() + + result := s.cache.Flush() + if result == nil { + return nil + } + + if err := s.srv.Send(result); err != nil { + return err + } + return nil +} + +func (s *ResultCacheServer) Context() context.Context { + return s.srv.Context() +} + // TODO LOCAL SERVER AND CLIENT FOR STANDALONE // ONLY FOR TEST type LocalQueryServer struct { diff --git a/internal/util/streamrpc/streamer_test.go b/internal/util/streamrpc/streamer_test.go new file mode 100644 index 0000000000..de1482adb9 --- /dev/null +++ b/internal/util/streamrpc/streamer_test.go @@ -0,0 +1,84 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamrpc + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" +) + +type ResultCacheServerSuite struct { + suite.Suite +} + +func (s *ResultCacheServerSuite) TestSend() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := NewLocalQueryClient(ctx) + srv := client.CreateServer() + cacheSrv := NewResultCacheServer(srv, 1024) + + err := cacheSrv.Send(&internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}}}, + }, + }) + s.NoError(err) + s.False(cacheSrv.cache.IsEmpty()) + + err = cacheSrv.Send(&internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{4, 5, 6}}}, + }, + }) + s.NoError(err) + + err = cacheSrv.Flush() + s.NoError(err) + s.True(cacheSrv.cache.IsEmpty()) + + msg, err := client.Recv() + s.NoError(err) + // Data: []int64{1,2,3,4,5,6} + s.Equal(6, len(msg.GetIds().GetIntId().GetData())) +} + +func (s *ResultCacheServerSuite) TestMerge() { + s.Nil(mergeCostAggregation(nil, nil)) + + cost := &internalpb.CostAggregation{} + s.Equal(cost, mergeCostAggregation(nil, cost)) + s.Equal(cost, mergeCostAggregation(cost, nil)) + + a := &internalpb.CostAggregation{ResponseTime: 1, ServiceTime: 1, TotalNQ: 1, TotalRelatedDataSize: 1} + b := &internalpb.CostAggregation{ResponseTime: 2, ServiceTime: 2, TotalNQ: 2, TotalRelatedDataSize: 2} + c := mergeCostAggregation(a, b) + s.Equal(int64(3), c.ResponseTime) + s.Equal(int64(3), c.ServiceTime) + s.Equal(int64(3), c.TotalNQ) + s.Equal(int64(3), c.TotalRelatedDataSize) +} + +func TestResultCacheServerSuite(t *testing.T) { + suite.Run(t, new(ResultCacheServerSuite)) +} diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 317ad937f8..c74bbefd35 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2100,6 +2100,7 @@ type queryNodeConfig struct { EnableSegmentPrune ParamItem `refreshable:"false"` DefaultSegmentFilterRatio ParamItem `refreshable:"false"` UseStreamComputing ParamItem `refreshable:"false"` + QueryStreamBatchSize ParamItem `refreshable:"false"` } func (p *queryNodeConfig) init(base *BaseTable) { @@ -2683,6 +2684,15 @@ user-task-polling: Doc: "use stream search mode when searching or querying", } p.UseStreamComputing.Init(base.mgr) + + p.QueryStreamBatchSize = ParamItem{ + Key: "queryNode.queryStreamBatchSize", + Version: "2.4.1", + DefaultValue: "4194304", + Doc: "return batch size of stream query", + Export: true, + } + p.QueryStreamBatchSize.Init(base.mgr) } // /////////////////////////////////////////////////////////////////////////////