diff --git a/Cargo.lock b/Cargo.lock index 767f5526d4..e2f1b09b8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3213,6 +3213,7 @@ dependencies = [ "influxdb_line_protocol", "mutable_buffer", "object_store", + "pin-project 1.0.5", "query", "read_buffer", "serde", diff --git a/server/Cargo.toml b/server/Cargo.toml index c4b39c8748..7667793e20 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -17,6 +17,7 @@ generated_types = { path = "../generated_types" } influxdb_line_protocol = { path = "../influxdb_line_protocol" } mutable_buffer = { path = "../mutable_buffer" } object_store = { path = "../object_store" } +pin-project = "1.0" query = { path = "../query" } read_buffer = { path = "../read_buffer" } serde = "1.0" diff --git a/server/src/buffer.rs b/server/src/buffer.rs index 19ac5d7031..1aeb468977 100644 --- a/server/src/buffer.rs +++ b/server/src/buffer.rs @@ -16,6 +16,7 @@ use std::{ }; //use byteorder::{ByteOrder, LittleEndian, WriteBytesExt}; +use crate::tracker::{TrackedFutureExt, TrackerRegistry}; use bytes::Bytes; use chrono::{DateTime, Utc}; use crc32fast::Hasher; @@ -72,6 +73,12 @@ pub enum Error { InvalidFlatbuffersSegment, } +#[derive(Debug, Clone)] +pub struct SegmentPersistenceTask { + writer_id: u32, + location: object_store::path::Path, +} + pub type Result = std::result::Result; /// An in-memory buffer of a write ahead log. It is split up into segments, @@ -369,6 +376,7 @@ impl Segment { /// the given object store location. pub fn persist_bytes_in_background( &self, + reg: &TrackerRegistry, writer_id: u32, db_name: &DatabaseName<'_>, store: Arc, @@ -377,29 +385,37 @@ impl Segment { let location = database_object_store_path(writer_id, db_name, &store); let location = object_store_path_for_segment(&location, self.id)?; + let task_meta = SegmentPersistenceTask { + writer_id, + location: location.clone(), + }; + let len = data.len(); let mut stream_data = std::io::Result::Ok(data.clone()); - tokio::task::spawn(async move { - while let Err(err) = store - .put( - &location, - futures::stream::once(async move { stream_data }), - len, - ) - .await - { - error!("error writing bytes to store: {}", err); - tokio::time::sleep(tokio::time::Duration::from_secs( - super::STORE_ERROR_PAUSE_SECONDS, - )) - .await; - stream_data = std::io::Result::Ok(data.clone()); - } + tokio::task::spawn( + async move { + while let Err(err) = store + .put( + &location, + futures::stream::once(async move { stream_data }), + len, + ) + .await + { + error!("error writing bytes to store: {}", err); + tokio::time::sleep(tokio::time::Duration::from_secs( + super::STORE_ERROR_PAUSE_SECONDS, + )) + .await; + stream_data = std::io::Result::Ok(data.clone()); + } - // TODO: Mark segment as persisted - info!("persisted data to {}", location.display()); - }); + // TODO: Mark segment as persisted + info!("persisted data to {}", location.display()); + } + .track(reg, task_meta), + ); Ok(()) } diff --git a/server/src/lib.rs b/server/src/lib.rs index f07cfef82e..3705c2872a 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -70,6 +70,7 @@ pub mod buffer; mod config; pub mod db; pub mod snapshot; +mod tracker; #[cfg(test)] mod query_tests; @@ -80,8 +81,10 @@ use std::sync::{ }; use crate::{ + buffer::SegmentPersistenceTask, config::{object_store_path_for_database_config, Config, DB_RULES_FILE_NAME}, db::Db, + tracker::TrackerRegistry, }; use data_types::{ data::{lines_to_replicated_write, ReplicatedWrite}, @@ -154,6 +157,7 @@ pub struct Server { connection_manager: Arc, pub store: Arc, executor: Arc, + segment_persistence_registry: TrackerRegistry, } impl Server { @@ -164,6 +168,7 @@ impl Server { store, connection_manager: Arc::new(connection_manager), executor: Arc::new(Executor::new()), + segment_persistence_registry: TrackerRegistry::new(), } } @@ -356,7 +361,12 @@ impl Server { let writer_id = self.require_id()?; let store = self.store.clone(); segment - .persist_bytes_in_background(writer_id, db_name, store) + .persist_bytes_in_background( + &self.segment_persistence_registry, + writer_id, + db_name, + store, + ) .context(WalError)?; } } diff --git a/server/src/tracker.rs b/server/src/tracker.rs new file mode 100644 index 0000000000..5c4a59ccd2 --- /dev/null +++ b/server/src/tracker.rs @@ -0,0 +1,251 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use futures::prelude::*; +use pin_project::{pin_project, pinned_drop}; + +/// Every future registered with a `TrackerRegistry` is assigned a unique +/// `TrackerId` +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct TrackerId(usize); + +#[derive(Debug)] +struct Tracker { + data: T, + abort: future::AbortHandle, +} + +#[derive(Debug)] +struct TrackerContextInner { + id: AtomicUsize, + trackers: Mutex>>, +} + +/// Allows tracking the lifecycle of futures registered by +/// `TrackedFutureExt::track` with an accompanying metadata payload of type T +/// +/// Additionally can trigger graceful termination of registered futures +#[derive(Debug)] +pub struct TrackerRegistry { + inner: Arc>, +} + +// Manual Clone to workaround https://github.com/rust-lang/rust/issues/26925 +impl Clone for TrackerRegistry { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl Default for TrackerRegistry { + fn default() -> Self { + Self { + inner: Arc::new(TrackerContextInner { + id: AtomicUsize::new(0), + trackers: Mutex::new(Default::default()), + }), + } + } +} + +impl TrackerRegistry { + pub fn new() -> Self { + Default::default() + } + + /// Trigger graceful termination of a registered future + /// + /// Returns false if no future found with the provided ID + /// + /// Note: If the future is currently executing, termination + /// will only occur when the future yields (returns from poll) + #[allow(dead_code)] + pub fn terminate(&self, id: TrackerId) -> bool { + if let Some(meta) = self + .inner + .trackers + .lock() + .expect("lock poisoned") + .get_mut(&id) + { + meta.abort.abort(); + true + } else { + false + } + } + + fn untrack(&self, id: &TrackerId) { + self.inner + .trackers + .lock() + .expect("lock poisoned") + .remove(id); + } + + fn track(&self, metadata: T) -> (TrackerId, future::AbortRegistration) { + let id = TrackerId(self.inner.id.fetch_add(1, Ordering::Relaxed)); + let (abort_handle, abort_registration) = future::AbortHandle::new_pair(); + + self.inner.trackers.lock().expect("lock poisoned").insert( + id, + Tracker { + abort: abort_handle, + data: metadata, + }, + ); + + (id, abort_registration) + } +} + +impl TrackerRegistry { + /// Returns a list of tracked futures, with their accompanying IDs and + /// metadata + #[allow(dead_code)] + pub fn tracked(&self) -> Vec<(TrackerId, T)> { + // TODO: Improve this - (#711) + self.inner + .trackers + .lock() + .expect("lock poisoned") + .iter() + .map(|(id, value)| (*id, value.data.clone())) + .collect() + } +} + +/// An extension trait that provides `self.track(reg, {})` allowing +/// registering this future with a `TrackerRegistry` +pub trait TrackedFutureExt: Future { + fn track(self, reg: &TrackerRegistry, metadata: T) -> TrackedFuture + where + Self: Sized, + { + let (id, registration) = reg.track(metadata); + + TrackedFuture { + inner: future::Abortable::new(self, registration), + reg: reg.clone(), + id, + } + } +} + +impl TrackedFutureExt for T where T: Future {} + +/// The `Future` returned by `TrackedFutureExt::track()` +/// Unregisters the future from the registered `TrackerRegistry` on drop +/// and provides the early termination functionality used by +/// `TrackerRegistry::terminate` +#[pin_project(PinnedDrop)] +pub struct TrackedFuture { + #[pin] + inner: future::Abortable, + + reg: TrackerRegistry, + id: TrackerId, +} + +impl Future for TrackedFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) + } +} + +#[pinned_drop] +impl PinnedDrop for TrackedFuture { + fn drop(self: Pin<&mut Self>) { + // Note: This could cause a double-panic in an extreme situation where + // the internal `TrackerRegistry` lock is poisoned and drop was + // called as part of unwinding the stack to handle another panic + let this = self.project(); + this.reg.untrack(this.id) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::oneshot; + + #[tokio::test] + async fn test_lifecycle() { + let (sender, receive) = oneshot::channel(); + let reg = TrackerRegistry::new(); + + let task = tokio::spawn(receive.track(®, ())); + + assert_eq!(reg.tracked().len(), 1); + + sender.send(()).unwrap(); + task.await.unwrap().unwrap().unwrap(); + + assert_eq!(reg.tracked().len(), 0); + } + + #[tokio::test] + async fn test_interleaved() { + let (sender1, receive1) = oneshot::channel(); + let (sender2, receive2) = oneshot::channel(); + let reg = TrackerRegistry::new(); + + let task1 = tokio::spawn(receive1.track(®, 1)); + let task2 = tokio::spawn(receive2.track(®, 2)); + + let mut tracked: Vec<_> = reg.tracked().iter().map(|x| x.1).collect(); + tracked.sort_unstable(); + assert_eq!(tracked, vec![1, 2]); + + sender2.send(()).unwrap(); + task2.await.unwrap().unwrap().unwrap(); + + let tracked: Vec<_> = reg.tracked().iter().map(|x| x.1).collect(); + assert_eq!(tracked, vec![1]); + + sender1.send(42).unwrap(); + let ret = task1.await.unwrap().unwrap().unwrap(); + + assert_eq!(ret, 42); + assert_eq!(reg.tracked().len(), 0); + } + + #[tokio::test] + async fn test_drop() { + let reg = TrackerRegistry::new(); + + { + let f = futures::future::pending::<()>().track(®, ()); + + assert_eq!(reg.tracked().len(), 1); + + std::mem::drop(f); + } + + assert_eq!(reg.tracked().len(), 0); + } + + #[tokio::test] + async fn test_terminate() { + let reg = TrackerRegistry::new(); + + let task = tokio::spawn(futures::future::pending::<()>().track(®, ())); + + let tracked = reg.tracked(); + assert_eq!(tracked.len(), 1); + + reg.terminate(tracked[0].0); + let result = task.await.unwrap(); + + assert!(result.is_err()); + assert_eq!(reg.tracked().len(), 0); + } +}