enhance: Merge query stream result for reduce delete task (#32855)

relate: https://github.com/milvus-io/milvus/issues/32854

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
pull/33411/head
aoiasd 2024-05-27 18:15:43 +08:00 committed by GitHub
parent 066c8ea175
commit 59a7a46904
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 225 additions and 3 deletions

View File

@ -377,6 +377,7 @@ queryNode:
maxQueueLength: 16 # Maximum length of task queue in flowgraph
maxParallelism: 1024 # Maximum number of tasks executed in parallel in the flowgraph
enableSegmentPrune: false # use partition prune function on shard delegator
queryStreamBatchSize: 4194304 # return batch size of stream query
ip: # if not specified, use the first unicastable address
port: 21123
grpc:

View File

@ -317,7 +317,7 @@ func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.Que
}
// Send task to scheduler and wait until it finished.
task := tasks.NewQueryStreamTask(ctx, collection, node.manager, req, srv)
task := tasks.NewQueryStreamTask(ctx, collection, node.manager, req, srv, node.streamBatchSzie)
if err := node.scheduler.Add(task); err != nil {
log.Warn("failed to add query task into scheduler", zap.Error(err))
return err

View File

@ -111,7 +111,8 @@ type QueryNode struct {
loader segments.Loader
// Search/Query
scheduler tasks.Scheduler
scheduler tasks.Scheduler
streamBatchSzie int
// etcd client
etcdCli *clientv3.Client
@ -328,6 +329,7 @@ func (node *QueryNode) Init() error {
node.scheduler = tasks.NewScheduler(
schedulePolicy,
)
node.streamBatchSzie = paramtable.Get().QueryNodeCfg.QueryStreamBatchSize.GetAsInt()
log.Info("queryNode init scheduler", zap.String("policy", schedulePolicy))
node.clusterManager = cluster.NewWorkerManager(func(ctx context.Context, nodeID int64) (cluster.Worker, error) {

View File

@ -16,6 +16,7 @@ func NewQueryStreamTask(ctx context.Context,
manager *segments.Manager,
req *querypb.QueryRequest,
srv streamrpc.QueryStreamServer,
streamBatchSize int,
) *QueryStreamTask {
return &QueryStreamTask{
ctx: ctx,
@ -23,6 +24,7 @@ func NewQueryStreamTask(ctx context.Context,
segmentManager: manager,
req: req,
srv: srv,
batchSize: streamBatchSize,
notifier: make(chan error, 1),
}
}
@ -33,6 +35,7 @@ type QueryStreamTask struct {
segmentManager *segments.Manager
req *querypb.QueryRequest
srv streamrpc.QueryStreamServer
batchSize int
notifier chan error
}
@ -64,7 +67,10 @@ func (t *QueryStreamTask) Execute() error {
}
defer retrievePlan.Delete()
segments, err := segments.RetrieveStream(t.ctx, t.segmentManager, retrievePlan, t.req, t.srv)
srv := streamrpc.NewResultCacheServer(t.srv, t.batchSize)
defer srv.Flush()
segments, err := segments.RetrieveStream(t.ctx, t.segmentManager, retrievePlan, t.req, srv)
defer t.segmentManager.Segment.Unpin(segments)
if err != nil {
return err

View File

@ -5,8 +5,10 @@ import (
"io"
"sync"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
@ -42,6 +44,123 @@ func NewConcurrentQueryStreamServer(srv QueryStreamServer) *ConcurrentQueryStrea
}
}
type RetrieveResultCache struct {
result *internalpb.RetrieveResults
size int
cap int
}
func (c *RetrieveResultCache) Put(result *internalpb.RetrieveResults) {
if c.result == nil {
c.result = result
c.size = proto.Size(result)
return
}
c.merge(result)
}
func (c *RetrieveResultCache) Flush() *internalpb.RetrieveResults {
result := c.result
c.result = nil
c.size = 0
return result
}
func (c *RetrieveResultCache) Alloc(result *internalpb.RetrieveResults) bool {
return proto.Size(result)+c.size <= c.cap
}
func (c *RetrieveResultCache) IsFull() bool {
return c.size > c.cap
}
func (c *RetrieveResultCache) IsEmpty() bool {
return c.size == 0
}
func (c *RetrieveResultCache) merge(result *internalpb.RetrieveResults) {
switch result.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
c.result.GetIds().GetIntId().Data = append(c.result.GetIds().GetIntId().GetData(), result.GetIds().GetIntId().GetData()...)
case *schemapb.IDs_StrId:
c.result.GetIds().GetStrId().Data = append(c.result.GetIds().GetStrId().GetData(), result.GetIds().GetStrId().GetData()...)
}
c.result.AllRetrieveCount = c.result.AllRetrieveCount + result.AllRetrieveCount
c.result.CostAggregation = mergeCostAggregation(c.result.GetCostAggregation(), result.GetCostAggregation())
c.size = proto.Size(c.result)
}
func mergeCostAggregation(a *internalpb.CostAggregation, b *internalpb.CostAggregation) *internalpb.CostAggregation {
if a == nil {
return b
}
if b == nil {
return a
}
return &internalpb.CostAggregation{
ResponseTime: a.GetResponseTime() + b.GetResponseTime(),
ServiceTime: a.GetServiceTime() + b.GetServiceTime(),
TotalNQ: a.GetTotalNQ() + b.GetTotalNQ(),
TotalRelatedDataSize: a.GetTotalRelatedDataSize() + b.GetTotalRelatedDataSize(),
}
}
// Merge result by size and time.
type ResultCacheServer struct {
srv QueryStreamServer
cache *RetrieveResultCache
mu sync.Mutex
}
func NewResultCacheServer(srv QueryStreamServer, cap int) *ResultCacheServer {
return &ResultCacheServer{
srv: srv,
cache: &RetrieveResultCache{cap: cap},
}
}
func (s *ResultCacheServer) Send(result *internalpb.RetrieveResults) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.cache.Alloc(result) && !s.cache.IsEmpty() {
result := s.cache.Flush()
if err := s.srv.Send(result); err != nil {
return err
}
}
s.cache.Put(result)
if s.cache.IsFull() {
result := s.cache.Flush()
if err := s.srv.Send(result); err != nil {
return err
}
}
return nil
}
func (s *ResultCacheServer) Flush() error {
s.mu.Lock()
defer s.mu.Unlock()
result := s.cache.Flush()
if result == nil {
return nil
}
if err := s.srv.Send(result); err != nil {
return err
}
return nil
}
func (s *ResultCacheServer) Context() context.Context {
return s.srv.Context()
}
// TODO LOCAL SERVER AND CLIENT FOR STANDALONE
// ONLY FOR TEST
type LocalQueryServer struct {

View File

@ -0,0 +1,84 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package streamrpc
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
type ResultCacheServerSuite struct {
suite.Suite
}
func (s *ResultCacheServerSuite) TestSend() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client := NewLocalQueryClient(ctx)
srv := client.CreateServer()
cacheSrv := NewResultCacheServer(srv, 1024)
err := cacheSrv.Send(&internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}}},
},
})
s.NoError(err)
s.False(cacheSrv.cache.IsEmpty())
err = cacheSrv.Send(&internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{4, 5, 6}}},
},
})
s.NoError(err)
err = cacheSrv.Flush()
s.NoError(err)
s.True(cacheSrv.cache.IsEmpty())
msg, err := client.Recv()
s.NoError(err)
// Data: []int64{1,2,3,4,5,6}
s.Equal(6, len(msg.GetIds().GetIntId().GetData()))
}
func (s *ResultCacheServerSuite) TestMerge() {
s.Nil(mergeCostAggregation(nil, nil))
cost := &internalpb.CostAggregation{}
s.Equal(cost, mergeCostAggregation(nil, cost))
s.Equal(cost, mergeCostAggregation(cost, nil))
a := &internalpb.CostAggregation{ResponseTime: 1, ServiceTime: 1, TotalNQ: 1, TotalRelatedDataSize: 1}
b := &internalpb.CostAggregation{ResponseTime: 2, ServiceTime: 2, TotalNQ: 2, TotalRelatedDataSize: 2}
c := mergeCostAggregation(a, b)
s.Equal(int64(3), c.ResponseTime)
s.Equal(int64(3), c.ServiceTime)
s.Equal(int64(3), c.TotalNQ)
s.Equal(int64(3), c.TotalRelatedDataSize)
}
func TestResultCacheServerSuite(t *testing.T) {
suite.Run(t, new(ResultCacheServerSuite))
}

View File

@ -2100,6 +2100,7 @@ type queryNodeConfig struct {
EnableSegmentPrune ParamItem `refreshable:"false"`
DefaultSegmentFilterRatio ParamItem `refreshable:"false"`
UseStreamComputing ParamItem `refreshable:"false"`
QueryStreamBatchSize ParamItem `refreshable:"false"`
}
func (p *queryNodeConfig) init(base *BaseTable) {
@ -2683,6 +2684,15 @@ user-task-polling:
Doc: "use stream search mode when searching or querying",
}
p.UseStreamComputing.Init(base.mgr)
p.QueryStreamBatchSize = ParamItem{
Key: "queryNode.queryStreamBatchSize",
Version: "2.4.1",
DefaultValue: "4194304",
Doc: "return batch size of stream query",
Export: true,
}
p.QueryStreamBatchSize.Init(base.mgr)
}
// /////////////////////////////////////////////////////////////////////////////