refactor: expose `CacheGetStatus` (and improve tests) (#4804)

* refactor: expose `CacheGetStatus` (and improve tests)

- add a `CacheGetStatus` which tells the user if the request was a hit
  or miss (or something inbetween)
- adapt some tests to use the status (only the tests where this could be
  relevant)
- move the test suite from using `sleep` to proper barriers (more stable
  under high load, more correct, potentially faster)

* refactor: improve `abort_and_wait` checks

* docs: typos and improve wording

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* refactor: `FutureExt2` -> `EnsurePendingExt`

* refactor: `Queried` -> `MissAlreadyLoading`

* docs: explain `abort_or_wait` more

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
pull/24376/head
Marco Neumann 2022-06-09 09:32:46 +02:00 committed by GitHub
parent f34282be2c
commit 2e3ba83795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 238 additions and 46 deletions

View File

@ -14,7 +14,7 @@ use tokio::{
task::JoinHandle,
};
use super::Cache;
use super::{Cache, CacheGetStatus};
/// Combine a [`CacheBackend`] and a [`Loader`] into a single [`Cache`]
#[derive(Debug)]
@ -61,20 +61,23 @@ where
type V = V;
type Extra = Extra;
async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V {
async fn get_with_status(&self, k: Self::K, extra: Self::Extra) -> (Self::V, CacheGetStatus) {
// place state locking into its own scope so it doesn't leak into the generator (async
// function)
let receiver = {
let (receiver, status) = {
let mut state = self.state.lock();
// check if the entry has already been cached
if let Some(v) = state.cached_entries.get(&k) {
return v;
return (v, CacheGetStatus::Hit);
}
// check if there is already a query for this key running
if let Some(running_query) = state.running_queries.get(&k) {
running_query.recv.clone()
(
running_query.recv.clone(),
CacheGetStatus::MissAlreadyLoading,
)
} else {
// requires new query
let (tx_main, rx_main) = tokio::sync::oneshot::channel();
@ -176,15 +179,17 @@ where
tag,
},
);
receiver
(receiver, CacheGetStatus::Miss)
}
};
receiver
let v = receiver
.await
.expect("cache loader panicked, see logs")
.lock()
.clone()
.clone();
(v, status)
}
async fn set(&self, k: Self::K, v: Self::V) {

View File

@ -12,6 +12,20 @@ pub mod driver;
#[cfg(test)]
mod test_util;
/// Status of a [`Cache`] GET request.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheGetStatus {
/// The requested entry was present in the storage backend.
Hit,
/// The requested entry was NOT present in the storage backend and the loader had no previous query running.
Miss,
/// The requested entry was NOT present in the storage backend, but there was already a loader query running for
/// this particular key.
MissAlreadyLoading,
}
/// High-level cache implementation.
///
/// # Concurrency
@ -41,7 +55,16 @@ pub trait Cache: Debug + Send + Sync + 'static {
type Extra: Debug + Send + 'static;
/// Get value from cache.
async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V;
///
/// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet.
async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V {
self.get_with_status(k, extra).await.0
}
/// Get value from cache and the status.
///
/// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet.
async fn get_with_status(&self, k: Self::K, extra: Self::Extra) -> (Self::V, CacheGetStatus);
/// Side-load an entry into the cache.
///

View File

@ -1,10 +1,14 @@
use std::{collections::HashSet, sync::Arc, time::Duration};
use async_trait::async_trait;
use futures::{Future, FutureExt};
use parking_lot::Mutex;
use tokio::sync::Notify;
use tokio::{
sync::{Barrier, Notify},
task::JoinHandle,
};
use crate::loader::Loader;
use crate::{cache::CacheGetStatus, loader::Loader};
use super::Cache;
@ -44,11 +48,26 @@ async fn test_linear_memory<C>(cache: Arc<C>, loader: Arc<TestLoader>)
where
C: Cache<K = u8, V = String, Extra = bool>,
{
assert_eq!(cache.get(1, true).await, String::from("1_true"));
assert_eq!(cache.get(1, false).await, String::from("1_true"));
assert_eq!(cache.get(2, false).await, String::from("2_false"));
assert_eq!(cache.get(2, false).await, String::from("2_false"));
assert_eq!(cache.get(1, true).await, String::from("1_true"));
assert_eq!(
cache.get_with_status(1, true).await,
(String::from("1_true"), CacheGetStatus::Miss),
);
assert_eq!(
cache.get_with_status(1, false).await,
(String::from("1_true"), CacheGetStatus::Hit),
);
assert_eq!(
cache.get_with_status(2, false).await,
(String::from("2_false"), CacheGetStatus::Miss),
);
assert_eq!(
cache.get_with_status(2, false).await,
(String::from("2_false"), CacheGetStatus::Hit),
);
assert_eq!(
cache.get_with_status(1, true).await,
(String::from("1_true"), CacheGetStatus::Hit),
);
assert_eq!(loader.loaded(), vec![1, 2]);
}
@ -60,16 +79,39 @@ where
loader.block();
let cache_captured = Arc::clone(&cache);
let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await });
let handle_2 = tokio::spawn(async move { cache.get(1, true).await });
let barrier_pending_1 = Arc::new(Barrier::new(2));
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
let handle_1 = tokio::spawn(async move {
cache_captured
.get_with_status(1, true)
.ensure_pending(barrier_pending_1_captured)
.await
});
tokio::time::sleep(Duration::from_millis(10)).await;
barrier_pending_1.wait().await;
let barrier_pending_2 = Arc::new(Barrier::new(2));
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
let handle_2 = tokio::spawn(async move {
// use a different `extra` here to proof that the first one was used
cache
.get_with_status(1, false)
.ensure_pending(barrier_pending_2_captured)
.await
});
barrier_pending_2.wait().await;
// Shouldn't issue concurrent load requests for the same key
let n_blocked = loader.unblock();
assert_eq!(n_blocked, 1);
assert_eq!(handle_1.await.unwrap(), String::from("1_true"));
assert_eq!(handle_2.await.unwrap(), String::from("1_true"));
assert_eq!(
handle_1.await.unwrap(),
(String::from("1_true"), CacheGetStatus::Miss),
);
assert_eq!(
handle_2.await.unwrap(),
(String::from("1_true"), CacheGetStatus::MissAlreadyLoading),
);
assert_eq!(loader.loaded(), vec![1]);
}
@ -80,13 +122,31 @@ where
{
loader.block();
let cache_captured = Arc::clone(&cache);
let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await });
let cache_captured = Arc::clone(&cache);
let handle_2 = tokio::spawn(async move { cache_captured.get(1, true).await });
let handle_3 = tokio::spawn(async move { cache.get(2, false).await });
let barrier = Arc::new(Barrier::new(4));
tokio::time::sleep(Duration::from_millis(10)).await;
let cache_captured = Arc::clone(&cache);
let barrier_captured = Arc::clone(&barrier);
let handle_1 = tokio::spawn(async move {
cache_captured
.get(1, true)
.ensure_pending(barrier_captured)
.await
});
let cache_captured = Arc::clone(&cache);
let barrier_captured = Arc::clone(&barrier);
let handle_2 = tokio::spawn(async move {
cache_captured
.get(1, true)
.ensure_pending(barrier_captured)
.await
});
let barrier_captured = Arc::clone(&barrier);
let handle_3 =
tokio::spawn(async move { cache.get(2, false).ensure_pending(barrier_captured).await });
barrier.wait().await;
let n_blocked = loader.unblock();
assert_eq!(n_blocked, 2);
@ -104,16 +164,30 @@ where
{
loader.block();
let barrier_pending_1 = Arc::new(Barrier::new(2));
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
let cache_captured = Arc::clone(&cache);
let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await });
tokio::time::sleep(Duration::from_millis(10)).await;
let handle_2 = tokio::spawn(async move { cache.get(1, false).await });
let handle_1 = tokio::spawn(async move {
cache_captured
.get(1, true)
.ensure_pending(barrier_pending_1_captured)
.await
});
tokio::time::sleep(Duration::from_millis(10)).await;
barrier_pending_1.wait().await;
let barrier_pending_2 = Arc::new(Barrier::new(2));
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
let handle_2 = tokio::spawn(async move {
cache
.get(1, false)
.ensure_pending(barrier_pending_2_captured)
.await
});
barrier_pending_2.wait().await;
// abort first handle
handle_1.abort();
tokio::time::sleep(Duration::from_millis(10)).await;
handle_1.abort_and_wait().await;
let n_blocked = loader.unblock();
assert_eq!(n_blocked, 1);
@ -130,14 +204,35 @@ where
loader.panic_once(1);
loader.block();
let barrier_pending_1 = Arc::new(Barrier::new(2));
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
let cache_captured = Arc::clone(&cache);
let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await });
tokio::time::sleep(Duration::from_millis(10)).await;
let cache_captured = Arc::clone(&cache);
let handle_2 = tokio::spawn(async move { cache_captured.get(1, false).await });
let handle_3 = tokio::spawn(async move { cache.get(2, false).await });
let handle_1 = tokio::spawn(async move {
cache_captured
.get(1, true)
.ensure_pending(barrier_pending_1_captured)
.await
});
tokio::time::sleep(Duration::from_millis(10)).await;
barrier_pending_1.wait().await;
let barrier_pending_23 = Arc::new(Barrier::new(3));
let barrier_pending_23_captured = Arc::clone(&barrier_pending_23);
let cache_captured = Arc::clone(&cache);
let handle_2 = tokio::spawn(async move {
cache_captured
.get(1, false)
.ensure_pending(barrier_pending_23_captured)
.await
});
let barrier_pending_23_captured = Arc::clone(&barrier_pending_23);
let handle_3 = tokio::spawn(async move {
cache
.get(2, false)
.ensure_pending(barrier_pending_23_captured)
.await
});
barrier_pending_23.wait().await;
let n_blocked = loader.unblock();
assert_eq!(n_blocked, 2);
@ -161,13 +256,18 @@ where
{
loader.block();
let handle = tokio::spawn(async move { cache.get(1, true).await });
let barrier_pending = Arc::new(Barrier::new(2));
let barrier_pending_captured = Arc::clone(&barrier_pending);
let handle = tokio::spawn(async move {
cache
.get(1, true)
.ensure_pending(barrier_pending_captured)
.await
});
tokio::time::sleep(Duration::from_millis(10)).await;
barrier_pending.wait().await;
handle.abort();
tokio::time::sleep(Duration::from_millis(10)).await;
handle.abort_and_wait().await;
assert_eq!(Arc::strong_count(&loader), 1);
}
@ -195,8 +295,15 @@ where
loader.block();
let cache_captured = Arc::clone(&cache);
let handle = tokio::spawn(async move { cache_captured.get(1, true).await });
tokio::time::sleep(Duration::from_millis(10)).await;
let barrier_pending = Arc::new(Barrier::new(2));
let barrier_pending_captured = Arc::clone(&barrier_pending);
let handle = tokio::spawn(async move {
cache_captured
.get(1, true)
.ensure_pending(barrier_pending_captured)
.await
});
barrier_pending.wait().await;
cache.set(1, String::from("foo")).await;
@ -285,3 +392,60 @@ impl Loader for TestLoader {
format!("{k}_{extra}")
}
}
#[async_trait]
trait EnsurePendingExt {
type Out;
/// Ensure that the future is pending. In the pending case, try to pass the given barrier. Afterwards await the future again.
///
/// This is helpful to ensure a future is in a pending state before continuing with the test setup.
async fn ensure_pending(self, barrier: Arc<Barrier>) -> Self::Out;
}
#[async_trait]
impl<F> EnsurePendingExt for F
where
F: Future + Send + Unpin,
{
type Out = F::Output;
async fn ensure_pending(self, barrier: Arc<Barrier>) -> Self::Out {
let mut fut = self.fuse();
futures::select_biased! {
_ = fut => panic!("fut should be pending"),
_ = barrier.wait().fuse() => (),
}
fut.await
}
}
#[async_trait]
trait AbortAndWaitExt {
/// Abort handle and wait for completion.
///
/// Note that this is NOT just a "wait with timeout or panic". This extension is specific to [`JoinHandle`] and will:
///
/// 1. Call [`JoinHandle::abort`].
/// 2. Await the [`JoinHandle`] with a timeout (or panic if the timeout is reached).
/// 3. Check that the handle returned a [`JoinError`] that signals that the tracked task was indeed cancelled and
/// didn't exit otherwise (either by finishing or by panicking).
async fn abort_and_wait(self);
}
#[async_trait]
impl<T> AbortAndWaitExt for JoinHandle<T>
where
T: std::fmt::Debug + Send,
{
async fn abort_and_wait(mut self) {
self.abort();
let join_err = tokio::time::timeout(Duration::from_secs(1), self)
.await
.expect("no timeout")
.expect_err("handle was aborted and therefore MUST fail");
assert!(join_err.is_cancelled());
}
}