Support an analogous Java Guava cache implementation (#20831)

Signed-off-by: yun.zhang <yun.zhang@zilliz.com>

Signed-off-by: yun.zhang <yun.zhang@zilliz.com>
pull/20845/head
jaime 2022-12-05 20:29:18 +08:00 committed by GitHub
parent 243d8cff82
commit 548e90ec68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 2266 additions and 824 deletions

View File

@ -40,7 +40,7 @@ var (
type VectorChunkManager struct {
cacheStorage ChunkManager
vectorStorage ChunkManager
cache *cache.LRU[string, *mmap.ReaderAt]
cache cache.LoadingCache[string, *mmap.ReaderAt]
insertCodec *InsertCodec
@ -48,7 +48,6 @@ type VectorChunkManager struct {
cacheLimit int64
cacheSize int64
cacheSizeMutex sync.Mutex
fixSize bool // Prevent cache capactiy from changing too frequently
}
var _ ChunkManager = (*VectorChunkManager)(nil)
@ -64,33 +63,52 @@ func NewVectorChunkManager(ctx context.Context, cacheStorage ChunkManager, vecto
cacheEnable: cacheEnable,
cacheLimit: cacheLimit,
}
if cacheEnable {
if cacheLimit <= 0 {
return nil, errors.New("cache limit must be positive if cacheEnable")
}
c, err := cache.NewLRU(defaultLocalCacheSize, func(k string, v *mmap.ReaderAt) {
size := v.Len()
err := v.Close()
if err != nil {
log.Error("Unmmap file failed", zap.Any("file", k))
}
err = cacheStorage.Remove(ctx, k)
if err != nil {
log.Error("cache storage remove file failed", zap.Any("file", k))
}
vcm.cacheSizeMutex.Lock()
vcm.cacheSize -= int64(size)
vcm.cacheSizeMutex.Unlock()
})
if err != nil {
return nil, err
}
vcm.cache = c
err := vcm.initCache(ctx)
if err != nil {
return nil, err
}
return vcm, nil
}
func (vcm *VectorChunkManager) initCache(ctx context.Context) error {
if !vcm.cacheEnable {
return nil
}
if vcm.cacheLimit <= 0 {
return errors.New("cache limit must be positive if cacheEnable")
}
loader := func(filePath string) (*mmap.ReaderAt, error) {
return vcm.readFile(ctx, filePath)
}
onRemoveFn := func(filePath string, v *mmap.ReaderAt) {
size := v.Len()
err := v.Close()
if err != nil {
log.Error("close mmap file failed", zap.Any("file", filePath))
}
err = vcm.cacheStorage.Remove(ctx, filePath)
if err != nil {
log.Error("cache storage remove file failed", zap.Any("file", filePath))
}
vcm.cacheSizeMutex.Lock()
vcm.cacheSize -= int64(size)
vcm.cacheSizeMutex.Unlock()
}
vcm.cache = cache.NewLoadingCache(loader,
cache.WithRemovalListener[string, *mmap.ReaderAt](onRemoveFn),
cache.WithMaximumSize[string, *mmap.ReaderAt](vcm.cacheLimit),
)
return nil
}
// For vector data, we will download vector file from storage. And we will
// deserialize the file for it has binlog style. At last we store pure vector
// data to local storage as cache.
@ -146,7 +164,7 @@ func (vcm *VectorChunkManager) Exist(ctx context.Context, filePath string) (bool
return vcm.vectorStorage.Exist(ctx, filePath)
}
func (vcm *VectorChunkManager) readWithCache(ctx context.Context, filePath string) ([]byte, error) {
func (vcm *VectorChunkManager) readFile(ctx context.Context, filePath string) (*mmap.ReaderAt, error) {
contents, err := vcm.vectorStorage.Read(ctx, filePath)
if err != nil {
return nil, err
@ -159,45 +177,31 @@ func (vcm *VectorChunkManager) readWithCache(ctx context.Context, filePath strin
if err != nil {
return nil, err
}
r, err := vcm.cacheStorage.Mmap(ctx, filePath)
if err != nil {
return nil, err
}
size, err := vcm.cacheStorage.Size(ctx, filePath)
if err != nil {
return nil, err
}
vcm.cacheSizeMutex.Lock()
vcm.cacheSize += size
vcm.cacheSize += int64(r.Len())
vcm.cacheSizeMutex.Unlock()
if !vcm.fixSize {
if vcm.cacheSize < vcm.cacheLimit {
if vcm.cache.Len() == vcm.cache.Capacity() {
newSize := float32(vcm.cache.Capacity()) * 1.25
vcm.cache.Resize(int(newSize))
}
} else {
// +1 is for add current value
vcm.cache.Resize(vcm.cache.Len() + 1)
vcm.fixSize = true
}
}
vcm.cache.Add(filePath, r)
return results, nil
return r, nil
}
// Read reads the pure vector data. If cached, it reads from local.
func (vcm *VectorChunkManager) Read(ctx context.Context, filePath string) ([]byte, error) {
if vcm.cacheEnable {
if r, ok := vcm.cache.Get(filePath); ok {
p := make([]byte, r.Len())
_, err := r.ReadAt(p, 0)
if err != nil {
return p, err
}
return p, nil
r, err := vcm.cache.Get(filePath)
if err != nil {
return nil, err
}
return vcm.readWithCache(ctx, filePath)
p := make([]byte, r.Len())
_, err = r.ReadAt(p, 0)
if err != nil {
return nil, err
}
return p, nil
}
contents, err := vcm.vectorStorage.Read(ctx, filePath)
if err != nil {
@ -238,7 +242,7 @@ func (vcm *VectorChunkManager) ListWithPrefix(ctx context.Context, prefix string
func (vcm *VectorChunkManager) Mmap(ctx context.Context, filePath string) (*mmap.ReaderAt, error) {
if vcm.cacheEnable && vcm.cache != nil {
if r, ok := vcm.cache.Get(filePath); ok {
if r, err := vcm.cache.Get(filePath); err == nil {
return r, nil
}
}
@ -252,19 +256,17 @@ func (vcm *VectorChunkManager) Reader(ctx context.Context, filePath string) (Fil
// ReadAt reads specific position data of vector. If cached, it reads from local.
func (vcm *VectorChunkManager) ReadAt(ctx context.Context, filePath string, off int64, length int64) ([]byte, error) {
if vcm.cacheEnable {
if r, ok := vcm.cache.Get(filePath); ok {
p := make([]byte, length)
_, err := r.ReadAt(p, off)
if err != nil {
return nil, err
}
return p, nil
}
results, err := vcm.readWithCache(ctx, filePath)
r, err := vcm.cache.Get(filePath)
if err != nil {
return nil, err
}
return results[off : off+length], nil
p := make([]byte, length)
_, err = r.ReadAt(p, off)
if err != nil {
return nil, err
}
return p, nil
}
contents, err := vcm.vectorStorage.Read(ctx, filePath)
if err != nil {
@ -292,7 +294,7 @@ func (vcm *VectorChunkManager) Remove(ctx context.Context, filePath string) erro
return err
}
if vcm.cacheEnable {
vcm.cache.Remove(filePath)
vcm.cache.Invalidate(filePath)
}
return nil
}
@ -304,7 +306,7 @@ func (vcm *VectorChunkManager) MultiRemove(ctx context.Context, filePaths []stri
}
if vcm.cacheEnable {
for _, p := range filePaths {
vcm.cache.Remove(p)
vcm.cache.Invalidate(p)
}
}
return nil
@ -321,7 +323,7 @@ func (vcm *VectorChunkManager) RemoveWithPrefix(ctx context.Context, prefix stri
return err
}
for _, p := range filePaths {
vcm.cache.Remove(p)
vcm.cache.Invalidate(p)
}
}
return nil

72
internal/util/cache/cache_interface.go vendored Normal file
View File

@ -0,0 +1,72 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
// Cache implement based on https://github.com/goburrow/cache, which
// provides partial implementations of Guava Cache, mainly support LRU.
// Cache is a key-value cache which entries are added and stayed in the
// cache until either are evicted or manually invalidated.
// TODO: support async clean up expired data
type Cache[K comparable, V any] interface {
// GetIfPresent returns value associated with Key or (nil, false)
// if there is no cached value for Key.
GetIfPresent(K) (V, bool)
// Put associates value with Key. If a value is already associated
// with Key, the old one will be replaced with Value.
Put(K, V)
// Invalidate discards cached value of the given Key.
Invalidate(K)
// InvalidateAll discards all entries.
InvalidateAll()
// Scan walk cache and apply a filter func to each element
Scan(func(K, V) bool) map[K]V
// Stats returns cache statistics.
Stats() *Stats
// Close implements io.Closer for cleaning up all resources.
// Users must ensure the cache is not being used before closing or
// after closed.
Close() error
}
// Func is a generic callback for entry events in the cache.
type Func[K comparable, V any] func(K, V)
// LoadingCache is a cache with values are loaded automatically and stored
// in the cache until either evicted or manually invalidated.
type LoadingCache[K comparable, V any] interface {
Cache[K, V]
// Get returns value associated with Key or call underlying LoaderFunc
// to load value if it is not present.
Get(K) (V, error)
// Refresh loads new value for Key. If the Key already existed, it will
// sync refresh it. or this function will block until the value is loaded.
Refresh(K) error
}
// LoaderFunc retrieves the value corresponding to given Key.
type LoaderFunc[K comparable, V any] func(K) (V, error)
type GetPreLoadDataFunc[K comparable, V any] func() (map[K]V, error)

140
internal/util/cache/hash.go vendored Normal file
View File

@ -0,0 +1,140 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"math"
"reflect"
)
// Hash is an interface implemented by cache keys to
// override default hash function.
type Hash interface {
Sum64() uint64
}
// sum calculates hash value of the given key.
func sum(k interface{}) uint64 {
switch h := k.(type) {
case Hash:
return h.Sum64()
case int:
return hashU64(uint64(h))
case int8:
return hashU32(uint32(h))
case int16:
return hashU32(uint32(h))
case int32:
return hashU32(uint32(h))
case int64:
return hashU64(uint64(h))
case uint:
return hashU64(uint64(h))
case uint8:
return hashU32(uint32(h))
case uint16:
return hashU32(uint32(h))
case uint32:
return hashU32(h)
case uint64:
return hashU64(h)
case uintptr:
return hashU64(uint64(h))
case float32:
return hashU32(math.Float32bits(h))
case float64:
return hashU64(math.Float64bits(h))
case bool:
if h {
return 1
}
return 0
case string:
return hashString(h)
}
// TODO: complex64 and complex128
if h, ok := hashPointer(k); ok {
return h
}
// TODO: use gob to encode k to bytes then hash.
return 0
}
const (
fnvOffset uint64 = 14695981039346656037
fnvPrime uint64 = 1099511628211
)
func hashU64(v uint64) uint64 {
// Inline code from hash/fnv to reduce memory allocations
h := fnvOffset
// for i := uint(0); i < 64; i += 8 {
// h ^= (v >> i) & 0xFF
// h *= fnvPrime
// }
h ^= (v >> 0) & 0xFF
h *= fnvPrime
h ^= (v >> 8) & 0xFF
h *= fnvPrime
h ^= (v >> 16) & 0xFF
h *= fnvPrime
h ^= (v >> 24) & 0xFF
h *= fnvPrime
h ^= (v >> 32) & 0xFF
h *= fnvPrime
h ^= (v >> 40) & 0xFF
h *= fnvPrime
h ^= (v >> 48) & 0xFF
h *= fnvPrime
h ^= (v >> 56) & 0xFF
h *= fnvPrime
return h
}
func hashU32(v uint32) uint64 {
h := fnvOffset
h ^= uint64(v>>0) & 0xFF
h *= fnvPrime
h ^= uint64(v>>8) & 0xFF
h *= fnvPrime
h ^= uint64(v>>16) & 0xFF
h *= fnvPrime
h ^= uint64(v>>24) & 0xFF
h *= fnvPrime
return h
}
// hashString calculates hash value using FNV-1a algorithm.
func hashString(data string) uint64 {
// Inline code from hash/fnv to reduce memory allocations
h := fnvOffset
for _, b := range data {
h ^= uint64(b)
h *= fnvPrime
}
return h
}
func hashPointer(k interface{}) (uint64, bool) {
v := reflect.ValueOf(k)
switch v.Kind() {
case reflect.Ptr, reflect.UnsafePointer, reflect.Func, reflect.Slice, reflect.Map, reflect.Chan:
return hashU64(uint64(v.Pointer())), true
default:
return 0, false
}
}

100
internal/util/cache/hash_test.go vendored Normal file
View File

@ -0,0 +1,100 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"encoding/binary"
"fmt"
"hash/fnv"
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func sumFNV(data []byte) uint64 {
h := fnv.New64a()
h.Write(data)
return h.Sum64()
}
func sumFNVu64(v uint64) uint64 {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, v)
return sumFNV(b)
}
func sumFNVu32(v uint32) uint64 {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, v)
return sumFNV(b)
}
func TestSum(t *testing.T) {
var tests = []struct {
k interface{}
h uint64
}{
{int(-1), sumFNVu64(^uint64(1) + 1)},
{int8(-8), sumFNVu32(^uint32(8) + 1)},
{int16(-16), sumFNVu32(^uint32(16) + 1)},
{int32(-32), sumFNVu32(^uint32(32) + 1)},
{int64(-64), sumFNVu64(^uint64(64) + 1)},
{uint(1), sumFNVu64(1)},
{uint8(8), sumFNVu32(8)},
{uint16(16), sumFNVu32(16)},
{uint32(32), sumFNVu32(32)},
{uint64(64), sumFNVu64(64)},
{byte(255), sumFNVu32(255)},
{rune(1024), sumFNVu32(1024)},
{true, 1},
{false, 0},
{float32(2.5), sumFNVu32(0x40200000)},
{float64(2.5), sumFNVu64(0x4004000000000000)},
/* #nosec G103 */
{uintptr(unsafe.Pointer(t)), sumFNVu64(uint64(uintptr(unsafe.Pointer(t))))},
{"", sumFNV(nil)},
{"string", sumFNV([]byte("string"))},
/* #nosec G103 */
{t, sumFNVu64(uint64(uintptr(unsafe.Pointer(t))))},
{(*testing.T)(nil), sumFNVu64(0)},
}
for _, tt := range tests {
h := sum(tt.k)
assert.Equal(t, h, tt.h, fmt.Sprintf("unexpected hash: %v (0x%x), key: %+v (%T), want: %v",
h, h, tt.k, tt.k, tt.h))
}
}
func BenchmarkSumInt(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
sum(0x0105)
}
})
}
func BenchmarkSumString(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
sum("09130105060103210913010506010321091301050601032109130105060103210913010506010321")
}
})
}

589
internal/util/cache/local_cache.go vendored Normal file
View File

@ -0,0 +1,589 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"errors"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
)
const (
// Default maximum number of cache entries.
maximumCapacity = 1 << 30
// Buffer size of entry channels
chanBufSize = 64
// Maximum number of entries to be drained in a single clean up.
drainMax = 16
// Number of cache access operations that will trigger clean up.
drainThreshold = 64
)
// currentTime is an alias for time.Now, used for testing.
var currentTime = time.Now
// localCache is an asynchronous LRU cache.
type localCache[K comparable, V any] struct {
// internal data structure
cache cache // Must be aligned on 32-bit
// user configurations
expireAfterAccess time.Duration
expireAfterWrite time.Duration
refreshAfterWrite time.Duration
policyName string
onInsertion Func[K, V]
onRemoval Func[K, V]
loader LoaderFunc[K, V]
getPreLoadData GetPreLoadDataFunc[K, V]
stats StatsCounter
// cap is the cache capacity.
cap int64
// accessQueue is the cache retention policy, which manages entries by access time.
accessQueue policy
// writeQueue is for managing entries by write time.
// It is only fulfilled when expireAfterWrite or refreshAfterWrite is set.
writeQueue policy
// events is the cache event queue for processEntries
events chan entryEvent
// readCount is a counter of the number of reads since the last write.
readCount int32
// for closing routines created by this cache.
closing int32
closeWG sync.WaitGroup
}
// newLocalCache returns a default localCache.
// init must be called before this cache can be used.
func newLocalCache[K comparable, V any]() *localCache[K, V] {
return &localCache[K, V]{
cap: maximumCapacity,
cache: cache{},
stats: &statsCounter{},
}
}
// init initializes cache replacement policy after all user configuration properties are set.
func (c *localCache[K, V]) init() {
c.accessQueue = newPolicy()
c.accessQueue.init(&c.cache, c.cap)
if c.expireAfterWrite > 0 || c.refreshAfterWrite > 0 {
c.writeQueue = &recencyQueue{}
} else {
c.writeQueue = discardingQueue{}
}
c.writeQueue.init(&c.cache, c.cap)
c.events = make(chan entryEvent, chanBufSize)
c.closeWG.Add(1)
go c.processEntries()
if c.getPreLoadData != nil {
c.asyncPreload()
}
}
// Close implements io.Closer and always returns a nil error.
// Caller would ensure the cache is not being used (reading and writing) before closing.
func (c *localCache[K, V]) Close() error {
if atomic.CompareAndSwapInt32(&c.closing, 0, 1) {
// Do not close events channel to avoid panic when cache is still being used.
c.events <- entryEvent{nil, eventClose}
// Wait for the goroutine to close this channel
c.closeWG.Wait()
}
return nil
}
// GetIfPresent gets cached value from entries list and updates
// last access time for the entry if it is found.
func (c *localCache[K, V]) GetIfPresent(k K) (v V, exist bool) {
en := c.cache.get(k, sum(k))
if en == nil {
c.stats.RecordMisses(1)
return v, false
}
now := currentTime()
if c.isExpired(en, now) {
c.stats.RecordMisses(1)
c.sendEvent(eventDelete, en)
return v, false
}
c.stats.RecordHits(1)
c.setEntryAccessTime(en, now)
c.sendEvent(eventAccess, en)
return en.getValue().(V), true
}
// Put adds new entry to entries list.
func (c *localCache[K, V]) Put(k K, v V) {
h := sum(k)
en := c.cache.get(k, h)
now := currentTime()
if en == nil {
en = newEntry(k, v, h)
c.setEntryWriteTime(en, now)
c.setEntryAccessTime(en, now)
// Add to the cache directly so the new value is available immediately.
// However, only do this within the cache capacity (approximately).
if c.cap == 0 || int64(c.cache.len()) < c.cap {
cen := c.cache.getOrSet(en)
if cen != nil {
cen.setValue(v)
en = cen
}
}
} else {
// Update value and send notice
en.setValue(v)
en.setWriteTime(now.UnixNano())
}
c.sendEvent(eventWrite, en)
}
// Invalidate removes the entry associated with key k.
func (c *localCache[K, V]) Invalidate(k K) {
en := c.cache.get(k, sum(k))
if en != nil {
en.setInvalidated(true)
c.sendEvent(eventDelete, en)
}
}
// InvalidateAll resets entries list.
func (c *localCache[K, V]) InvalidateAll() {
c.cache.walk(func(en *entry) {
en.setInvalidated(true)
})
c.sendEvent(eventDelete, nil)
}
// Scan entries list with a filter function
func (c *localCache[K, V]) Scan(filter func(K, V) bool) map[K]V {
ret := make(map[K]V)
c.cache.walk(func(en *entry) {
k := en.key.(K)
v := en.getValue().(V)
if filter(k, v) {
ret[k] = v
}
})
return ret
}
// Get returns value associated with k or call underlying loader to retrieve value
// if it is not in the cache. The returned value is only cached when loader returns
// nil error.
func (c *localCache[K, V]) Get(k K) (V, error) {
en := c.cache.get(k, sum(k))
if en == nil {
c.stats.RecordMisses(1)
return c.load(k)
}
// Check if this entry needs to be refreshed
now := currentTime()
if c.isExpired(en, now) {
c.stats.RecordMisses(1)
if c.loader == nil {
c.sendEvent(eventDelete, en)
} else {
// Update value if expired
c.setEntryAccessTime(en, now)
c.refresh(en)
}
} else {
c.stats.RecordHits(1)
c.setEntryAccessTime(en, now)
c.sendEvent(eventAccess, en)
}
return en.getValue().(V), nil
}
// Refresh synchronously load and block until it value is loaded.
func (c *localCache[K, V]) Refresh(k K) error {
if c.loader == nil {
return errors.New("cache loader should be set")
}
en := c.cache.get(k, sum(k))
var err error
if en == nil {
_, err = c.load(k)
} else {
err = c.refresh(en)
}
return err
}
// Stats copies cache stats to t.
func (c *localCache[K, V]) Stats() *Stats {
t := &Stats{}
c.stats.Snapshot(t)
return t
}
// asyncPreload async preload cache by Put
func (c *localCache[K, V]) asyncPreload() error {
var err error
go func() {
var data map[K]V
data, err = c.getPreLoadData()
if err != nil {
return
}
for k, v := range data {
c.Put(k, v)
}
}()
return nil
}
func (c *localCache[K, V]) processEntries() {
defer c.closeWG.Done()
for e := range c.events {
switch e.event {
case eventWrite:
c.write(e.entry)
c.postWriteCleanup()
case eventAccess:
c.access(e.entry)
c.postReadCleanup()
case eventDelete:
if e.entry == nil {
c.removeAll()
} else {
c.remove(e.entry)
}
c.postReadCleanup()
case eventClose:
c.removeAll()
return
}
}
}
// sendEvent sends event only when the cache is not closing/closed.
func (c *localCache[K, V]) sendEvent(typ event, en *entry) {
if atomic.LoadInt32(&c.closing) == 0 {
c.events <- entryEvent{en, typ}
}
}
// This function must only be called from processEntries goroutine.
func (c *localCache[K, V]) write(en *entry) {
ren := c.accessQueue.write(en)
c.writeQueue.write(en)
if c.onInsertion != nil {
c.onInsertion(en.key.(K), en.getValue().(V))
}
if ren != nil {
c.writeQueue.remove(ren)
// An entry has been evicted
c.stats.RecordEviction()
if c.onRemoval != nil {
c.onRemoval(ren.key.(K), ren.getValue().(V))
}
}
}
// removeAll remove all entries in the cache.
// This function must only be called from processEntries goroutine.
func (c *localCache[K, V]) removeAll() {
c.accessQueue.iterate(func(en *entry) bool {
c.remove(en)
return true
})
}
// remove removes the given element from the cache and entries list.
// It also calls onRemoval callback if it is set.
func (c *localCache[K, V]) remove(en *entry) {
ren := c.accessQueue.remove(en)
c.writeQueue.remove(en)
if ren != nil && c.onRemoval != nil {
c.onRemoval(ren.key.(K), ren.getValue().(V))
}
}
// access moves the given element to the top of the entries list.
// This function must only be called from processEntries goroutine.
func (c *localCache[K, V]) access(en *entry) {
c.accessQueue.access(en)
}
// load uses current loader to synchronously retrieve value for k and adds new
// entry to the cache only if loader returns a nil error.
func (c *localCache[K, V]) load(k K) (v V, err error) {
if c.loader == nil {
var ret V
return ret, errors.New("cache loader function must be set")
}
// TODO: Poll the value instead when the entry is loading.
start := currentTime()
v, err = c.loader(k)
now := currentTime()
loadTime := now.Sub(start)
if err != nil {
c.stats.RecordLoadError(loadTime)
return v, err
}
c.stats.RecordLoadSuccess(loadTime)
en := newEntry(k, v, sum(k))
c.setEntryWriteTime(en, now)
c.setEntryAccessTime(en, now)
c.sendEvent(eventWrite, en)
return v, nil
}
// refresh reloads value for the given key. If loader returns an error,
// that error will be omitted. Otherwise, the entry value will be updated.
func (c *localCache[K, V]) refresh(en *entry) error {
defer en.setLoading(false)
start := currentTime()
v, err := c.loader(en.key.(K))
now := currentTime()
loadTime := now.Sub(start)
if err == nil {
c.stats.RecordLoadSuccess(loadTime)
en.setValue(v)
en.setWriteTime(now.UnixNano())
c.sendEvent(eventWrite, en)
} else {
log.Warn("refresh cache fail", zap.Any("key", en.key), zap.Error(err))
c.stats.RecordLoadError(loadTime)
}
return err
}
// postReadCleanup is run after entry access/delete event.
// This function must only be called from processEntries goroutine.
func (c *localCache[K, V]) postReadCleanup() {
if atomic.AddInt32(&c.readCount, 1) > drainThreshold {
atomic.StoreInt32(&c.readCount, 0)
c.expireEntries()
}
}
// postWriteCleanup is run after entry add event.
// This function must only be called from processEntries goroutine.
func (c *localCache[K, V]) postWriteCleanup() {
atomic.StoreInt32(&c.readCount, 0)
c.expireEntries()
}
// expireEntries removes expired entries.
func (c *localCache[K, V]) expireEntries() {
remain := drainMax
now := currentTime()
if c.expireAfterAccess > 0 {
expiry := now.Add(-c.expireAfterAccess).UnixNano()
c.accessQueue.iterate(func(en *entry) bool {
if remain == 0 || en.getAccessTime() >= expiry {
// Can stop as the entries are sorted by access time.
// (the next entry is accessed more recently.)
return false
}
// accessTime + expiry passed
c.remove(en)
c.stats.RecordEviction()
remain--
return remain > 0
})
}
if remain > 0 && c.expireAfterWrite > 0 {
expiry := now.Add(-c.expireAfterWrite).UnixNano()
c.writeQueue.iterate(func(en *entry) bool {
if remain == 0 || en.getWriteTime() >= expiry {
return false
}
// writeTime + expiry passed
c.remove(en)
c.stats.RecordEviction()
remain--
return remain > 0
})
}
if remain > 0 && c.loader != nil && c.refreshAfterWrite > 0 {
expiry := now.Add(-c.refreshAfterWrite).UnixNano()
c.writeQueue.iterate(func(en *entry) bool {
if remain == 0 || en.getWriteTime() >= expiry {
return false
}
err := c.refresh(en)
if err == nil {
remain--
}
return remain > 0
})
}
}
func (c *localCache[K, V]) isExpired(en *entry, now time.Time) bool {
if en.getInvalidated() {
return true
}
if c.expireAfterAccess > 0 && en.getAccessTime() < now.Add(-c.expireAfterAccess).UnixNano() {
// accessTime + expiry passed
return true
}
if c.expireAfterWrite > 0 && en.getWriteTime() < now.Add(-c.expireAfterWrite).UnixNano() {
// writeTime + expiry passed
return true
}
return false
}
func (c *localCache[K, V]) needRefresh(en *entry, now time.Time) bool {
if en.getLoading() {
return false
}
if c.refreshAfterWrite > 0 {
tm := en.getWriteTime()
if tm > 0 && tm < now.Add(-c.refreshAfterWrite).UnixNano() {
// writeTime + refresh passed
return true
}
}
return false
}
// setEntryAccessTime sets access time if needed.
func (c *localCache[K, V]) setEntryAccessTime(en *entry, now time.Time) {
if c.expireAfterAccess > 0 {
en.setAccessTime(now.UnixNano())
}
}
// setEntryWriteTime sets write time if needed.
func (c *localCache[K, V]) setEntryWriteTime(en *entry, now time.Time) {
if c.expireAfterWrite > 0 || c.refreshAfterWrite > 0 {
en.setWriteTime(now.UnixNano())
}
}
// NewCache returns a local in-memory Cache.
func NewCache[K comparable, V any](options ...Option[K, V]) Cache[K, V] {
c := newLocalCache[K, V]()
for _, opt := range options {
opt(c)
}
c.init()
return c
}
// NewLoadingCache returns a new LoadingCache with given loader function
// and cache options.
func NewLoadingCache[K comparable, V any](loader LoaderFunc[K, V], options ...Option[K, V]) LoadingCache[K, V] {
c := newLocalCache[K, V]()
c.loader = loader
for _, opt := range options {
opt(c)
}
c.init()
return c
}
// Option add options for default Cache.
type Option[K comparable, V any] func(c *localCache[K, V])
// WithMaximumSize returns an Option which sets maximum size for the cache.
// Any non-positive numbers is considered as unlimited.
func WithMaximumSize[K comparable, V any](size int64) Option[K, V] {
if size < 0 {
size = 0
}
if size > maximumCapacity {
size = maximumCapacity
}
return func(c *localCache[K, V]) {
c.cap = size
}
}
// WithRemovalListener returns an Option to set cache to call onRemoval for each
// entry evicted from the cache.
func WithRemovalListener[K comparable, V any](onRemoval Func[K, V]) Option[K, V] {
return func(c *localCache[K, V]) {
c.onRemoval = onRemoval
}
}
// WithExpireAfterAccess returns an option to expire a cache entry after the
// given duration without being accessed.
func WithExpireAfterAccess[K comparable, V any](d time.Duration) Option[K, V] {
return func(c *localCache[K, V]) {
c.expireAfterAccess = d
}
}
// WithExpireAfterWrite returns an option to expire a cache entry after the
// given duration from creation.
func WithExpireAfterWrite[K comparable, V any](d time.Duration) Option[K, V] {
return func(c *localCache[K, V]) {
c.expireAfterWrite = d
}
}
// WithRefreshAfterWrite returns an option to refresh a cache entry after the
// given duration. This option is only applicable for LoadingCache.
func WithRefreshAfterWrite[K comparable, V any](d time.Duration) Option[K, V] {
return func(c *localCache[K, V]) {
c.refreshAfterWrite = d
}
}
// WithStatsCounter returns an option which overrides default cache stats counter.
func WithStatsCounter[K comparable, V any](st StatsCounter) Option[K, V] {
return func(c *localCache[K, V]) {
c.stats = st
}
}
// WithPolicy returns an option which sets cache policy associated to the given name.
// Supported policies are: lru, slru.
func WithPolicy[K comparable, V any](name string) Option[K, V] {
return func(c *localCache[K, V]) {
c.policyName = name
}
}
// WithAsyncInitPreLoader return an option which to async loading data during initialization.
func WithAsyncInitPreLoader[K comparable, V any](fn GetPreLoadDataFunc[K, V]) Option[K, V] {
return func(c *localCache[K, V]) {
c.getPreLoadData = fn
}
}
func WithInsertionListener[K comparable, V any](onInsertion Func[K, V]) Option[K, V] {
return func(c *localCache[K, V]) {
c.onInsertion = onInsertion
}
}

502
internal/util/cache/local_cache_test.go vendored Normal file
View File

@ -0,0 +1,502 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"errors"
"fmt"
"math/rand"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCache(t *testing.T) {
data := map[string]int{
"1": 1,
"2": 2,
}
wg := sync.WaitGroup{}
c := NewCache(WithInsertionListener(func(string, int) {
wg.Done()
}))
defer c.Close()
wg.Add(len(data))
for k, v := range data {
c.Put(k, v)
}
wg.Wait()
for k, dv := range data {
v, ok := c.GetIfPresent(k)
assert.True(t, ok)
assert.Equal(t, v, dv)
}
ret := c.Scan(
func(k string, v int) bool {
return true
},
)
for k, v := range ret {
dv, ok := data[k]
assert.True(t, ok)
assert.Equal(t, dv, v)
}
}
func TestMaximumSize(t *testing.T) {
max := 10
wg := sync.WaitGroup{}
insFunc := func(k int, v int) {
wg.Done()
}
c := NewCache(WithMaximumSize[int, int](int64(max)), WithInsertionListener(insFunc)).(*localCache[int, int])
defer c.Close()
wg.Add(max)
for i := 0; i < max; i++ {
c.Put(i, i)
}
wg.Wait()
n := cacheSize(&c.cache)
assert.Equal(t, n, max)
c.onInsertion = nil
for i := 0; i < 2*max; i++ {
k := rand.Intn(2 * max)
c.Put(k, k)
time.Sleep(time.Duration(i+1) * time.Millisecond)
n = cacheSize(&c.cache)
assert.Equal(t, n, max)
}
}
func TestRemovalListener(t *testing.T) {
removed := make(map[int]int)
wg := sync.WaitGroup{}
remFunc := func(k int, v int) {
removed[k] = v
wg.Done()
}
insFunc := func(k int, v int) {
wg.Done()
}
max := 3
c := NewCache(WithMaximumSize[int, int](int64(max)), WithRemovalListener(remFunc),
WithInsertionListener(insFunc))
defer c.Close()
wg.Add(max + 2)
for i := 1; i < max+2; i++ {
c.Put(i, i)
}
wg.Wait()
assert.Equal(t, 1, len(removed))
assert.Equal(t, 1, removed[1])
wg.Add(1)
c.Invalidate(3)
wg.Wait()
assert.Equal(t, 2, len(removed))
assert.Equal(t, 3, removed[3])
wg.Add(2)
c.InvalidateAll()
wg.Wait()
assert.Equal(t, 4, len(removed))
assert.Equal(t, 2, removed[2])
assert.Equal(t, 4, removed[4])
}
func TestClose(t *testing.T) {
removed := 0
wg := sync.WaitGroup{}
remFunc := func(k int, v int) {
removed++
wg.Done()
}
insFunc := func(k int, v int) {
wg.Done()
}
c := NewCache(WithRemovalListener(remFunc), WithInsertionListener(insFunc))
n := 10
wg.Add(n)
for i := 0; i < n; i++ {
c.Put(i, i)
}
wg.Wait()
wg.Add(n)
c.Close()
wg.Wait()
assert.Equal(t, n, removed)
}
func TestLoadingCache(t *testing.T) {
loadCount := 0
loader := func(k int) (int, error) {
loadCount++
if k%2 != 0 {
return 0, errors.New("odd")
}
return k, nil
}
wg := sync.WaitGroup{}
insFunc := func(int, int) {
wg.Done()
}
c := NewLoadingCache(loader, WithInsertionListener(insFunc))
defer c.Close()
wg.Add(1)
v, err := c.Get(2)
assert.NoError(t, err)
assert.Equal(t, 2, v)
assert.Equal(t, 1, loadCount)
wg.Wait()
v, err = c.Get(2)
assert.NoError(t, err)
assert.Equal(t, 2, v)
assert.Equal(t, 1, loadCount)
_, err = c.Get(1)
assert.Error(t, err)
// Should not insert
wg.Wait()
}
func TestCacheStats(t *testing.T) {
wg := sync.WaitGroup{}
loader := func(k string) (string, error) {
return k, nil
}
insFunc := func(string, string) {
wg.Done()
}
c := NewLoadingCache(loader, WithInsertionListener(insFunc))
defer c.Close()
wg.Add(1)
_, err := c.Get("x")
assert.NoError(t, err)
st := c.Stats()
assert.Equal(t, uint64(1), st.MissCount)
assert.Equal(t, uint64(1), st.LoadSuccessCount)
assert.True(t, st.TotalLoadTime > 0)
wg.Wait()
_, err = c.Get("x")
assert.NoError(t, err)
st = c.Stats()
assert.Equal(t, uint64(1), st.HitCount)
}
func TestExpireAfterAccess(t *testing.T) {
wg := sync.WaitGroup{}
fn := func(k uint, v uint) {
wg.Done()
}
mockTime := newMockTime()
currentTime = mockTime.now
c := NewCache(WithExpireAfterAccess[uint, uint](1*time.Second), WithRemovalListener(fn),
WithInsertionListener(fn)).(*localCache[uint, uint])
defer c.Close()
wg.Add(1)
c.Put(1, 1)
wg.Wait()
mockTime.add(1 * time.Second)
wg.Add(2)
c.Put(2, 2)
c.Put(3, 3)
wg.Wait()
n := cacheSize(&c.cache)
if n != 3 {
wg.Add(n)
assert.Fail(t, fmt.Sprintf("unexpected cache size: %d, want: %d", n, 3))
}
mockTime.add(1 * time.Nanosecond)
wg.Add(2)
c.Put(4, 4)
wg.Wait()
n = cacheSize(&c.cache)
wg.Add(n)
assert.Equal(t, 3, n)
_, ok := c.GetIfPresent(1)
assert.False(t, ok)
}
func TestExpireAfterWrite(t *testing.T) {
loadCount := 0
loader := func(k string) (int, error) {
loadCount++
return loadCount, nil
}
mockTime := newMockTime()
currentTime = mockTime.now
c := NewLoadingCache(loader, WithExpireAfterWrite[string, int](1*time.Second))
defer c.Close()
// New value
v, err := c.Get("refresh")
assert.NoError(t, err)
assert.Equal(t, 1, v)
assert.Equal(t, 1, loadCount)
time.Sleep(200 * time.Millisecond)
// Within 1s, the value should not yet expired.
mockTime.add(1 * time.Second)
v, err = c.Get("refresh")
assert.NoError(t, err)
assert.Equal(t, 1, v)
assert.Equal(t, 1, loadCount)
// After 1s, the value should be expired and refresh triggered.
mockTime.add(1 * time.Nanosecond)
v, err = c.Get("refresh")
assert.NoError(t, err)
assert.Equal(t, 2, v)
assert.Equal(t, 2, loadCount)
// value has already been loaded.
v, err = c.Get("refresh")
assert.NoError(t, err)
assert.Equal(t, 2, v)
assert.Equal(t, 2, loadCount)
}
func TestRefreshAterWrite(t *testing.T) {
var mutex sync.Mutex
loaded := make(map[int]int)
loader := func(k int) (int, error) {
mutex.Lock()
n := loaded[k]
n++
loaded[k] = n
mutex.Unlock()
return n, nil
}
wg := sync.WaitGroup{}
insFunc := func(int, int) {
wg.Done()
}
mockTime := newMockTime()
currentTime = mockTime.now
c := NewLoadingCache(loader,
WithExpireAfterAccess[int, int](4*time.Second),
WithRefreshAfterWrite[int, int](2*time.Second),
WithInsertionListener(insFunc))
defer c.Close()
wg.Add(3)
v, err := c.Get(1)
assert.NoError(t, err)
assert.Equal(t, 1, v)
// 3s
mockTime.add(3 * time.Second)
v, err = c.Get(2)
assert.NoError(t, err)
assert.Equal(t, 1, v)
wg.Wait()
assert.Equal(t, 2, loaded[1])
assert.Equal(t, 1, loaded[2])
v, err = c.Get(1)
assert.NoError(t, err)
assert.Equal(t, 2, v)
// 8s
mockTime.add(5 * time.Second)
wg.Add(1)
v, err = c.Get(1)
assert.NoError(t, err)
assert.Equal(t, 3, v)
}
func TestGetIfPresentExpired(t *testing.T) {
wg := sync.WaitGroup{}
insFunc := func(int, string) {
wg.Done()
}
c := NewCache(WithExpireAfterWrite[int, string](1*time.Second), WithInsertionListener(insFunc))
mockTime := newMockTime()
currentTime = mockTime.now
v, ok := c.GetIfPresent(0)
assert.False(t, ok)
wg.Add(1)
c.Put(0, "0")
v, ok = c.GetIfPresent(0)
assert.True(t, ok)
assert.Equal(t, "0", v)
wg.Wait()
mockTime.add(2 * time.Second)
_, ok = c.GetIfPresent(0)
assert.False(t, ok)
}
func TestWithAsyncInitPreLoader(t *testing.T) {
wg := sync.WaitGroup{}
data := map[string]string{
"1": "1",
"2": "1",
"3": "1",
}
wg.Add(1)
cnt := len(data)
i := 0
insFunc := func(k string, v string) {
r, ok := data[k]
assert.True(t, ok)
assert.Equal(t, v, r)
i++
if i == cnt {
wg.Done()
}
}
loader := func(k string) (string, error) {
assert.Fail(t, "should not reach here!")
return "", nil
}
preLoaderFunc := func() (map[string]string, error) {
return data, nil
}
c := NewLoadingCache(loader, WithMaximumSize[string, string](3),
WithInsertionListener(insFunc), WithAsyncInitPreLoader(preLoaderFunc))
defer c.Close()
wg.Wait()
_, ok := c.GetIfPresent("1")
assert.True(t, ok)
_, ok = c.GetIfPresent("2")
assert.True(t, ok)
_, ok = c.GetIfPresent("3")
assert.True(t, ok)
}
func TestSynchronousReload(t *testing.T) {
var val string
loader := func(k int) (string, error) {
time.Sleep(1 * time.Millisecond)
if val == "" {
return "", errors.New("empty")
}
return val, nil
}
c := NewLoadingCache(loader, WithExpireAfterWrite[int, string](1*time.Second))
val = "a"
v, err := c.Get(1)
assert.NoError(t, err)
assert.Equal(t, val, v)
val = "b"
v, err = c.Get(1)
assert.NoError(t, err)
assert.Equal(t, val, v)
val = ""
_, err = c.Get(2)
assert.Error(t, err)
}
func TestCloseMultiple(t *testing.T) {
c := NewCache[int, int]()
start := make(chan bool)
const n = 10
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
<-start
c.Close()
}()
}
close(start)
wg.Wait()
// Should not panic
assert.NotPanics(t, func() {
c.GetIfPresent(0)
})
assert.NotPanics(t, func() {
c.Put(1, 1)
})
assert.NotPanics(t, func() {
c.Invalidate(0)
})
assert.NotPanics(t, func() {
c.InvalidateAll()
})
assert.NotPanics(t, func() {
c.Close()
})
}
func BenchmarkGetSame(b *testing.B) {
c := NewCache[string, string]()
c.Put("*", "*")
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c.GetIfPresent("*")
}
})
}
// mockTime is used for tests which required current system time.
type mockTime struct {
mu sync.RWMutex
value time.Time
}
func newMockTime() *mockTime {
return &mockTime{
value: time.Now(),
}
}
func (t *mockTime) add(d time.Duration) {
t.mu.Lock()
defer t.mu.Unlock()
t.value = t.value.Add(d)
}
func (t *mockTime) now() time.Time {
t.mu.RLock()
defer t.mu.RUnlock()
return t.value
}

View File

@ -1,311 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"container/list"
"errors"
"fmt"
"sync"
)
// LRU generic utility for lru cache.
type LRU[K comparable, V any] struct {
evictList *list.List
items map[K]*list.Element
capacity int
onEvicted func(k K, v V)
m sync.RWMutex
evictedCh chan *entry[K, V]
closeCh chan struct{}
closeOnce sync.Once
stats *Stats
}
// Stats is the model for cache statistics.
type Stats struct {
hitCount float32
evictedCount float32
readCount float32
writeCount float32
}
// String implement stringer for printing.
func (s *Stats) String() string {
var hitRatio float32
var evictedRatio float32
if s.readCount != 0 {
hitRatio = s.hitCount / s.readCount
evictedRatio = s.evictedCount / s.writeCount
}
return fmt.Sprintf("lru cache hit ratio = %f, evictedRatio = %f", hitRatio, evictedRatio)
}
type entry[K comparable, V any] struct {
key K
value V
}
// NewLRU creates a LRU cache with provided capacity and `onEvicted` function.
// `onEvicted` will be executed when an item is chosed to be evicted.
func NewLRU[K comparable, V any](capacity int, onEvicted func(k K, v V)) (*LRU[K, V], error) {
if capacity <= 0 {
return nil, errors.New("cache size must be positive")
}
c := &LRU[K, V]{
capacity: capacity,
evictList: list.New(),
items: make(map[K]*list.Element),
onEvicted: onEvicted,
evictedCh: make(chan *entry[K, V], 16),
closeCh: make(chan struct{}),
stats: &Stats{},
}
go c.evictedWorker()
return c, nil
}
// evictedWorker executes onEvicted function for each evicted items.
func (c *LRU[K, V]) evictedWorker() {
for {
select {
case <-c.closeCh:
return
case e, ok := <-c.evictedCh:
if ok {
if c.onEvicted != nil {
c.onEvicted(e.key, e.value)
}
}
}
}
}
// closed returns whether cache is closed.
func (c *LRU[K, V]) closed() bool {
select {
case <-c.closeCh:
return true
default:
return false
}
}
// Add puts an item into cache.
func (c *LRU[K, V]) Add(key K, value V) {
c.m.Lock()
defer c.m.Unlock()
if c.closed() {
// evict since cache closed
c.onEvicted(key, value)
return
}
c.stats.writeCount++
if e, ok := c.items[key]; ok {
c.evictList.MoveToFront(e)
e.Value.(*entry[K, V]).value = value
return
}
e := &entry[K, V]{key: key, value: value}
listE := c.evictList.PushFront(e)
c.items[key] = listE
if c.evictList.Len() > c.capacity {
c.stats.evictedCount++
oldestE := c.evictList.Back()
if oldestE != nil {
c.evictList.Remove(oldestE)
kv := oldestE.Value.(*entry[K, V])
delete(c.items, kv.key)
if c.onEvicted != nil {
c.evictedCh <- kv
}
}
}
}
// Get returns value for provided key.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.m.RLock()
defer c.m.RUnlock()
var zeroV V
if c.closed() {
// cache closed, returns nothing
return zeroV, false
}
c.stats.readCount++
if e, ok := c.items[key]; ok {
c.stats.hitCount++
c.evictList.MoveToFront(e)
kv := e.Value.(*entry[K, V])
return kv.value, true
}
return zeroV, false
}
// Remove removes item associated with provided key.
func (c *LRU[K, V]) Remove(key K) {
c.m.Lock()
defer c.m.Unlock()
if c.closed() {
return
}
if e, ok := c.items[key]; ok {
c.evictList.Remove(e)
kv := e.Value.(*entry[K, V])
delete(c.items, kv.key)
if c.onEvicted != nil {
c.evictedCh <- kv
}
}
}
// Contains returns whether items with provided key exists in cache.
func (c *LRU[K, V]) Contains(key K) bool {
c.m.RLock()
defer c.m.RUnlock()
if c.closed() {
return false
}
_, ok := c.items[key]
return ok
}
// Keys returns all the keys exist in cache.
func (c *LRU[K, V]) Keys() []K {
c.m.RLock()
defer c.m.RUnlock()
if c.closed() {
return nil
}
keys := make([]K, len(c.items))
i := 0
for ent := c.evictList.Back(); ent != nil; ent = ent.Prev() {
keys[i] = ent.Value.(*entry[K, V]).key
i++
}
return keys
}
// Len returns items count in cache.
func (c *LRU[K, V]) Len() int {
c.m.RLock()
defer c.m.RUnlock()
if c.closed() {
return 0
}
return c.evictList.Len()
}
// Capacity returns cache capacity.
func (c *LRU[K, V]) Capacity() int {
return c.capacity
}
// Purge removes all items and put them into evictedCh.
func (c *LRU[K, V]) Purge() {
c.m.Lock()
defer c.m.Unlock()
for k, v := range c.items {
if c.onEvicted != nil {
c.evictedCh <- v.Value.(*entry[K, V])
}
delete(c.items, k)
}
c.evictList.Init()
}
// Resize changes the capacity of cache.
func (c *LRU[K, V]) Resize(capacity int) int {
c.m.Lock()
defer c.m.Unlock()
if c.closed() {
return 0
}
c.capacity = capacity
if capacity >= c.evictList.Len() {
return 0
}
diff := c.evictList.Len() - c.capacity
for i := 0; i < diff; i++ {
oldestE := c.evictList.Back()
if oldestE != nil {
c.evictList.Remove(oldestE)
kv := oldestE.Value.(*entry[K, V])
delete(c.items, kv.key)
if c.onEvicted != nil {
c.evictedCh <- kv
}
}
}
return diff
}
// GetOldest returns the oldest item in cache.
func (c *LRU[K, V]) GetOldest() (K, V, bool) {
c.m.RLock()
defer c.m.RUnlock()
var (
zeroK K
zeroV V
)
if c.closed() {
return zeroK, zeroV, false
}
ent := c.evictList.Back()
if ent != nil {
kv := ent.Value.(*entry[K, V])
return kv.key, kv.value, true
}
return zeroK, zeroV, false
}
// Close cleans up the cache resources.
func (c *LRU[K, V]) Close() {
c.closeOnce.Do(func() {
// fetch lock to
// - wait on-going operations done
// - block incoming operations
c.m.Lock()
close(c.closeCh)
c.m.Unlock()
// execute purge in a goroutine, otherwise Purge may block forever putting evictedCh
go func() {
c.Purge()
close(c.evictedCh)
}()
for e := range c.evictedCh {
c.onEvicted(e.key, e.value)
}
})
}
// Stats returns cache statistics.
func (c *LRU[K, V]) Stats() *Stats {
return c.stats
}

View File

@ -1,440 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewLRU(t *testing.T) {
c, err := NewLRU[int, int](1, nil)
assert.Nil(t, err)
assert.NotNil(t, c)
c, err = NewLRU[int, int](0, nil)
assert.NotNil(t, err)
assert.Nil(t, c)
c, err = NewLRU[int, int](-1, nil)
assert.NotNil(t, err)
assert.Nil(t, c)
}
func TestLRU_Add(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testValueExtra := "test_value_extra"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
testKey3 := "test_key_3"
testValue3 := "test_value_3"
c.Add(testKey1, testValue1)
c.Add(testKey2, testValue2)
v, ok := c.Get(testKey1)
assert.True(t, ok)
assert.EqualValues(t, testValue1, v)
v, ok = c.Get(testKey2)
assert.True(t, ok)
assert.EqualValues(t, testValue2, v)
c.Add(testKey1, testValueExtra)
k, v, ok := c.GetOldest()
assert.True(t, ok)
assert.EqualValues(t, testKey2, k)
assert.EqualValues(t, testValue2, v)
c.Add(testKey3, testValue3)
v, ok = c.Get(testKey3)
assert.True(t, ok)
assert.EqualValues(t, testValue3, v)
v, ok = c.Get(testKey2)
assert.False(t, ok)
assert.Empty(t, v)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 1
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_Contains(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(1, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
c.Add(testKey1, testValue1)
ok := c.Contains(testKey1)
assert.True(t, ok)
c.Add(testKey2, testValue2)
ok = c.Contains(testKey2)
assert.True(t, ok)
ok = c.Contains(testKey1)
assert.False(t, ok)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 1
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_Get(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(1, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
c.Add(testKey1, testValue1)
v, ok := c.Get(testKey1)
assert.True(t, ok)
assert.EqualValues(t, testValue1, v)
c.Add(testKey2, testValue2)
v, ok = c.Get(testKey2)
assert.True(t, ok)
assert.EqualValues(t, testValue2, v)
v, ok = c.Get(testKey1)
assert.False(t, ok)
assert.Empty(t, v)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 1
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_GetOldest(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
testKey3 := "test_key_3"
testValue3 := "test_value_3"
k, v, ok := c.GetOldest()
assert.False(t, ok)
assert.Empty(t, k)
assert.Empty(t, v)
c.Add(testKey1, testValue1)
k, v, ok = c.GetOldest()
assert.True(t, ok)
assert.EqualValues(t, testKey1, k)
assert.EqualValues(t, testValue1, v)
c.Add(testKey2, testValue2)
k, v, ok = c.GetOldest()
assert.True(t, ok)
assert.EqualValues(t, testKey1, k)
assert.EqualValues(t, testValue1, v)
v, ok = c.Get(testKey1)
assert.True(t, ok)
assert.EqualValues(t, testValue1, v)
k, v, ok = c.GetOldest()
assert.True(t, ok)
assert.EqualValues(t, testKey2, k)
assert.EqualValues(t, testValue2, v)
c.Add(testKey3, testValue3)
k, v, ok = c.GetOldest()
assert.True(t, ok)
assert.EqualValues(t, testKey1, k)
assert.EqualValues(t, testValue1, v)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 1
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_Keys(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
testKey3 := "test_key_3"
testValue3 := "test_value_3"
c.Add(testKey1, testValue1)
c.Add(testKey2, testValue2)
keys := c.Keys()
assert.ElementsMatch(t, []string{testKey1, testKey2}, keys)
v, ok := c.Get(testKey1)
assert.True(t, ok)
assert.EqualValues(t, testValue1, v)
keys = c.Keys()
assert.ElementsMatch(t, []string{testKey2, testKey1}, keys)
c.Add(testKey3, testValue3)
keys = c.Keys()
assert.ElementsMatch(t, []string{testKey3, testKey1}, keys)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 1
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_Len(t *testing.T) {
c, err := NewLRU[string, string](2, nil)
assert.Nil(t, err)
assert.EqualValues(t, c.Len(), 0)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
testKey3 := "test_key_3"
testValue3 := "test_value_3"
c.Add(testKey1, testValue1)
c.Add(testKey2, testValue2)
assert.EqualValues(t, c.Len(), 2)
c.Add(testKey3, testValue3)
assert.EqualValues(t, c.Len(), 2)
}
func TestLRU_Capacity(t *testing.T) {
c, err := NewLRU[string, string](5, nil)
assert.Nil(t, err)
assert.EqualValues(t, c.Len(), 0)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
testKey3 := "test_key_3"
testValue3 := "test_value_3"
c.Add(testKey1, testValue1)
assert.EqualValues(t, c.Capacity(), 5)
c.Add(testKey2, testValue2)
assert.EqualValues(t, c.Capacity(), 5)
c.Add(testKey3, testValue3)
assert.EqualValues(t, c.Capacity(), 5)
}
func TestLRU_Purge(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
assert.EqualValues(t, c.Len(), 0)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
testKey3 := "test_key_3"
testValue3 := "test_value_3"
c.Add(testKey1, testValue1)
c.Add(testKey2, testValue2)
assert.EqualValues(t, c.Len(), 2)
c.Add(testKey3, testValue3)
assert.EqualValues(t, c.Len(), 2)
c.Purge()
assert.EqualValues(t, c.Len(), 0)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 3
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_Remove(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
assert.EqualValues(t, c.Len(), 0)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
c.Add(testKey1, testValue1)
c.Add(testKey2, testValue2)
assert.EqualValues(t, c.Len(), 2)
c.Remove(testKey1)
c.Remove(testKey2)
assert.EqualValues(t, c.Len(), 0)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 2
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_RemoveOldest(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
assert.EqualValues(t, c.Len(), 0)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
c.Add(testKey1, testValue1)
c.Add(testKey2, testValue2)
assert.EqualValues(t, c.Len(), 2)
v, ok := c.Get(testKey1)
assert.True(t, ok)
assert.EqualValues(t, v, testValue1)
v, ok = c.Get(testKey2)
assert.True(t, ok)
assert.EqualValues(t, v, testValue2)
c.Remove(testKey1)
c.Remove(testKey2)
v, ok = c.Get(testKey1)
assert.False(t, ok)
assert.Empty(t, v)
v, ok = c.Get(testKey2)
assert.False(t, ok)
assert.Empty(t, v)
assert.EqualValues(t, c.Len(), 0)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 2
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_Resize(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
assert.Nil(t, err)
assert.EqualValues(t, c.Len(), 0)
testKey1 := "test_key_1"
testValue1 := "test_value_1"
testKey2 := "test_key_2"
testValue2 := "test_value_2"
c.Add(testKey1, testValue1)
c.Add(testKey2, testValue2)
assert.EqualValues(t, c.Len(), 2)
v, ok := c.Get(testKey1)
assert.True(t, ok)
assert.EqualValues(t, v, testValue1)
v, ok = c.Get(testKey2)
assert.True(t, ok)
assert.EqualValues(t, v, testValue2)
c.Resize(1)
v, ok = c.Get(testKey1)
assert.False(t, ok)
assert.Empty(t, v)
v, ok = c.Get(testKey2)
assert.True(t, ok)
assert.EqualValues(t, v, testValue2)
assert.EqualValues(t, c.Len(), 1)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 1
}, 1*time.Second, 100*time.Millisecond)
c.Resize(3)
assert.EqualValues(t, c.Len(), 1)
assert.Eventually(t, func() bool {
return atomic.LoadInt32(&evicted) == 1
}, 1*time.Second, 100*time.Millisecond)
}
func TestLRU_closed(t *testing.T) {
evicted := int32(0)
c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) })
require.NoError(t, err)
c.Close()
c.Add("testKey", "testValue")
assert.Equal(t, int32(1), evicted)
_, ok := c.Get("testKey")
assert.False(t, ok)
assert.NotPanics(t, func() {
c.Remove("testKey")
})
contains := c.Contains("testKey")
assert.False(t, contains)
keys := c.Keys()
assert.Nil(t, keys)
l := c.Len()
assert.Equal(t, 0, l)
diff := c.Resize(1)
assert.Equal(t, 0, diff)
assert.Equal(t, 2, c.Capacity())
_, _, ok = c.GetOldest()
assert.False(t, ok)
assert.NotPanics(t, func() {
c.Close()
}, "double close")
}

98
internal/util/cache/lru_impl.go vendored Normal file
View File

@ -0,0 +1,98 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"container/list"
)
// lruCache is a LRU cache.
type lruCache struct {
cache *cache
cap int64
ls list.List
}
// init initializes cache list.
func (l *lruCache) init(c *cache, cap int64) {
l.cache = c
l.cap = cap
l.ls.Init()
}
// write adds new entry to the cache and returns evicted entry if necessary.
func (l *lruCache) write(en *entry) *entry {
// Fast path
if en.accessList != nil {
// Entry existed, update its status instead.
l.markAccess(en)
return nil
}
// Try to add new entry to the list
cen := l.cache.getOrSet(en)
if cen == nil {
// Brand new entry, add to the LRU list.
en.accessList = l.ls.PushFront(en)
} else {
// Entry has already been added, update its value instead.
cen.setValue(en.getValue())
cen.setWriteTime(en.getWriteTime())
if cen.accessList == nil {
// Entry is loaded to the cache but not yet registered.
cen.accessList = l.ls.PushFront(cen)
} else {
l.markAccess(cen)
}
}
if l.cap > 0 && int64(l.ls.Len()) > l.cap {
// Remove the last element when capacity exceeded.
en = getEntry(l.ls.Back())
return l.remove(en)
}
return nil
}
// access updates cache entry for a get.
func (l *lruCache) access(en *entry) {
if en.accessList != nil {
l.markAccess(en)
}
}
// markAccess marks the element has just been accessed.
// en.accessList must not be null.
func (l *lruCache) markAccess(en *entry) {
l.ls.MoveToFront(en.accessList)
}
// remove an entry from the cache.
func (l *lruCache) remove(en *entry) *entry {
if en.accessList == nil {
// Already deleted
return nil
}
l.cache.delete(en)
l.ls.Remove(en.accessList)
en.accessList = nil
return en
}
// iterate walks through all lists by access time.
func (l *lruCache) iterate(fn func(en *entry) bool) {
iterateListFromBack(&l.ls, fn)
}

159
internal/util/cache/lru_impl_test.go vendored Normal file
View File

@ -0,0 +1,159 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
type lruTest struct {
c cache
lru lruCache
t *testing.T
}
func (t *lruTest) assertLRULen(n int) {
sz := cacheSize(&t.c)
lz := t.lru.ls.Len()
assert.Equal(t.t, n, sz)
assert.Equal(t.t, n, lz)
}
func (t *lruTest) assertEntry(en *entry, k int, v string, id uint8) {
if en == nil {
t.t.Helper()
t.t.Fatalf("unexpected entry: %v", en)
}
ak := en.key.(int)
av := en.getValue().(string)
assert.Equal(t.t, k, ak)
assert.Equal(t.t, v, av)
assert.Equal(t.t, id, en.listID)
}
func (t *lruTest) assertLRUEntry(k int) {
en := t.c.get(k, 0)
assert.NotNil(t.t, en)
ak := en.key.(int)
av := en.getValue().(string)
v := fmt.Sprintf("%d", k)
assert.Equal(t.t, k, ak)
assert.Equal(t.t, v, av)
assert.Equal(t.t, uint8(0), en.listID)
}
func (t *lruTest) assertSLRUEntry(k int, id uint8) {
en := t.c.get(k, 0)
assert.NotNil(t.t, en)
ak := en.key.(int)
av := en.getValue().(string)
v := fmt.Sprintf("%d", k)
assert.Equal(t.t, k, ak)
assert.Equal(t.t, v, av)
assert.Equal(t.t, id, en.listID)
}
func TestLRU(t *testing.T) {
s := lruTest{t: t}
s.lru.init(&s.c, 3)
en := createLRUEntries(4)
remEn := s.lru.write(en[0])
assert.Nil(t, remEn)
// 0
s.assertLRULen(1)
s.assertLRUEntry(0)
remEn = s.lru.write(en[1])
// 1 0
assert.Nil(t, remEn)
s.assertLRULen(2)
s.assertLRUEntry(1)
s.assertLRUEntry(0)
s.lru.access(en[0])
// 0 1
remEn = s.lru.write(en[2])
// 2 0 1
assert.Nil(t, remEn)
s.assertLRULen(3)
remEn = s.lru.write(en[3])
// 3 2 0
s.assertEntry(remEn, 1, "1", 0)
s.assertLRULen(3)
s.assertLRUEntry(3)
s.assertLRUEntry(2)
s.assertLRUEntry(0)
remEn = s.lru.remove(en[2])
// 3 0
s.assertEntry(remEn, 2, "2", 0)
s.assertLRULen(2)
s.assertLRUEntry(3)
s.assertLRUEntry(0)
}
func TestLRUWalk(t *testing.T) {
s := lruTest{t: t}
s.lru.init(&s.c, 5)
entries := createLRUEntries(6)
for _, e := range entries {
s.lru.write(e)
}
// 5 4 3 2 1
found := ""
s.lru.iterate(func(en *entry) bool {
found += en.getValue().(string) + " "
return true
})
assert.Equal(t, "1 2 3 4 5 ", found)
s.lru.access(entries[1])
s.lru.access(entries[5])
s.lru.access(entries[3])
// 3 5 1 4 2
found = ""
s.lru.iterate(func(en *entry) bool {
found += en.getValue().(string) + " "
if en.key.(int)%2 == 0 {
s.lru.remove(en)
}
return en.key.(int) != 5
})
assert.Equal(t, "2 4 1 5 ", found)
s.assertLRULen(3)
s.assertLRUEntry(3)
s.assertLRUEntry(5)
s.assertLRUEntry(1)
}
func createLRUEntries(n int) []*entry {
en := make([]*entry, n)
for i := range en {
en[i] = newEntry(i, fmt.Sprintf("%d", i), 0 /* unused */)
}
return en
}

267
internal/util/cache/policy.go vendored Normal file
View File

@ -0,0 +1,267 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"container/list"
"sync"
"sync/atomic"
)
const (
// Number of cache data store will be 2 ^ concurrencyLevel.
concurrencyLevel = 2
segmentCount = 1 << concurrencyLevel
segmentMask = segmentCount - 1
)
// entry stores cached entry key and value.
type entry struct {
// Structs with first field align to 64 bits will also be aligned to 64.
// https://golang.org/pkg/sync/atomic/#pkg-note-BUG
// hash is the hash value of this entry key
hash uint64
// accessTime is the last time this entry was accessed.
accessTime int64 // Access atomically - must be aligned on 32-bit
// writeTime is the last time this entry was updated.
writeTime int64 // Access atomically - must be aligned on 32-bit
// FIXME: More efficient way to store boolean flags
invalidated int32
loading int32
key interface{}
value atomic.Value // Store value
// These properties are managed by only cache policy so do not need atomic access.
// accessList is the list (ordered by access time) this entry is currently in.
accessList *list.Element
// writeList is the list (ordered by write time) this entry is currently in.
writeList *list.Element
// listID is ID of the list which this entry is currently in.
listID uint8
}
func newEntry(k interface{}, v interface{}, h uint64) *entry {
en := &entry{
key: k,
hash: h,
}
en.setValue(v)
return en
}
func (e *entry) getValue() interface{} {
return e.value.Load()
}
func (e *entry) setValue(v interface{}) {
e.value.Store(v)
}
func (e *entry) getAccessTime() int64 {
return atomic.LoadInt64(&e.accessTime)
}
func (e *entry) setAccessTime(v int64) {
atomic.StoreInt64(&e.accessTime, v)
}
func (e *entry) getWriteTime() int64 {
return atomic.LoadInt64(&e.writeTime)
}
func (e *entry) setWriteTime(v int64) {
atomic.StoreInt64(&e.writeTime, v)
}
func (e *entry) getLoading() bool {
return atomic.LoadInt32(&e.loading) != 0
}
func (e *entry) setLoading(v bool) bool {
if v {
return atomic.CompareAndSwapInt32(&e.loading, 0, 1)
}
return atomic.CompareAndSwapInt32(&e.loading, 1, 0)
}
func (e *entry) getInvalidated() bool {
return atomic.LoadInt32(&e.invalidated) != 0
}
func (e *entry) setInvalidated(v bool) {
if v {
atomic.StoreInt32(&e.invalidated, 1)
} else {
atomic.StoreInt32(&e.invalidated, 0)
}
}
// getEntry returns the entry attached to the given list element.
func getEntry(el *list.Element) *entry {
return el.Value.(*entry)
}
// event is the cache event (add, hit or delete).
type event uint8
const (
eventWrite event = iota
eventAccess
eventDelete
eventClose
)
type entryEvent struct {
entry *entry
event event
}
// cache is a data structure for cache entries.
type cache struct {
size int64 // Access atomically - must be aligned on 32-bit
segs [segmentCount]sync.Map // map[Key]*entry
}
func (c *cache) get(k interface{}, h uint64) *entry {
seg := c.segment(h)
v, ok := seg.Load(k)
if ok {
return v.(*entry)
}
return nil
}
func (c *cache) getOrSet(v *entry) *entry {
seg := c.segment(v.hash)
en, ok := seg.LoadOrStore(v.key, v)
if ok {
return en.(*entry)
}
atomic.AddInt64(&c.size, 1)
return nil
}
func (c *cache) delete(v *entry) {
seg := c.segment(v.hash)
seg.Delete(v.key)
atomic.AddInt64(&c.size, -1)
}
func (c *cache) len() int {
return int(atomic.LoadInt64(&c.size))
}
func (c *cache) walk(fn func(*entry)) {
for i := range c.segs {
c.segs[i].Range(func(k, v interface{}) bool {
fn(v.(*entry))
return true
})
}
}
func (c *cache) segment(h uint64) *sync.Map {
return &c.segs[h&segmentMask]
}
// policy is a cache policy.
type policy interface {
// init initializes the policy.
init(cache *cache, maximumSize int64)
// write handles Write event for the entry.
// It adds new entry and returns evicted entry if needed.
write(entry *entry) *entry
// access handles Access event for the entry.
// It marks then entry recently accessed.
access(entry *entry)
// remove the entry.
remove(entry *entry) *entry
// iterate all entries by their access time.
iterate(func(entry *entry) bool)
}
func newPolicy() policy {
return &lruCache{}
}
// recencyQueue manages cache entries by write time.
type recencyQueue struct {
ls list.List
}
func (w *recencyQueue) init(cache *cache, maximumSize int64) {
w.ls.Init()
}
func (w *recencyQueue) write(en *entry) *entry {
if en.writeList == nil {
en.writeList = w.ls.PushFront(en)
} else {
w.ls.MoveToFront(en.writeList)
}
return nil
}
func (w *recencyQueue) access(en *entry) {
}
func (w *recencyQueue) remove(en *entry) *entry {
if en.writeList == nil {
return en
}
w.ls.Remove(en.writeList)
en.writeList = nil
return en
}
func (w *recencyQueue) iterate(fn func(en *entry) bool) {
iterateListFromBack(&w.ls, fn)
}
type discardingQueue struct{}
func (discardingQueue) init(cache *cache, maximumSize int64) {
}
func (discardingQueue) write(en *entry) *entry {
return nil
}
func (discardingQueue) access(en *entry) {
}
func (discardingQueue) remove(en *entry) *entry {
return en
}
func (discardingQueue) iterate(fn func(en *entry) bool) {
}
func iterateListFromBack(ls *list.List, fn func(en *entry) bool) {
for el := ls.Back(); el != nil; {
en := getEntry(el)
prev := el.Prev() // Get Prev as fn can delete the entry.
if !fn(en) {
return
}
el = prev
}
}

51
internal/util/cache/policy_test.go vendored Normal file
View File

@ -0,0 +1,51 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"sync/atomic"
"testing"
)
func cacheSize(c *cache) int {
length := 0
c.walk(func(*entry) {
length++
})
return length
}
func BenchmarkCacheSegment(b *testing.B) {
c := cache{}
const count = 1 << 10
entries := make([]*entry, count)
for i := range entries {
entries[i] = newEntry(i, i, uint64(i))
}
var n int32
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
i := atomic.AddInt32(&n, 1)
c.getOrSet(entries[i&(count-1)])
if i > 0 && i&0xf == 0 {
c.delete(entries[(i-1)&(count-1)])
}
}
})
}

143
internal/util/cache/stats.go vendored Normal file
View File

@ -0,0 +1,143 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"fmt"
"sync/atomic"
"time"
)
// Stats is statistics about performance of a cache.
type Stats struct {
HitCount uint64
MissCount uint64
LoadSuccessCount uint64
LoadErrorCount uint64
TotalLoadTime time.Duration
EvictionCount uint64
}
// RequestCount returns a total of HitCount and MissCount.
func (s *Stats) RequestCount() uint64 {
return s.HitCount + s.MissCount
}
// HitRate returns the ratio of cache requests which were hits.
func (s *Stats) HitRate() float64 {
total := s.RequestCount()
if total == 0 {
return 1.0
}
return float64(s.HitCount) / float64(total)
}
// MissRate returns the ratio of cache requests which were misses.
func (s *Stats) MissRate() float64 {
total := s.RequestCount()
if total == 0 {
return 0.0
}
return float64(s.MissCount) / float64(total)
}
// LoadErrorRate returns the ratio of cache loading attempts which returned errors.
func (s *Stats) LoadErrorRate() float64 {
total := s.LoadSuccessCount + s.LoadErrorCount
if total == 0 {
return 0.0
}
return float64(s.LoadErrorCount) / float64(total)
}
// AverageLoadPenalty returns the average time spent loading new values.
func (s *Stats) AverageLoadPenalty() time.Duration {
total := s.LoadSuccessCount + s.LoadErrorCount
if total == 0 {
return 0.0
}
return s.TotalLoadTime / time.Duration(total)
}
// String returns a string representation of this statistics.
func (s *Stats) String() string {
return fmt.Sprintf("hits: %d, misses: %d, successes: %d, errors: %d, time: %s, evictions: %d",
s.HitCount, s.MissCount, s.LoadSuccessCount, s.LoadErrorCount, s.TotalLoadTime, s.EvictionCount)
}
// StatsCounter accumulates statistics of a cache.
type StatsCounter interface {
// RecordHits records cache hits.
RecordHits(count uint64)
// RecordMisses records cache misses.
RecordMisses(count uint64)
// RecordLoadSuccess records successful load of a new entry.
RecordLoadSuccess(loadTime time.Duration)
// RecordLoadError records failed load of a new entry.
RecordLoadError(loadTime time.Duration)
// RecordEviction records eviction of an entry from the cache.
RecordEviction()
// Snapshot writes snapshot of this counter values to the given Stats pointer.
Snapshot(*Stats)
}
// statsCounter is a simple implementation of StatsCounter.
type statsCounter struct {
Stats
}
// RecordHits increases HitCount atomically.
func (s *statsCounter) RecordHits(count uint64) {
atomic.AddUint64(&s.Stats.HitCount, count)
}
// RecordMisses increases MissCount atomically.
func (s *statsCounter) RecordMisses(count uint64) {
atomic.AddUint64(&s.Stats.MissCount, count)
}
// RecordLoadSuccess increases LoadSuccessCount atomically.
func (s *statsCounter) RecordLoadSuccess(loadTime time.Duration) {
atomic.AddUint64(&s.Stats.LoadSuccessCount, 1)
atomic.AddInt64((*int64)(&s.Stats.TotalLoadTime), int64(loadTime))
}
// RecordLoadError increases LoadErrorCount atomically.
func (s *statsCounter) RecordLoadError(loadTime time.Duration) {
atomic.AddUint64(&s.Stats.LoadErrorCount, 1)
atomic.AddInt64((*int64)(&s.Stats.TotalLoadTime), int64(loadTime))
}
// RecordEviction increases EvictionCount atomically.
func (s *statsCounter) RecordEviction() {
atomic.AddUint64(&s.Stats.EvictionCount, 1)
}
// Snapshot copies current stats to t.
func (s *statsCounter) Snapshot(t *Stats) {
t.HitCount = atomic.LoadUint64(&s.HitCount)
t.MissCount = atomic.LoadUint64(&s.MissCount)
t.LoadSuccessCount = atomic.LoadUint64(&s.LoadSuccessCount)
t.LoadErrorCount = atomic.LoadUint64(&s.LoadErrorCount)
t.TotalLoadTime = time.Duration(atomic.LoadInt64((*int64)(&s.TotalLoadTime)))
t.EvictionCount = atomic.LoadUint64(&s.EvictionCount)
}

69
internal/util/cache/stats_test.go vendored Normal file
View File

@ -0,0 +1,69 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"testing"
"time"
)
func TestStatsCounter(t *testing.T) {
c := statsCounter{}
c.RecordHits(3)
c.RecordMisses(2)
c.RecordLoadSuccess(2 * time.Second)
c.RecordLoadError(1 * time.Second)
c.RecordEviction()
var st Stats
c.Snapshot(&st)
if st.HitCount != 3 {
t.Fatalf("unexpected hit count: %v", st)
}
if st.MissCount != 2 {
t.Fatalf("unexpected miss count: %v", st)
}
if st.LoadSuccessCount != 1 {
t.Fatalf("unexpected success count: %v", st)
}
if st.LoadErrorCount != 1 {
t.Fatalf("unexpected error count: %v", st)
}
if st.TotalLoadTime != 3*time.Second {
t.Fatalf("unexpected load time: %v", st)
}
if st.EvictionCount != 1 {
t.Fatalf("unexpected eviction count: %v", st)
}
if st.RequestCount() != 5 {
t.Fatalf("unexpected request count: %v", st.RequestCount())
}
if st.HitRate() != 0.6 {
t.Fatalf("unexpected hit rate: %v", st.HitRate())
}
if st.MissRate() != 0.4 {
t.Fatalf("unexpected miss rate: %v", st.MissRate())
}
if st.LoadErrorRate() != 0.5 {
t.Fatalf("unexpected error rate: %v", st.LoadErrorRate())
}
if st.AverageLoadPenalty() != (1500 * time.Millisecond) {
t.Fatalf("unexpected load penalty: %v", st.AverageLoadPenalty())
}
}

View File

@ -295,13 +295,14 @@ func sprinterr(m dsl.Matcher) {
}
func largeloopcopy(m dsl.Matcher) {
m.Match(
`for $_, $v := range $_ { $*_ }`,
).
Where(m["v"].Type.Size > 1024).
Report(`loop copies large value each iteration`)
}
// disable this check, because it can not apply to generic type
//func largeloopcopy(m dsl.Matcher) {
// m.Match(
// `for $_, $v := range $_ { $*_ }`,
// ).
// Where(m["v"].Type.Size > 1024).
// Report(`loop copies large value each iteration`)
//}
func joinpath(m dsl.Matcher) {
m.Match(