Add cgo worker pool for querynode (#18461)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/18590/head
congqixia 2022-08-09 16:34:37 +08:00 committed by GitHub
parent 002f509808
commit 179b496824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 267 additions and 67 deletions

View File

@ -38,6 +38,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/concurrency"
)
// ReplicaInterface specifies all the methods that the Collection object needs to implement in QueryNode.
@ -125,6 +126,8 @@ type metaReplica struct {
sealedSegments map[UniqueID]*Segment
excludedSegments map[UniqueID][]*datapb.SegmentInfo // map[collectionID]segmentIDs
cgoPool *concurrency.Pool
}
// getSegmentsMemSize get the memory size in bytes of all the Segments
@ -530,7 +533,7 @@ func (replica *metaReplica) addSegment(segmentID UniqueID, partitionID UniqueID,
if err != nil {
return err
}
seg, err := newSegment(collection, segmentID, partitionID, collectionID, vChannelID, segType)
seg, err := newSegment(collection, segmentID, partitionID, collectionID, vChannelID, segType, replica.cgoPool)
if err != nil {
return err
}
@ -747,7 +750,7 @@ func (replica *metaReplica) freeAll() {
}
// newCollectionReplica returns a new ReplicaInterface
func newCollectionReplica() ReplicaInterface {
func newCollectionReplica(pool *concurrency.Pool) ReplicaInterface {
var replica ReplicaInterface = &metaReplica{
collections: make(map[UniqueID]*Collection),
partitions: make(map[UniqueID]*Partition),
@ -755,6 +758,7 @@ func newCollectionReplica() ReplicaInterface {
sealedSegments: make(map[UniqueID]*Segment),
excludedSegments: make(map[UniqueID][]*datapb.SegmentInfo),
cgoPool: pool,
}
return replica

View File

@ -17,12 +17,15 @@
package querynode
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/concurrency"
)
func TestMetaReplica_collection(t *testing.T) {
@ -225,6 +228,9 @@ func TestMetaReplica_segment(t *testing.T) {
assert.NoError(t, err)
defer replica.freeAll()
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
schema := genTestCollectionSchema()
collection := replica.addCollection(defaultCollectionID, schema)
replica.addPartition(defaultCollectionID, defaultPartitionID)
@ -244,12 +250,12 @@ func TestMetaReplica_segment(t *testing.T) {
},
}
segment1, err := newSegment(collection, UniqueID(1), defaultPartitionID, defaultCollectionID, "", segmentTypeGrowing)
segment1, err := newSegment(collection, UniqueID(1), defaultPartitionID, defaultCollectionID, "", segmentTypeGrowing, pool)
assert.NoError(t, err)
err = replica.setSegment(segment1)
assert.NoError(t, err)
segment2, err := newSegment(collection, UniqueID(2), defaultPartitionID, defaultCollectionID, "", segmentTypeSealed)
segment2, err := newSegment(collection, UniqueID(2), defaultPartitionID, defaultCollectionID, "", segmentTypeSealed, pool)
assert.NoError(t, err)
segment2.setIndexedFieldInfo(fieldID, indexInfo)
err = replica.setSegment(segment2)
@ -271,27 +277,30 @@ func TestMetaReplica_segment(t *testing.T) {
assert.NoError(t, err)
defer replica.freeAll()
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
schema := genTestCollectionSchema()
collection := replica.addCollection(defaultCollectionID, schema)
replica.addPartition(defaultCollectionID, defaultPartitionID)
replica.addPartition(defaultCollectionID, defaultPartitionID+1)
segment1, err := newSegment(collection, UniqueID(1), defaultPartitionID, defaultCollectionID, "channel1", segmentTypeGrowing)
segment1, err := newSegment(collection, UniqueID(1), defaultPartitionID, defaultCollectionID, "channel1", segmentTypeGrowing, pool)
assert.NoError(t, err)
err = replica.setSegment(segment1)
assert.NoError(t, err)
segment2, err := newSegment(collection, UniqueID(2), defaultPartitionID+1, defaultCollectionID, "channel2", segmentTypeGrowing)
segment2, err := newSegment(collection, UniqueID(2), defaultPartitionID+1, defaultCollectionID, "channel2", segmentTypeGrowing, pool)
assert.NoError(t, err)
err = replica.setSegment(segment2)
assert.NoError(t, err)
segment3, err := newSegment(collection, UniqueID(3), defaultPartitionID+1, defaultCollectionID, "channel2", segmentTypeGrowing)
segment3, err := newSegment(collection, UniqueID(3), defaultPartitionID+1, defaultCollectionID, "channel2", segmentTypeGrowing, pool)
assert.NoError(t, err)
err = replica.setSegment(segment3)
assert.NoError(t, err)
segment4, err := newSegment(collection, UniqueID(4), defaultPartitionID, defaultCollectionID, "channel1", segmentTypeSealed)
segment4, err := newSegment(collection, UniqueID(4), defaultPartitionID, defaultCollectionID, "channel1", segmentTypeSealed, pool)
assert.NoError(t, err)
err = replica.setSegment(segment4)
assert.NoError(t, err)

View File

@ -24,8 +24,10 @@ import (
"math"
"math/rand"
"path"
"runtime"
"strconv"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/indexcgowrapper"
"github.com/milvus-io/milvus/internal/util/typeutil"
@ -1210,12 +1212,18 @@ func genSealedSegment(schema *schemapb.CollectionSchema,
vChannel Channel,
msgLength int) (*Segment, error) {
col := newCollection(collectionID, schema)
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
if err != nil {
return nil, err
}
seg, err := newSegment(col,
segmentID,
partitionID,
collectionID,
vChannel,
segmentTypeSealed)
segmentTypeSealed,
pool)
if err != nil {
return nil, err
}
@ -1252,20 +1260,28 @@ func genSimpleSealedSegment(msgLength int) (*Segment, error) {
}
func genSimpleReplica() (ReplicaInterface, error) {
r := newCollectionReplica()
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
if err != nil {
return nil, err
}
r := newCollectionReplica(pool)
schema := genTestCollectionSchema()
r.addCollection(defaultCollectionID, schema)
err := r.addPartition(defaultCollectionID, defaultPartitionID)
err = r.addPartition(defaultCollectionID, defaultPartitionID)
return r, err
}
func genSimpleSegmentLoaderWithMqFactory(metaReplica ReplicaInterface, factory msgstream.Factory) (*segmentLoader, error) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(1))
if err != nil {
return nil, err
}
kv, err := genEtcdKV()
if err != nil {
return nil, err
}
cm := storage.NewLocalChunkManager(storage.RootPath(defaultLocalStorage))
return newSegmentLoader(metaReplica, kv, cm, factory), nil
return newSegmentLoader(metaReplica, kv, cm, factory, pool), nil
}
func genSimpleReplicaWithSealSegment(ctx context.Context) (ReplicaInterface, error) {

View File

@ -32,6 +32,7 @@ import (
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"sync"
"sync/atomic"
@ -40,6 +41,7 @@ import (
"unsafe"
"github.com/golang/protobuf/proto"
"github.com/panjf2000/ants/v2"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
@ -53,6 +55,7 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
@ -114,6 +117,9 @@ type QueryNode struct {
ShardClusterService *ShardClusterService
//shard query service, handles shard-level query & search
queryShardService *queryShardService
// cgoPool is the worker pool to control concurrency of cgo call
cgoPool *concurrency.Pool
}
// NewQueryNode will return a QueryNode with abnormal state.
@ -226,13 +232,33 @@ func (node *QueryNode) Init() error {
node.etcdKV = etcdkv.NewEtcdKV(node.etcdCli, Params.EtcdCfg.MetaRootPath)
log.Info("queryNode try to connect etcd success", zap.Any("MetaRootPath", Params.EtcdCfg.MetaRootPath))
node.metaReplica = newCollectionReplica()
cpuNum := runtime.GOMAXPROCS(0)
node.cgoPool, err = concurrency.NewPool(cpuNum, ants.WithPreAlloc(true))
if err != nil {
log.Error("QueryNode init cgo pool failed", zap.Error(err))
initError = err
return
}
sig := make(chan struct{})
for i := 0; i < cpuNum; i++ {
node.cgoPool.Submit(func() (interface{}, error) {
runtime.LockOSThread()
<-sig
return nil, nil
})
}
close(sig)
node.metaReplica = newCollectionReplica(node.cgoPool)
node.loader = newSegmentLoader(
node.metaReplica,
node.etcdKV,
node.vectorStorage,
node.factory)
node.factory,
node.cgoPool)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.factory)

View File

@ -21,6 +21,7 @@ import (
"io/ioutil"
"net/url"
"os"
"runtime"
"sync"
"testing"
"time"
@ -30,6 +31,7 @@ import (
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/dependency"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
@ -92,7 +94,13 @@ func newQueryNodeMock() *QueryNode {
factory := newMessageStreamFactory()
svr := NewQueryNode(ctx, factory)
tsReplica := newTSafeReplica()
replica := newCollectionReplica()
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
if err != nil {
panic(err)
}
replica := newCollectionReplica(pool)
svr.metaReplica = replica
svr.dataSyncService = newDataSyncService(ctx, svr.metaReplica, tsReplica, factory)
svr.vectorStorage, err = factory.NewVectorStorageChunkManager(ctx)
@ -103,7 +111,7 @@ func newQueryNodeMock() *QueryNode {
if err != nil {
panic(err)
}
svr.loader = newSegmentLoader(svr.metaReplica, etcdKV, svr.vectorStorage, factory)
svr.loader = newSegmentLoader(svr.metaReplica, etcdKV, svr.vectorStorage, factory, pool)
svr.etcdKV = etcdKV
return svr

View File

@ -32,6 +32,7 @@ import (
"sync"
"unsafe"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/metrics"
@ -93,6 +94,8 @@ type Segment struct {
indexedFieldInfos map[UniqueID]*IndexedFieldInfo
pkFilter *bloom.BloomFilter // bloom filter of pk inside a segment
pool *concurrency.Pool
}
// ID returns the identity number.
@ -161,7 +164,7 @@ func (s *Segment) hasLoadIndexForIndexedField(fieldID int64) bool {
return false
}
func newSegment(collection *Collection, segmentID UniqueID, partitionID UniqueID, collectionID UniqueID, vChannelID Channel, segType segmentType) (*Segment, error) {
func newSegment(collection *Collection, segmentID UniqueID, partitionID UniqueID, collectionID UniqueID, vChannelID Channel, segType segmentType, pool *concurrency.Pool) (*Segment, error) {
/*
CSegmentInterface
NewSegment(CCollection collection, uint64_t segment_id, SegmentType seg_type);
@ -169,9 +172,15 @@ func newSegment(collection *Collection, segmentID UniqueID, partitionID UniqueID
var segmentPtr C.CSegmentInterface
switch segType {
case segmentTypeSealed:
segmentPtr = C.NewSegment(collection.collectionPtr, C.Sealed, C.int64_t(segmentID))
pool.Submit(func() (interface{}, error) {
segmentPtr = C.NewSegment(collection.collectionPtr, C.Sealed, C.int64_t(segmentID))
return nil, nil
}).Await()
case segmentTypeGrowing:
segmentPtr = C.NewSegment(collection.collectionPtr, C.Growing, C.int64_t(segmentID))
pool.Submit(func() (interface{}, error) {
segmentPtr = C.NewSegment(collection.collectionPtr, C.Growing, C.int64_t(segmentID))
return nil, nil
}).Await()
default:
err := fmt.Errorf("illegal segment type %d when create segment %d", segType, segmentID)
log.Error("create new segment error",
@ -199,6 +208,7 @@ func newSegment(collection *Collection, segmentID UniqueID, partitionID UniqueID
indexedFieldInfos: make(map[UniqueID]*IndexedFieldInfo),
pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
pool: pool,
}
return segment, nil
@ -214,7 +224,10 @@ func deleteSegment(segment *Segment) {
}
cPtr := segment.segmentPtr
C.DeleteSegment(cPtr)
segment.pool.Submit(func() (interface{}, error) {
C.DeleteSegment(cPtr)
return nil, nil
}).Await()
segment.segmentPtr = nil
log.Info("delete segment from memory",
@ -232,7 +245,12 @@ func (s *Segment) getRealCount() int64 {
if s.segmentPtr == nil {
return -1
}
var rowCount = C.GetRealCount(s.segmentPtr)
var rowCount C.int64_t
s.pool.Submit(func() (interface{}, error) {
rowCount = C.GetRealCount(s.segmentPtr)
return nil, nil
}).Await()
return int64(rowCount)
}
@ -244,7 +262,12 @@ func (s *Segment) getRowCount() int64 {
if s.segmentPtr == nil {
return -1
}
var rowCount = C.GetRowCount(s.segmentPtr)
var rowCount C.int64_t
s.pool.Submit(func() (interface{}, error) {
rowCount = C.GetRowCount(s.segmentPtr)
return nil, nil
}).Await()
return int64(rowCount)
}
@ -256,7 +279,13 @@ func (s *Segment) getDeletedCount() int64 {
if s.segmentPtr == nil {
return -1
}
var deletedCount = C.GetDeletedCount(s.segmentPtr)
var deletedCount C.int64_t
s.pool.Submit(func() (interface{}, error) {
deletedCount = C.GetRowCount(s.segmentPtr)
return nil, nil
}).Await()
return int64(deletedCount)
}
@ -268,7 +297,11 @@ func (s *Segment) getMemSize() int64 {
if s.segmentPtr == nil {
return -1
}
var memoryUsageInBytes = C.GetMemoryUsageInBytes(s.segmentPtr)
var memoryUsageInBytes C.int64_t
s.pool.Submit(func() (interface{}, error) {
memoryUsageInBytes = C.GetMemoryUsageInBytes(s.segmentPtr)
return nil, nil
}).Await()
return int64(memoryUsageInBytes)
}
@ -298,10 +331,15 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
zap.Int64("segmentID", s.segmentID),
zap.String("segmentType", s.segmentType.String()),
zap.Bool("loadIndex", loadIndex))
tr := timerecord.NewTimeRecorder("cgoSearch")
status := C.Search(s.segmentPtr, searchReq.plan.cSearchPlan, searchReq.cPlaceholderGroup,
C.uint64_t(searchReq.timestamp), &searchResult.cSearchResult, C.int64_t(s.segmentID))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
var status C.CStatus
s.pool.Submit(func() (interface{}, error) {
tr := timerecord.NewTimeRecorder("cgoSearch")
status = C.Search(s.segmentPtr, searchReq.plan.cSearchPlan, searchReq.cPlaceholderGroup,
C.uint64_t(searchReq.timestamp), &searchResult.cSearchResult, C.int64_t(s.segmentID))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
return nil, nil
}).Await()
if err := HandleCStatus(&status, "Search failed"); err != nil {
return nil, err
}
@ -320,13 +358,20 @@ func (s *Segment) retrieve(plan *RetrievePlan) (*segcorepb.RetrieveResults, erro
var retrieveResult RetrieveResult
ts := C.uint64_t(plan.Timestamp)
tr := timerecord.NewTimeRecorder("cgoRetrieve")
status := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, ts, &retrieveResult.cRetrieveResult)
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("do retrieve on segment",
zap.Int64("msgID", plan.msgID),
zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String()))
var status C.CStatus
s.pool.Submit(func() (interface{}, error) {
tr := timerecord.NewTimeRecorder("cgoRetrieve")
status = C.Retrieve(s.segmentPtr, plan.cRetrievePlan, ts, &retrieveResult.cRetrieveResult)
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("do retrieve on segment",
zap.Int64("msgID", plan.msgID),
zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String()))
return nil, nil
}).Await()
if err := HandleCStatus(&status, "Retrieve failed"); err != nil {
return nil, err
}
@ -567,8 +612,12 @@ func (s *Segment) segmentPreInsert(numOfRecords int) (int64, error) {
return 0, nil
}
var offset int64
var status C.CStatus
cOffset := (*C.int64_t)(&offset)
status := C.PreInsert(s.segmentPtr, C.int64_t(int64(numOfRecords)), cOffset)
s.pool.Submit(func() (interface{}, error) {
status = C.PreInsert(s.segmentPtr, C.int64_t(int64(numOfRecords)), cOffset)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "PreInsert failed"); err != nil {
return 0, err
}
@ -580,7 +629,13 @@ func (s *Segment) segmentPreDelete(numOfRecords int) int64 {
long int
PreDelete(CSegmentInterface c_segment, long int size);
*/
var offset = C.PreDelete(s.segmentPtr, C.int64_t(int64(numOfRecords)))
var offset C.int64_t
s.pool.Submit(func() (interface{}, error) {
offset = C.PreDelete(s.segmentPtr, C.int64_t(int64(numOfRecords)))
return nil, nil
}).Await()
return int64(offset)
}
@ -605,13 +660,20 @@ func (s *Segment) segmentInsert(offset int64, entityIDs []UniqueID, timestamps [
var cEntityIdsPtr = (*C.int64_t)(&(entityIDs)[0])
var cTimestampsPtr = (*C.uint64_t)(&(timestamps)[0])
status := C.Insert(s.segmentPtr,
cOffset,
cNumOfRows,
cEntityIdsPtr,
cTimestampsPtr,
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
(C.uint64_t)(len(insertRecordBlob)))
var status C.CStatus
s.pool.Submit(func() (interface{}, error) {
status = C.Insert(s.segmentPtr,
cOffset,
cNumOfRows,
cEntityIdsPtr,
cTimestampsPtr,
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
(C.uint64_t)(len(insertRecordBlob)))
return nil, nil
}).Await()
if err := HandleCStatus(&status, "Insert failed"); err != nil {
return err
}
@ -677,7 +739,13 @@ func (s *Segment) segmentDelete(offset int64, entityIDs []primaryKey, timestamps
return fmt.Errorf("failed to marshal ids: %s", err)
}
status := C.Delete(s.segmentPtr, cOffset, cSize, (*C.uint8_t)(unsafe.Pointer(&dataBlob[0])), (C.uint64_t)(len(dataBlob)), cTimestampsPtr)
var status C.CStatus
s.pool.Submit(func() (interface{}, error) {
status = C.Delete(s.segmentPtr, cOffset, cSize, (*C.uint8_t)(unsafe.Pointer(&dataBlob[0])), (C.uint64_t)(len(dataBlob)), cTimestampsPtr)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "Delete failed"); err != nil {
return err
}
@ -711,7 +779,12 @@ func (s *Segment) segmentLoadFieldData(fieldID int64, rowCount int64, data *sche
row_count: C.int64_t(rowCount),
}
status := C.LoadFieldData(s.segmentPtr, loadInfo)
var status C.CStatus
s.pool.Submit(func() (interface{}, error) {
status = C.LoadFieldData(s.segmentPtr, loadInfo)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "LoadFieldData failed"); err != nil {
return err
}
@ -774,7 +847,12 @@ func (s *Segment) segmentLoadDeletedRecord(primaryKeys []primaryKey, timestamps
CStatus
LoadDeletedRecord(CSegmentInterface c_segment, CLoadDeletedRecordInfo deleted_record_info)
*/
status := C.LoadDeletedRecord(s.segmentPtr, loadInfo)
var status C.CStatus
s.pool.Submit(func() (interface{}, error) {
status = C.LoadDeletedRecord(s.segmentPtr, loadInfo)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "LoadDeletedRecord failed"); err != nil {
return err
}
@ -807,7 +885,12 @@ func (s *Segment) segmentLoadIndexData(bytesIndex [][]byte, indexInfo *querypb.F
return errors.New(errMsg)
}
status := C.UpdateSealedSegmentIndex(s.segmentPtr, loadIndexInfo.cLoadIndexInfo)
var status C.CStatus
s.pool.Submit(func() (interface{}, error) {
status = C.UpdateSealedSegmentIndex(s.segmentPtr, loadIndexInfo.cLoadIndexInfo)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "UpdateSealedSegmentIndex failed"); err != nil {
return err
}

View File

@ -61,6 +61,8 @@ type segmentLoader struct {
ioPool *concurrency.Pool
cpuPool *concurrency.Pool
// cgoPool for all cgo invocation
cgoPool *concurrency.Pool
factory msgstream.Factory
}
@ -141,7 +143,7 @@ func (loader *segmentLoader) LoadSegment(req *querypb.LoadSegmentsRequest, segme
return err
}
segment, err := newSegment(collection, segmentID, partitionID, collectionID, vChannelID, segmentType)
segment, err := newSegment(collection, segmentID, partitionID, collectionID, vChannelID, segmentType, loader.cgoPool)
if err != nil {
log.Error("load segment failed when create new segment",
zap.Int64("collectionID", collectionID),
@ -815,7 +817,8 @@ func newSegmentLoader(
metaReplica ReplicaInterface,
etcdKV *etcdkv.EtcdKV,
cm storage.ChunkManager,
factory msgstream.Factory) *segmentLoader {
factory msgstream.Factory,
pool *concurrency.Pool) *segmentLoader {
cpuNum := runtime.GOMAXPROCS(0)
// This error is not nil only if the options of creating pool is invalid
@ -850,6 +853,7 @@ func newSegmentLoader(
// init them later
ioPool: ioPool,
cpuPool: cpuPool,
cgoPool: pool,
factory: factory,
}

View File

@ -30,9 +30,11 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestSegmentLoader_loadSegment(t *testing.T) {
@ -128,6 +130,9 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) {
loader := node.loader
assert.NotNil(t, loader)
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
var fieldPk *schemapb.FieldSchema
switch pkType {
case schemapb.DataType_Int64:
@ -175,7 +180,8 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) {
defaultPartitionID,
defaultCollectionID,
defaultDMLChannel,
segmentTypeSealed)
segmentTypeSealed,
pool)
assert.Nil(t, err)
binlog, _, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema)
@ -328,7 +334,7 @@ func TestSegmentLoader_testLoadGrowing(t *testing.T) {
collection, err := node.metaReplica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
segment, err := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing)
segment, err := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing, loader.cgoPool)
assert.Nil(t, err)
insertData, err := genInsertData(defaultMsgLength, collection.schema)
@ -357,7 +363,7 @@ func TestSegmentLoader_testLoadGrowing(t *testing.T) {
collection, err := node.metaReplica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
segment, err := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing)
segment, err := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing, node.loader.cgoPool)
assert.Nil(t, err)
insertData, err := genInsertData(defaultMsgLength, collection.schema)

View File

@ -21,6 +21,7 @@ import (
"fmt"
"log"
"math"
"runtime"
"testing"
"github.com/milvus-io/milvus/internal/proto/commonpb"
@ -29,6 +30,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/datapb"
@ -36,11 +38,15 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/funcutil"
)
//-------------------------------------------------------------------------------------- constructor and destructor
func TestSegment_newSegment(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
collectionMeta := genCollectionMeta(collectionID, schema)
@ -49,7 +55,7 @@ func TestSegment_newSegment(t *testing.T) {
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Nil(t, err)
assert.Equal(t, segmentID, segment.segmentID)
deleteSegment(segment)
@ -59,12 +65,15 @@ func TestSegment_newSegment(t *testing.T) {
_, err = newSegment(collection,
defaultSegmentID,
defaultPartitionID,
collectionID, "", 100)
collectionID, "", 100, pool)
assert.Error(t, err)
})
}
func TestSegment_deleteSegment(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
collectionMeta := genCollectionMeta(collectionID, schema)
@ -73,7 +82,7 @@ func TestSegment_deleteSegment(t *testing.T) {
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -90,6 +99,9 @@ func TestSegment_deleteSegment(t *testing.T) {
//-------------------------------------------------------------------------------------- stats functions
func TestSegment_getRowCount(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
@ -97,7 +109,7 @@ func TestSegment_getRowCount(t *testing.T) {
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -132,6 +144,9 @@ func TestSegment_getRowCount(t *testing.T) {
}
func TestSegment_retrieve(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
@ -139,7 +154,7 @@ func TestSegment_retrieve(t *testing.T) {
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -211,6 +226,9 @@ func TestSegment_retrieve(t *testing.T) {
}
func TestSegment_getDeletedCount(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
@ -218,7 +236,7 @@ func TestSegment_getDeletedCount(t *testing.T) {
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -260,6 +278,9 @@ func TestSegment_getDeletedCount(t *testing.T) {
}
func TestSegment_getMemSize(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
@ -267,7 +288,7 @@ func TestSegment_getMemSize(t *testing.T) {
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -296,13 +317,16 @@ func TestSegment_getMemSize(t *testing.T) {
//-------------------------------------------------------------------------------------- dm & search functions
func TestSegment_segmentInsert(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
collection := newCollection(collectionID, schema)
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -340,13 +364,16 @@ func TestSegment_segmentInsert(t *testing.T) {
}
func TestSegment_segmentDelete(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
collection := newCollection(collectionID, schema)
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -436,13 +463,16 @@ func TestSegment_segmentSearch(t *testing.T) {
//-------------------------------------------------------------------------------------- preDm functions
func TestSegment_segmentPreInsert(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
collection := newCollection(collectionID, schema)
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -455,13 +485,16 @@ func TestSegment_segmentPreInsert(t *testing.T) {
}
func TestSegment_segmentPreDelete(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
collectionID := UniqueID(0)
schema := genTestCollectionSchema()
collection := newCollection(collectionID, schema)
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, pool)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
@ -487,6 +520,9 @@ func TestSegment_segmentPreDelete(t *testing.T) {
}
func TestSegment_segmentLoadDeletedRecord(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
fieldParam := constFieldParam{
id: 100,
dataType: schemapb.DataType_Int64,
@ -505,7 +541,8 @@ func TestSegment_segmentLoadDeletedRecord(t *testing.T) {
defaultPartitionID,
defaultCollectionID,
defaultDMLChannel,
segmentTypeSealed)
segmentTypeSealed,
pool)
assert.Nil(t, err)
ids := []int64{1, 2, 3}
pks := make([]primaryKey, 0)
@ -574,6 +611,9 @@ func TestSegment_indexInfo(t *testing.T) {
}
func TestSegment_BasicMetrics(t *testing.T) {
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
schema := genTestCollectionSchema()
collection := newCollection(defaultCollectionID, schema)
segment, err := newSegment(collection,
@ -581,7 +621,8 @@ func TestSegment_BasicMetrics(t *testing.T) {
defaultPartitionID,
defaultCollectionID,
defaultDMLChannel,
segmentTypeSealed)
segmentTypeSealed,
pool)
assert.Nil(t, err)
t.Run("test id binlog row size", func(t *testing.T) {
@ -620,6 +661,8 @@ func TestSegment_BasicMetrics(t *testing.T) {
func TestSegment_fillIndexedFieldsData(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool, err := concurrency.NewPool(runtime.GOMAXPROCS(0))
require.NoError(t, err)
schema := genTestCollectionSchema()
collection := newCollection(defaultCollectionID, schema)
@ -628,7 +671,8 @@ func TestSegment_fillIndexedFieldsData(t *testing.T) {
defaultPartitionID,
defaultCollectionID,
defaultDMLChannel,
segmentTypeSealed)
segmentTypeSealed,
pool)
assert.Nil(t, err)
vecCM, err := genVectorChunkManager(ctx, collection)