diff --git a/cache_system/src/cache/driver.rs b/cache_system/src/cache/driver.rs index d336a7d920..ed33643c4b 100644 --- a/cache_system/src/cache/driver.rs +++ b/cache_system/src/cache/driver.rs @@ -14,7 +14,7 @@ use tokio::{ task::JoinHandle, }; -use super::Cache; +use super::{Cache, CacheGetStatus}; /// Combine a [`CacheBackend`] and a [`Loader`] into a single [`Cache`] #[derive(Debug)] @@ -61,20 +61,23 @@ where type V = V; type Extra = Extra; - async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V { + async fn get_with_status(&self, k: Self::K, extra: Self::Extra) -> (Self::V, CacheGetStatus) { // place state locking into its own scope so it doesn't leak into the generator (async // function) - let receiver = { + let (receiver, status) = { let mut state = self.state.lock(); // check if the entry has already been cached if let Some(v) = state.cached_entries.get(&k) { - return v; + return (v, CacheGetStatus::Hit); } // check if there is already a query for this key running if let Some(running_query) = state.running_queries.get(&k) { - running_query.recv.clone() + ( + running_query.recv.clone(), + CacheGetStatus::MissAlreadyLoading, + ) } else { // requires new query let (tx_main, rx_main) = tokio::sync::oneshot::channel(); @@ -176,15 +179,17 @@ where tag, }, ); - receiver + (receiver, CacheGetStatus::Miss) } }; - receiver + let v = receiver .await .expect("cache loader panicked, see logs") .lock() - .clone() + .clone(); + + (v, status) } async fn set(&self, k: Self::K, v: Self::V) { diff --git a/cache_system/src/cache/mod.rs b/cache_system/src/cache/mod.rs index 67eecd21e1..7a7e543ed6 100644 --- a/cache_system/src/cache/mod.rs +++ b/cache_system/src/cache/mod.rs @@ -12,6 +12,20 @@ pub mod driver; #[cfg(test)] mod test_util; +/// Status of a [`Cache`] GET request. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CacheGetStatus { + /// The requested entry was present in the storage backend. + Hit, + + /// The requested entry was NOT present in the storage backend and the loader had no previous query running. + Miss, + + /// The requested entry was NOT present in the storage backend, but there was already a loader query running for + /// this particular key. + MissAlreadyLoading, +} + /// High-level cache implementation. /// /// # Concurrency @@ -41,7 +55,16 @@ pub trait Cache: Debug + Send + Sync + 'static { type Extra: Debug + Send + 'static; /// Get value from cache. - async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V; + /// + /// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet. + async fn get(&self, k: Self::K, extra: Self::Extra) -> Self::V { + self.get_with_status(k, extra).await.0 + } + + /// Get value from cache and the status. + /// + /// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet. + async fn get_with_status(&self, k: Self::K, extra: Self::Extra) -> (Self::V, CacheGetStatus); /// Side-load an entry into the cache. /// diff --git a/cache_system/src/cache/test_util.rs b/cache_system/src/cache/test_util.rs index cd69b868af..6a6cd2fffe 100644 --- a/cache_system/src/cache/test_util.rs +++ b/cache_system/src/cache/test_util.rs @@ -1,10 +1,14 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; use async_trait::async_trait; +use futures::{Future, FutureExt}; use parking_lot::Mutex; -use tokio::sync::Notify; +use tokio::{ + sync::{Barrier, Notify}, + task::JoinHandle, +}; -use crate::loader::Loader; +use crate::{cache::CacheGetStatus, loader::Loader}; use super::Cache; @@ -44,11 +48,26 @@ async fn test_linear_memory(cache: Arc, loader: Arc) where C: Cache, { - assert_eq!(cache.get(1, true).await, String::from("1_true")); - assert_eq!(cache.get(1, false).await, String::from("1_true")); - assert_eq!(cache.get(2, false).await, String::from("2_false")); - assert_eq!(cache.get(2, false).await, String::from("2_false")); - assert_eq!(cache.get(1, true).await, String::from("1_true")); + assert_eq!( + cache.get_with_status(1, true).await, + (String::from("1_true"), CacheGetStatus::Miss), + ); + assert_eq!( + cache.get_with_status(1, false).await, + (String::from("1_true"), CacheGetStatus::Hit), + ); + assert_eq!( + cache.get_with_status(2, false).await, + (String::from("2_false"), CacheGetStatus::Miss), + ); + assert_eq!( + cache.get_with_status(2, false).await, + (String::from("2_false"), CacheGetStatus::Hit), + ); + assert_eq!( + cache.get_with_status(1, true).await, + (String::from("1_true"), CacheGetStatus::Hit), + ); assert_eq!(loader.loaded(), vec![1, 2]); } @@ -60,16 +79,39 @@ where loader.block(); let cache_captured = Arc::clone(&cache); - let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await }); - let handle_2 = tokio::spawn(async move { cache.get(1, true).await }); + let barrier_pending_1 = Arc::new(Barrier::new(2)); + let barrier_pending_1_captured = Arc::clone(&barrier_pending_1); + let handle_1 = tokio::spawn(async move { + cache_captured + .get_with_status(1, true) + .ensure_pending(barrier_pending_1_captured) + .await + }); - tokio::time::sleep(Duration::from_millis(10)).await; + barrier_pending_1.wait().await; + let barrier_pending_2 = Arc::new(Barrier::new(2)); + let barrier_pending_2_captured = Arc::clone(&barrier_pending_2); + let handle_2 = tokio::spawn(async move { + // use a different `extra` here to proof that the first one was used + cache + .get_with_status(1, false) + .ensure_pending(barrier_pending_2_captured) + .await + }); + + barrier_pending_2.wait().await; // Shouldn't issue concurrent load requests for the same key let n_blocked = loader.unblock(); assert_eq!(n_blocked, 1); - assert_eq!(handle_1.await.unwrap(), String::from("1_true")); - assert_eq!(handle_2.await.unwrap(), String::from("1_true")); + assert_eq!( + handle_1.await.unwrap(), + (String::from("1_true"), CacheGetStatus::Miss), + ); + assert_eq!( + handle_2.await.unwrap(), + (String::from("1_true"), CacheGetStatus::MissAlreadyLoading), + ); assert_eq!(loader.loaded(), vec![1]); } @@ -80,13 +122,31 @@ where { loader.block(); - let cache_captured = Arc::clone(&cache); - let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await }); - let cache_captured = Arc::clone(&cache); - let handle_2 = tokio::spawn(async move { cache_captured.get(1, true).await }); - let handle_3 = tokio::spawn(async move { cache.get(2, false).await }); + let barrier = Arc::new(Barrier::new(4)); - tokio::time::sleep(Duration::from_millis(10)).await; + let cache_captured = Arc::clone(&cache); + let barrier_captured = Arc::clone(&barrier); + let handle_1 = tokio::spawn(async move { + cache_captured + .get(1, true) + .ensure_pending(barrier_captured) + .await + }); + + let cache_captured = Arc::clone(&cache); + let barrier_captured = Arc::clone(&barrier); + let handle_2 = tokio::spawn(async move { + cache_captured + .get(1, true) + .ensure_pending(barrier_captured) + .await + }); + + let barrier_captured = Arc::clone(&barrier); + let handle_3 = + tokio::spawn(async move { cache.get(2, false).ensure_pending(barrier_captured).await }); + + barrier.wait().await; let n_blocked = loader.unblock(); assert_eq!(n_blocked, 2); @@ -104,16 +164,30 @@ where { loader.block(); + let barrier_pending_1 = Arc::new(Barrier::new(2)); + let barrier_pending_1_captured = Arc::clone(&barrier_pending_1); let cache_captured = Arc::clone(&cache); - let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await }); - tokio::time::sleep(Duration::from_millis(10)).await; - let handle_2 = tokio::spawn(async move { cache.get(1, false).await }); + let handle_1 = tokio::spawn(async move { + cache_captured + .get(1, true) + .ensure_pending(barrier_pending_1_captured) + .await + }); - tokio::time::sleep(Duration::from_millis(10)).await; + barrier_pending_1.wait().await; + let barrier_pending_2 = Arc::new(Barrier::new(2)); + let barrier_pending_2_captured = Arc::clone(&barrier_pending_2); + let handle_2 = tokio::spawn(async move { + cache + .get(1, false) + .ensure_pending(barrier_pending_2_captured) + .await + }); + + barrier_pending_2.wait().await; // abort first handle - handle_1.abort(); - tokio::time::sleep(Duration::from_millis(10)).await; + handle_1.abort_and_wait().await; let n_blocked = loader.unblock(); assert_eq!(n_blocked, 1); @@ -130,14 +204,35 @@ where loader.panic_once(1); loader.block(); + let barrier_pending_1 = Arc::new(Barrier::new(2)); + let barrier_pending_1_captured = Arc::clone(&barrier_pending_1); let cache_captured = Arc::clone(&cache); - let handle_1 = tokio::spawn(async move { cache_captured.get(1, true).await }); - tokio::time::sleep(Duration::from_millis(10)).await; - let cache_captured = Arc::clone(&cache); - let handle_2 = tokio::spawn(async move { cache_captured.get(1, false).await }); - let handle_3 = tokio::spawn(async move { cache.get(2, false).await }); + let handle_1 = tokio::spawn(async move { + cache_captured + .get(1, true) + .ensure_pending(barrier_pending_1_captured) + .await + }); - tokio::time::sleep(Duration::from_millis(10)).await; + barrier_pending_1.wait().await; + let barrier_pending_23 = Arc::new(Barrier::new(3)); + let barrier_pending_23_captured = Arc::clone(&barrier_pending_23); + let cache_captured = Arc::clone(&cache); + let handle_2 = tokio::spawn(async move { + cache_captured + .get(1, false) + .ensure_pending(barrier_pending_23_captured) + .await + }); + let barrier_pending_23_captured = Arc::clone(&barrier_pending_23); + let handle_3 = tokio::spawn(async move { + cache + .get(2, false) + .ensure_pending(barrier_pending_23_captured) + .await + }); + + barrier_pending_23.wait().await; let n_blocked = loader.unblock(); assert_eq!(n_blocked, 2); @@ -161,13 +256,18 @@ where { loader.block(); - let handle = tokio::spawn(async move { cache.get(1, true).await }); + let barrier_pending = Arc::new(Barrier::new(2)); + let barrier_pending_captured = Arc::clone(&barrier_pending); + let handle = tokio::spawn(async move { + cache + .get(1, true) + .ensure_pending(barrier_pending_captured) + .await + }); - tokio::time::sleep(Duration::from_millis(10)).await; + barrier_pending.wait().await; - handle.abort(); - - tokio::time::sleep(Duration::from_millis(10)).await; + handle.abort_and_wait().await; assert_eq!(Arc::strong_count(&loader), 1); } @@ -195,8 +295,15 @@ where loader.block(); let cache_captured = Arc::clone(&cache); - let handle = tokio::spawn(async move { cache_captured.get(1, true).await }); - tokio::time::sleep(Duration::from_millis(10)).await; + let barrier_pending = Arc::new(Barrier::new(2)); + let barrier_pending_captured = Arc::clone(&barrier_pending); + let handle = tokio::spawn(async move { + cache_captured + .get(1, true) + .ensure_pending(barrier_pending_captured) + .await + }); + barrier_pending.wait().await; cache.set(1, String::from("foo")).await; @@ -285,3 +392,60 @@ impl Loader for TestLoader { format!("{k}_{extra}") } } + +#[async_trait] +trait EnsurePendingExt { + type Out; + + /// Ensure that the future is pending. In the pending case, try to pass the given barrier. Afterwards await the future again. + /// + /// This is helpful to ensure a future is in a pending state before continuing with the test setup. + async fn ensure_pending(self, barrier: Arc) -> Self::Out; +} + +#[async_trait] +impl EnsurePendingExt for F +where + F: Future + Send + Unpin, +{ + type Out = F::Output; + + async fn ensure_pending(self, barrier: Arc) -> Self::Out { + let mut fut = self.fuse(); + futures::select_biased! { + _ = fut => panic!("fut should be pending"), + _ = barrier.wait().fuse() => (), + } + + fut.await + } +} + +#[async_trait] +trait AbortAndWaitExt { + /// Abort handle and wait for completion. + /// + /// Note that this is NOT just a "wait with timeout or panic". This extension is specific to [`JoinHandle`] and will: + /// + /// 1. Call [`JoinHandle::abort`]. + /// 2. Await the [`JoinHandle`] with a timeout (or panic if the timeout is reached). + /// 3. Check that the handle returned a [`JoinError`] that signals that the tracked task was indeed cancelled and + /// didn't exit otherwise (either by finishing or by panicking). + async fn abort_and_wait(self); +} + +#[async_trait] +impl AbortAndWaitExt for JoinHandle +where + T: std::fmt::Debug + Send, +{ + async fn abort_and_wait(mut self) { + self.abort(); + + let join_err = tokio::time::timeout(Duration::from_secs(1), self) + .await + .expect("no timeout") + .expect_err("handle was aborted and therefore MUST fail"); + assert!(join_err.is_cancelled()); + } +}