Fix LRU cache may panic on double close (#20854)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/20861/head
congqixia 2022-11-29 09:47:14 +08:00 committed by GitHub
parent b5f178f22b
commit 938c09679c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 125 additions and 20 deletions

View File

@ -18,24 +18,25 @@ package cache
import (
"container/list"
"context"
"errors"
"fmt"
"sync"
)
// LRU generic utility for lru cache.
type LRU[K comparable, V any] struct {
ctx context.Context
cancel context.CancelFunc
evictList *list.List
items map[K]*list.Element
capacity int
onEvicted func(k K, v V)
m sync.RWMutex
evictedCh chan *entry[K, V]
closeCh chan struct{}
closeOnce sync.Once
stats *Stats
}
// Stats is the model for cache statistics.
type Stats struct {
hitCount float32
evictedCount float32
@ -43,6 +44,7 @@ type Stats struct {
writeCount float32
}
// String implement stringer for printing.
func (s *Stats) String() string {
var hitRatio float32
var evictedRatio float32
@ -59,29 +61,30 @@ type entry[K comparable, V any] struct {
value V
}
// NewLRU creates a LRU cache with provided capacity and `onEvicted` function.
// `onEvicted` will be executed when an item is chosed to be evicted.
func NewLRU[K comparable, V any](capacity int, onEvicted func(k K, v V)) (*LRU[K, V], error) {
if capacity <= 0 {
return nil, errors.New("cache size must be positive")
}
ctx, cancel := context.WithCancel(context.Background())
c := &LRU[K, V]{
ctx: ctx,
cancel: cancel,
capacity: capacity,
evictList: list.New(),
items: make(map[K]*list.Element),
onEvicted: onEvicted,
evictedCh: make(chan *entry[K, V], 16),
closeCh: make(chan struct{}),
stats: &Stats{},
}
go c.evictedWorker()
return c, nil
}
// evictedWorker executes onEvicted function for each evicted items.
func (c *LRU[K, V]) evictedWorker() {
for {
select {
case <-c.ctx.Done():
case <-c.closeCh:
return
case e, ok := <-c.evictedCh:
if ok {
@ -93,9 +96,27 @@ func (c *LRU[K, V]) evictedWorker() {
}
}
// closed returns whether cache is closed.
func (c *LRU[K, V]) closed() bool {
select {
case <-c.closeCh:
return true
default:
return false
}
}
// Add puts an item into cache.
func (c *LRU[K, V]) Add(key K, value V) {
c.m.Lock()
defer c.m.Unlock()
if c.closed() {
// evict since cache closed
c.onEvicted(key, value)
return
}
c.stats.writeCount++
if e, ok := c.items[key]; ok {
c.evictList.MoveToFront(e)
@ -120,9 +141,17 @@ func (c *LRU[K, V]) Add(key K, value V) {
}
}
// Get returns value for provided key.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.m.RLock()
defer c.m.RUnlock()
var zeroV V
if c.closed() {
// cache closed, returns nothing
return zeroV, false
}
c.stats.readCount++
if e, ok := c.items[key]; ok {
c.stats.hitCount++
@ -131,13 +160,18 @@ func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
return kv.value, true
}
var zeroV V
return zeroV, false
}
// Remove removes item associated with provided key.
func (c *LRU[K, V]) Remove(key K) {
c.m.Lock()
defer c.m.Unlock()
if c.closed() {
return
}
if e, ok := c.items[key]; ok {
c.evictList.Remove(e)
kv := e.Value.(*entry[K, V])
@ -148,16 +182,24 @@ func (c *LRU[K, V]) Remove(key K) {
}
}
// Contains returns whether items with provided key exists in cache.
func (c *LRU[K, V]) Contains(key K) bool {
c.m.RLock()
defer c.m.RUnlock()
if c.closed() {
return false
}
_, ok := c.items[key]
return ok
}
// Keys returns all the keys exist in cache.
func (c *LRU[K, V]) Keys() []K {
c.m.RLock()
defer c.m.RUnlock()
if c.closed() {
return nil
}
keys := make([]K, len(c.items))
i := 0
for ent := c.evictList.Back(); ent != nil; ent = ent.Prev() {
@ -167,16 +209,22 @@ func (c *LRU[K, V]) Keys() []K {
return keys
}
// Len returns items count in cache.
func (c *LRU[K, V]) Len() int {
c.m.RLock()
defer c.m.RUnlock()
if c.closed() {
return 0
}
return c.evictList.Len()
}
// Capacity returns cache capacity.
func (c *LRU[K, V]) Capacity() int {
return c.capacity
}
// Purge removes all items and put them into evictedCh.
func (c *LRU[K, V]) Purge() {
c.m.Lock()
defer c.m.Unlock()
@ -189,9 +237,14 @@ func (c *LRU[K, V]) Purge() {
c.evictList.Init()
}
// Resize changes the capacity of cache.
func (c *LRU[K, V]) Resize(capacity int) int {
c.m.Lock()
defer c.m.Unlock()
if c.closed() {
return 0
}
c.capacity = capacity
if capacity >= c.evictList.Len() {
return 0
@ -211,35 +264,48 @@ func (c *LRU[K, V]) Resize(capacity int) int {
return diff
}
// GetOldest returns the oldest item in cache.
func (c *LRU[K, V]) GetOldest() (K, V, bool) {
c.m.RLock()
defer c.m.RUnlock()
var (
zeroK K
zeroV V
)
if c.closed() {
return zeroK, zeroV, false
}
ent := c.evictList.Back()
if ent != nil {
kv := ent.Value.(*entry[K, V])
return kv.key, kv.value, true
}
var (
zeroK K
zeroV V
)
return zeroK, zeroV, false
}
// Close cleans up the cache resources.
func (c *LRU[K, V]) Close() {
c.Purge()
c.cancel()
remain := len(c.evictedCh)
for i := 0; i < remain; i++ {
e, ok := <-c.evictedCh
if ok {
c.closeOnce.Do(func() {
// fetch lock to
// - wait on-going operations done
// - block incoming operations
c.m.Lock()
close(c.closeCh)
c.m.Unlock()
// execute purge in a goroutine, otherwise Purge may block forever putting evictedCh
go func() {
c.Purge()
close(c.evictedCh)
}()
for e := range c.evictedCh {
c.onEvicted(e.key, e.value)
}
}
close(c.evictedCh)
})
}
// Stats returns cache statistics.
func (c *LRU[K, V]) Stats() *Stats {
return c.stats
}

View File

@ -22,6 +22,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewLRU(t *testing.T) {
@ -399,3 +400,41 @@ func TestLRU_Resize(t *testing.T) {
return atomic.LoadInt32(&evicted) == 1
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_closed(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
require.NoError(t, err)
c.Close()
c.Add("testKey", "testValue")
assert.Equal(t, int32(1), evicted)
_, ok := c.Get("testKey")
assert.False(t, ok)
assert.NotPanics(t, func() {
c.Remove("testKey")
})
contains := c.Contains("testKey")
assert.False(t, contains)
keys := c.Keys()
assert.Nil(t, keys)
l := c.Len()
assert.Equal(t, 0, l)
diff := c.Resize(1)
assert.Equal(t, 0, diff)
assert.Equal(t, 2, c.Capacity())
_, _, ok = c.GetOldest()
assert.False(t, ok)
assert.NotPanics(t, func() {
c.Close()
}, "double close")
}