diff --git a/data_types/src/router.rs b/data_types/src/router.rs index 280d00b46c..5b1ff68a27 100644 --- a/data_types/src/router.rs +++ b/data_types/src/router.rs @@ -19,6 +19,12 @@ impl ShardId { } } +impl std::fmt::Display for ShardId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ShardId({})", self.get()) + } +} + /// ShardConfig defines rules for assigning a line/row to an individual /// host or a group of hosts. A shard /// is a logical concept, but the usage is meant to split data into diff --git a/influxdb_iox/src/commands/run/router.rs b/influxdb_iox/src/commands/run/router.rs index ba9f1b0f12..b0066976c6 100644 --- a/influxdb_iox/src/commands/run/router.rs +++ b/influxdb_iox/src/commands/run/router.rs @@ -15,6 +15,7 @@ use crate::{ use router::{resolver::RemoteTemplate, server::RouterServer}; use structopt::StructOpt; use thiserror::Error; +use time::SystemProvider; #[derive(Debug, Error)] pub enum Error { @@ -63,10 +64,15 @@ pub async fn command(config: Config) -> Result<()> { let common_state = CommonServerState::from_config(config.run_config.clone())?; let remote_template = config.remote_template.map(RemoteTemplate::new); - let router_server = Arc::new(RouterServer::new( - remote_template, - common_state.trace_collector(), - )); + let time_provider = Arc::new(SystemProvider::new()); + let router_server = Arc::new( + RouterServer::new( + remote_template, + common_state.trace_collector(), + time_provider, + ) + .await, + ); let server_type = Arc::new(RouterServerType::new(router_server, &common_state)); Ok(influxdb_ioxd::main(common_state, server_type).await?) diff --git a/router/src/lib.rs b/router/src/lib.rs index 4e7a0cca78..23c8ef5284 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -14,3 +14,4 @@ pub mod resolver; pub mod router; pub mod server; pub mod sharder; +pub mod write_sink; diff --git a/router/src/router.rs b/router/src/router.rs index af39098b6d..6571f1f830 100644 --- a/router/src/router.rs +++ b/router/src/router.rs @@ -1,16 +1,91 @@ -use data_types::router::Router as RouterConfig; +use std::{ + collections::{BTreeMap, HashMap}, + fmt::Write, + sync::Arc, +}; + +use data_types::router::{Router as RouterConfig, ShardId}; +use mutable_batch::DbWrite; +use snafu::{ResultExt, Snafu}; + +use crate::{ + connection_pool::ConnectionPool, resolver::Resolver, sharder::shard_write, + write_sink::WriteSinkSet, +}; + +#[derive(Debug, Snafu)] +pub enum WriteErrorShard { + #[snafu(display("Did not find sink set for shard ID {}", shard_id.get()))] + NoSinkSetFound { shard_id: ShardId }, + + #[snafu(display("Write to sink set failed: {}", source))] + SinkSetFailure { source: crate::write_sink::Error }, +} + +#[derive(Debug, Snafu)] +pub enum WriteError { + #[snafu(display("One or more writes failed: {}", fmt_write_errors(errors)))] + MultiWriteFailure { + errors: BTreeMap, + }, +} + +fn fmt_write_errors(errors: &BTreeMap) -> String { + const MAX_ERRORS: usize = 2; + + let mut out = String::new(); + + for (shard_id, error) in errors.iter().take(MAX_ERRORS) { + if !out.is_empty() { + write!(&mut out, ", ").expect("write to string failed?!"); + } + write!(&mut out, "{} => \"{}\"", shard_id, error).expect("write to string failed?!"); + } + + if errors.len() > MAX_ERRORS { + write!(&mut out, "...").expect("write to string failed?!"); + } + + out +} /// Router for a single database. #[derive(Debug)] pub struct Router { /// Router config. config: RouterConfig, + + /// We use a [`HashMap`] here for `O(1)` lookups. Do not rely on the iteration order. + write_sink_sets: HashMap, } impl Router { /// Create new router from config. - pub fn new(config: RouterConfig) -> Self { - Self { config } + pub fn new( + config: RouterConfig, + resolver: Arc, + connection_pool: Arc, + ) -> Self { + let write_sink_sets = config + .write_sinks + .iter() + .map(|(shard_id, set_config)| { + ( + *shard_id, + WriteSinkSet::new( + &config.name, + set_config.clone(), + Arc::clone(&resolver), + Arc::clone(&connection_pool), + ), + ) + }) + .collect(); + + Self { + config, + write_sink_sets, + } } /// Router config. @@ -24,22 +99,252 @@ impl Router { pub fn name(&self) -> &str { &self.config.name } + + /// Shard and write data. + pub async fn write(&self, write: DbWrite) -> Result<(), WriteError> { + let mut errors: BTreeMap = Default::default(); + + // The iteration order is stable here due to the [`BTreeMap`], so we ensure deterministic behavior and error order. + let sharded: BTreeMap<_, _> = shard_write(&write, &self.config.write_sharder); + for (shard_id, write) in sharded { + if let Err(e) = self.write_shard(shard_id, &write).await { + errors.insert(shard_id, e); + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(WriteError::MultiWriteFailure { errors }) + } + } + + async fn write_shard(&self, shard_id: ShardId, write: &DbWrite) -> Result<(), WriteErrorShard> { + match self.write_sink_sets.get(&shard_id) { + Some(sink_set) => sink_set.write(write).await.context(SinkSetFailure), + None => Err(WriteErrorShard::NoSinkSetFound { shard_id }), + } + } } #[cfg(test)] mod tests { + use crate::{grpc_client::MockClient, resolver::RemoteTemplate}; + use super::*; - #[test] - fn test_getters() { + use data_types::{ + router::{ + Matcher, MatcherToShard, ShardConfig, WriteSink as WriteSinkConfig, + WriteSinkSet as WriteSinkSetConfig, WriteSinkVariant as WriteSinkVariantConfig, + }, + sequence::Sequence, + server_id::ServerId, + }; + use mutable_batch::WriteMeta; + use mutable_batch_lp::lines_to_batches; + use regex::Regex; + use time::Time; + + #[tokio::test] + async fn test_getters() { + let resolver = Arc::new(Resolver::new(None)); + let connection_pool = Arc::new(ConnectionPool::new_testing().await); + let cfg = RouterConfig { name: String::from("my_router"), write_sharder: Default::default(), write_sinks: Default::default(), query_sinks: Default::default(), }; - let router = Router::new(cfg.clone()); + let router = Router::new(cfg.clone(), resolver, connection_pool); + assert_eq!(router.config(), &cfg); assert_eq!(router.name(), "my_router"); } + + #[tokio::test] + async fn test_write() { + let server_id_1 = ServerId::try_from(1).unwrap(); + let server_id_2 = ServerId::try_from(2).unwrap(); + let server_id_3 = ServerId::try_from(3).unwrap(); + + let resolver = Arc::new(Resolver::new(Some(RemoteTemplate::new("{id}")))); + let connection_pool = Arc::new(ConnectionPool::new_testing().await); + + let client_1 = connection_pool.grpc_client("1").await.unwrap(); + let client_2 = connection_pool.grpc_client("2").await.unwrap(); + let client_3 = connection_pool.grpc_client("3").await.unwrap(); + let client_1 = client_1.as_any().downcast_ref::().unwrap(); + let client_2 = client_2.as_any().downcast_ref::().unwrap(); + let client_3 = client_3.as_any().downcast_ref::().unwrap(); + + let cfg = RouterConfig { + name: String::from("my_router"), + write_sharder: ShardConfig { + specific_targets: vec![ + MatcherToShard { + matcher: Matcher { + table_name_regex: Some(Regex::new("foo_bar").unwrap()), + }, + shard: ShardId::new(10), + }, + MatcherToShard { + matcher: Matcher { + table_name_regex: Some(Regex::new("foo_three").unwrap()), + }, + shard: ShardId::new(30), + }, + MatcherToShard { + matcher: Matcher { + table_name_regex: Some(Regex::new("foo_.*").unwrap()), + }, + shard: ShardId::new(20), + }, + MatcherToShard { + matcher: Matcher { + table_name_regex: Some(Regex::new("doom").unwrap()), + }, + shard: ShardId::new(40), + }, + MatcherToShard { + matcher: Matcher { + table_name_regex: Some(Regex::new("nooo").unwrap()), + }, + shard: ShardId::new(50), + }, + MatcherToShard { + matcher: Matcher { + table_name_regex: Some(Regex::new(".*").unwrap()), + }, + shard: ShardId::new(20), + }, + ], + hash_ring: None, + }, + write_sinks: BTreeMap::from([ + ( + ShardId::new(10), + WriteSinkSetConfig { + sinks: vec![WriteSinkConfig { + sink: WriteSinkVariantConfig::GrpcRemote(server_id_1), + ignore_errors: false, + }], + }, + ), + ( + ShardId::new(20), + WriteSinkSetConfig { + sinks: vec![WriteSinkConfig { + sink: WriteSinkVariantConfig::GrpcRemote(server_id_2), + ignore_errors: false, + }], + }, + ), + ( + ShardId::new(30), + WriteSinkSetConfig { + sinks: vec![WriteSinkConfig { + sink: WriteSinkVariantConfig::GrpcRemote(server_id_3), + ignore_errors: false, + }], + }, + ), + ]), + query_sinks: Default::default(), + }; + let router = Router::new(cfg.clone(), resolver, connection_pool); + + // clean write + let meta_1 = WriteMeta::sequenced( + Sequence::new(1, 2), + Time::from_timestamp_nanos(1337), + None, + 10, + ); + let write_1 = db_write( + &["foo_x x=2 2", "foo_bar x=1 1", "foo_y x=3 3", "www x=4 4"], + &meta_1, + ); + router.write(write_1).await.unwrap(); + client_1.assert_writes(&[( + String::from("my_router"), + db_write(&["foo_bar x=1 1"], &meta_1), + )]); + client_2.assert_writes(&[( + String::from("my_router"), + db_write(&["foo_x x=2 2", "foo_y x=3 3", "www x=4 4"], &meta_1), + )]); + + // write w/ errors + client_2.poison(); + let meta_2 = WriteMeta::sequenced( + Sequence::new(3, 4), + Time::from_timestamp_nanos(42), + None, + 20, + ); + let write_2 = db_write( + &[ + "foo_bar x=5 5", + "doom x=6 6", + "foo_bar x=7 7", + "www x=8 8", + "foo_bar x=9 9", + "nooo x=10 10", + "foo_bar x=11 11", + "foo_three x=12 12", + "doom x=13 13", + "foo_three x=14 14", + "www x=15 15", + "foo_three x=16 16", + "nooo x=17 17", + "foo_three x=18 18", + ], + &meta_2, + ); + let err = router.write(write_2).await.unwrap_err(); + assert_eq!(err.to_string(), "One or more writes failed: ShardId(20) => \"Write to sink set failed: Cannot write: poisened\", ShardId(40) => \"Did not find sink set for shard ID 40\"..."); + client_1.assert_writes(&[ + ( + String::from("my_router"), + db_write(&["foo_bar x=1 1"], &meta_1), + ), + ( + String::from("my_router"), + db_write( + &[ + "foo_bar x=5 5", + "foo_bar x=7 7", + "foo_bar x=9 9", + "foo_bar x=11 11", + ], + &meta_2, + ), + ), + ]); + client_2.assert_writes(&[( + String::from("my_router"), + db_write(&["foo_x x=2 2", "foo_y x=3 3", "www x=4 4"], &meta_1), + )]); + client_3.assert_writes(&[( + String::from("my_router"), + db_write( + &[ + "foo_three x=12 12", + "foo_three x=14 14", + "foo_three x=16 16", + "foo_three x=18 18", + ], + &meta_2, + ), + )]); + } + + fn db_write(lines: &[&str], meta: &WriteMeta) -> DbWrite { + DbWrite::new( + lines_to_batches(&lines.join("\n"), 0).unwrap(), + meta.clone(), + ) + } } diff --git a/router/src/server.rs b/router/src/server.rs index a47f91ab96..1438c65963 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,9 +4,12 @@ use data_types::{router::Router as RouterConfig, server_id::ServerId}; use metric::Registry as MetricRegistry; use parking_lot::RwLock; use snafu::Snafu; +use time::TimeProvider; use trace::TraceCollector; +use write_buffer::config::WriteBufferConfigFactory; use crate::{ + connection_pool::ConnectionPool, resolver::{RemoteTemplate, Resolver}, router::Router, }; @@ -26,12 +29,14 @@ pub struct RouterServer { trace_collector: Option>, routers: RwLock>>, resolver: Arc, + connection_pool: Arc, } impl RouterServer { - pub fn new( + pub async fn new( remote_template: Option, trace_collector: Option>, + time_provider: Arc, ) -> Self { let metric_registry = Arc::new(metric::Registry::new()); @@ -41,6 +46,9 @@ impl RouterServer { trace_collector, routers: Default::default(), resolver: Arc::new(Resolver::new(remote_template)), + connection_pool: Arc::new( + ConnectionPool::new(false, WriteBufferConfigFactory::new(time_provider)).await, + ), } } @@ -86,7 +94,11 @@ impl RouterServer { /// /// Returns `true` if the router already existed. pub fn update_router(&self, config: RouterConfig) -> bool { - let router = Router::new(config); + let router = Router::new( + config, + Arc::clone(&self.resolver), + Arc::clone(&self.connection_pool), + ); self.routers .write() .insert(router.name().to_string(), Arc::new(router)) @@ -96,8 +108,15 @@ impl RouterServer { /// Delete router. /// /// Returns `true` if the router existed. - pub fn delete_router(&self, name: &str) -> bool { - self.routers.write().remove(name).is_some() + pub fn delete_router(&self, router_name: &str) -> bool { + self.routers.write().remove(router_name).is_some() + } + + /// Get registered router, if any. + /// + /// The router name is identical to the database for which this router handles data. + pub fn router(&self, router_name: &str) -> Option> { + self.routers.read().get(router_name).cloned() } /// Resolver associated with this server. @@ -107,10 +126,14 @@ impl RouterServer { } pub mod test_utils { + use std::sync::Arc; + + use time::SystemProvider; + use super::RouterServer; - pub fn make_router_server() -> RouterServer { - RouterServer::new(None, None) + pub async fn make_router_server() -> RouterServer { + RouterServer::new(None, None, Arc::new(SystemProvider::new())).await } } @@ -122,13 +145,13 @@ mod tests { use super::*; - #[test] - fn test_server_id() { + #[tokio::test] + async fn test_server_id() { let id13 = ServerId::try_from(13).unwrap(); let id42 = ServerId::try_from(42).unwrap(); // server starts w/o any ID - let server = make_router_server(); + let server = make_router_server().await; assert_eq!(server.server_id(), None); // setting ID @@ -144,9 +167,9 @@ mod tests { assert!(matches!(err, SetServerIdError::AlreadySet { .. })); } - #[test] - fn test_router_crud() { - let server = make_router_server(); + #[tokio::test] + async fn test_router_crud() { + let server = make_router_server().await; let cfg_foo_1 = RouterConfig { name: String::from("foo"), @@ -180,6 +203,8 @@ mod tests { assert_eq!(routers.len(), 2); assert_eq!(routers[0].config(), &cfg_bar); assert_eq!(routers[1].config(), &cfg_foo_1); + assert_eq!(server.router("bar").unwrap().config(), &cfg_bar); + assert_eq!(server.router("foo").unwrap().config(), &cfg_foo_1); // update router assert!(server.update_router(cfg_foo_2.clone())); @@ -187,12 +212,18 @@ mod tests { assert_eq!(routers.len(), 2); assert_eq!(routers[0].config(), &cfg_bar); assert_eq!(routers[1].config(), &cfg_foo_2); + assert_eq!(server.router("bar").unwrap().config(), &cfg_bar); + assert_eq!(server.router("foo").unwrap().config(), &cfg_foo_2); // delete routers assert!(server.delete_router("foo")); let routers = server.routers(); assert_eq!(routers.len(), 1); assert_eq!(routers[0].config(), &cfg_bar); + assert_eq!(server.router("bar").unwrap().config(), &cfg_bar); + assert!(server.router("foo").is_none()); + + // deleting router a 2nd time works assert!(!server.delete_router("foo")); } } diff --git a/router/src/write_sink.rs b/router/src/write_sink.rs new file mode 100644 index 0000000000..360a824f6f --- /dev/null +++ b/router/src/write_sink.rs @@ -0,0 +1,368 @@ +use std::sync::Arc; + +use data_types::{ + router::{ + WriteSink as WriteSinkConfig, WriteSinkSet as WriteSinkSetConfig, + WriteSinkVariant as WriteSinkVariantConfig, + }, + server_id::ServerId, + write_buffer::WriteBufferConnection, +}; +use mutable_batch::DbWrite; +use snafu::{OptionExt, ResultExt, Snafu}; + +use crate::{ + connection_pool::{ConnectionError, ConnectionPool}, + resolver::Resolver, +}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("No remote for server ID {}", server_id))] + NoRemote { server_id: ServerId }, + + #[snafu(display("Cannot connect: {}", source))] + ConnectionFailure { source: ConnectionError }, + + #[snafu(display("Cannot write: {}", source))] + WriteFailure { + source: Box, + }, +} + +#[derive(Debug)] +struct VariantGrpcRemote { + db_name: String, + server_id: ServerId, + resolver: Arc, + connection_pool: Arc, +} + +impl VariantGrpcRemote { + fn new( + db_name: String, + server_id: ServerId, + resolver: Arc, + connection_pool: Arc, + ) -> Self { + Self { + db_name, + server_id, + resolver, + connection_pool, + } + } + + async fn write(&self, write: &DbWrite) -> Result<(), Error> { + let connection_string = self + .resolver + .resolve_remote(self.server_id) + .context(NoRemote { + server_id: self.server_id, + })?; + let client = self + .connection_pool + .grpc_client(&connection_string) + .await + .context(ConnectionFailure)?; + client + .write(&self.db_name, write) + .await + .context(WriteFailure) + } +} + +#[derive(Debug)] +struct VariantWriteBuffer { + db_name: String, + write_buffer_cfg: WriteBufferConnection, + connection_pool: Arc, +} + +impl VariantWriteBuffer { + fn new( + db_name: String, + write_buffer_cfg: WriteBufferConnection, + connection_pool: Arc, + ) -> Self { + Self { + db_name, + write_buffer_cfg, + connection_pool, + } + } + + async fn write(&self, write: &DbWrite) -> Result<(), Error> { + let write_buffer = self + .connection_pool + .write_buffer_producer(&self.db_name, &self.write_buffer_cfg) + .await + .context(ConnectionFailure)?; + + // TODO(marco): use multiple sequencers + write_buffer + .store_write(0, write) + .await + .context(WriteFailure)?; + + Ok(()) + } +} + +#[derive(Debug)] +enum WriteSinkVariant { + /// Send write to a remote server via gRPC + GrpcRemote(VariantGrpcRemote), + + /// Send write to a write buffer (which may be backed by kafka, local disk, etc) + WriteBuffer(VariantWriteBuffer), +} + +/// Write sink abstraction. +#[derive(Debug)] +pub struct WriteSink { + ignore_errors: bool, + variant: WriteSinkVariant, +} + +impl WriteSink { + pub fn new( + db_name: &str, + config: WriteSinkConfig, + resolver: Arc, + connection_pool: Arc, + ) -> Self { + let variant = match config.sink { + WriteSinkVariantConfig::GrpcRemote(server_id) => WriteSinkVariant::GrpcRemote( + VariantGrpcRemote::new(db_name.to_string(), server_id, resolver, connection_pool), + ), + WriteSinkVariantConfig::WriteBuffer(write_buffer_cfg) => WriteSinkVariant::WriteBuffer( + VariantWriteBuffer::new(db_name.to_string(), write_buffer_cfg, connection_pool), + ), + }; + + Self { + ignore_errors: config.ignore_errors, + variant, + } + } + + pub async fn write(&self, write: &DbWrite) -> Result<(), Error> { + let res = match &self.variant { + WriteSinkVariant::GrpcRemote(v) => v.write(write).await, + WriteSinkVariant::WriteBuffer(v) => v.write(write).await, + }; + + match res { + Ok(()) => Ok(()), + Err(_) if self.ignore_errors => Ok(()), + e => e, + } + } +} + +/// A set of [`WriteSink`]s. +#[derive(Debug)] +pub struct WriteSinkSet { + sinks: Vec, +} + +impl WriteSinkSet { + /// Create new set from config. + pub fn new( + db_name: &str, + config: WriteSinkSetConfig, + resolver: Arc, + connection_pool: Arc, + ) -> Self { + Self { + sinks: config + .sinks + .into_iter() + .map(|sink_config| { + WriteSink::new( + db_name, + sink_config, + Arc::clone(&resolver), + Arc::clone(&connection_pool), + ) + }) + .collect(), + } + } + + /// Write to sinks. Fails on first error. + pub async fn write(&self, write: &DbWrite) -> Result<(), Error> { + for sink in &self.sinks { + sink.write(write).await?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use data_types::write_buffer::WriteBufferDirection; + use mutable_batch_lp::lines_to_batches; + use time::SystemProvider; + use write_buffer::config::WriteBufferConfigFactory; + + use crate::grpc_client::MockClient; + + use super::*; + + #[tokio::test] + async fn test_write_sink_error_handling() { + let server_id = ServerId::try_from(1).unwrap(); + + let resolver = Arc::new(Resolver::new(None)); + resolver.update_remote(server_id, String::from("1.2.3.4")); + + let time_provider = Arc::new(SystemProvider::new()); + let wb_factory = WriteBufferConfigFactory::new(time_provider); + wb_factory.register_always_fail_mock(String::from("failing_wb")); + let connection_pool = Arc::new(ConnectionPool::new(true, wb_factory).await); + + let client_grpc = connection_pool.grpc_client("1.2.3.4").await.unwrap(); + let client_grpc = client_grpc.as_any().downcast_ref::().unwrap(); + client_grpc.poison(); + + let write = DbWrite::new( + lines_to_batches("foo x=1 1", 0).unwrap(), + Default::default(), + ); + + // gRPC, do NOT ignore errors + let config = WriteSinkConfig { + sink: WriteSinkVariantConfig::GrpcRemote(server_id), + ignore_errors: false, + }; + let sink = WriteSink::new( + "my_db", + config, + Arc::clone(&resolver), + Arc::clone(&connection_pool), + ); + sink.write(&write).await.unwrap_err(); + + // gRPC, ignore errors + let config = WriteSinkConfig { + sink: WriteSinkVariantConfig::GrpcRemote(server_id), + ignore_errors: true, + }; + let sink = WriteSink::new( + "my_db", + config, + Arc::clone(&resolver), + Arc::clone(&connection_pool), + ); + sink.write(&write).await.unwrap(); + + // write buffer, do NOT ignore errors + let write_buffer_cfg = WriteBufferConnection { + direction: WriteBufferDirection::Write, + type_: String::from("mock"), + connection: String::from("failing_wb"), + ..Default::default() + }; + let config = WriteSinkConfig { + sink: WriteSinkVariantConfig::WriteBuffer(write_buffer_cfg.clone()), + ignore_errors: false, + }; + let sink = WriteSink::new( + "my_db", + config, + Arc::clone(&resolver), + Arc::clone(&connection_pool), + ); + sink.write(&write).await.unwrap_err(); + + // write buffer, ignore errors + let config = WriteSinkConfig { + sink: WriteSinkVariantConfig::WriteBuffer(write_buffer_cfg), + ignore_errors: true, + }; + let sink = WriteSink::new( + "my_db", + config, + Arc::clone(&resolver), + Arc::clone(&connection_pool), + ); + sink.write(&write).await.unwrap(); + } + + #[tokio::test] + async fn test_write_sink_set() { + let server_id_1 = ServerId::try_from(1).unwrap(); + let server_id_2 = ServerId::try_from(2).unwrap(); + let server_id_3 = ServerId::try_from(3).unwrap(); + + let resolver = Arc::new(Resolver::new(None)); + resolver.update_remote(server_id_1, String::from("1")); + resolver.update_remote(server_id_2, String::from("2")); + resolver.update_remote(server_id_3, String::from("3")); + + let connection_pool = Arc::new(ConnectionPool::new_testing().await); + + let client_1 = connection_pool.grpc_client("1").await.unwrap(); + let client_2 = connection_pool.grpc_client("2").await.unwrap(); + let client_3 = connection_pool.grpc_client("3").await.unwrap(); + let client_1 = client_1.as_any().downcast_ref::().unwrap(); + let client_2 = client_2.as_any().downcast_ref::().unwrap(); + let client_3 = client_3.as_any().downcast_ref::().unwrap(); + + let sink_set = WriteSinkSet::new( + "my_db", + WriteSinkSetConfig { + sinks: vec![ + WriteSinkConfig { + sink: WriteSinkVariantConfig::GrpcRemote(server_id_1), + ignore_errors: false, + }, + WriteSinkConfig { + sink: WriteSinkVariantConfig::GrpcRemote(server_id_2), + ignore_errors: false, + }, + WriteSinkConfig { + sink: WriteSinkVariantConfig::GrpcRemote(server_id_3), + ignore_errors: false, + }, + ], + }, + resolver, + connection_pool, + ); + + let write_1 = DbWrite::new( + lines_to_batches("foo x=1 1", 0).unwrap(), + Default::default(), + ); + sink_set.write(&write_1).await.unwrap(); + + let writes_1 = [(String::from("my_db"), write_1.clone())]; + client_1.assert_writes(&writes_1); + client_2.assert_writes(&writes_1); + client_3.assert_writes(&writes_1); + + client_2.poison(); + + let write_2 = DbWrite::new( + lines_to_batches("foo x=2 2", 0).unwrap(), + Default::default(), + ); + sink_set.write(&write_2).await.unwrap_err(); + + // The sink set stops on first non-ignored error. So + // - client 1 got the new data + // - client 2 failed, but still has the data from the first write + // - client 3 got skipped due to the failure, but still has the data from the first write + let writes_2 = [ + (String::from("my_db"), write_1.clone()), + (String::from("my_db"), write_2.clone()), + ]; + client_1.assert_writes(&writes_2); + client_2.assert_writes(&writes_1); + client_3.assert_writes(&writes_1); + } +}