mirror of https://github.com/milvus-io/milvus.git
fix: use stateful lock to avoid load and release on LocalSegment concurrently (#31606)
issue: #31605 --------- Signed-off-by: chyezh <chyezh@outlook.com>pull/32001/head
parent
2aa6f3d3ec
commit
73adf2a5cc
|
@ -370,7 +370,7 @@ func (mgr *segmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, err
|
|||
defer func() {
|
||||
if err != nil {
|
||||
for _, segment := range ret {
|
||||
segment.RUnlock()
|
||||
segment.Unpin()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -379,7 +379,7 @@ func (mgr *segmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, err
|
|||
if segment.Level() == datapb.SegmentLevel_L0 {
|
||||
return true
|
||||
}
|
||||
err = segment.RLock()
|
||||
err = segment.PinIfNotReleased()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
@ -399,7 +399,7 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter)
|
|||
defer func() {
|
||||
if err != nil {
|
||||
for _, segment := range lockedSegments {
|
||||
segment.RUnlock()
|
||||
segment.Unpin()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -417,14 +417,14 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter)
|
|||
sealedExist = sealedExist && filter(sealed, filters...)
|
||||
|
||||
if growingExist {
|
||||
err = growing.RLock()
|
||||
err = growing.PinIfNotReleased()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lockedSegments = append(lockedSegments, growing)
|
||||
}
|
||||
if sealedExist {
|
||||
err = sealed.RLock()
|
||||
err = sealed.PinIfNotReleased()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -442,7 +442,7 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter)
|
|||
|
||||
func (mgr *segmentManager) Unpin(segments []Segment) {
|
||||
for _, segment := range segments {
|
||||
segment.RUnlock()
|
||||
segment.Unpin()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -874,8 +874,8 @@ func (_c *MockSegment_Partition_Call) RunAndReturn(run func() int64) *MockSegmen
|
|||
return _c
|
||||
}
|
||||
|
||||
// RLock provides a mock function with given fields:
|
||||
func (_m *MockSegment) RLock() error {
|
||||
// PinIfNotReleased provides a mock function with given fields:
|
||||
func (_m *MockSegment) PinIfNotReleased() error {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 error
|
||||
|
@ -915,8 +915,8 @@ func (_c *MockSegment_RLock_Call) RunAndReturn(run func() error) *MockSegment_RL
|
|||
return _c
|
||||
}
|
||||
|
||||
// RUnlock provides a mock function with given fields:
|
||||
func (_m *MockSegment) RUnlock() {
|
||||
// Unpin provides a mock function with given fields:
|
||||
func (_m *MockSegment) Unpin() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,6 @@ import (
|
|||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/apache/arrow/go/v12/arrow/array"
|
||||
|
@ -50,6 +49,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments/state"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
typeutil_internal "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
|
@ -213,7 +213,7 @@ var _ Segment = (*LocalSegment)(nil)
|
|||
// Segment is a wrapper of the underlying C-structure segment.
|
||||
type LocalSegment struct {
|
||||
baseSegment
|
||||
ptrLock sync.RWMutex // protects segmentPtr
|
||||
ptrLock *state.LoadStateLock
|
||||
ptr C.CSegmentInterface
|
||||
|
||||
// cached results, to avoid too many CGO calls
|
||||
|
@ -242,10 +242,13 @@ func NewSegment(ctx context.Context,
|
|||
return NewL0Segment(collection, segmentType, version, loadInfo)
|
||||
}
|
||||
var cSegType C.SegmentType
|
||||
var locker *state.LoadStateLock
|
||||
switch segmentType {
|
||||
case SegmentTypeSealed:
|
||||
cSegType = C.Sealed
|
||||
locker = state.NewLoadStateLock(state.LoadStateOnlyMeta)
|
||||
case SegmentTypeGrowing:
|
||||
locker = state.NewLoadStateLock(state.LoadStateDataLoaded)
|
||||
cSegType = C.Growing
|
||||
default:
|
||||
return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID())
|
||||
|
@ -275,6 +278,7 @@ func NewSegment(ctx context.Context,
|
|||
|
||||
segment := &LocalSegment{
|
||||
baseSegment: newBaseSegment(collection, segmentType, version, loadInfo),
|
||||
ptrLock: locker,
|
||||
ptr: newPtr,
|
||||
lastDeltaTimestamp: atomic.NewUint64(0),
|
||||
fields: typeutil.NewConcurrentMap[int64, *FieldInfo](),
|
||||
|
@ -308,11 +312,14 @@ func NewSegmentV2(
|
|||
}
|
||||
var segmentPtr C.CSegmentInterface
|
||||
var status C.CStatus
|
||||
var locker *state.LoadStateLock
|
||||
switch segmentType {
|
||||
case SegmentTypeSealed:
|
||||
status = C.NewSegment(collection.collectionPtr, C.Sealed, C.int64_t(loadInfo.GetSegmentID()), &segmentPtr)
|
||||
locker = state.NewLoadStateLock(state.LoadStateOnlyMeta)
|
||||
case SegmentTypeGrowing:
|
||||
status = C.NewSegment(collection.collectionPtr, C.Growing, C.int64_t(loadInfo.GetSegmentID()), &segmentPtr)
|
||||
locker = state.NewLoadStateLock(state.LoadStateDataLoaded)
|
||||
default:
|
||||
return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID())
|
||||
}
|
||||
|
@ -338,6 +345,7 @@ func NewSegmentV2(
|
|||
|
||||
segment := &LocalSegment{
|
||||
baseSegment: newBaseSegment(collection, segmentType, version, loadInfo),
|
||||
ptrLock: locker,
|
||||
ptr: segmentPtr,
|
||||
lastDeltaTimestamp: atomic.NewUint64(0),
|
||||
fields: typeutil.NewConcurrentMap[int64, *FieldInfo](),
|
||||
|
@ -355,41 +363,31 @@ func NewSegmentV2(
|
|||
return segment, nil
|
||||
}
|
||||
|
||||
func (s *LocalSegment) isValid() bool {
|
||||
return s.ptr != nil
|
||||
}
|
||||
|
||||
// RLock acquires the `ptrLock` and returns true if the pointer is valid
|
||||
// PinIfNotReleased acquires the `ptrLock` and returns true if the pointer is valid
|
||||
// Provide ONLY the read lock operations,
|
||||
// don't make `ptrLock` public to avoid abusing of the mutex.
|
||||
func (s *LocalSegment) RLock() error {
|
||||
s.ptrLock.RLock()
|
||||
if !s.isValid() {
|
||||
s.ptrLock.RUnlock()
|
||||
func (s *LocalSegment) PinIfNotReleased() error {
|
||||
if !s.ptrLock.PinIfNotReleased() {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LocalSegment) RUnlock() {
|
||||
s.ptrLock.RUnlock()
|
||||
func (s *LocalSegment) Unpin() {
|
||||
s.ptrLock.Unpin()
|
||||
}
|
||||
|
||||
func (s *LocalSegment) InsertCount() int64 {
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
return s.insertCount.Load()
|
||||
}
|
||||
|
||||
func (s *LocalSegment) RowNum() int64 {
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if !s.isValid() {
|
||||
// if segment is not loaded, return 0 (maybe not loaded or release by lru)
|
||||
if !s.ptrLock.RLockIf(state.IsDataLoaded) {
|
||||
log.Warn("segment is not valid", zap.Int64("segmentID", s.ID()))
|
||||
return 0
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
rowNum := s.rowNum.Load()
|
||||
if rowNum < 0 {
|
||||
|
@ -406,12 +404,10 @@ func (s *LocalSegment) RowNum() int64 {
|
|||
}
|
||||
|
||||
func (s *LocalSegment) MemSize() int64 {
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if !s.isValid() {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return 0
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
memSize := s.memSize.Load()
|
||||
if memSize < 0 {
|
||||
|
@ -449,11 +445,11 @@ func (s *LocalSegment) ExistIndex(fieldID int64) bool {
|
|||
}
|
||||
|
||||
func (s *LocalSegment) HasRawData(fieldID int64) bool {
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
if !s.isValid() {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return false
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
ret := C.HasRawData(s.ptr, C.int64_t(fieldID))
|
||||
return bool(ret)
|
||||
}
|
||||
|
@ -482,12 +478,11 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
|
|||
zap.Int64("segmentID", s.ID()),
|
||||
zap.String("segmentType", s.segmentType.String()),
|
||||
)
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
// TODO: check if the segment is readable but not released. too many related logic need to be refactor.
|
||||
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
traceCtx := ParseCTraceContext(ctx)
|
||||
|
||||
|
@ -520,12 +515,11 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
|
|||
}
|
||||
|
||||
func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
// TODO: check if the segment is readable but not released. too many related logic need to be refactor.
|
||||
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", s.Collection()),
|
||||
|
@ -616,13 +610,10 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
|
|||
if s.Type() != SegmentTypeGrowing {
|
||||
return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.segmentType.String())
|
||||
}
|
||||
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
offset, err := s.preInsert(ctx, len(rowIDs))
|
||||
if err != nil {
|
||||
|
@ -676,13 +667,10 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys []storage.Primary
|
|||
if len(primaryKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
cOffset := C.int64_t(0) // depre
|
||||
cSize := C.int64_t(len(primaryKeys))
|
||||
|
@ -743,12 +731,10 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys []storage.Primary
|
|||
|
||||
// -------------------------------------------------------------------------------------- interfaces for sealed segment
|
||||
func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, fields []*datapb.FieldBinlog) error {
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", s.Collection()),
|
||||
|
@ -847,16 +833,14 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
|
|||
|
||||
s.loadStatus.Store(string(options.LoadStatus))
|
||||
|
||||
s.ptrLock.RLock()
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, fmt.Sprintf("LoadFieldData-%d-%d", s.ID(), fieldID))
|
||||
defer sp.End()
|
||||
|
||||
if s.ptr == nil {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", s.Collection()),
|
||||
zap.Int64("partitionID", s.Partition()),
|
||||
|
@ -1011,12 +995,10 @@ func (s *LocalSegment) LoadDeltaData2(ctx context.Context, schema *schemapb.Coll
|
|||
}
|
||||
|
||||
func (s *LocalSegment) AddFieldDataInfo(ctx context.Context, rowCount int64, fields []*datapb.FieldBinlog) error {
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", s.Collection()),
|
||||
|
@ -1066,12 +1048,10 @@ func (s *LocalSegment) LoadDeltaData(ctx context.Context, deltaData *storage.Del
|
|||
pks, tss := deltaData.Pks, deltaData.Tss
|
||||
rowNum := deltaData.RowCount
|
||||
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", s.Collection()),
|
||||
|
@ -1224,12 +1204,10 @@ func (s *LocalSegment) UpdateIndexInfo(ctx context.Context, indexInfo *querypb.F
|
|||
zap.Int64("segmentID", s.ID()),
|
||||
zap.Int64("fieldID", indexInfo.FieldID),
|
||||
)
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
var status C.CStatus
|
||||
GetDynamicPool().Submit(func() (any, error) {
|
||||
|
@ -1263,12 +1241,10 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64) {
|
|||
zap.Int64("segmentID", s.ID()),
|
||||
zap.Int64("fieldID", fieldID),
|
||||
)
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
var status C.CStatus
|
||||
|
||||
|
@ -1287,11 +1263,11 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64) {
|
|||
}).Await()
|
||||
case "async":
|
||||
GetLoadPool().Submit(func() (any, error) {
|
||||
s.ptrLock.RLock()
|
||||
defer s.ptrLock.RUnlock()
|
||||
if s.ptr == nil {
|
||||
if !s.ptrLock.RLockIf(state.IsNotReleased) {
|
||||
return nil, nil
|
||||
}
|
||||
defer s.ptrLock.RUnlock()
|
||||
|
||||
cFieldID := C.int64_t(fieldID)
|
||||
status = C.WarmupChunkCache(s.ptr, cFieldID)
|
||||
if err := HandleCStatus(ctx, &status, ""); err != nil {
|
||||
|
@ -1357,26 +1333,17 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
|
|||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
/*
|
||||
void
|
||||
deleteSegment(CSegmentInterface segment);
|
||||
*/
|
||||
var ptr C.CSegmentInterface
|
||||
|
||||
// wait all read ops finished
|
||||
s.ptrLock.Lock()
|
||||
ptr = s.ptr
|
||||
s.ptr = nil
|
||||
if options.Scope == ReleaseScopeData {
|
||||
s.loadStatus.Store(string(LoadStatusMeta))
|
||||
}
|
||||
s.ptrLock.Unlock()
|
||||
|
||||
if ptr == nil {
|
||||
stateLockGuard := s.startRelease(options.Scope)
|
||||
if stateLockGuard == nil { // release is already done.
|
||||
return
|
||||
}
|
||||
// release will never fail
|
||||
defer stateLockGuard.Done(nil)
|
||||
|
||||
// wait all read ops finished
|
||||
ptr := s.ptr
|
||||
if options.Scope == ReleaseScopeData {
|
||||
s.loadStatus.Store(string(LoadStatusMeta))
|
||||
C.ClearSegmentData(ptr)
|
||||
return
|
||||
}
|
||||
|
@ -1406,3 +1373,20 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
|
|||
zap.Int64("insertCount", s.InsertCount()),
|
||||
)
|
||||
}
|
||||
|
||||
// StartLoadData starts the loading process of the segment.
|
||||
func (s *LocalSegment) StartLoadData() (state.LoadStateLockGuard, error) {
|
||||
return s.ptrLock.StartLoadData()
|
||||
}
|
||||
|
||||
// startRelease starts the releasing process of the segment.
|
||||
func (s *LocalSegment) startRelease(scope ReleaseScope) state.LoadStateLockGuard {
|
||||
switch scope {
|
||||
case ReleaseScopeData:
|
||||
return s.ptrLock.StartReleaseData()
|
||||
case ReleaseScopeAll:
|
||||
return s.ptrLock.StartReleaseAll()
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected release scope %d", scope))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,8 +62,10 @@ type Segment interface {
|
|||
Level() datapb.SegmentLevel
|
||||
LoadStatus() LoadStatus
|
||||
LoadInfo() *querypb.SegmentLoadInfo
|
||||
RLock() error
|
||||
RUnlock()
|
||||
// PinIfNotReleased the segment to prevent it from being released
|
||||
PinIfNotReleased() error
|
||||
// Unpin the segment to allow it to be released
|
||||
Unpin()
|
||||
|
||||
// Stats related
|
||||
// InsertCount returns the number of inserted rows, not effected by deletion
|
||||
|
|
|
@ -68,11 +68,11 @@ func NewL0Segment(collection *Collection,
|
|||
return segment, nil
|
||||
}
|
||||
|
||||
func (s *L0Segment) RLock() error {
|
||||
func (s *L0Segment) PinIfNotReleased() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *L0Segment) RUnlock() {}
|
||||
func (s *L0Segment) Unpin() {}
|
||||
|
||||
func (s *L0Segment) InsertCount() int64 {
|
||||
return 0
|
||||
|
|
|
@ -375,7 +375,23 @@ func (loader *segmentLoaderV2) LoadSegment(ctx context.Context,
|
|||
segment *LocalSegment,
|
||||
loadInfo *querypb.SegmentLoadInfo,
|
||||
loadstatus LoadStatus,
|
||||
) error {
|
||||
) (err error) {
|
||||
// TODO: we should create a transaction-like api to load segment for segment interface,
|
||||
// but not do many things in segment loader.
|
||||
stateLockGuard, err := segment.StartLoadData()
|
||||
// segment can not do load now.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// segment is already loaded.
|
||||
// TODO: if stateLockGuard is nil, we should not call LoadSegment anymore.
|
||||
// but current Load is not clear enough to do an actual state transition, keep previous logic to avoid introduced bug.
|
||||
if stateLockGuard != nil {
|
||||
stateLockGuard.Done(err)
|
||||
}
|
||||
}()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", segment.Collection()),
|
||||
zap.Int64("partitionID", segment.Partition()),
|
||||
|
@ -1008,7 +1024,23 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
|
|||
segment *LocalSegment,
|
||||
loadInfo *querypb.SegmentLoadInfo,
|
||||
loadStatus LoadStatus,
|
||||
) error {
|
||||
) (err error) {
|
||||
// TODO: we should create a transaction-like api to load segment for segment interface,
|
||||
// but not do many things in segment loader.
|
||||
stateLockGuard, err := segment.StartLoadData()
|
||||
// segment can not do load now.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// segment is already loaded.
|
||||
// TODO: if stateLockGuard is nil, we should not call LoadSegment anymore.
|
||||
// but current Load is not clear enough to do an actual state transition, keep previous logic to avoid introduced bug.
|
||||
if stateLockGuard != nil {
|
||||
stateLockGuard.Done(err)
|
||||
}
|
||||
}()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", segment.Collection()),
|
||||
zap.Int64("partitionID", segment.Partition()),
|
||||
|
|
|
@ -94,10 +94,13 @@ func (suite *SegmentSuite) SetupTest() {
|
|||
suite.chunkManager,
|
||||
)
|
||||
suite.Require().NoError(err)
|
||||
g, err := suite.sealed.(*LocalSegment).StartLoadData()
|
||||
suite.Require().NoError(err)
|
||||
for _, binlog := range binlogs {
|
||||
err = suite.sealed.(*LocalSegment).LoadFieldData(ctx, binlog.FieldID, int64(msgLength), binlog)
|
||||
suite.Require().NoError(err)
|
||||
}
|
||||
g.Done(nil)
|
||||
|
||||
suite.growing, err = NewSegment(ctx,
|
||||
suite.collection,
|
||||
|
@ -198,9 +201,7 @@ func (suite *SegmentSuite) TestSegmentReleased() {
|
|||
|
||||
sealed := suite.sealed.(*LocalSegment)
|
||||
|
||||
sealed.ptrLock.RLock()
|
||||
suite.False(sealed.isValid())
|
||||
sealed.ptrLock.RUnlock()
|
||||
suite.False(sealed.ptrLock.PinIfNotReleased())
|
||||
suite.EqualValues(0, sealed.RowNum())
|
||||
suite.EqualValues(0, sealed.MemSize())
|
||||
suite.False(sealed.HasRawData(101))
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type loadStateEnum int
|
||||
|
||||
// LoadState represent the state transition of segment.
|
||||
// LoadStateOnlyMeta: segment is created with meta, but not loaded.
|
||||
// LoadStateDataLoading: segment is loading data.
|
||||
// LoadStateDataLoaded: segment is full loaded, ready to be searched or queried.
|
||||
// LoadStateDataReleasing: segment is releasing data.
|
||||
// LoadStateReleased: segment is released.
|
||||
// LoadStateOnlyMeta -> LoadStateDataLoading -> LoadStateDataLoaded -> LoadStateDataReleasing -> (LoadStateReleased or LoadStateOnlyMeta)
|
||||
const (
|
||||
LoadStateOnlyMeta loadStateEnum = iota
|
||||
LoadStateDataLoading // There will be only one goroutine access segment when loading.
|
||||
LoadStateDataLoaded
|
||||
LoadStateDataReleasing // There will be only one goroutine access segment when releasing.
|
||||
LoadStateReleased
|
||||
)
|
||||
|
||||
// LoadState is the state of segment loading.
|
||||
func (ls loadStateEnum) String() string {
|
||||
switch ls {
|
||||
case LoadStateOnlyMeta:
|
||||
return "meta"
|
||||
case LoadStateDataLoading:
|
||||
return "loading-data"
|
||||
case LoadStateDataLoaded:
|
||||
return "loaded"
|
||||
case LoadStateDataReleasing:
|
||||
return "releasing-data"
|
||||
case LoadStateReleased:
|
||||
return "released"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// NewLoadStateLock creates a LoadState.
|
||||
func NewLoadStateLock(state loadStateEnum) *LoadStateLock {
|
||||
if state != LoadStateOnlyMeta && state != LoadStateDataLoaded {
|
||||
panic(fmt.Sprintf("invalid state for construction of LoadStateLock, %s", state.String()))
|
||||
}
|
||||
|
||||
mu := &sync.RWMutex{}
|
||||
return &LoadStateLock{
|
||||
mu: mu,
|
||||
cv: sync.Cond{L: mu},
|
||||
state: state,
|
||||
refCnt: atomic.NewInt32(0),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadStateLock is the state of segment loading.
|
||||
type LoadStateLock struct {
|
||||
mu *sync.RWMutex
|
||||
cv sync.Cond
|
||||
state loadStateEnum
|
||||
refCnt *atomic.Int32
|
||||
// ReleaseAll can be called only when refCnt is 0.
|
||||
// We need it to be modified when lock is
|
||||
}
|
||||
|
||||
// RLockIfNotReleased locks the segment if the state is not released.
|
||||
func (ls *LoadStateLock) RLockIf(pred StatePredicate) bool {
|
||||
ls.mu.RLock()
|
||||
if !pred(ls.state) {
|
||||
ls.mu.RUnlock()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// RUnlock unlocks the segment.
|
||||
func (ls *LoadStateLock) RUnlock() {
|
||||
ls.mu.RUnlock()
|
||||
}
|
||||
|
||||
// PinIfNotReleased pin the segment into memory, avoid ReleaseAll to release it.
|
||||
func (ls *LoadStateLock) PinIfNotReleased() bool {
|
||||
ls.mu.RLock()
|
||||
defer ls.mu.RUnlock()
|
||||
if ls.state == LoadStateReleased {
|
||||
return false
|
||||
}
|
||||
ls.refCnt.Inc()
|
||||
return true
|
||||
}
|
||||
|
||||
// Unpin unpin the segment, then segment can be released by ReleaseAll.
|
||||
func (ls *LoadStateLock) Unpin() {
|
||||
ls.mu.RLock()
|
||||
defer ls.mu.RUnlock()
|
||||
newCnt := ls.refCnt.Dec()
|
||||
if newCnt < 0 {
|
||||
panic("unpin more than pin")
|
||||
}
|
||||
if newCnt == 0 {
|
||||
// notify ReleaseAll to release segment if refcnt is zero.
|
||||
ls.cv.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// StartLoadData starts load segment data
|
||||
// Fast fail if segment is not in LoadStateOnlyMeta.
|
||||
func (ls *LoadStateLock) StartLoadData() (LoadStateLockGuard, error) {
|
||||
// only meta can be loaded.
|
||||
ls.cv.L.Lock()
|
||||
defer ls.cv.L.Unlock()
|
||||
|
||||
if ls.state == LoadStateDataLoaded {
|
||||
return nil, nil
|
||||
}
|
||||
if ls.state != LoadStateOnlyMeta {
|
||||
return nil, errors.New("segment is not in LoadStateOnlyMeta, cannot start to loading data")
|
||||
}
|
||||
ls.state = LoadStateDataLoading
|
||||
ls.cv.Broadcast()
|
||||
|
||||
return newLoadStateLockGuard(ls, LoadStateOnlyMeta, LoadStateDataLoaded), nil
|
||||
}
|
||||
|
||||
// StartReleaseData wait until the segment is releasable and starts releasing segment data.
|
||||
func (ls *LoadStateLock) StartReleaseData() (g LoadStateLockGuard) {
|
||||
ls.cv.L.Lock()
|
||||
defer ls.cv.L.Unlock()
|
||||
|
||||
ls.waitUntilCanReleaseData()
|
||||
|
||||
switch ls.state {
|
||||
case LoadStateDataLoaded:
|
||||
ls.state = LoadStateDataReleasing
|
||||
ls.cv.Broadcast()
|
||||
return newLoadStateLockGuard(ls, LoadStateDataLoaded, LoadStateOnlyMeta)
|
||||
case LoadStateOnlyMeta:
|
||||
// already transit to target state, do nothing.
|
||||
return nil
|
||||
case LoadStateReleased:
|
||||
// do nothing for empty segment.
|
||||
return nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unreachable code: invalid state when releasing data, %s", ls.state.String()))
|
||||
}
|
||||
}
|
||||
|
||||
// StartReleaseAll wait until the segment is releasable and starts releasing all segment.
|
||||
func (ls *LoadStateLock) StartReleaseAll() (g LoadStateLockGuard) {
|
||||
ls.cv.L.Lock()
|
||||
defer ls.cv.L.Unlock()
|
||||
|
||||
ls.waitUntilCanReleaseAll()
|
||||
|
||||
switch ls.state {
|
||||
case LoadStateDataLoaded:
|
||||
ls.state = LoadStateReleased
|
||||
ls.cv.Broadcast()
|
||||
return newNopLoadStateLockGuard()
|
||||
case LoadStateOnlyMeta:
|
||||
ls.state = LoadStateReleased
|
||||
ls.cv.Broadcast()
|
||||
return newNopLoadStateLockGuard()
|
||||
case LoadStateReleased:
|
||||
// already transit to target state, do nothing.
|
||||
return nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unreachable code: invalid state when releasing data, %s", ls.state.String()))
|
||||
}
|
||||
}
|
||||
|
||||
// waitUntilCanReleaseData waits until segment is release data able.
|
||||
func (ls *LoadStateLock) waitUntilCanReleaseData() {
|
||||
state := ls.state
|
||||
for state != LoadStateDataLoaded && state != LoadStateOnlyMeta && state != LoadStateReleased {
|
||||
ls.cv.Wait()
|
||||
state = ls.state
|
||||
}
|
||||
}
|
||||
|
||||
// waitUntilCanReleaseAll waits until segment is releasable.
|
||||
func (ls *LoadStateLock) waitUntilCanReleaseAll() {
|
||||
state := ls.state
|
||||
for (state != LoadStateDataLoaded && state != LoadStateOnlyMeta && state != LoadStateReleased) || ls.refCnt.Load() != 0 {
|
||||
ls.cv.Wait()
|
||||
state = ls.state
|
||||
}
|
||||
}
|
||||
|
||||
type StatePredicate func(state loadStateEnum) bool
|
||||
|
||||
// IsNotReleased checks if the segment is not released.
|
||||
func IsNotReleased(state loadStateEnum) bool {
|
||||
return state != LoadStateReleased
|
||||
}
|
||||
|
||||
// IsDataLoaded checks if the segment is loaded.
|
||||
func IsDataLoaded(state loadStateEnum) bool {
|
||||
return state == LoadStateDataLoaded
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package state
|
||||
|
||||
type LoadStateLockGuard interface {
|
||||
Done(err error)
|
||||
}
|
||||
|
||||
// newLoadStateLockGuard creates a LoadStateGuard.
|
||||
func newLoadStateLockGuard(ls *LoadStateLock, original loadStateEnum, target loadStateEnum) *loadStateLockGuard {
|
||||
return &loadStateLockGuard{
|
||||
ls: ls,
|
||||
original: original,
|
||||
target: target,
|
||||
}
|
||||
}
|
||||
|
||||
// loadStateLockGuard is a guard to update the state of LoadState.
|
||||
type loadStateLockGuard struct {
|
||||
ls *LoadStateLock
|
||||
original loadStateEnum
|
||||
target loadStateEnum
|
||||
}
|
||||
|
||||
// Done updates the state of LoadState to target state.
|
||||
func (g *loadStateLockGuard) Done(err error) {
|
||||
g.ls.cv.L.Lock()
|
||||
g.ls.cv.Broadcast()
|
||||
defer g.ls.cv.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
g.ls.state = g.original
|
||||
return
|
||||
}
|
||||
g.ls.state = g.target
|
||||
}
|
||||
|
||||
// newNopLoadStateLockGuard creates a LoadStateLockGuard that does nothing.
|
||||
func newNopLoadStateLockGuard() LoadStateLockGuard {
|
||||
return nopLockGuard{}
|
||||
}
|
||||
|
||||
// nopLockGuard is a guard that does nothing.
|
||||
type nopLockGuard struct{}
|
||||
|
||||
// Done does nothing.
|
||||
func (nopLockGuard) Done(err error) {}
|
|
@ -0,0 +1,224 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLoadStateLoadData(t *testing.T) {
|
||||
l := NewLoadStateLock(LoadStateOnlyMeta)
|
||||
// Test Load Data, roll back
|
||||
g, err := l.StartLoadData()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, g)
|
||||
assert.Equal(t, LoadStateDataLoading, l.state)
|
||||
g.Done(errors.New("test"))
|
||||
assert.Equal(t, LoadStateOnlyMeta, l.state)
|
||||
|
||||
// Test Load Data, success
|
||||
g, err = l.StartLoadData()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, g)
|
||||
assert.Equal(t, LoadStateDataLoading, l.state)
|
||||
g.Done(nil)
|
||||
assert.Equal(t, LoadStateDataLoaded, l.state)
|
||||
|
||||
// nothing to do with loaded.
|
||||
g, err = l.StartLoadData()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, g)
|
||||
|
||||
for _, s := range []loadStateEnum{
|
||||
LoadStateDataLoading,
|
||||
LoadStateDataReleasing,
|
||||
LoadStateReleased,
|
||||
} {
|
||||
l.state = s
|
||||
g, err = l.StartLoadData()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartReleaseData(t *testing.T) {
|
||||
l := NewLoadStateLock(LoadStateOnlyMeta)
|
||||
// Test Release Data, nothing to do on only meta.
|
||||
g := l.StartReleaseData()
|
||||
assert.Nil(t, g)
|
||||
assert.Equal(t, LoadStateOnlyMeta, l.state)
|
||||
|
||||
// roll back
|
||||
// never roll back on current using.
|
||||
l.state = LoadStateDataLoaded
|
||||
g = l.StartReleaseData()
|
||||
assert.Equal(t, LoadStateDataReleasing, l.state)
|
||||
assert.NotNil(t, g)
|
||||
g.Done(errors.New("test"))
|
||||
assert.Equal(t, LoadStateDataLoaded, l.state)
|
||||
|
||||
// success
|
||||
l.state = LoadStateDataLoaded
|
||||
g = l.StartReleaseData()
|
||||
assert.Equal(t, LoadStateDataReleasing, l.state)
|
||||
assert.NotNil(t, g)
|
||||
g.Done(nil)
|
||||
assert.Equal(t, LoadStateOnlyMeta, l.state)
|
||||
|
||||
// nothing to do on released
|
||||
l.state = LoadStateReleased
|
||||
g = l.StartReleaseData()
|
||||
assert.Nil(t, g)
|
||||
|
||||
// test blocking.
|
||||
l.state = LoadStateOnlyMeta
|
||||
g, err := l.StartLoadData()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
g := l.StartReleaseData()
|
||||
assert.NotNil(t, g)
|
||||
g.Done(nil)
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
// should be blocked because on loading.
|
||||
select {
|
||||
case <-ch:
|
||||
t.Errorf("should be blocked")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
// loaded finished.
|
||||
g.Done(nil)
|
||||
|
||||
// release can be started.
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("should not be blocked")
|
||||
}
|
||||
assert.Equal(t, LoadStateOnlyMeta, l.state)
|
||||
}
|
||||
|
||||
func TestStartReleaseAll(t *testing.T) {
|
||||
l := NewLoadStateLock(LoadStateOnlyMeta)
|
||||
// Test Release All, nothing to do on only meta.
|
||||
g := l.StartReleaseAll()
|
||||
assert.NotNil(t, g)
|
||||
assert.Equal(t, LoadStateReleased, l.state)
|
||||
g.Done(nil)
|
||||
assert.Equal(t, LoadStateReleased, l.state)
|
||||
|
||||
// roll back
|
||||
// never roll back on current using.
|
||||
l.state = LoadStateDataLoaded
|
||||
g = l.StartReleaseData()
|
||||
assert.Equal(t, LoadStateDataReleasing, l.state)
|
||||
assert.NotNil(t, g)
|
||||
g.Done(errors.New("test"))
|
||||
assert.Equal(t, LoadStateDataLoaded, l.state)
|
||||
|
||||
// success
|
||||
l.state = LoadStateDataLoaded
|
||||
g = l.StartReleaseAll()
|
||||
assert.Equal(t, LoadStateReleased, l.state)
|
||||
assert.NotNil(t, g)
|
||||
g.Done(nil)
|
||||
assert.Equal(t, LoadStateReleased, l.state)
|
||||
|
||||
// nothing to do on released
|
||||
l.state = LoadStateReleased
|
||||
g = l.StartReleaseAll()
|
||||
assert.Nil(t, g)
|
||||
|
||||
// test blocking.
|
||||
l.state = LoadStateOnlyMeta
|
||||
g, err := l.StartLoadData()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
g := l.StartReleaseAll()
|
||||
assert.NotNil(t, g)
|
||||
g.Done(nil)
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
// should be blocked because on loading.
|
||||
select {
|
||||
case <-ch:
|
||||
t.Errorf("should be blocked")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
// loaded finished.
|
||||
g.Done(nil)
|
||||
|
||||
// release can be started.
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("should not be blocked")
|
||||
}
|
||||
assert.Equal(t, LoadStateReleased, l.state)
|
||||
}
|
||||
|
||||
func TestRLock(t *testing.T) {
|
||||
l := NewLoadStateLock(LoadStateOnlyMeta)
|
||||
assert.True(t, l.RLockIf(IsNotReleased))
|
||||
l.RUnlock()
|
||||
assert.False(t, l.RLockIf(IsDataLoaded))
|
||||
|
||||
l = NewLoadStateLock(LoadStateDataLoaded)
|
||||
assert.True(t, l.RLockIf(IsNotReleased))
|
||||
l.RUnlock()
|
||||
assert.True(t, l.RLockIf(IsDataLoaded))
|
||||
l.RUnlock()
|
||||
|
||||
l = NewLoadStateLock(LoadStateOnlyMeta)
|
||||
l.StartReleaseAll().Done(nil)
|
||||
assert.False(t, l.RLockIf(IsNotReleased))
|
||||
assert.False(t, l.RLockIf(IsDataLoaded))
|
||||
}
|
||||
|
||||
func TestPin(t *testing.T) {
|
||||
l := NewLoadStateLock(LoadStateOnlyMeta)
|
||||
assert.True(t, l.PinIfNotReleased())
|
||||
l.Unpin()
|
||||
|
||||
l.StartReleaseAll().Done(nil)
|
||||
assert.False(t, l.PinIfNotReleased())
|
||||
|
||||
l = NewLoadStateLock(LoadStateDataLoaded)
|
||||
assert.True(t, l.PinIfNotReleased())
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
l.StartReleaseAll().Done(nil)
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
t.Errorf("should be blocked")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
|
||||
// should be blocked until refcnt is zero.
|
||||
assert.True(t, l.PinIfNotReleased())
|
||||
l.Unpin()
|
||||
select {
|
||||
case <-ch:
|
||||
t.Errorf("should be blocked")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
l.Unpin()
|
||||
<-ch
|
||||
|
||||
assert.Panics(t, func() {
|
||||
// too much unpin
|
||||
l.Unpin()
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue