From 9e0163e12f5473bb221b18cd0511ff2a2f381c69 Mon Sep 17 00:00:00 2001 From: yah01 Date: Thu, 4 Jan 2024 17:50:46 +0800 Subject: [PATCH] enhance: use GPU pool for gpu tasks (#29678) - this much improve the performance for GPU index Signed-off-by: yah01 --- internal/querynodev2/segments/collection.go | 21 +++++++++++++++ .../tasks/concurrent_safe_scheduler.go | 12 ++++++++- internal/querynodev2/tasks/mock_task_test.go | 4 +++ .../querynodev2/tasks/query_stream_task.go | 4 +++ internal/querynodev2/tasks/query_task.go | 4 +++ internal/querynodev2/tasks/task.go | 4 +++ internal/querynodev2/tasks/tasks.go | 3 +++ pkg/util/indexparamcheck/index_type.go | 8 ++++++ pkg/util/paramtable/component_param.go | 26 ++++++++++++------- 9 files changed, 76 insertions(+), 10 deletions(-) diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 2b36f7f7fb..cd7ad2ce25 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -29,13 +29,17 @@ import ( "unsafe" "github.com/golang/protobuf/proto" + "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -123,6 +127,7 @@ type Collection struct { loadType querypb.LoadType metricType atomic.String schema atomic.Pointer[schemapb.CollectionSchema] + isGpuIndex bool refCount *atomic.Uint32 } @@ -137,6 +142,11 @@ func (c *Collection) Schema() *schemapb.CollectionSchema { return c.schema.Load() } +// IsGpuIndex returns a boolean value indicating whether the collection is using a GPU index. +func (c *Collection) IsGpuIndex() bool { + return c.isGpuIndex +} + // getPartitionIDs return partitionIDs of collection func (c *Collection) GetPartitions() []int64 { return c.partitions.Collect() @@ -205,6 +215,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM collection := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob))) + isGpuIndex := false if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 { indexMetaBlob, err := proto.Marshal(indexMeta) if err != nil { @@ -212,6 +223,15 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM return nil } C.SetIndexMeta(collection, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob))) + + for _, indexMeta := range indexMeta.GetIndexMetas() { + isGpuIndex = lo.ContainsBy(indexMeta.GetIndexParams(), func(param *commonpb.KeyValuePair) bool { + return param.Key == common.IndexTypeKey && indexparamcheck.IsGpuIndex(param.Value) + }) + if isGpuIndex { + break + } + } } coll := &Collection{ @@ -220,6 +240,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM partitions: typeutil.NewConcurrentSet[int64](), loadType: loadType, refCount: atomic.NewUint32(0), + isGpuIndex: isGpuIndex, } coll.schema.Store(schema) diff --git a/internal/querynodev2/tasks/concurrent_safe_scheduler.go b/internal/querynodev2/tasks/concurrent_safe_scheduler.go index 7968cd172b..0a2396e8e3 100644 --- a/internal/querynodev2/tasks/concurrent_safe_scheduler.go +++ b/internal/querynodev2/tasks/concurrent_safe_scheduler.go @@ -30,6 +30,7 @@ func newScheduler(policy schedulePolicy) Scheduler { receiveChan: make(chan addTaskReq, maxReceiveChanSize), execChan: make(chan Task), pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)), + gpuPool: conc.NewPool[any](paramtable.Get().QueryNodeCfg.MaxGpuReadConcurrency.GetAsInt(), conc.WithPreAlloc(true)), schedulerCounter: schedulerCounter{}, lifetime: lifetime.NewLifetime(lifetime.Initializing), } @@ -46,6 +47,7 @@ type scheduler struct { receiveChan chan addTaskReq execChan chan Task pool *conc.Pool[any] + gpuPool *conc.Pool[any] // wg is the waitgroup for internal worker goroutine wg sync.WaitGroup @@ -227,7 +229,7 @@ func (s *scheduler) exec() { continue } - s.pool.Submit(func() (any, error) { + s.getPool(t).Submit(func() (any, error) { // Update concurrency metric and notify task done. metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1) @@ -245,6 +247,14 @@ func (s *scheduler) exec() { } } +func (s *scheduler) getPool(t Task) *conc.Pool[any] { + if t.IsGpuIndex() { + return s.gpuPool + } + + return s.pool +} + // setupExecListener setup the execChan and next task to run. func (s *scheduler) setupExecListener(lastWaitingTask Task) (Task, int64, chan Task) { var execChan chan Task diff --git a/internal/querynodev2/tasks/mock_task_test.go b/internal/querynodev2/tasks/mock_task_test.go index 7aac1aa24f..63b53f0a27 100644 --- a/internal/querynodev2/tasks/mock_task_test.go +++ b/internal/querynodev2/tasks/mock_task_test.go @@ -64,6 +64,10 @@ func (t *MockTask) Username() string { return t.username } +func (t *MockTask) IsGpuIndex() bool { + return false +} + func (t *MockTask) TimeRecorder() *timerecord.TimeRecorder { return t.tr } diff --git a/internal/querynodev2/tasks/query_stream_task.go b/internal/querynodev2/tasks/query_stream_task.go index 96149cc172..8a2c525578 100644 --- a/internal/querynodev2/tasks/query_stream_task.go +++ b/internal/querynodev2/tasks/query_stream_task.go @@ -41,6 +41,10 @@ func (t *QueryStreamTask) Username() string { return t.req.Req.GetUsername() } +func (t *QueryStreamTask) IsGpuIndex() bool { + return false +} + // PreExecute the task, only call once. func (t *QueryStreamTask) PreExecute() error { return nil diff --git a/internal/querynodev2/tasks/query_task.go b/internal/querynodev2/tasks/query_task.go index 7bc7a62956..0460be5223 100644 --- a/internal/querynodev2/tasks/query_task.go +++ b/internal/querynodev2/tasks/query_task.go @@ -51,6 +51,10 @@ func (t *QueryTask) Username() string { return t.req.Req.GetUsername() } +func (t *QueryTask) IsGpuIndex() bool { + return false +} + // PreExecute the task, only call once. func (t *QueryTask) PreExecute() error { // Update task wait time metric before execute diff --git a/internal/querynodev2/tasks/task.go b/internal/querynodev2/tasks/task.go index e86339a382..9fbf12545a 100644 --- a/internal/querynodev2/tasks/task.go +++ b/internal/querynodev2/tasks/task.go @@ -77,6 +77,10 @@ func (t *SearchTask) Username() string { return t.req.Req.GetUsername() } +func (t *SearchTask) IsGpuIndex() bool { + return t.collection.IsGpuIndex() +} + func (t *SearchTask) PreExecute() error { // Update task wait time metric before execute nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10) diff --git a/internal/querynodev2/tasks/tasks.go b/internal/querynodev2/tasks/tasks.go index 6a0d55b2ed..9032654195 100644 --- a/internal/querynodev2/tasks/tasks.go +++ b/internal/querynodev2/tasks/tasks.go @@ -82,6 +82,9 @@ type Task interface { // Return "" if the task do not contain any user info. Username() string + // Return whether the task would be running on GPU. + IsGpuIndex() bool + // PreExecute the task, only call once. PreExecute() error diff --git a/pkg/util/indexparamcheck/index_type.go b/pkg/util/indexparamcheck/index_type.go index dba1b6cdf3..ebef1bc7a6 100644 --- a/pkg/util/indexparamcheck/index_type.go +++ b/pkg/util/indexparamcheck/index_type.go @@ -16,6 +16,7 @@ type IndexType = string // IndexType definitions const ( + IndexGpuBF IndexType = "GPU_BRUTE_FORCE" IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT" IndexRaftIvfPQ IndexType = "GPU_IVF_PQ" IndexRaftCagra IndexType = "GPU_CAGRA" @@ -29,3 +30,10 @@ const ( IndexHNSW IndexType = "HNSW" IndexDISKANN IndexType = "DISKANN" ) + +func IsGpuIndex(indexType IndexType) bool { + return indexType == IndexGpuBF || + indexType == IndexRaftIvfFlat || + indexType == IndexRaftIvfPQ || + indexType == IndexRaftCagra +} diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 2a4f4f10dc..a6d9e83a76 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1754,15 +1754,16 @@ type queryNodeConfig struct { // chunk cache ReadAheadPolicy ParamItem `refreshable:"false"` - GroupEnabled ParamItem `refreshable:"true"` - MaxReceiveChanSize ParamItem `refreshable:"false"` - MaxUnsolvedQueueSize ParamItem `refreshable:"true"` - MaxReadConcurrency ParamItem `refreshable:"true"` - MaxGroupNQ ParamItem `refreshable:"true"` - TopKMergeRatio ParamItem `refreshable:"true"` - CPURatio ParamItem `refreshable:"true"` - MaxTimestampLag ParamItem `refreshable:"true"` - GCEnabled ParamItem `refreshable:"true"` + GroupEnabled ParamItem `refreshable:"true"` + MaxReceiveChanSize ParamItem `refreshable:"false"` + MaxUnsolvedQueueSize ParamItem `refreshable:"true"` + MaxReadConcurrency ParamItem `refreshable:"true"` + MaxGpuReadConcurrency ParamItem `refreshable:"false"` + MaxGroupNQ ParamItem `refreshable:"true"` + TopKMergeRatio ParamItem `refreshable:"true"` + CPURatio ParamItem `refreshable:"true"` + MaxTimestampLag ParamItem `refreshable:"true"` + GCEnabled ParamItem `refreshable:"true"` GCHelperEnabled ParamItem `refreshable:"false"` MinimumGOGCConfig ParamItem `refreshable:"false"` @@ -2000,6 +2001,13 @@ Max read concurrency must greater than or equal to 1, and less than or equal to } p.MaxReadConcurrency.Init(base.mgr) + p.MaxGpuReadConcurrency = ParamItem{ + Key: "queryNode.scheduler.maGpuReadConcurrency", + Version: "2.0.0", + DefaultValue: "8", + } + p.MaxGpuReadConcurrency.Init(base.mgr) + p.MaxUnsolvedQueueSize = ParamItem{ Key: "queryNode.scheduler.unsolvedQueueSize", Version: "2.0.0",