diff --git a/Cargo.lock b/Cargo.lock index 6a4909737d..e589e37384 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4921,6 +4921,7 @@ dependencies = [ "service_grpc_object_store", "service_grpc_schema", "sharder", + "smallvec", "snafu", "test_helpers", "thiserror", diff --git a/router/Cargo.toml b/router/Cargo.toml index 45cd273c24..6b65a95e80 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,6 +46,7 @@ write_buffer = { path = "../write_buffer" } write_summary = { path = "../write_summary" } tower = { version = "0.4.13", features = ["balance"] } crossbeam-utils = "0.8.15" +smallvec = "1.10.0" [dev-dependencies] assert_matches = "1.5" diff --git a/router/src/dml_handlers/rpc_write.rs b/router/src/dml_handlers/rpc_write.rs index bb9b048f95..589b5208f9 100644 --- a/router/src/dml_handlers/rpc_write.rs +++ b/router/src/dml_handlers/rpc_write.rs @@ -3,10 +3,14 @@ mod circuit_breaker; mod circuit_breaking_client; pub mod client; pub mod lazy_connector; +mod upstream_snapshot; use crate::dml_handlers::rpc_write::client::WriteClient; -use self::{balancer::Balancer, circuit_breaking_client::CircuitBreakingClient}; +use self::{ + balancer::Balancer, circuit_breaking_client::CircuitBreakingClient, + upstream_snapshot::UpstreamSnapshot, +}; use super::{DmlHandler, Partitioned}; use async_trait::async_trait; @@ -90,7 +94,7 @@ impl RpcWrite { #[async_trait] impl DmlHandler for RpcWrite where - C: client::WriteClient + 'static, + C: WriteClient + 'static, { type WriteInput = Partitioned>; type WriteOutput = Vec; @@ -162,12 +166,14 @@ where } async fn write_loop( - mut endpoints: impl Iterator + Send, + mut endpoints: UpstreamSnapshot<'_, T>, req: WriteRequest, ) -> Result<(), RpcWriteError> where T: WriteClient, { + // Infinitely cycle through the snapshot, trying each node in turn until the + // request succeeds or this async call times out. let mut delay = Duration::from_millis(50); loop { match endpoints diff --git a/router/src/dml_handlers/rpc_write/balancer.rs b/router/src/dml_handlers/rpc_write/balancer.rs index 12b7aa26ae..671224f726 100644 --- a/router/src/dml_handlers/rpc_write/balancer.rs +++ b/router/src/dml_handlers/rpc_write/balancer.rs @@ -8,6 +8,7 @@ use tokio::task::JoinHandle; use super::{ circuit_breaker::CircuitBreaker, circuit_breaking_client::{CircuitBreakerState, CircuitBreakingClient}, + upstream_snapshot::UpstreamSnapshot, }; thread_local! { @@ -67,7 +68,7 @@ where /// evaluated at this point and the result is returned to the caller as an /// infinite / cycling iterator. A node that becomes unavailable after the /// snapshot was taken will continue to be returned by the iterator. - pub(super) fn endpoints(&self) -> impl Iterator> { + pub(super) fn endpoints(&self) -> UpstreamSnapshot<'_, CircuitBreakingClient> { // Grab and increment the current counter. let counter = COUNTER.with(|cell| { let mut cell = cell.borrow_mut(); @@ -114,7 +115,7 @@ where } }; - probe.into_iter().chain(healthy).cycle().skip(idx) + UpstreamSnapshot::new(probe.into_iter().chain(healthy), idx) } } @@ -310,7 +311,8 @@ mod tests { } /// A test that ensures only healthy clients are returned by the balancer, - /// and that they are polled exactly once per request. + /// and that they are polled exactly once per call to + /// [`Balancer::endpoints()`]. #[tokio::test] async fn test_balancer_yield_healthy_polled_once() { const BALANCER_CALLS: usize = 10; @@ -377,8 +379,8 @@ mod tests { async fn test_balancer_upstream_recovery() { const BALANCER_CALLS: usize = 10; - // Initialise 3 RPC clients and configure their mock circuit breakers; - // two returns a unhealthy state, one is healthy. + // Initialise a single client and configure its mock circuit breaker to + // return unhealthy. let circuit = Arc::new(MockCircuitBreaker::default()); circuit.set_healthy(false); circuit.set_should_probe(false); @@ -389,12 +391,15 @@ mod tests { let balancer = Balancer::new([client], None); + // The balancer should yield no candidates. let mut endpoints = balancer.endpoints(); assert_matches!(endpoints.next(), None); assert_eq!(circuit.is_healthy_count(), 1); + // Mark the client as healthy. circuit.set_healthy(true); + // A single client should be yielded let mut endpoints = balancer.endpoints(); assert_matches!(endpoints.next(), Some(_)); assert_eq!(circuit.is_healthy_count(), 2); diff --git a/router/src/dml_handlers/rpc_write/circuit_breaking_client.rs b/router/src/dml_handlers/rpc_write/circuit_breaking_client.rs index 8fbbfdd772..04464245d4 100644 --- a/router/src/dml_handlers/rpc_write/circuit_breaking_client.rs +++ b/router/src/dml_handlers/rpc_write/circuit_breaking_client.rs @@ -90,7 +90,7 @@ where } #[async_trait] -impl WriteClient for &CircuitBreakingClient +impl WriteClient for CircuitBreakingClient where T: WriteClient, C: CircuitBreakerState, diff --git a/router/src/dml_handlers/rpc_write/upstream_snapshot.rs b/router/src/dml_handlers/rpc_write/upstream_snapshot.rs new file mode 100644 index 0000000000..d388995624 --- /dev/null +++ b/router/src/dml_handlers/rpc_write/upstream_snapshot.rs @@ -0,0 +1,199 @@ +use smallvec::SmallVec; + +/// An infinite cycling iterator, yielding the 0-indexed `i`-th element first +/// (modulo wrapping). +/// +/// The last yielded element can be removed from the iterator by calling +/// [`UpstreamSnapshot::remove_last_unstable()`]. +#[derive(Debug)] +pub(super) struct UpstreamSnapshot<'a, C> { + clients: SmallVec<[&'a C; 3]>, + idx: usize, +} + +impl<'a, C> UpstreamSnapshot<'a, C> { + /// Initialise a new snapshot, yielding the 0-indexed `i`-th element of + /// `clients` next (or wrapping around if `i` is out-of-bounds). + /// + /// Holds up to 3 elements on the stack; ore than 3 elements will cause an + /// allocation during construction. + pub(super) fn new(clients: impl Iterator, i: usize) -> Self { + Self { + clients: clients.collect(), + // So first call is the ith element even after the inc in next(). + idx: i.wrapping_sub(1), + } + } + + /// Remove the last yielded upstream from this snapshot. + /// + /// # Ordering + /// + /// Calling this method MAY change the order of the yielded elements but + /// MUST maintain equal visit counts across all elements. + /// + /// # Correctness + /// + /// If called before [`UpstreamSnapshot`] has yielded any elements, this MAY + /// remove an arbitrary element from the snapshot. + #[allow(unused)] + pub(super) fn remove_last_unstable(&mut self) { + self.clients.swap_remove(self.idx()); + // Try the element now in the idx position next. + self.idx = self.idx.wrapping_sub(1); + } + + #[inline(always)] + fn idx(&self) -> usize { + self.idx % self.clients.len() + } +} + +impl<'a, C> Iterator for UpstreamSnapshot<'a, C> { + type Item = &'a C; + + fn next(&mut self) -> Option { + if self.clients.is_empty() { + return None; + } + self.idx = self.idx.wrapping_add(1); + Some(self.clients[self.idx()]) + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.clients.len())) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + + use super::*; + + #[test] + fn test_size_hint() { + let elements = [ + AtomicUsize::new(0), + AtomicUsize::new(0), + AtomicUsize::new(0), + ]; + + let mut snap = UpstreamSnapshot::new(elements.iter(), 0); + + let (min, max) = snap.size_hint(); + assert_eq!(min, 0); + assert_eq!(max, Some(3)); + + snap.remove_last_unstable(); // Arbitrary element removed + + let (min, max) = snap.size_hint(); + assert_eq!(min, 0); + assert_eq!(max, Some(2)); + } + + #[test] + fn test_start_index() { + let elements = [1, 2, 3]; + + assert_eq!( + *UpstreamSnapshot::new(elements.iter(), 0) + .next() + .expect("should yield value"), + 1 + ); + assert_eq!( + *UpstreamSnapshot::new(elements.iter(), 1) + .next() + .expect("should yield value"), + 2 + ); + assert_eq!( + *UpstreamSnapshot::new(elements.iter(), 2) + .next() + .expect("should yield value"), + 3 + ); + + // Wraparound + assert_eq!( + *UpstreamSnapshot::new(elements.iter(), 3) + .next() + .expect("should yield value"), + 1 + ); + } + + #[test] + fn test_cycles() { + let elements = [ + AtomicUsize::new(0), + AtomicUsize::new(0), + AtomicUsize::new(0), + ]; + + // Create a snapshot and iterate over it twice. + { + let mut snap = UpstreamSnapshot::new(elements.iter(), 0); + for _ in 0..(elements.len() * 2) { + snap.next() + .expect("should cycle forever") + .fetch_add(1, Ordering::Relaxed); + } + } + + // Assert all elements were visited twice. + elements + .into_iter() + .for_each(|v| assert_eq!(v.load(Ordering::Relaxed), 2)); + } + + #[test] + fn test_remove_element() { + let elements = [1, 2, 3]; + + // First element removed + { + let mut snap = UpstreamSnapshot::new(elements.iter(), 0); + assert_eq!(snap.next(), Some(&1)); + snap.remove_last_unstable(); + assert_eq!(snap.next(), Some(&3)); // Not 2 - unstable remove! + assert_eq!(snap.next(), Some(&2)); + assert_eq!(snap.next(), Some(&3)); + } + + // Second element removed + { + let mut snap = UpstreamSnapshot::new(elements.iter(), 0); + assert_eq!(snap.next(), Some(&1)); + assert_eq!(snap.next(), Some(&2)); + snap.remove_last_unstable(); + assert_eq!(snap.next(), Some(&3)); + assert_eq!(snap.next(), Some(&1)); + assert_eq!(snap.next(), Some(&3)); + } + + // Last element removed + { + let mut snap = UpstreamSnapshot::new(elements.iter(), 0); + assert_eq!(snap.next(), Some(&1)); + assert_eq!(snap.next(), Some(&2)); + assert_eq!(snap.next(), Some(&3)); + snap.remove_last_unstable(); + assert_eq!(snap.next(), Some(&1)); + assert_eq!(snap.next(), Some(&2)); + assert_eq!(snap.next(), Some(&1)); + } + } + + #[test] + fn test_remove_last_element() { + let elements = [42]; + let mut snap = UpstreamSnapshot::new(elements.iter(), 0); + assert_eq!(snap.next(), Some(&42)); + assert_eq!(snap.next(), Some(&42)); + snap.remove_last_unstable(); + assert_eq!(snap.next(), None); + assert_eq!(snap.next(), None); + } +}