Fix segmentPtr concurrent visit and destroy (#19560)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/19528/head
congqixia 2022-09-29 21:54:55 +08:00 committed by GitHub
parent ccf30d358c
commit d79f88c5f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 143 additions and 81 deletions

View File

@ -166,7 +166,7 @@ func (dNode *deleteNode) delete(deleteData *deleteData, segmentID UniqueID, wg *
return fmt.Errorf("getSegmentByID failed, err = %s", err)
}
if targetSegment.segmentType != segmentTypeSealed {
if targetSegment.getType() != segmentTypeSealed {
return fmt.Errorf("unexpected segmentType when delete, segmentID = %d, segmentType = %s", segmentID, targetSegment.segmentType.String())
}

View File

@ -175,6 +175,8 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
// 2. do preInsert
for segmentID := range iData.insertRecords {
log := log.With(
zap.Int64("segmentID", segmentID))
var targetSegment, err = iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing)
if err != nil {
// should not happen, segment should be created before
@ -187,6 +189,10 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
if targetSegment != nil {
offset, err := targetSegment.segmentPreInsert(numOfRecords)
if err != nil {
if errors.Is(err, errSegmentUnhealthy) {
log.Debug("segment removed before preInsert")
continue
}
// error occurs when cgo function `PreInsert` failed
err = fmt.Errorf("segmentPreInsert failed, segmentID = %d, err = %s", segmentID, err)
log.Error(err.Error(), zap.Int64("collectionID", iNode.collectionID), zap.String("channel", iNode.channel))
@ -256,7 +262,9 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
panic(err)
}
offset := segment.segmentPreDelete(len(pks))
delData.deleteOffset[segmentID] = offset
if offset >= 0 {
delData.deleteOffset[segmentID] = offset
}
}
// 3. do delete
@ -364,6 +372,9 @@ func filterSegmentsByPKs(pks []primaryKey, timestamps []Timestamp, segment *Segm
// insert would execute insert operations for specific growing segment
func (iNode *insertNode) insert(iData *insertData, segmentID UniqueID) error {
log := log.With(
zap.Int64("collectionID", iNode.collectionID),
zap.Int64("segmentID", segmentID))
var targetSegment, err = iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing)
if err != nil {
return fmt.Errorf("getSegmentByID failed, err = %s", err)
@ -379,15 +390,22 @@ func (iNode *insertNode) insert(iData *insertData, segmentID UniqueID) error {
err = targetSegment.segmentInsert(offsets, ids, timestamps, insertRecord)
if err != nil {
if errors.Is(err, errSegmentUnhealthy) {
log.Debug("segment removed before insert")
return nil
}
return fmt.Errorf("segmentInsert failed, segmentID = %d, err = %s", segmentID, err)
}
log.Debug("Do insert done", zap.Int("len", len(iData.insertIDs[segmentID])), zap.Int64("collectionID", targetSegment.collectionID), zap.Int64("segmentID", segmentID))
log.Debug("Do insert done", zap.Int("len", len(iData.insertIDs[segmentID])))
return nil
}
// delete would execute delete operations for specific growing segment
func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID) error {
log := log.With(
zap.Int64("collectionID", iNode.collectionID),
zap.Int64("segmentID", segmentID))
targetSegment, err := iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing)
if err != nil {
if errors.Is(err, ErrSegmentNotFound) {
@ -400,7 +418,7 @@ func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID) erro
return fmt.Errorf("getSegmentByID failed, err = %s", err)
}
if targetSegment.segmentType != segmentTypeGrowing {
if targetSegment.getType() != segmentTypeGrowing {
return fmt.Errorf("unexpected segmentType when delete, segmentType = %s", targetSegment.segmentType.String())
}
@ -410,10 +428,14 @@ func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID) erro
err = targetSegment.segmentDelete(offset, ids, timestamps)
if err != nil {
if errors.Is(err, errSegmentUnhealthy) {
log.Debug("segment removed before delete")
return nil
}
return fmt.Errorf("segmentDelete failed, err = %s", err)
}
log.Debug("Do delete done", zap.Int("len", len(deleteData.deleteIDs[segmentID])), zap.Int64("segmentID", segmentID))
log.Debug("Do delete done", zap.Int("len", len(deleteData.deleteIDs[segmentID])))
return nil
}

View File

@ -887,7 +887,7 @@ func (replica *metaReplica) getSegmentInfo(segment *Segment) *querypb.SegmentInf
IndexName: indexName,
IndexID: indexID,
DmChannel: segment.vChannelID,
SegmentState: segment.segmentType,
SegmentState: segment.getType(),
IndexInfos: indexInfos,
NodeIds: []UniqueID{Params.QueryNodeCfg.GetNodeID()},
}

View File

@ -226,7 +226,7 @@ func TestStreaming_search(t *testing.T) {
seg, err := streaming.getSegmentByID(defaultSegmentID, segmentTypeGrowing)
assert.NoError(t, err)
seg.segmentPtr = nil
seg.setUnhealthy()
_, _, _, err = searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID,

View File

@ -36,12 +36,14 @@ import (
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/bits-and-blooms/bloom/v3"
"github.com/golang/protobuf/proto"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/api/commonpb"
@ -66,6 +68,8 @@ const (
maxBloomFalsePositive float64 = 0.005
)
var errSegmentUnhealthy = errors.New("segment unhealthy")
// IndexedFieldInfo contains binlog info of vector field
type IndexedFieldInfo struct {
fieldBinlog *datapb.FieldBinlog
@ -74,6 +78,7 @@ type IndexedFieldInfo struct {
// Segment is a wrapper of the underlying C-structure segment.
type Segment struct {
mut sync.RWMutex // protects segmentPtr
segmentPtr C.CSegmentInterface
segmentID UniqueID
@ -85,16 +90,13 @@ type Segment struct {
lastMemSize int64
lastRowCount int64
rmMutex sync.RWMutex // guards recentlyModified
recentlyModified bool
typeMu sync.RWMutex // guards segmentType
segmentType segmentType
recentlyModified *atomic.Bool
segmentType *atomic.Int32
destroyed *atomic.Bool
idBinlogRowSizes []int64
indexedFieldMutex sync.RWMutex // guards indexedFieldInfos
indexedFieldInfos map[UniqueID]*IndexedFieldInfo
indexedFieldInfos *typeutil.ConcurrentMap[UniqueID, *IndexedFieldInfo]
pkFilter *bloom.BloomFilter // bloom filter of pk inside a segment
@ -115,56 +117,52 @@ func (s *Segment) getIDBinlogRowSizes() []int64 {
}
func (s *Segment) setRecentlyModified(modify bool) {
s.rmMutex.Lock()
defer s.rmMutex.Unlock()
s.recentlyModified = modify
s.recentlyModified.Store(modify)
}
func (s *Segment) getRecentlyModified() bool {
s.rmMutex.RLock()
defer s.rmMutex.RUnlock()
return s.recentlyModified
return s.recentlyModified.Load()
}
func (s *Segment) setType(segType segmentType) {
s.typeMu.Lock()
defer s.typeMu.Unlock()
s.segmentType = segType
s.segmentType.Store(int32(segType))
}
func (s *Segment) getType() segmentType {
s.typeMu.RLock()
defer s.typeMu.RUnlock()
return s.segmentType
return commonpb.SegmentState(s.segmentType.Load())
}
func (s *Segment) setIndexedFieldInfo(fieldID UniqueID, info *IndexedFieldInfo) {
s.indexedFieldMutex.Lock()
defer s.indexedFieldMutex.Unlock()
s.indexedFieldInfos[fieldID] = info
s.indexedFieldInfos.Insert(fieldID, info)
}
func (s *Segment) getIndexedFieldInfo(fieldID UniqueID) (*IndexedFieldInfo, error) {
s.indexedFieldMutex.RLock()
defer s.indexedFieldMutex.RUnlock()
if info, ok := s.indexedFieldInfos[fieldID]; ok {
return &IndexedFieldInfo{
fieldBinlog: info.fieldBinlog,
indexInfo: info.indexInfo,
}, nil
info, ok := s.indexedFieldInfos.Get(fieldID)
if !ok {
return nil, fmt.Errorf("Invalid fieldID %d", fieldID)
}
return nil, fmt.Errorf("Invalid fieldID %d", fieldID)
return &IndexedFieldInfo{
fieldBinlog: info.fieldBinlog,
indexInfo: info.indexInfo,
}, nil
}
func (s *Segment) hasLoadIndexForIndexedField(fieldID int64) bool {
s.indexedFieldMutex.RLock()
defer s.indexedFieldMutex.RUnlock()
if fieldInfo, ok := s.indexedFieldInfos[fieldID]; ok {
return fieldInfo.indexInfo != nil && fieldInfo.indexInfo.EnableIndex
fieldInfo, ok := s.indexedFieldInfos.Get(fieldID)
if !ok {
return false
}
return fieldInfo.indexInfo != nil && fieldInfo.indexInfo.EnableIndex
}
return false
// healthy checks whether it's safe to use `segmentPtr`.
// shall acquire mut.RLock before check this flag.
func (s *Segment) healthy() bool {
return !s.destroyed.Load()
}
func (s *Segment) setUnhealthy() {
s.destroyed.Store(true)
}
func newSegment(collection *Collection,
@ -210,13 +208,15 @@ func newSegment(collection *Collection,
var segment = &Segment{
segmentPtr: segmentPtr,
segmentType: segType,
segmentType: atomic.NewInt32(int32(segType)),
segmentID: segmentID,
partitionID: partitionID,
collectionID: collectionID,
version: version,
vChannelID: vChannelID,
indexedFieldInfos: make(map[UniqueID]*IndexedFieldInfo),
indexedFieldInfos: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](),
recentlyModified: atomic.NewBool(false),
destroyed: atomic.NewBool(false),
pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
pool: pool,
@ -230,16 +230,22 @@ func deleteSegment(segment *Segment) {
void
deleteSegment(CSegmentInterface segment);
*/
if segment.segmentPtr == nil {
var cPtr C.CSegmentInterface
// wait all read ops finished
segment.mut.Lock()
segment.setUnhealthy()
cPtr = segment.segmentPtr
segment.segmentPtr = nil
segment.mut.Unlock()
if cPtr == nil {
return
}
cPtr := segment.segmentPtr
segment.pool.Submit(func() (interface{}, error) {
C.DeleteSegment(cPtr)
return nil, nil
}).Await()
segment.segmentPtr = nil
log.Info("delete segment from memory",
zap.Int64("collectionID", segment.collectionID),
@ -253,7 +259,9 @@ func (s *Segment) getRealCount() int64 {
int64_t
GetRealCount(CSegmentInterface c_segment);
*/
if s.segmentPtr == nil {
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return -1
}
var rowCount C.int64_t
@ -270,7 +278,9 @@ func (s *Segment) getRowCount() int64 {
long int
getRowCount(CSegmentInterface c_segment);
*/
if s.segmentPtr == nil {
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return -1
}
var rowCount C.int64_t
@ -287,7 +297,9 @@ func (s *Segment) getDeletedCount() int64 {
long int
getDeletedCount(CSegmentInterface c_segment);
*/
if s.segmentPtr == nil {
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return -1
}
@ -305,7 +317,9 @@ func (s *Segment) getMemSize() int64 {
long int
GetMemoryUsageInBytes(CSegmentInterface c_segment);
*/
if s.segmentPtr == nil {
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return -1
}
var memoryUsageInBytes C.int64_t
@ -327,8 +341,10 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
long int* result_ids,
float* result_distances);
*/
if s.segmentPtr == nil {
return nil, errors.New("null seg core pointer")
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return nil, fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
}
if searchReq.plan == nil {
@ -363,8 +379,10 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
}
func (s *Segment) retrieve(plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
if s.segmentPtr == nil {
return nil, errors.New("null seg core pointer")
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return nil, fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
}
var retrieveResult RetrieveResult
@ -621,9 +639,16 @@ func (s *Segment) segmentPreInsert(numOfRecords int) (int64, error) {
long int
PreInsert(CSegmentInterface c_segment, long int size);
*/
if s.segmentType != segmentTypeGrowing {
if s.getType() != segmentTypeGrowing {
return 0, nil
}
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return -1, fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
}
var offset int64
var status C.CStatus
cOffset := (*C.int64_t)(&offset)
@ -642,6 +667,11 @@ func (s *Segment) segmentPreDelete(numOfRecords int) int64 {
long int
PreDelete(CSegmentInterface c_segment, long int size);
*/
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return -1
}
var offset C.int64_t
s.pool.Submit(func() (interface{}, error) {
@ -654,12 +684,14 @@ func (s *Segment) segmentPreDelete(numOfRecords int) int64 {
}
func (s *Segment) segmentInsert(offset int64, entityIDs []UniqueID, timestamps []Timestamp, record *segcorepb.InsertRecord) error {
if s.segmentType != segmentTypeGrowing {
if s.getType() != segmentTypeGrowing {
return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.segmentType.String())
}
if s.segmentPtr == nil {
return errors.New("null seg core pointer")
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
}
insertRecordBlob, err := proto.Marshal(record)
@ -708,8 +740,10 @@ func (s *Segment) segmentDelete(offset int64, entityIDs []primaryKey, timestamps
return fmt.Errorf("empty pks to delete")
}
if s.segmentPtr == nil {
return errors.New("null seg core pointer")
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
}
if len(entityIDs) != len(timestamps) {
@ -772,13 +806,15 @@ func (s *Segment) segmentLoadFieldData(fieldID int64, rowCount int64, data *sche
CStatus
LoadFieldData(CSegmentInterface c_segment, CLoadFieldDataInfo load_field_data_info);
*/
if s.segmentPtr == nil {
return errors.New("null seg core pointer")
}
if s.segmentType != segmentTypeSealed {
if s.getType() != segmentTypeSealed {
errMsg := fmt.Sprintln("segmentLoadFieldData failed, illegal segment type ", s.segmentType, "segmentID = ", s.ID())
return errors.New(errMsg)
}
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
}
dataBlob, err := proto.Marshal(data)
if err != nil {
@ -811,8 +847,10 @@ func (s *Segment) segmentLoadFieldData(fieldID int64, rowCount int64, data *sche
}
func (s *Segment) segmentLoadDeletedRecord(primaryKeys []primaryKey, timestamps []Timestamp, rowCount int64) error {
if s.segmentPtr == nil {
return errors.New("null seg core pointer")
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
}
if len(primaryKeys) <= 0 {
@ -893,15 +931,15 @@ func (s *Segment) segmentLoadIndexData(bytesIndex [][]byte, indexInfo *querypb.F
}
return err
}
if s.segmentPtr == nil {
return errors.New("null seg core pointer")
}
if s.segmentType != segmentTypeSealed {
if s.getType() != segmentTypeSealed {
errMsg := fmt.Sprintln("updateSegmentIndex failed, illegal segment type ", s.segmentType, "segmentID = ", s.ID())
return errors.New(errMsg)
}
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
}
var status C.CStatus
s.pool.Submit(func() (interface{}, error) {

View File

@ -27,6 +27,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/api/commonpb"
"github.com/milvus-io/milvus/api/schemapb"
@ -904,11 +905,12 @@ func TestSegmentLoader_getFieldType(t *testing.T) {
loader := &segmentLoader{metaReplica: replica}
// failed to get collection.
segment := &Segment{segmentType: segmentTypeSealed}
segment := &Segment{segmentType: atomic.NewInt32(0)}
segment.setType(segmentTypeSealed)
_, err := loader.getFieldType(segment, 100)
assert.Error(t, err)
segment.segmentType = segmentTypeGrowing
segment.setType(segmentTypeGrowing)
_, err = loader.getFieldType(segment, 100)
assert.Error(t, err)
@ -927,12 +929,12 @@ func TestSegmentLoader_getFieldType(t *testing.T) {
}, nil
}
segment.segmentType = segmentTypeGrowing
segment.setType(segmentTypeGrowing)
fieldType, err := loader.getFieldType(segment, 100)
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Int64, fieldType)
segment.segmentType = segmentTypeSealed
segment.setType(segmentTypeSealed)
fieldType, err = loader.getFieldType(segment, 100)
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Int64, fieldType)

View File

@ -89,7 +89,7 @@ func TestSegment_deleteSegment(t *testing.T) {
t.Run("test delete nil ptr", func(t *testing.T) {
s, err := genSimpleSealedSegment(defaultMsgLength)
assert.NoError(t, err)
s.segmentPtr = nil
s.setUnhealthy()
deleteSegment(s)
})
}
@ -134,7 +134,7 @@ func TestSegment_getRowCount(t *testing.T) {
t.Run("test getRowCount nil ptr", func(t *testing.T) {
s, err := genSimpleSealedSegment(defaultMsgLength)
assert.NoError(t, err)
s.segmentPtr = nil
s.setUnhealthy()
res := s.getRowCount()
assert.Equal(t, int64(-1), res)
})
@ -273,7 +273,7 @@ func TestSegment_getDeletedCount(t *testing.T) {
t.Run("test getDeletedCount nil ptr", func(t *testing.T) {
s, err := genSimpleSealedSegment(defaultMsgLength)
assert.NoError(t, err)
s.segmentPtr = nil
s.setUnhealthy()
res := s.getDeletedCount()
assert.Equal(t, int64(-1), res)
})