fix: use stateful lock to avoid load and release on LocalSegment concurrently (#31606)

issue: #31605

---------

Signed-off-by: chyezh <chyezh@outlook.com>
pull/32001/head
chyezh 2024-04-08 17:09:16 +08:00 committed by GitHub
parent 2aa6f3d3ec
commit 73adf2a5cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 602 additions and 109 deletions

View File

@ -370,7 +370,7 @@ func (mgr *segmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, err
defer func() {
if err != nil {
for _, segment := range ret {
segment.RUnlock()
segment.Unpin()
}
}
}()
@ -379,7 +379,7 @@ func (mgr *segmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, err
if segment.Level() == datapb.SegmentLevel_L0 {
return true
}
err = segment.RLock()
err = segment.PinIfNotReleased()
if err != nil {
return false
}
@ -399,7 +399,7 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter)
defer func() {
if err != nil {
for _, segment := range lockedSegments {
segment.RUnlock()
segment.Unpin()
}
}
}()
@ -417,14 +417,14 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter)
sealedExist = sealedExist && filter(sealed, filters...)
if growingExist {
err = growing.RLock()
err = growing.PinIfNotReleased()
if err != nil {
return nil, err
}
lockedSegments = append(lockedSegments, growing)
}
if sealedExist {
err = sealed.RLock()
err = sealed.PinIfNotReleased()
if err != nil {
return nil, err
}
@ -442,7 +442,7 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter)
func (mgr *segmentManager) Unpin(segments []Segment) {
for _, segment := range segments {
segment.RUnlock()
segment.Unpin()
}
}

View File

@ -874,8 +874,8 @@ func (_c *MockSegment_Partition_Call) RunAndReturn(run func() int64) *MockSegmen
return _c
}
// RLock provides a mock function with given fields:
func (_m *MockSegment) RLock() error {
// PinIfNotReleased provides a mock function with given fields:
func (_m *MockSegment) PinIfNotReleased() error {
ret := _m.Called()
var r0 error
@ -915,8 +915,8 @@ func (_c *MockSegment_RLock_Call) RunAndReturn(run func() error) *MockSegment_RL
return _c
}
// RUnlock provides a mock function with given fields:
func (_m *MockSegment) RUnlock() {
// Unpin provides a mock function with given fields:
func (_m *MockSegment) Unpin() {
_m.Called()
}

View File

@ -31,7 +31,6 @@ import (
"io"
"strconv"
"strings"
"sync"
"unsafe"
"github.com/apache/arrow/go/v12/arrow/array"
@ -50,6 +49,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/querynodev2/segments/state"
"github.com/milvus-io/milvus/internal/storage"
typeutil_internal "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
@ -213,7 +213,7 @@ var _ Segment = (*LocalSegment)(nil)
// Segment is a wrapper of the underlying C-structure segment.
type LocalSegment struct {
baseSegment
ptrLock sync.RWMutex // protects segmentPtr
ptrLock *state.LoadStateLock
ptr C.CSegmentInterface
// cached results, to avoid too many CGO calls
@ -242,10 +242,13 @@ func NewSegment(ctx context.Context,
return NewL0Segment(collection, segmentType, version, loadInfo)
}
var cSegType C.SegmentType
var locker *state.LoadStateLock
switch segmentType {
case SegmentTypeSealed:
cSegType = C.Sealed
locker = state.NewLoadStateLock(state.LoadStateOnlyMeta)
case SegmentTypeGrowing:
locker = state.NewLoadStateLock(state.LoadStateDataLoaded)
cSegType = C.Growing
default:
return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID())
@ -275,6 +278,7 @@ func NewSegment(ctx context.Context,
segment := &LocalSegment{
baseSegment: newBaseSegment(collection, segmentType, version, loadInfo),
ptrLock: locker,
ptr: newPtr,
lastDeltaTimestamp: atomic.NewUint64(0),
fields: typeutil.NewConcurrentMap[int64, *FieldInfo](),
@ -308,11 +312,14 @@ func NewSegmentV2(
}
var segmentPtr C.CSegmentInterface
var status C.CStatus
var locker *state.LoadStateLock
switch segmentType {
case SegmentTypeSealed:
status = C.NewSegment(collection.collectionPtr, C.Sealed, C.int64_t(loadInfo.GetSegmentID()), &segmentPtr)
locker = state.NewLoadStateLock(state.LoadStateOnlyMeta)
case SegmentTypeGrowing:
status = C.NewSegment(collection.collectionPtr, C.Growing, C.int64_t(loadInfo.GetSegmentID()), &segmentPtr)
locker = state.NewLoadStateLock(state.LoadStateDataLoaded)
default:
return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID())
}
@ -338,6 +345,7 @@ func NewSegmentV2(
segment := &LocalSegment{
baseSegment: newBaseSegment(collection, segmentType, version, loadInfo),
ptrLock: locker,
ptr: segmentPtr,
lastDeltaTimestamp: atomic.NewUint64(0),
fields: typeutil.NewConcurrentMap[int64, *FieldInfo](),
@ -355,41 +363,31 @@ func NewSegmentV2(
return segment, nil
}
func (s *LocalSegment) isValid() bool {
return s.ptr != nil
}
// RLock acquires the `ptrLock` and returns true if the pointer is valid
// PinIfNotReleased acquires the `ptrLock` and returns true if the pointer is valid
// Provide ONLY the read lock operations,
// don't make `ptrLock` public to avoid abusing of the mutex.
func (s *LocalSegment) RLock() error {
s.ptrLock.RLock()
if !s.isValid() {
s.ptrLock.RUnlock()
func (s *LocalSegment) PinIfNotReleased() error {
if !s.ptrLock.PinIfNotReleased() {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
return nil
}
func (s *LocalSegment) RUnlock() {
s.ptrLock.RUnlock()
func (s *LocalSegment) Unpin() {
s.ptrLock.Unpin()
}
func (s *LocalSegment) InsertCount() int64 {
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
return s.insertCount.Load()
}
func (s *LocalSegment) RowNum() int64 {
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if !s.isValid() {
// if segment is not loaded, return 0 (maybe not loaded or release by lru)
if !s.ptrLock.RLockIf(state.IsDataLoaded) {
log.Warn("segment is not valid", zap.Int64("segmentID", s.ID()))
return 0
}
defer s.ptrLock.RUnlock()
rowNum := s.rowNum.Load()
if rowNum < 0 {
@ -406,12 +404,10 @@ func (s *LocalSegment) RowNum() int64 {
}
func (s *LocalSegment) MemSize() int64 {
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if !s.isValid() {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return 0
}
defer s.ptrLock.RUnlock()
memSize := s.memSize.Load()
if memSize < 0 {
@ -449,11 +445,11 @@ func (s *LocalSegment) ExistIndex(fieldID int64) bool {
}
func (s *LocalSegment) HasRawData(fieldID int64) bool {
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if !s.isValid() {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return false
}
defer s.ptrLock.RUnlock()
ret := C.HasRawData(s.ptr, C.int64_t(fieldID))
return bool(ret)
}
@ -482,12 +478,11 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
zap.Int64("segmentID", s.ID()),
zap.String("segmentType", s.segmentType.String()),
)
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
// TODO: check if the segment is readable but not released. too many related logic need to be refactor.
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
traceCtx := ParseCTraceContext(ctx)
@ -520,12 +515,11 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
}
func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
// TODO: check if the segment is readable but not released. too many related logic need to be refactor.
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()),
@ -616,13 +610,10 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
if s.Type() != SegmentTypeGrowing {
return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.segmentType.String())
}
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
offset, err := s.preInsert(ctx, len(rowIDs))
if err != nil {
@ -676,13 +667,10 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys []storage.Primary
if len(primaryKeys) == 0 {
return nil
}
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
cOffset := C.int64_t(0) // depre
cSize := C.int64_t(len(primaryKeys))
@ -743,12 +731,10 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys []storage.Primary
// -------------------------------------------------------------------------------------- interfaces for sealed segment
func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, fields []*datapb.FieldBinlog) error {
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()),
@ -847,16 +833,14 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
s.loadStatus.Store(string(options.LoadStatus))
s.ptrLock.RLock()
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, fmt.Sprintf("LoadFieldData-%d-%d", s.ID(), fieldID))
defer sp.End()
if s.ptr == nil {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
@ -1011,12 +995,10 @@ func (s *LocalSegment) LoadDeltaData2(ctx context.Context, schema *schemapb.Coll
}
func (s *LocalSegment) AddFieldDataInfo(ctx context.Context, rowCount int64, fields []*datapb.FieldBinlog) error {
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()),
@ -1066,12 +1048,10 @@ func (s *LocalSegment) LoadDeltaData(ctx context.Context, deltaData *storage.Del
pks, tss := deltaData.Pks, deltaData.Tss
rowNum := deltaData.RowCount
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()),
@ -1224,12 +1204,10 @@ func (s *LocalSegment) UpdateIndexInfo(ctx context.Context, indexInfo *querypb.F
zap.Int64("segmentID", s.ID()),
zap.Int64("fieldID", indexInfo.FieldID),
)
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
defer s.ptrLock.RUnlock()
var status C.CStatus
GetDynamicPool().Submit(func() (any, error) {
@ -1263,12 +1241,10 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64) {
zap.Int64("segmentID", s.ID()),
zap.Int64("fieldID", fieldID),
)
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return
}
defer s.ptrLock.RUnlock()
var status C.CStatus
@ -1287,11 +1263,11 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64) {
}).Await()
case "async":
GetLoadPool().Submit(func() (any, error) {
s.ptrLock.RLock()
defer s.ptrLock.RUnlock()
if s.ptr == nil {
if !s.ptrLock.RLockIf(state.IsNotReleased) {
return nil, nil
}
defer s.ptrLock.RUnlock()
cFieldID := C.int64_t(fieldID)
status = C.WarmupChunkCache(s.ptr, cFieldID)
if err := HandleCStatus(ctx, &status, ""); err != nil {
@ -1357,26 +1333,17 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
for _, opt := range opts {
opt(options)
}
/*
void
deleteSegment(CSegmentInterface segment);
*/
var ptr C.CSegmentInterface
// wait all read ops finished
s.ptrLock.Lock()
ptr = s.ptr
s.ptr = nil
if options.Scope == ReleaseScopeData {
s.loadStatus.Store(string(LoadStatusMeta))
}
s.ptrLock.Unlock()
if ptr == nil {
stateLockGuard := s.startRelease(options.Scope)
if stateLockGuard == nil { // release is already done.
return
}
// release will never fail
defer stateLockGuard.Done(nil)
// wait all read ops finished
ptr := s.ptr
if options.Scope == ReleaseScopeData {
s.loadStatus.Store(string(LoadStatusMeta))
C.ClearSegmentData(ptr)
return
}
@ -1406,3 +1373,20 @@ func (s *LocalSegment) Release(opts ...releaseOption) {
zap.Int64("insertCount", s.InsertCount()),
)
}
// StartLoadData starts the loading process of the segment.
func (s *LocalSegment) StartLoadData() (state.LoadStateLockGuard, error) {
return s.ptrLock.StartLoadData()
}
// startRelease starts the releasing process of the segment.
func (s *LocalSegment) startRelease(scope ReleaseScope) state.LoadStateLockGuard {
switch scope {
case ReleaseScopeData:
return s.ptrLock.StartReleaseData()
case ReleaseScopeAll:
return s.ptrLock.StartReleaseAll()
default:
panic(fmt.Sprintf("unexpected release scope %d", scope))
}
}

View File

@ -62,8 +62,10 @@ type Segment interface {
Level() datapb.SegmentLevel
LoadStatus() LoadStatus
LoadInfo() *querypb.SegmentLoadInfo
RLock() error
RUnlock()
// PinIfNotReleased the segment to prevent it from being released
PinIfNotReleased() error
// Unpin the segment to allow it to be released
Unpin()
// Stats related
// InsertCount returns the number of inserted rows, not effected by deletion

View File

@ -68,11 +68,11 @@ func NewL0Segment(collection *Collection,
return segment, nil
}
func (s *L0Segment) RLock() error {
func (s *L0Segment) PinIfNotReleased() error {
return nil
}
func (s *L0Segment) RUnlock() {}
func (s *L0Segment) Unpin() {}
func (s *L0Segment) InsertCount() int64 {
return 0

View File

@ -375,7 +375,23 @@ func (loader *segmentLoaderV2) LoadSegment(ctx context.Context,
segment *LocalSegment,
loadInfo *querypb.SegmentLoadInfo,
loadstatus LoadStatus,
) error {
) (err error) {
// TODO: we should create a transaction-like api to load segment for segment interface,
// but not do many things in segment loader.
stateLockGuard, err := segment.StartLoadData()
// segment can not do load now.
if err != nil {
return err
}
defer func() {
// segment is already loaded.
// TODO: if stateLockGuard is nil, we should not call LoadSegment anymore.
// but current Load is not clear enough to do an actual state transition, keep previous logic to avoid introduced bug.
if stateLockGuard != nil {
stateLockGuard.Done(err)
}
}()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", segment.Collection()),
zap.Int64("partitionID", segment.Partition()),
@ -1008,7 +1024,23 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
segment *LocalSegment,
loadInfo *querypb.SegmentLoadInfo,
loadStatus LoadStatus,
) error {
) (err error) {
// TODO: we should create a transaction-like api to load segment for segment interface,
// but not do many things in segment loader.
stateLockGuard, err := segment.StartLoadData()
// segment can not do load now.
if err != nil {
return err
}
defer func() {
// segment is already loaded.
// TODO: if stateLockGuard is nil, we should not call LoadSegment anymore.
// but current Load is not clear enough to do an actual state transition, keep previous logic to avoid introduced bug.
if stateLockGuard != nil {
stateLockGuard.Done(err)
}
}()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", segment.Collection()),
zap.Int64("partitionID", segment.Partition()),

View File

@ -94,10 +94,13 @@ func (suite *SegmentSuite) SetupTest() {
suite.chunkManager,
)
suite.Require().NoError(err)
g, err := suite.sealed.(*LocalSegment).StartLoadData()
suite.Require().NoError(err)
for _, binlog := range binlogs {
err = suite.sealed.(*LocalSegment).LoadFieldData(ctx, binlog.FieldID, int64(msgLength), binlog)
suite.Require().NoError(err)
}
g.Done(nil)
suite.growing, err = NewSegment(ctx,
suite.collection,
@ -198,9 +201,7 @@ func (suite *SegmentSuite) TestSegmentReleased() {
sealed := suite.sealed.(*LocalSegment)
sealed.ptrLock.RLock()
suite.False(sealed.isValid())
sealed.ptrLock.RUnlock()
suite.False(sealed.ptrLock.PinIfNotReleased())
suite.EqualValues(0, sealed.RowNum())
suite.EqualValues(0, sealed.MemSize())
suite.False(sealed.HasRawData(101))

View File

@ -0,0 +1,205 @@
package state
import (
"fmt"
"sync"
"github.com/cockroachdb/errors"
"go.uber.org/atomic"
)
type loadStateEnum int
// LoadState represent the state transition of segment.
// LoadStateOnlyMeta: segment is created with meta, but not loaded.
// LoadStateDataLoading: segment is loading data.
// LoadStateDataLoaded: segment is full loaded, ready to be searched or queried.
// LoadStateDataReleasing: segment is releasing data.
// LoadStateReleased: segment is released.
// LoadStateOnlyMeta -> LoadStateDataLoading -> LoadStateDataLoaded -> LoadStateDataReleasing -> (LoadStateReleased or LoadStateOnlyMeta)
const (
LoadStateOnlyMeta loadStateEnum = iota
LoadStateDataLoading // There will be only one goroutine access segment when loading.
LoadStateDataLoaded
LoadStateDataReleasing // There will be only one goroutine access segment when releasing.
LoadStateReleased
)
// LoadState is the state of segment loading.
func (ls loadStateEnum) String() string {
switch ls {
case LoadStateOnlyMeta:
return "meta"
case LoadStateDataLoading:
return "loading-data"
case LoadStateDataLoaded:
return "loaded"
case LoadStateDataReleasing:
return "releasing-data"
case LoadStateReleased:
return "released"
default:
return "unknown"
}
}
// NewLoadStateLock creates a LoadState.
func NewLoadStateLock(state loadStateEnum) *LoadStateLock {
if state != LoadStateOnlyMeta && state != LoadStateDataLoaded {
panic(fmt.Sprintf("invalid state for construction of LoadStateLock, %s", state.String()))
}
mu := &sync.RWMutex{}
return &LoadStateLock{
mu: mu,
cv: sync.Cond{L: mu},
state: state,
refCnt: atomic.NewInt32(0),
}
}
// LoadStateLock is the state of segment loading.
type LoadStateLock struct {
mu *sync.RWMutex
cv sync.Cond
state loadStateEnum
refCnt *atomic.Int32
// ReleaseAll can be called only when refCnt is 0.
// We need it to be modified when lock is
}
// RLockIfNotReleased locks the segment if the state is not released.
func (ls *LoadStateLock) RLockIf(pred StatePredicate) bool {
ls.mu.RLock()
if !pred(ls.state) {
ls.mu.RUnlock()
return false
}
return true
}
// RUnlock unlocks the segment.
func (ls *LoadStateLock) RUnlock() {
ls.mu.RUnlock()
}
// PinIfNotReleased pin the segment into memory, avoid ReleaseAll to release it.
func (ls *LoadStateLock) PinIfNotReleased() bool {
ls.mu.RLock()
defer ls.mu.RUnlock()
if ls.state == LoadStateReleased {
return false
}
ls.refCnt.Inc()
return true
}
// Unpin unpin the segment, then segment can be released by ReleaseAll.
func (ls *LoadStateLock) Unpin() {
ls.mu.RLock()
defer ls.mu.RUnlock()
newCnt := ls.refCnt.Dec()
if newCnt < 0 {
panic("unpin more than pin")
}
if newCnt == 0 {
// notify ReleaseAll to release segment if refcnt is zero.
ls.cv.Broadcast()
}
}
// StartLoadData starts load segment data
// Fast fail if segment is not in LoadStateOnlyMeta.
func (ls *LoadStateLock) StartLoadData() (LoadStateLockGuard, error) {
// only meta can be loaded.
ls.cv.L.Lock()
defer ls.cv.L.Unlock()
if ls.state == LoadStateDataLoaded {
return nil, nil
}
if ls.state != LoadStateOnlyMeta {
return nil, errors.New("segment is not in LoadStateOnlyMeta, cannot start to loading data")
}
ls.state = LoadStateDataLoading
ls.cv.Broadcast()
return newLoadStateLockGuard(ls, LoadStateOnlyMeta, LoadStateDataLoaded), nil
}
// StartReleaseData wait until the segment is releasable and starts releasing segment data.
func (ls *LoadStateLock) StartReleaseData() (g LoadStateLockGuard) {
ls.cv.L.Lock()
defer ls.cv.L.Unlock()
ls.waitUntilCanReleaseData()
switch ls.state {
case LoadStateDataLoaded:
ls.state = LoadStateDataReleasing
ls.cv.Broadcast()
return newLoadStateLockGuard(ls, LoadStateDataLoaded, LoadStateOnlyMeta)
case LoadStateOnlyMeta:
// already transit to target state, do nothing.
return nil
case LoadStateReleased:
// do nothing for empty segment.
return nil
default:
panic(fmt.Sprintf("unreachable code: invalid state when releasing data, %s", ls.state.String()))
}
}
// StartReleaseAll wait until the segment is releasable and starts releasing all segment.
func (ls *LoadStateLock) StartReleaseAll() (g LoadStateLockGuard) {
ls.cv.L.Lock()
defer ls.cv.L.Unlock()
ls.waitUntilCanReleaseAll()
switch ls.state {
case LoadStateDataLoaded:
ls.state = LoadStateReleased
ls.cv.Broadcast()
return newNopLoadStateLockGuard()
case LoadStateOnlyMeta:
ls.state = LoadStateReleased
ls.cv.Broadcast()
return newNopLoadStateLockGuard()
case LoadStateReleased:
// already transit to target state, do nothing.
return nil
default:
panic(fmt.Sprintf("unreachable code: invalid state when releasing data, %s", ls.state.String()))
}
}
// waitUntilCanReleaseData waits until segment is release data able.
func (ls *LoadStateLock) waitUntilCanReleaseData() {
state := ls.state
for state != LoadStateDataLoaded && state != LoadStateOnlyMeta && state != LoadStateReleased {
ls.cv.Wait()
state = ls.state
}
}
// waitUntilCanReleaseAll waits until segment is releasable.
func (ls *LoadStateLock) waitUntilCanReleaseAll() {
state := ls.state
for (state != LoadStateDataLoaded && state != LoadStateOnlyMeta && state != LoadStateReleased) || ls.refCnt.Load() != 0 {
ls.cv.Wait()
state = ls.state
}
}
type StatePredicate func(state loadStateEnum) bool
// IsNotReleased checks if the segment is not released.
func IsNotReleased(state loadStateEnum) bool {
return state != LoadStateReleased
}
// IsDataLoaded checks if the segment is loaded.
func IsDataLoaded(state loadStateEnum) bool {
return state == LoadStateDataLoaded
}

View File

@ -0,0 +1,45 @@
package state
type LoadStateLockGuard interface {
Done(err error)
}
// newLoadStateLockGuard creates a LoadStateGuard.
func newLoadStateLockGuard(ls *LoadStateLock, original loadStateEnum, target loadStateEnum) *loadStateLockGuard {
return &loadStateLockGuard{
ls: ls,
original: original,
target: target,
}
}
// loadStateLockGuard is a guard to update the state of LoadState.
type loadStateLockGuard struct {
ls *LoadStateLock
original loadStateEnum
target loadStateEnum
}
// Done updates the state of LoadState to target state.
func (g *loadStateLockGuard) Done(err error) {
g.ls.cv.L.Lock()
g.ls.cv.Broadcast()
defer g.ls.cv.L.Unlock()
if err != nil {
g.ls.state = g.original
return
}
g.ls.state = g.target
}
// newNopLoadStateLockGuard creates a LoadStateLockGuard that does nothing.
func newNopLoadStateLockGuard() LoadStateLockGuard {
return nopLockGuard{}
}
// nopLockGuard is a guard that does nothing.
type nopLockGuard struct{}
// Done does nothing.
func (nopLockGuard) Done(err error) {}

View File

@ -0,0 +1,224 @@
package state
import (
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
)
func TestLoadStateLoadData(t *testing.T) {
l := NewLoadStateLock(LoadStateOnlyMeta)
// Test Load Data, roll back
g, err := l.StartLoadData()
assert.NoError(t, err)
assert.NotNil(t, g)
assert.Equal(t, LoadStateDataLoading, l.state)
g.Done(errors.New("test"))
assert.Equal(t, LoadStateOnlyMeta, l.state)
// Test Load Data, success
g, err = l.StartLoadData()
assert.NoError(t, err)
assert.NotNil(t, g)
assert.Equal(t, LoadStateDataLoading, l.state)
g.Done(nil)
assert.Equal(t, LoadStateDataLoaded, l.state)
// nothing to do with loaded.
g, err = l.StartLoadData()
assert.NoError(t, err)
assert.Nil(t, g)
for _, s := range []loadStateEnum{
LoadStateDataLoading,
LoadStateDataReleasing,
LoadStateReleased,
} {
l.state = s
g, err = l.StartLoadData()
assert.Error(t, err)
assert.Nil(t, g)
}
}
func TestStartReleaseData(t *testing.T) {
l := NewLoadStateLock(LoadStateOnlyMeta)
// Test Release Data, nothing to do on only meta.
g := l.StartReleaseData()
assert.Nil(t, g)
assert.Equal(t, LoadStateOnlyMeta, l.state)
// roll back
// never roll back on current using.
l.state = LoadStateDataLoaded
g = l.StartReleaseData()
assert.Equal(t, LoadStateDataReleasing, l.state)
assert.NotNil(t, g)
g.Done(errors.New("test"))
assert.Equal(t, LoadStateDataLoaded, l.state)
// success
l.state = LoadStateDataLoaded
g = l.StartReleaseData()
assert.Equal(t, LoadStateDataReleasing, l.state)
assert.NotNil(t, g)
g.Done(nil)
assert.Equal(t, LoadStateOnlyMeta, l.state)
// nothing to do on released
l.state = LoadStateReleased
g = l.StartReleaseData()
assert.Nil(t, g)
// test blocking.
l.state = LoadStateOnlyMeta
g, err := l.StartLoadData()
assert.NoError(t, err)
ch := make(chan struct{})
go func() {
g := l.StartReleaseData()
assert.NotNil(t, g)
g.Done(nil)
close(ch)
}()
// should be blocked because on loading.
select {
case <-ch:
t.Errorf("should be blocked")
case <-time.After(500 * time.Millisecond):
}
// loaded finished.
g.Done(nil)
// release can be started.
select {
case <-ch:
case <-time.After(500 * time.Millisecond):
t.Errorf("should not be blocked")
}
assert.Equal(t, LoadStateOnlyMeta, l.state)
}
func TestStartReleaseAll(t *testing.T) {
l := NewLoadStateLock(LoadStateOnlyMeta)
// Test Release All, nothing to do on only meta.
g := l.StartReleaseAll()
assert.NotNil(t, g)
assert.Equal(t, LoadStateReleased, l.state)
g.Done(nil)
assert.Equal(t, LoadStateReleased, l.state)
// roll back
// never roll back on current using.
l.state = LoadStateDataLoaded
g = l.StartReleaseData()
assert.Equal(t, LoadStateDataReleasing, l.state)
assert.NotNil(t, g)
g.Done(errors.New("test"))
assert.Equal(t, LoadStateDataLoaded, l.state)
// success
l.state = LoadStateDataLoaded
g = l.StartReleaseAll()
assert.Equal(t, LoadStateReleased, l.state)
assert.NotNil(t, g)
g.Done(nil)
assert.Equal(t, LoadStateReleased, l.state)
// nothing to do on released
l.state = LoadStateReleased
g = l.StartReleaseAll()
assert.Nil(t, g)
// test blocking.
l.state = LoadStateOnlyMeta
g, err := l.StartLoadData()
assert.NoError(t, err)
ch := make(chan struct{})
go func() {
g := l.StartReleaseAll()
assert.NotNil(t, g)
g.Done(nil)
close(ch)
}()
// should be blocked because on loading.
select {
case <-ch:
t.Errorf("should be blocked")
case <-time.After(500 * time.Millisecond):
}
// loaded finished.
g.Done(nil)
// release can be started.
select {
case <-ch:
case <-time.After(500 * time.Millisecond):
t.Errorf("should not be blocked")
}
assert.Equal(t, LoadStateReleased, l.state)
}
func TestRLock(t *testing.T) {
l := NewLoadStateLock(LoadStateOnlyMeta)
assert.True(t, l.RLockIf(IsNotReleased))
l.RUnlock()
assert.False(t, l.RLockIf(IsDataLoaded))
l = NewLoadStateLock(LoadStateDataLoaded)
assert.True(t, l.RLockIf(IsNotReleased))
l.RUnlock()
assert.True(t, l.RLockIf(IsDataLoaded))
l.RUnlock()
l = NewLoadStateLock(LoadStateOnlyMeta)
l.StartReleaseAll().Done(nil)
assert.False(t, l.RLockIf(IsNotReleased))
assert.False(t, l.RLockIf(IsDataLoaded))
}
func TestPin(t *testing.T) {
l := NewLoadStateLock(LoadStateOnlyMeta)
assert.True(t, l.PinIfNotReleased())
l.Unpin()
l.StartReleaseAll().Done(nil)
assert.False(t, l.PinIfNotReleased())
l = NewLoadStateLock(LoadStateDataLoaded)
assert.True(t, l.PinIfNotReleased())
ch := make(chan struct{})
go func() {
l.StartReleaseAll().Done(nil)
close(ch)
}()
select {
case <-ch:
t.Errorf("should be blocked")
case <-time.After(500 * time.Millisecond):
}
// should be blocked until refcnt is zero.
assert.True(t, l.PinIfNotReleased())
l.Unpin()
select {
case <-ch:
t.Errorf("should be blocked")
case <-time.After(500 * time.Millisecond):
}
l.Unpin()
<-ch
assert.Panics(t, func() {
// too much unpin
l.Unpin()
})
}