refactor: expose `CacheGetStatus` (and improve tests) (#4804)
* refactor: expose `CacheGetStatus` (and improve tests) - add a `CacheGetStatus` which tells the user if the request was a hit or miss (or something inbetween) - adapt some tests to use the status (only the tests where this could be relevant) - move the test suite from using `sleep` to proper barriers (more stable under high load, more correct, potentially faster) * refactor: improve `abort_and_wait` checks * docs: typos and improve wording Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * refactor: `FutureExt2` -> `EnsurePendingExt` * refactor: `Queried` -> `MissAlreadyLoading` * docs: explain `abort_or_wait` more Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>pull/24376/head
parent
f34282be2c
commit
2e3ba83795
|
@ -14,7 +14,7 @@ use tokio::{
|
|||
task::JoinHandle,
|
||||
};
|
||||
|
||||
use super::Cache;
|
||||
use super::{Cache, CacheGetStatus};
|
||||
|
||||
/// Combine a [`CacheBackend`] and a [`Loader`] into a single [`Cache`]
|
||||
#[derive(Debug)]
|
||||
|
@ -61,20 +61,23 @@ where
|
|||
type V = V;
|
||||
type Extra = Extra;
|
||||
|
||||
async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V {
|
||||
async fn get_with_status(&self, k: Self::K, extra: Self::Extra) -> (Self::V, CacheGetStatus) {
|
||||
// place state locking into its own scope so it doesn't leak into the generator (async
|
||||
// function)
|
||||
let receiver = {
|
||||
let (receiver, status) = {
|
||||
let mut state = self.state.lock();
|
||||
|
||||
// check if the entry has already been cached
|
||||
if let Some(v) = state.cached_entries.get(&k) {
|
||||
return v;
|
||||
return (v, CacheGetStatus::Hit);
|
||||
}
|
||||
|
||||
// check if there is already a query for this key running
|
||||
if let Some(running_query) = state.running_queries.get(&k) {
|
||||
running_query.recv.clone()
|
||||
(
|
||||
running_query.recv.clone(),
|
||||
CacheGetStatus::MissAlreadyLoading,
|
||||
)
|
||||
} else {
|
||||
// requires new query
|
||||
let (tx_main, rx_main) = tokio::sync::oneshot::channel();
|
||||
|
@ -176,15 +179,17 @@ where
|
|||
tag,
|
||||
},
|
||||
);
|
||||
receiver
|
||||
(receiver, CacheGetStatus::Miss)
|
||||
}
|
||||
};
|
||||
|
||||
receiver
|
||||
let v = receiver
|
||||
.await
|
||||
.expect("cache loader panicked, see logs")
|
||||
.lock()
|
||||
.clone()
|
||||
.clone();
|
||||
|
||||
(v, status)
|
||||
}
|
||||
|
||||
async fn set(&self, k: Self::K, v: Self::V) {
|
||||
|
|
|
@ -12,6 +12,20 @@ pub mod driver;
|
|||
#[cfg(test)]
|
||||
mod test_util;
|
||||
|
||||
/// Status of a [`Cache`] GET request.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CacheGetStatus {
|
||||
/// The requested entry was present in the storage backend.
|
||||
Hit,
|
||||
|
||||
/// The requested entry was NOT present in the storage backend and the loader had no previous query running.
|
||||
Miss,
|
||||
|
||||
/// The requested entry was NOT present in the storage backend, but there was already a loader query running for
|
||||
/// this particular key.
|
||||
MissAlreadyLoading,
|
||||
}
|
||||
|
||||
/// High-level cache implementation.
|
||||
///
|
||||
/// # Concurrency
|
||||
|
@ -41,7 +55,16 @@ pub trait Cache: Debug + Send + Sync + 'static {
|
|||
type Extra: Debug + Send + 'static;
|
||||
|
||||
/// Get value from cache.
|
||||
async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V;
|
||||
///
|
||||
/// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet.
|
||||
async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V {
|
||||
self.get_with_status(k, extra).await.0
|
||||
}
|
||||
|
||||
/// Get value from cache and the status.
|
||||
///
|
||||
/// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet.
|
||||
async fn get_with_status(&self, k: Self::K, extra: Self::Extra) -> (Self::V, CacheGetStatus);
|
||||
|
||||
/// Side-load an entry into the cache.
|
||||
///
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
use std::{collections::HashSet, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{Future, FutureExt};
|
||||
use parking_lot::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::{
|
||||
sync::{Barrier, Notify},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
use crate::loader::Loader;
|
||||
use crate::{cache::CacheGetStatus, loader::Loader};
|
||||
|
||||
use super::Cache;
|
||||
|
||||
|
@ -44,11 +48,26 @@ async fn test_linear_memory<C>(cache: Arc<C>, loader: Arc<TestLoader>)
|
|||
where
|
||||
C: Cache<K = u8, V = String, Extra = bool>,
|
||||
{
|
||||
assert_eq!(cache.get(1, true).await, String::from("1_true"));
|
||||
assert_eq!(cache.get(1, false).await, String::from("1_true"));
|
||||
assert_eq!(cache.get(2, false).await, String::from("2_false"));
|
||||
assert_eq!(cache.get(2, false).await, String::from("2_false"));
|
||||
assert_eq!(cache.get(1, true).await, String::from("1_true"));
|
||||
assert_eq!(
|
||||
cache.get_with_status(1, true).await,
|
||||
(String::from("1_true"), CacheGetStatus::Miss),
|
||||
);
|
||||
assert_eq!(
|
||||
cache.get_with_status(1, false).await,
|
||||
(String::from("1_true"), CacheGetStatus::Hit),
|
||||
);
|
||||
assert_eq!(
|
||||
cache.get_with_status(2, false).await,
|
||||
(String::from("2_false"), CacheGetStatus::Miss),
|
||||
);
|
||||
assert_eq!(
|
||||
cache.get_with_status(2, false).await,
|
||||
(String::from("2_false"), CacheGetStatus::Hit),
|
||||
);
|
||||
assert_eq!(
|
||||
cache.get_with_status(1, true).await,
|
||||
(String::from("1_true"), CacheGetStatus::Hit),
|
||||
);
|
||||
|
||||
assert_eq!(loader.loaded(), vec![1, 2]);
|
||||
}
|
||||
|
@ -60,16 +79,39 @@ where
|
|||
loader.block();
|
||||
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await });
|
||||
let handle_2 = tokio::spawn(async move { cache.get(1, true).await });
|
||||
let barrier_pending_1 = Arc::new(Barrier::new(2));
|
||||
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
|
||||
let handle_1 = tokio::spawn(async move {
|
||||
cache_captured
|
||||
.get_with_status(1, true)
|
||||
.ensure_pending(barrier_pending_1_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
barrier_pending_1.wait().await;
|
||||
let barrier_pending_2 = Arc::new(Barrier::new(2));
|
||||
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
|
||||
let handle_2 = tokio::spawn(async move {
|
||||
// use a different `extra` here to proof that the first one was used
|
||||
cache
|
||||
.get_with_status(1, false)
|
||||
.ensure_pending(barrier_pending_2_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
barrier_pending_2.wait().await;
|
||||
// Shouldn't issue concurrent load requests for the same key
|
||||
let n_blocked = loader.unblock();
|
||||
assert_eq!(n_blocked, 1);
|
||||
|
||||
assert_eq!(handle_1.await.unwrap(), String::from("1_true"));
|
||||
assert_eq!(handle_2.await.unwrap(), String::from("1_true"));
|
||||
assert_eq!(
|
||||
handle_1.await.unwrap(),
|
||||
(String::from("1_true"), CacheGetStatus::Miss),
|
||||
);
|
||||
assert_eq!(
|
||||
handle_2.await.unwrap(),
|
||||
(String::from("1_true"), CacheGetStatus::MissAlreadyLoading),
|
||||
);
|
||||
|
||||
assert_eq!(loader.loaded(), vec![1]);
|
||||
}
|
||||
|
@ -80,13 +122,31 @@ where
|
|||
{
|
||||
loader.block();
|
||||
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await });
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let handle_2 = tokio::spawn(async move { cache_captured.get(1, true).await });
|
||||
let handle_3 = tokio::spawn(async move { cache.get(2, false).await });
|
||||
let barrier = Arc::new(Barrier::new(4));
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let barrier_captured = Arc::clone(&barrier);
|
||||
let handle_1 = tokio::spawn(async move {
|
||||
cache_captured
|
||||
.get(1, true)
|
||||
.ensure_pending(barrier_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let barrier_captured = Arc::clone(&barrier);
|
||||
let handle_2 = tokio::spawn(async move {
|
||||
cache_captured
|
||||
.get(1, true)
|
||||
.ensure_pending(barrier_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
let barrier_captured = Arc::clone(&barrier);
|
||||
let handle_3 =
|
||||
tokio::spawn(async move { cache.get(2, false).ensure_pending(barrier_captured).await });
|
||||
|
||||
barrier.wait().await;
|
||||
|
||||
let n_blocked = loader.unblock();
|
||||
assert_eq!(n_blocked, 2);
|
||||
|
@ -104,16 +164,30 @@ where
|
|||
{
|
||||
loader.block();
|
||||
|
||||
let barrier_pending_1 = Arc::new(Barrier::new(2));
|
||||
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await });
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
let handle_2 = tokio::spawn(async move { cache.get(1, false).await });
|
||||
let handle_1 = tokio::spawn(async move {
|
||||
cache_captured
|
||||
.get(1, true)
|
||||
.ensure_pending(barrier_pending_1_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
barrier_pending_1.wait().await;
|
||||
let barrier_pending_2 = Arc::new(Barrier::new(2));
|
||||
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
|
||||
let handle_2 = tokio::spawn(async move {
|
||||
cache
|
||||
.get(1, false)
|
||||
.ensure_pending(barrier_pending_2_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
barrier_pending_2.wait().await;
|
||||
|
||||
// abort first handle
|
||||
handle_1.abort();
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
handle_1.abort_and_wait().await;
|
||||
|
||||
let n_blocked = loader.unblock();
|
||||
assert_eq!(n_blocked, 1);
|
||||
|
@ -130,14 +204,35 @@ where
|
|||
loader.panic_once(1);
|
||||
loader.block();
|
||||
|
||||
let barrier_pending_1 = Arc::new(Barrier::new(2));
|
||||
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await });
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let handle_2 = tokio::spawn(async move { cache_captured.get(1, false).await });
|
||||
let handle_3 = tokio::spawn(async move { cache.get(2, false).await });
|
||||
let handle_1 = tokio::spawn(async move {
|
||||
cache_captured
|
||||
.get(1, true)
|
||||
.ensure_pending(barrier_pending_1_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
barrier_pending_1.wait().await;
|
||||
let barrier_pending_23 = Arc::new(Barrier::new(3));
|
||||
let barrier_pending_23_captured = Arc::clone(&barrier_pending_23);
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let handle_2 = tokio::spawn(async move {
|
||||
cache_captured
|
||||
.get(1, false)
|
||||
.ensure_pending(barrier_pending_23_captured)
|
||||
.await
|
||||
});
|
||||
let barrier_pending_23_captured = Arc::clone(&barrier_pending_23);
|
||||
let handle_3 = tokio::spawn(async move {
|
||||
cache
|
||||
.get(2, false)
|
||||
.ensure_pending(barrier_pending_23_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
barrier_pending_23.wait().await;
|
||||
|
||||
let n_blocked = loader.unblock();
|
||||
assert_eq!(n_blocked, 2);
|
||||
|
@ -161,13 +256,18 @@ where
|
|||
{
|
||||
loader.block();
|
||||
|
||||
let handle = tokio::spawn(async move { cache.get(1, true).await });
|
||||
let barrier_pending = Arc::new(Barrier::new(2));
|
||||
let barrier_pending_captured = Arc::clone(&barrier_pending);
|
||||
let handle = tokio::spawn(async move {
|
||||
cache
|
||||
.get(1, true)
|
||||
.ensure_pending(barrier_pending_captured)
|
||||
.await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
barrier_pending.wait().await;
|
||||
|
||||
handle.abort();
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
handle.abort_and_wait().await;
|
||||
|
||||
assert_eq!(Arc::strong_count(&loader), 1);
|
||||
}
|
||||
|
@ -195,8 +295,15 @@ where
|
|||
loader.block();
|
||||
|
||||
let cache_captured = Arc::clone(&cache);
|
||||
let handle = tokio::spawn(async move { cache_captured.get(1, true).await });
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
let barrier_pending = Arc::new(Barrier::new(2));
|
||||
let barrier_pending_captured = Arc::clone(&barrier_pending);
|
||||
let handle = tokio::spawn(async move {
|
||||
cache_captured
|
||||
.get(1, true)
|
||||
.ensure_pending(barrier_pending_captured)
|
||||
.await
|
||||
});
|
||||
barrier_pending.wait().await;
|
||||
|
||||
cache.set(1, String::from("foo")).await;
|
||||
|
||||
|
@ -285,3 +392,60 @@ impl Loader for TestLoader {
|
|||
format!("{k}_{extra}")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait EnsurePendingExt {
|
||||
type Out;
|
||||
|
||||
/// Ensure that the future is pending. In the pending case, try to pass the given barrier. Afterwards await the future again.
|
||||
///
|
||||
/// This is helpful to ensure a future is in a pending state before continuing with the test setup.
|
||||
async fn ensure_pending(self, barrier: Arc<Barrier>) -> Self::Out;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F> EnsurePendingExt for F
|
||||
where
|
||||
F: Future + Send + Unpin,
|
||||
{
|
||||
type Out = F::Output;
|
||||
|
||||
async fn ensure_pending(self, barrier: Arc<Barrier>) -> Self::Out {
|
||||
let mut fut = self.fuse();
|
||||
futures::select_biased! {
|
||||
_ = fut => panic!("fut should be pending"),
|
||||
_ = barrier.wait().fuse() => (),
|
||||
}
|
||||
|
||||
fut.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait AbortAndWaitExt {
|
||||
/// Abort handle and wait for completion.
|
||||
///
|
||||
/// Note that this is NOT just a "wait with timeout or panic". This extension is specific to [`JoinHandle`] and will:
|
||||
///
|
||||
/// 1. Call [`JoinHandle::abort`].
|
||||
/// 2. Await the [`JoinHandle`] with a timeout (or panic if the timeout is reached).
|
||||
/// 3. Check that the handle returned a [`JoinError`] that signals that the tracked task was indeed cancelled and
|
||||
/// didn't exit otherwise (either by finishing or by panicking).
|
||||
async fn abort_and_wait(self);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> AbortAndWaitExt for JoinHandle<T>
|
||||
where
|
||||
T: std::fmt::Debug + Send,
|
||||
{
|
||||
async fn abort_and_wait(mut self) {
|
||||
self.abort();
|
||||
|
||||
let join_err = tokio::time::timeout(Duration::from_secs(1), self)
|
||||
.await
|
||||
.expect("no timeout")
|
||||
.expect_err("handle was aborted and therefore MUST fail");
|
||||
assert!(join_err.is_cancelled());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue