diff --git a/router2/src/dml_handlers/mock.rs b/router2/src/dml_handlers/mock.rs index b4f3c0a3fd..5a3e520fc9 100644 --- a/router2/src/dml_handlers/mock.rs +++ b/router2/src/dml_handlers/mock.rs @@ -1,20 +1,20 @@ -use std::{collections::VecDeque, sync::Arc}; +use std::{collections::VecDeque, fmt::Debug, sync::Arc}; use async_trait::async_trait; use data_types::{delete_predicate::DeletePredicate, DatabaseName}; -use hashbrown::HashMap; -use mutable_batch::MutableBatch; use parking_lot::Mutex; use trace::ctx::SpanContext; use super::{DmlError, DmlHandler}; +/// A captured call to a [`MockDmlHandler`], generic over `W`, the captured +/// [`DmlHandler::WriteInput`] type. #[derive(Debug, Clone)] -pub enum MockDmlHandlerCall { +pub enum MockDmlHandlerCall { Write { namespace: String, - batches: HashMap, + write_input: W, }, Delete { namespace: String, @@ -23,23 +23,42 @@ pub enum MockDmlHandlerCall { }, } -#[derive(Debug, Default)] -struct Inner { - calls: Vec, +#[derive(Debug)] +struct Inner { + calls: Vec>, write_return: VecDeque>, delete_return: VecDeque>, } -impl Inner { - fn record_call(&mut self, call: MockDmlHandlerCall) { +impl Default for Inner { + fn default() -> Self { + Self { + calls: Default::default(), + write_return: Default::default(), + delete_return: Default::default(), + } + } +} + +impl Inner { + fn record_call(&mut self, call: MockDmlHandlerCall) { self.calls.push(call); } } -#[derive(Debug, Default)] -pub struct MockDmlHandler(Mutex); +#[derive(Debug)] +pub struct MockDmlHandler(Mutex>); -impl MockDmlHandler { +impl Default for MockDmlHandler { + fn default() -> Self { + Self(Default::default()) + } +} + +impl MockDmlHandler +where + W: Clone, +{ pub fn with_write_return(self, ret: impl Into>>) -> Self { self.0.lock().write_return = ret.into(); self @@ -50,7 +69,7 @@ impl MockDmlHandler { self } - pub fn calls(&self) -> Vec { + pub fn calls(&self) -> Vec> { self.0.lock().calls.clone() } } @@ -68,22 +87,25 @@ macro_rules! record_and_return { } #[async_trait] -impl DmlHandler for Arc { +impl DmlHandler for Arc> +where + W: Debug + Send + Sync, +{ type WriteError = DmlError; type DeleteError = DmlError; - type WriteInput = HashMap; + type WriteInput = W; async fn write( &self, namespace: DatabaseName<'static>, - batches: Self::WriteInput, + write_input: Self::WriteInput, _span_ctx: Option, ) -> Result<(), Self::WriteError> { record_and_return!( self, MockDmlHandlerCall::Write { namespace: namespace.into(), - batches, + write_input, }, write_return ) diff --git a/router2/src/dml_handlers/schema_validation.rs b/router2/src/dml_handlers/schema_validation.rs index 1a08cef478..e760fed537 100644 --- a/router2/src/dml_handlers/schema_validation.rs +++ b/router2/src/dml_handlers/schema_validation.rs @@ -363,9 +363,9 @@ mod tests { assert_matches!(err, SchemaError::Validate(_)); // THe mock should observe exactly one write from the first call. - assert_matches!(mock.calls().as_slice(), [MockDmlHandlerCall::Write{namespace, batches}] => { + assert_matches!(mock.calls().as_slice(), [MockDmlHandlerCall::Write{namespace, write_input}] => { assert_eq!(namespace, NAMESPACE); - let batch = batches.get("bananas").expect("table not found in write"); + let batch = write_input.get("bananas").expect("table not found in write"); assert_eq!(batch.rows(), 1); let col = batch.column("val").expect("column not found in write"); assert_matches!(col.influx_type(), InfluxColumnType::Field(InfluxFieldType::Integer));