From 874325d9cede900d2a909876e677b82b1cf4133e Mon Sep 17 00:00:00 2001 From: Dom Dwyer Date: Mon, 27 Mar 2023 13:31:31 +0200 Subject: [PATCH] refactor(test): generic return iterator This change allows the MockWriteClient to accept more input types and internalises the box dyn bits; I got tired of writing "Box::new()" everywhere. --- router/src/dml_handlers/rpc_write.rs | 111 ++++++++------------ router/src/dml_handlers/rpc_write/client.rs | 11 +- 2 files changed, 47 insertions(+), 75 deletions(-) diff --git a/router/src/dml_handlers/rpc_write.rs b/router/src/dml_handlers/rpc_write.rs index a0e73fc942..2d4bc2499f 100644 --- a/router/src/dml_handlers/rpc_write.rs +++ b/router/src/dml_handlers/rpc_write.rs @@ -482,9 +482,9 @@ mod tests { let input = Partitioned::new(PartitionKey::from("2022-01-01"), batches.clone()); // Init the write handler with a mock client to capture the rpc calls. - let client1 = Arc::new(MockWriteClient::default().with_ret(Box::new(iter::once(Err( + let client1 = Arc::new(MockWriteClient::default().with_ret(iter::once(Err( RpcWriteError::Upstream(tonic::Status::internal("")), - ))))); + )))); let client2 = Arc::new(MockWriteClient::default()); let client3 = Arc::new(MockWriteClient::default()); let handler = RpcWrite::new( @@ -547,21 +547,14 @@ mod tests { // The first client in line fails the first request, but will succeed // the second try. - let client1 = Arc::new( - MockWriteClient::default().with_ret(Box::new( - [ - Err(RpcWriteError::Upstream(tonic::Status::internal(""))), - Ok(()), - ] - .into_iter(), - )), - ); + let client1 = Arc::new(MockWriteClient::default().with_ret([ + Err(RpcWriteError::Upstream(tonic::Status::internal(""))), + Ok(()), + ])); // This client always errors. - let client2 = Arc::new( - MockWriteClient::default().with_ret(Box::new(iter::repeat_with(|| { - Err(RpcWriteError::Upstream(tonic::Status::internal(""))) - }))), - ); + let client2 = Arc::new(MockWriteClient::default().with_ret(iter::repeat_with(|| { + Err(RpcWriteError::Upstream(tonic::Status::internal(""))) + }))); let handler = RpcWrite::new( [ @@ -632,11 +625,9 @@ mod tests { /// error. #[tokio::test] async fn test_write_upstream_error() { - let client_1 = Arc::new( - MockWriteClient::default().with_ret(Box::new(iter::repeat_with(|| { - Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))) - }))), - ); + let client_1 = Arc::new(MockWriteClient::default().with_ret(iter::repeat_with(|| { + Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))) + }))); let circuit_1 = Arc::new(MockCircuitBreaker::default()); circuit_1.set_healthy(true); @@ -656,11 +647,9 @@ mod tests { /// to a user-friendly [`RpcWriteError::NoUpstreams`] for consistency. #[tokio::test] async fn test_write_map_upstream_not_connected_error() { - let client_1 = Arc::new( - MockWriteClient::default().with_ret(Box::new(iter::repeat_with(|| { - Err(RpcWriteError::UpstreamNotConnected("bananas".to_string())) - }))), - ); + let client_1 = Arc::new(MockWriteClient::default().with_ret(iter::repeat_with(|| { + Err(RpcWriteError::UpstreamNotConnected("bananas".to_string())) + }))); let circuit_1 = Arc::new(MockCircuitBreaker::default()); circuit_1.set_healthy(true); @@ -679,19 +668,15 @@ mod tests { #[tokio::test] async fn test_write_not_enough_upstreams_for_replication() { // Initialise two upstreams, 1 healthy, 1 not. - let client_1 = Arc::new( - MockWriteClient::default().with_ret(Box::new(iter::repeat_with(|| { - Err(RpcWriteError::UpstreamNotConnected("bananas".to_string())) - }))), - ); + let client_1 = Arc::new(MockWriteClient::default().with_ret(iter::repeat_with(|| { + Err(RpcWriteError::UpstreamNotConnected("bananas".to_string())) + }))); let circuit_1 = Arc::new(MockCircuitBreaker::default()); circuit_1.set_healthy(true); - let client_2 = Arc::new( - MockWriteClient::default().with_ret(Box::new(iter::repeat_with(|| { - Err(RpcWriteError::UpstreamNotConnected("bananas".to_string())) - }))), - ); + let client_2 = Arc::new(MockWriteClient::default().with_ret(iter::repeat_with(|| { + Err(RpcWriteError::UpstreamNotConnected("bananas".to_string())) + }))); let circuit_2 = Arc::new(MockCircuitBreaker::default()); circuit_2.set_healthy(false); @@ -711,11 +696,11 @@ mod tests { #[tokio::test] async fn test_write_replication_distinct_hosts() { // Initialise two upstreams. - let client_1 = Arc::new(MockWriteClient::default().with_ret(Box::new(iter::once(Ok(()))))); + let client_1 = Arc::new(MockWriteClient::default().with_ret(iter::once(Ok(())))); let circuit_1 = Arc::new(MockCircuitBreaker::default()); circuit_1.set_healthy(true); - let client_2 = Arc::new(MockWriteClient::default().with_ret(Box::new(iter::once(Ok(()))))); + let client_2 = Arc::new(MockWriteClient::default().with_ret(iter::once(Ok(())))); let circuit_2 = Arc::new(MockCircuitBreaker::default()); circuit_2.set_healthy(true); @@ -744,15 +729,13 @@ mod tests { async fn test_write_replication_distinct_hosts_partial_write() { // Initialise two upstreams, 1 willing to ACK a write, and the other // always throwing an error. - let client_1 = Arc::new(MockWriteClient::default().with_ret(Box::new(iter::once(Ok(()))))); + let client_1 = Arc::new(MockWriteClient::default().with_ret(iter::once(Ok(())))); let circuit_1 = Arc::new(MockCircuitBreaker::default()); circuit_1.set_healthy(true); - let client_2 = Arc::new( - MockWriteClient::default().with_ret(Box::new(iter::repeat_with(|| { - Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))) - }))), - ); + let client_2 = Arc::new(MockWriteClient::default().with_ret(iter::repeat_with(|| { + Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))) + }))); let circuit_2 = Arc::new(MockCircuitBreaker::default()); circuit_2.set_healthy(true); @@ -788,19 +771,14 @@ mod tests { async fn test_write_replication_tolerates_temporary_error() { // Initialise two upstreams, 1 willing to ACK a write, and the other // always throwing an error. - let client_1 = Arc::new(MockWriteClient::default().with_ret(Box::new(iter::once(Ok(()))))); + let client_1 = Arc::new(MockWriteClient::default().with_ret(iter::once(Ok(())))); let circuit_1 = Arc::new(MockCircuitBreaker::default()); circuit_1.set_healthy(true); - let client_2 = Arc::new( - MockWriteClient::default().with_ret(Box::new( - [ - Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))), - Ok(()), - ] - .into_iter(), - )), - ); + let client_2 = Arc::new(MockWriteClient::default().with_ret([ + Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))), + Ok(()), + ])); let circuit_2 = Arc::new(MockCircuitBreaker::default()); circuit_2.set_healthy(true); @@ -835,30 +813,23 @@ mod tests { async fn test_write_replication_tolerates_bad_upstream() { // Initialise three upstreams, 1 willing to ACK a write immediately, the // second will error twice, and the third always errors. - let client_1 = Arc::new(MockWriteClient::default().with_ret(Box::new(iter::once(Ok(()))))); + let client_1 = Arc::new(MockWriteClient::default().with_ret(iter::once(Ok(())))); let circuit_1 = Arc::new(MockCircuitBreaker::default()); circuit_1.set_healthy(true); // This client sometimes errors (2 times) - let client_2 = Arc::new( - MockWriteClient::default().with_ret(Box::new( - [ - Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))), - Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))), - Ok(()), - ] - .into_iter(), - )), - ); + let client_2 = Arc::new(MockWriteClient::default().with_ret([ + Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))), + Err(RpcWriteError::Upstream(tonic::Status::internal("bananas"))), + Ok(()), + ])); let circuit_2 = Arc::new(MockCircuitBreaker::default()); circuit_2.set_healthy(true); // This client always errors - let client_3 = Arc::new( - MockWriteClient::default().with_ret(Box::new(iter::repeat_with(|| { - Err(RpcWriteError::UpstreamNotConnected("bananas".to_string())) - }))), - ); + let client_3 = Arc::new(MockWriteClient::default().with_ret(iter::repeat_with(|| { + Err(RpcWriteError::UpstreamNotConnected("bananas".to_string())) + }))); let circuit_3 = Arc::new(MockCircuitBreaker::default()); circuit_3.set_healthy(true); diff --git a/router/src/dml_handlers/rpc_write/client.rs b/router/src/dml_handlers/rpc_write/client.rs index 33d087cf78..3bdd4155d6 100644 --- a/router/src/dml_handlers/rpc_write/client.rs +++ b/router/src/dml_handlers/rpc_write/client.rs @@ -78,11 +78,12 @@ pub mod mock { /// Read values off of the provided iterator and return them for calls /// to [`Self::write()`]. #[cfg(test)] - pub(crate) fn with_ret( - self, - ret: Box> + Send + Sync>, - ) -> Self { - self.state.lock().ret = ret; + pub(crate) fn with_ret(self, ret: T) -> Self + where + T: IntoIterator, + U: Iterator> + Send + Sync + 'static, + { + self.state.lock().ret = Box::new(ret.into_iter()); self } }