enhance: use GPU pool for gpu tasks (#29678)

- this much improve the performance for GPU index

Signed-off-by: yah01 <yang.cen@zilliz.com>
pull/29138/merge
yah01 2024-01-04 17:50:46 +08:00 committed by GitHub
parent 4f8c540c77
commit 9e0163e12f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 76 additions and 10 deletions

View File

@ -29,13 +29,17 @@ import (
"unsafe"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -123,6 +127,7 @@ type Collection struct {
loadType querypb.LoadType
metricType atomic.String
schema atomic.Pointer[schemapb.CollectionSchema]
isGpuIndex bool
refCount *atomic.Uint32
}
@ -137,6 +142,11 @@ func (c *Collection) Schema() *schemapb.CollectionSchema {
return c.schema.Load()
}
// IsGpuIndex returns a boolean value indicating whether the collection is using a GPU index.
func (c *Collection) IsGpuIndex() bool {
return c.isGpuIndex
}
// getPartitionIDs return partitionIDs of collection
func (c *Collection) GetPartitions() []int64 {
return c.partitions.Collect()
@ -205,6 +215,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
collection := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob)))
isGpuIndex := false
if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 {
indexMetaBlob, err := proto.Marshal(indexMeta)
if err != nil {
@ -212,6 +223,15 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
return nil
}
C.SetIndexMeta(collection, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob)))
for _, indexMeta := range indexMeta.GetIndexMetas() {
isGpuIndex = lo.ContainsBy(indexMeta.GetIndexParams(), func(param *commonpb.KeyValuePair) bool {
return param.Key == common.IndexTypeKey && indexparamcheck.IsGpuIndex(param.Value)
})
if isGpuIndex {
break
}
}
}
coll := &Collection{
@ -220,6 +240,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
partitions: typeutil.NewConcurrentSet[int64](),
loadType: loadType,
refCount: atomic.NewUint32(0),
isGpuIndex: isGpuIndex,
}
coll.schema.Store(schema)

View File

@ -30,6 +30,7 @@ func newScheduler(policy schedulePolicy) Scheduler {
receiveChan: make(chan addTaskReq, maxReceiveChanSize),
execChan: make(chan Task),
pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)),
gpuPool: conc.NewPool[any](paramtable.Get().QueryNodeCfg.MaxGpuReadConcurrency.GetAsInt(), conc.WithPreAlloc(true)),
schedulerCounter: schedulerCounter{},
lifetime: lifetime.NewLifetime(lifetime.Initializing),
}
@ -46,6 +47,7 @@ type scheduler struct {
receiveChan chan addTaskReq
execChan chan Task
pool *conc.Pool[any]
gpuPool *conc.Pool[any]
// wg is the waitgroup for internal worker goroutine
wg sync.WaitGroup
@ -227,7 +229,7 @@ func (s *scheduler) exec() {
continue
}
s.pool.Submit(func() (any, error) {
s.getPool(t).Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)
@ -245,6 +247,14 @@ func (s *scheduler) exec() {
}
}
func (s *scheduler) getPool(t Task) *conc.Pool[any] {
if t.IsGpuIndex() {
return s.gpuPool
}
return s.pool
}
// setupExecListener setup the execChan and next task to run.
func (s *scheduler) setupExecListener(lastWaitingTask Task) (Task, int64, chan Task) {
var execChan chan Task

View File

@ -64,6 +64,10 @@ func (t *MockTask) Username() string {
return t.username
}
func (t *MockTask) IsGpuIndex() bool {
return false
}
func (t *MockTask) TimeRecorder() *timerecord.TimeRecorder {
return t.tr
}

View File

@ -41,6 +41,10 @@ func (t *QueryStreamTask) Username() string {
return t.req.Req.GetUsername()
}
func (t *QueryStreamTask) IsGpuIndex() bool {
return false
}
// PreExecute the task, only call once.
func (t *QueryStreamTask) PreExecute() error {
return nil

View File

@ -51,6 +51,10 @@ func (t *QueryTask) Username() string {
return t.req.Req.GetUsername()
}
func (t *QueryTask) IsGpuIndex() bool {
return false
}
// PreExecute the task, only call once.
func (t *QueryTask) PreExecute() error {
// Update task wait time metric before execute

View File

@ -77,6 +77,10 @@ func (t *SearchTask) Username() string {
return t.req.Req.GetUsername()
}
func (t *SearchTask) IsGpuIndex() bool {
return t.collection.IsGpuIndex()
}
func (t *SearchTask) PreExecute() error {
// Update task wait time metric before execute
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)

View File

@ -82,6 +82,9 @@ type Task interface {
// Return "" if the task do not contain any user info.
Username() string
// Return whether the task would be running on GPU.
IsGpuIndex() bool
// PreExecute the task, only call once.
PreExecute() error

View File

@ -16,6 +16,7 @@ type IndexType = string
// IndexType definitions
const (
IndexGpuBF IndexType = "GPU_BRUTE_FORCE"
IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT"
IndexRaftIvfPQ IndexType = "GPU_IVF_PQ"
IndexRaftCagra IndexType = "GPU_CAGRA"
@ -29,3 +30,10 @@ const (
IndexHNSW IndexType = "HNSW"
IndexDISKANN IndexType = "DISKANN"
)
func IsGpuIndex(indexType IndexType) bool {
return indexType == IndexGpuBF ||
indexType == IndexRaftIvfFlat ||
indexType == IndexRaftIvfPQ ||
indexType == IndexRaftCagra
}

View File

@ -1758,6 +1758,7 @@ type queryNodeConfig struct {
MaxReceiveChanSize ParamItem `refreshable:"false"`
MaxUnsolvedQueueSize ParamItem `refreshable:"true"`
MaxReadConcurrency ParamItem `refreshable:"true"`
MaxGpuReadConcurrency ParamItem `refreshable:"false"`
MaxGroupNQ ParamItem `refreshable:"true"`
TopKMergeRatio ParamItem `refreshable:"true"`
CPURatio ParamItem `refreshable:"true"`
@ -2000,6 +2001,13 @@ Max read concurrency must greater than or equal to 1, and less than or equal to
}
p.MaxReadConcurrency.Init(base.mgr)
p.MaxGpuReadConcurrency = ParamItem{
Key: "queryNode.scheduler.maGpuReadConcurrency",
Version: "2.0.0",
DefaultValue: "8",
}
p.MaxGpuReadConcurrency.Init(base.mgr)
p.MaxUnsolvedQueueSize = ParamItem{
Key: "queryNode.scheduler.unsolvedQueueSize",
Version: "2.0.0",