From c957d8154fe3b2adfaca4f01131bdc98e37f2b2c Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Sun, 8 Aug 2021 20:26:11 +0100 Subject: [PATCH] feat: blocking Freezable (#2224) * feat: blocking Freezable * chore: test --- Cargo.lock | 2 + internal_types/Cargo.toml | 3 + internal_types/src/freezable.rs | 213 ++++++++++++++++-- .../src/persistence_windows.rs | 4 +- server/src/database.rs | 4 +- server/src/lib.rs | 6 +- 6 files changed, 206 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 84daf2e04d..5b137d9dab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1765,11 +1765,13 @@ version = "0.1.0" dependencies = [ "arrow", "arrow_util", + "futures", "hashbrown", "indexmap", "itertools 0.10.1", "observability_deps", "snafu", + "tokio", ] [[package]] diff --git a/internal_types/Cargo.toml b/internal_types/Cargo.toml index 2eee530504..ae2d66c438 100644 --- a/internal_types/Cargo.toml +++ b/internal_types/Cargo.toml @@ -13,6 +13,9 @@ indexmap = "1.6" itertools = "0.10.1" observability_deps = { path = "../observability_deps" } snafu = "0.6" +tokio = { version = "1.0", features = ["sync"] } [dev-dependencies] arrow_util = { path = "../arrow_util" } +futures = "0.3" +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "time"] } diff --git a/internal_types/src/freezable.rs b/internal_types/src/freezable.rs index 88cc046ecb..5335c806fd 100644 --- a/internal_types/src/freezable.rs +++ b/internal_types/src/freezable.rs @@ -1,5 +1,5 @@ -use std::ops::Deref; -use std::sync::{Arc, Weak}; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; /// A wrapper around a type `T` that can be frozen with `Freezable::try_freeze`, preventing /// modification of the contained `T` until the returned `FreezeHandle` is dropped @@ -44,7 +44,7 @@ use std::sync::{Arc, Weak}; /// /// // Start transaction /// let handle = { -/// let mut locked = lockable.write().unwrap(); +/// let mut locked = lockable.read().unwrap(); /// locked.try_freeze().expect("other transaction in progress") /// }; /// @@ -64,41 +64,90 @@ use std::sync::{Arc, Weak}; /// } /// ``` /// +/// +/// A freeze handle can also be acquired asynchronously +/// +/// ``` +/// use internal_types::freezable::Freezable; +/// use std::sync::RwLock; +/// +/// let rt = tokio::runtime::Builder::new_current_thread().build().unwrap(); +/// +/// rt.block_on(async move { +/// let lockable = RwLock::new(Freezable::new(23)); +/// let fut_handle = lockable.read().unwrap().freeze(); +/// +/// // NB: Only frozen once future resolved to FreezeHandle +/// *lockable.write().unwrap().get_mut().unwrap() = 56; +/// +/// let handle = fut_handle.await; +/// +/// // The contained data now cannot be modified +/// assert!(lockable.write().unwrap().get_mut().is_none()); +/// // But it can still be read +/// assert_eq!(**lockable.read().unwrap(), 56); +/// +/// // -------------- +/// // Do async work +/// // -------------- +/// +/// // Finish transaction +/// *lockable.write().unwrap().unfreeze(handle) = 57; +/// }); +/// ``` +/// #[derive(Debug)] -pub struct Freezable(Arc); +pub struct Freezable { + lock: Arc>, + payload: T, +} impl Freezable { pub fn new(payload: T) -> Self { - Self(Arc::new(payload)) + Self { + lock: Default::default(), + payload, + } } /// Returns a `FreezeHandle` that prevents modification /// of the contents of `Freezable` until it is dropped /// /// Returns None if the object is already frozen - pub fn try_freeze(&mut self) -> Option> { - // Verify exclusive - self.get_mut()?; - Some(FreezeHandle(Arc::downgrade(&self.0))) + pub fn try_freeze(&self) -> Option { + let guard = Arc::clone(&self.lock).try_lock_owned().ok()?; + Some(FreezeHandle(guard)) + } + + /// Returns a future that resolves to a FreezeHandle + pub fn freeze(&self) -> impl std::future::Future { + let captured = Arc::clone(&self.lock); + async move { FreezeHandle(captured.lock_owned().await) } } /// Unfreezes this instance, returning a mutable reference to /// its contained data - pub fn unfreeze(&mut self, handle: FreezeHandle) -> &mut T { + pub fn unfreeze(&mut self, handle: FreezeHandle) -> WriteGuard<'_, T> { assert!( - std::ptr::eq(&*self.0, handle.0.as_ptr()), + Arc::ptr_eq(&self.lock, tokio::sync::OwnedMutexGuard::mutex(&handle.0)), "provided FreezeHandle is not for this instance" ); - std::mem::drop(handle); - // Just dropped `FreezeHandle` so should be valid - self.get_mut().unwrap() + + WriteGuard { + freezable: self, + guard: handle.0, + } } /// Try to get mutable access to the data /// /// Returns `None` if this instance is frozen - pub fn get_mut(&mut self) -> Option<&mut T> { - Arc::get_mut(&mut self.0) + pub fn get_mut(&mut self) -> Option> { + let guard = Arc::clone(&self.lock).try_lock_owned().ok()?; + Some(WriteGuard { + freezable: self, + guard, + }) } } @@ -106,13 +155,47 @@ impl Deref for Freezable { type Target = T; fn deref(&self) -> &Self::Target { - &self.0 + &self.payload } } impl std::fmt::Display for Freezable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) + self.payload.fmt(f) + } +} + +/// The `WriteGuard` provides mutable access to the `Freezable` data whilst ensuring +/// that it remains unfrozen for the duration of the mutable access +#[derive(Debug)] +pub struct WriteGuard<'a, T> { + freezable: &'a mut Freezable, + + /// A locked guard from `Freezable` + /// + /// Ensures nothing can freeze the Freezable whilst this WriteGuard exists + guard: tokio::sync::OwnedMutexGuard<()>, +} + +impl<'a, T> WriteGuard<'a, T> { + /// Converts this `WriteGuard` into a `FreezeHandle` + pub fn freeze(self) -> FreezeHandle { + FreezeHandle(self.guard) + } +} + +impl<'a, T> Deref for WriteGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.freezable.payload + } +} + +impl<'a, T> DerefMut for WriteGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + // This is valid as holding mutex guard + &mut self.freezable.payload } } @@ -127,4 +210,96 @@ impl std::fmt::Display for Freezable { /// to outlive the `&mut Freezable` from which it was created /// #[derive(Debug)] -pub struct FreezeHandle(Weak); +pub struct FreezeHandle(tokio::sync::OwnedMutexGuard<()>); + +#[cfg(test)] +mod tests { + use super::*; + use futures::{future::FutureExt, pin_mut}; + use std::sync::RwLock; + + #[tokio::test] + async fn test_freeze() { + let waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&waker); + + let mut f = Freezable::new(1); + + let freeze_fut = f.freeze(); + pin_mut!(freeze_fut); + + let write_guard = f.get_mut().unwrap(); + + // Shouldn't resolve whilst write guard active + assert!(freeze_fut.poll_unpin(&mut cx).is_pending()); + + std::mem::drop(write_guard); + + // Should resolve once write guard removed + let handle = freeze_fut.now_or_never().unwrap(); + + // Should prevent freezing + assert!(f.try_freeze().is_none()); + assert!(f.get_mut().is_none()); + + // But not acquiring a new future + let freeze_fut = f.freeze(); + pin_mut!(freeze_fut); + + // Future shouldn't complete whilst handle active + assert!(freeze_fut.poll_unpin(&mut cx).is_pending()); + + std::mem::drop(handle); + + // Future should complete once handle dropped + freeze_fut.now_or_never().unwrap(); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn fuzz() { + let count = 1000; + let shared = Arc::new(( + RwLock::new(Freezable::new(0_usize)), + tokio::sync::Barrier::new(count), + )); + + let futures = (0..count).into_iter().map(|i| { + let captured = Arc::clone(&shared); + tokio::spawn(async move { + let (lockable, barrier) = captured.as_ref(); + + // Wait for all tasks to start + barrier.wait().await; + + // Get handle + let fut = lockable.read().unwrap().freeze(); + let handle = fut.await; + + // Start transaction + let handle = { + let mut locked = lockable.write().unwrap(); + let mut guard = locked.unfreeze(handle); + + assert_eq!(*guard, 0); + *guard = i; + + guard.freeze() + }; + + // Do async work + tokio::time::sleep(tokio::time::Duration::from_nanos(10)).await; + + // Commit transaction + { + let mut locked = lockable.write().unwrap(); + let mut guard = locked.unfreeze(handle); + + assert_eq!(*guard, i); + *guard = 0; + } + }) + }); + + futures::future::try_join_all(futures).await.unwrap(); + } +} diff --git a/persistence_windows/src/persistence_windows.rs b/persistence_windows/src/persistence_windows.rs index a19ebf85c7..dd18413f41 100644 --- a/persistence_windows/src/persistence_windows.rs +++ b/persistence_windows/src/persistence_windows.rs @@ -66,7 +66,7 @@ pub struct PersistenceWindows { /// #[derive(Debug)] pub struct FlushHandle { - handle: FreezeHandle>, + handle: FreezeHandle, /// The number of closed windows at the time of the handle's creation /// /// This identifies the windows that can have their @@ -211,7 +211,7 @@ impl PersistenceWindows { // if there is no ongoing persistence operation, try and // add closed windows to the `persistable` window - if let Some(persistable) = self.persistable.get_mut() { + if let Some(mut persistable) = self.persistable.get_mut() { while self .closed .front() diff --git a/server/src/database.rs b/server/src/database.rs index 27a0281057..91babde556 100644 --- a/server/src/database.rs +++ b/server/src/database.rs @@ -179,7 +179,7 @@ impl Database { pub fn wipe_preserved_catalog(&self) -> Result>, Error> { let db_name = &self.shared.config.name; let (current_state, handle) = { - let mut state = self.shared.state.write(); + let state = self.shared.state.read(); let current_state = match &**state { DatabaseState::CatalogLoadError(rules_loaded, _) => rules_loaded.clone(), _ => { @@ -278,7 +278,7 @@ async fn initialize_database(shared: &DatabaseShared) { while !shared.shutdown.is_cancelled() { // Acquire locks and determine if work to be done let maybe_transaction = { - let mut state = shared.state.write(); + let state = shared.state.read(); match &**state { // Already initialized diff --git a/server/src/lib.rs b/server/src/lib.rs index be81a48299..3af797c12e 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -589,8 +589,8 @@ where // Have exclusive lock on state - can drop database creation lock std::mem::drop(guard); - let state = state.get_mut().expect("no transaction in progress"); - let database = match state { + let mut state = state.get_mut().expect("no transaction in progress"); + let database = match &mut *state { ServerState::Initialized(initialized) => { match initialized.databases.entry(db_name.clone()) { hashbrown::hash_map::Entry::Vacant(vacant) => { @@ -633,7 +633,7 @@ where } let (init_ready, handle) = { - let mut state = self.shared.state.write(); + let state = self.shared.state.read(); let init_ready = match &**state { ServerState::Startup(_) => {