From 9567acd62186a0789b3d49a4d6cc4bbe37568b3c Mon Sep 17 00:00:00 2001
From: Marco Neumann <marco@crepererum.net>
Date: Wed, 2 Feb 2022 09:35:54 +0000
Subject: [PATCH] feat: expose all relevant configs for rskafka write buffers
 (#3599)

* feat: expose all relevant configs for rskafka write buffers

* refactor: `CreationConfig` => `TopicCreationConfig`
---
 write_buffer/src/config.rs         |   2 +
 write_buffer/src/rskafka/config.rs | 256 +++++++++++++++++++++++++++++
 write_buffer/src/rskafka/mod.rs    |  87 +++++++---
 3 files changed, 324 insertions(+), 21 deletions(-)
 create mode 100644 write_buffer/src/rskafka/config.rs

diff --git a/write_buffer/src/config.rs b/write_buffer/src/config.rs
index 5a2fb88a3a..0f44e1b256 100644
--- a/write_buffer/src/config.rs
+++ b/write_buffer/src/config.rs
@@ -121,6 +121,7 @@ impl WriteBufferConfigFactory {
                 let rskafa_buffer = RSKafkaProducer::new(
                     cfg.connection.clone(),
                     db_name.to_owned(),
+                    &cfg.connection_config,
                     cfg.creation_config.as_ref(),
                     Arc::clone(&self.time_provider),
                     trace_collector.map(Arc::clone),
@@ -207,6 +208,7 @@ impl WriteBufferConfigFactory {
                 let rskafka_buffer = RSKafkaConsumer::new(
                     cfg.connection.clone(),
                     db_name.to_owned(),
+                    &cfg.connection_config,
                     cfg.creation_config.as_ref(),
                     trace_collector.map(Arc::clone),
                 )
diff --git a/write_buffer/src/rskafka/config.rs b/write_buffer/src/rskafka/config.rs
new file mode 100644
index 0000000000..2d862e73fa
--- /dev/null
+++ b/write_buffer/src/rskafka/config.rs
@@ -0,0 +1,256 @@
+use std::{collections::BTreeMap, time::Duration};
+
+use data_types::write_buffer::WriteBufferCreationConfig;
+
+use crate::core::WriteBufferError;
+
+/// Generic client config that is used for consumers, producers as well as admin operations (like "create topic").
+#[derive(Debug, PartialEq, Eq)]
+pub struct ClientConfig {
+    /// Maximum message size in bytes.
+    ///
+    /// extracted from `max_message_size`. Defaults to `None` (rskafka default).
+    pub max_message_size: Option<usize>,
+}
+
+impl TryFrom<&BTreeMap<String, String>> for ClientConfig {
+    type Error = WriteBufferError;
+
+    fn try_from(cfg: &BTreeMap<String, String>) -> Result<Self, Self::Error> {
+        Ok(Self {
+            max_message_size: cfg.get("max_message_size").map(|s| s.parse()).transpose()?,
+        })
+    }
+}
+
+/// Config for topic creation.
+#[derive(Debug, PartialEq, Eq)]
+pub struct TopicCreationConfig {
+    /// Number of partitions.
+    pub num_partitions: i32,
+
+    /// Replication factor.
+    ///
+    /// Extracted from `replication_factor` option. Defaults to `1`.
+    pub replication_factor: i16,
+
+    /// Timeout in ms.
+    ///
+    /// Extracted from `timeout_ms` option. Defaults to `5_000`.
+    pub timeout_ms: i32,
+}
+
+impl TryFrom<&WriteBufferCreationConfig> for TopicCreationConfig {
+    type Error = WriteBufferError;
+
+    fn try_from(cfg: &WriteBufferCreationConfig) -> Result<Self, Self::Error> {
+        Ok(Self {
+            num_partitions: i32::try_from(cfg.n_sequencers.get())?,
+            replication_factor: cfg
+                .options
+                .get("replication_factor")
+                .map(|s| s.parse())
+                .transpose()?
+                .unwrap_or(1),
+            timeout_ms: cfg
+                .options
+                .get("timeout_ms")
+                .map(|s| s.parse())
+                .transpose()?
+                .unwrap_or(5_000),
+        })
+    }
+}
+
+/// Config for consumers.
+#[derive(Debug, PartialEq, Eq)]
+pub struct ConsumerConfig {
+    /// Will wait for at least `min_batch_size` bytes of data
+    ///
+    /// Extracted from `consumer_max_wait_ms`. Defaults to `None` (rskafka default).
+    pub max_wait_ms: Option<i32>,
+
+    /// The maximum amount of data to fetch in a single batch
+    ///
+    /// Extracted from `consumer_min_batch_size`. Defaults to `None` (rskafka default).
+    pub min_batch_size: Option<i32>,
+
+    /// The maximum amount of time to wait for data before returning
+    ///
+    /// Extracted from `consumer_max_batch_size`. Defaults to `None` (rskafka default).
+    pub max_batch_size: Option<i32>,
+}
+
+impl TryFrom<&BTreeMap<String, String>> for ConsumerConfig {
+    type Error = WriteBufferError;
+
+    fn try_from(cfg: &BTreeMap<String, String>) -> Result<Self, Self::Error> {
+        Ok(Self {
+            max_wait_ms: cfg
+                .get("consumer_max_wait_ms")
+                .map(|s| s.parse())
+                .transpose()?,
+            min_batch_size: cfg
+                .get("consumer_min_batch_size")
+                .map(|s| s.parse())
+                .transpose()?,
+            max_batch_size: cfg
+                .get("consumer_max_batch_size")
+                .map(|s| s.parse())
+                .transpose()?,
+        })
+    }
+}
+
+/// Config for producers.
+#[derive(Debug, PartialEq, Eq)]
+pub struct ProducerConfig {
+    /// Linger time.
+    ///
+    /// Extracted from `producer_linger_ms`. Defaults to `None` (rskafka default).
+    pub linger: Option<Duration>,
+
+    /// Maximum batch size in bytes.
+    ///
+    /// Extracted from `producer_max_batch_size`. Defaults to `100 * 1024`.
+    pub max_batch_size: usize,
+}
+
+impl TryFrom<&BTreeMap<String, String>> for ProducerConfig {
+    type Error = WriteBufferError;
+
+    fn try_from(cfg: &BTreeMap<String, String>) -> Result<Self, Self::Error> {
+        let linger_ms: Option<u64> = cfg
+            .get("producer_linger_ms")
+            .map(|s| s.parse())
+            .transpose()?;
+
+        Ok(Self {
+            linger: linger_ms.map(Duration::from_millis),
+            max_batch_size: cfg
+                .get("producer_max_batch_size")
+                .map(|s| s.parse())
+                .transpose()?
+                .unwrap_or(100 * 1024),
+        })
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::{collections::BTreeMap, num::NonZeroU32};
+
+    use super::*;
+
+    #[test]
+    fn test_client_config_default() {
+        let actual = ClientConfig::try_from(&BTreeMap::default()).unwrap();
+        let expected = ClientConfig {
+            max_message_size: None,
+        };
+        assert_eq!(actual, expected);
+    }
+
+    #[test]
+    fn test_client_config_parse() {
+        let actual = ClientConfig::try_from(&BTreeMap::from([
+            (String::from("max_message_size"), String::from("1024")),
+            (String::from("foo"), String::from("bar")),
+        ]))
+        .unwrap();
+        let expected = ClientConfig {
+            max_message_size: Some(1024),
+        };
+        assert_eq!(actual, expected);
+    }
+
+    #[test]
+    fn test_topic_creation_config_default() {
+        let actual = TopicCreationConfig::try_from(&WriteBufferCreationConfig {
+            n_sequencers: NonZeroU32::new(2).unwrap(),
+            options: BTreeMap::default(),
+        })
+        .unwrap();
+        let expected = TopicCreationConfig {
+            num_partitions: 2,
+            replication_factor: 1,
+            timeout_ms: 5_000,
+        };
+        assert_eq!(actual, expected);
+    }
+
+    #[test]
+    fn test_topic_creation_config_parse() {
+        let actual = TopicCreationConfig::try_from(&WriteBufferCreationConfig {
+            n_sequencers: NonZeroU32::new(2).unwrap(),
+            options: BTreeMap::from([
+                (String::from("replication_factor"), String::from("3")),
+                (String::from("timeout_ms"), String::from("100")),
+                (String::from("foo"), String::from("bar")),
+            ]),
+        })
+        .unwrap();
+        let expected = TopicCreationConfig {
+            num_partitions: 2,
+            replication_factor: 3,
+            timeout_ms: 100,
+        };
+        assert_eq!(actual, expected);
+    }
+
+    #[test]
+    fn test_consumer_config_default() {
+        let actual = ConsumerConfig::try_from(&BTreeMap::default()).unwrap();
+        let expected = ConsumerConfig {
+            max_wait_ms: None,
+            min_batch_size: None,
+            max_batch_size: None,
+        };
+        assert_eq!(actual, expected);
+    }
+
+    #[test]
+    fn test_consumer_config_parse() {
+        let actual = ConsumerConfig::try_from(&BTreeMap::from([
+            (String::from("consumer_max_wait_ms"), String::from("11")),
+            (String::from("consumer_min_batch_size"), String::from("22")),
+            (String::from("consumer_max_batch_size"), String::from("33")),
+            (String::from("foo"), String::from("bar")),
+        ]))
+        .unwrap();
+        let expected = ConsumerConfig {
+            max_wait_ms: Some(11),
+            min_batch_size: Some(22),
+            max_batch_size: Some(33),
+        };
+        assert_eq!(actual, expected);
+    }
+
+    #[test]
+    fn test_producer_config_default() {
+        let actual = ProducerConfig::try_from(&BTreeMap::default()).unwrap();
+        let expected = ProducerConfig {
+            linger: None,
+            max_batch_size: 100 * 1024,
+        };
+        assert_eq!(actual, expected);
+    }
+
+    #[test]
+    fn test_producer_config_parse() {
+        let actual = ProducerConfig::try_from(&BTreeMap::from([
+            (String::from("producer_linger_ms"), String::from("42")),
+            (
+                String::from("producer_max_batch_size"),
+                String::from("1337"),
+            ),
+            (String::from("foo"), String::from("bar")),
+        ]))
+        .unwrap();
+        let expected = ProducerConfig {
+            linger: Some(Duration::from_millis(42)),
+            max_batch_size: 1337,
+        };
+        assert_eq!(actual, expected);
+    }
+}
diff --git a/write_buffer/src/rskafka/mod.rs b/write_buffer/src/rskafka/mod.rs
index 00b2bbda91..e8269e3708 100644
--- a/write_buffer/src/rskafka/mod.rs
+++ b/write_buffer/src/rskafka/mod.rs
@@ -4,7 +4,6 @@ use std::{
         atomic::{AtomicI64, Ordering},
         Arc,
     },
-    time::Duration,
 };
 
 use async_trait::async_trait;
@@ -29,9 +28,13 @@ use crate::{
     },
 };
 
-use self::aggregator::DmlAggregator;
+use self::{
+    aggregator::DmlAggregator,
+    config::{ClientConfig, ConsumerConfig, ProducerConfig, TopicCreationConfig},
+};
 
 mod aggregator;
+mod config;
 
 type Result<T, E = WriteBufferError> = std::result::Result<T, E>;
 
@@ -44,23 +47,35 @@ impl RSKafkaProducer {
     pub async fn new(
         conn: String,
         database_name: String,
+        connection_config: &BTreeMap<String, String>,
         creation_config: Option<&WriteBufferCreationConfig>,
         time_provider: Arc<dyn TimeProvider>,
         trace_collector: Option<Arc<dyn TraceCollector>>,
     ) -> Result<Self> {
-        let partition_clients = setup_topic(conn, database_name.clone(), creation_config).await?;
+        let partition_clients = setup_topic(
+            conn,
+            database_name.clone(),
+            connection_config,
+            creation_config,
+        )
+        .await?;
+
+        let producer_config = ProducerConfig::try_from(connection_config)?;
+
         let producers = partition_clients
             .into_iter()
             .map(|(sequencer_id, partition_client)| {
-                let producer = BatchProducerBuilder::new(Arc::new(partition_client))
-                    .with_linger(Duration::from_millis(100))
-                    .build(DmlAggregator::new(
-                        trace_collector.clone(),
-                        database_name.clone(),
-                        1024 * 500,
-                        sequencer_id,
-                        Arc::clone(&time_provider),
-                    ));
+                let mut producer_builder = BatchProducerBuilder::new(Arc::new(partition_client));
+                if let Some(linger) = producer_config.linger {
+                    producer_builder = producer_builder.with_linger(linger);
+                }
+                let producer = producer_builder.build(DmlAggregator::new(
+                    trace_collector.clone(),
+                    database_name.clone(),
+                    producer_config.max_batch_size,
+                    sequencer_id,
+                    Arc::clone(&time_provider),
+                ));
 
                 (sequencer_id, producer)
             })
@@ -113,16 +128,24 @@ struct ConsumerPartition {
 pub struct RSKafkaConsumer {
     partitions: BTreeMap<u32, ConsumerPartition>,
     trace_collector: Option<Arc<dyn TraceCollector>>,
+    consumer_config: ConsumerConfig,
 }
 
 impl RSKafkaConsumer {
     pub async fn new(
         conn: String,
         database_name: String,
+        connection_config: &BTreeMap<String, String>,
         creation_config: Option<&WriteBufferCreationConfig>,
         trace_collector: Option<Arc<dyn TraceCollector>>,
     ) -> Result<Self> {
-        let partition_clients = setup_topic(conn, database_name.clone(), creation_config).await?;
+        let partition_clients = setup_topic(
+            conn,
+            database_name.clone(),
+            connection_config,
+            creation_config,
+        )
+        .await?;
 
         let partitions = partition_clients
             .into_iter()
@@ -143,6 +166,7 @@ impl RSKafkaConsumer {
         Ok(Self {
             partitions,
             trace_collector,
+            consumer_config: ConsumerConfig::try_from(connection_config)?,
         })
     }
 }
@@ -155,12 +179,22 @@ impl WriteBufferReading for RSKafkaConsumer {
         for (sequencer_id, partition) in &self.partitions {
             let trace_collector = self.trace_collector.clone();
             let next_offset = Arc::clone(&partition.next_offset);
-            let stream = StreamConsumerBuilder::new(
+
+            let mut stream_builder = StreamConsumerBuilder::new(
                 Arc::clone(&partition.partition_client),
                 next_offset.load(Ordering::SeqCst),
-            )
-            .with_max_wait_ms(100)
-            .build();
+            );
+            if let Some(max_wait_ms) = self.consumer_config.max_wait_ms {
+                stream_builder = stream_builder.with_max_wait_ms(max_wait_ms);
+            }
+            if let Some(min_batch_size) = self.consumer_config.min_batch_size {
+                stream_builder = stream_builder.with_min_batch_size(min_batch_size);
+            }
+            if let Some(max_batch_size) = self.consumer_config.max_batch_size {
+                stream_builder = stream_builder.with_max_batch_size(max_batch_size);
+            }
+            let stream = stream_builder.build();
+
             let stream = stream.map(move |res| {
                 let (record, _watermark) = res?;
 
@@ -247,9 +281,15 @@ impl WriteBufferReading for RSKafkaConsumer {
 async fn setup_topic(
     conn: String,
     database_name: String,
+    connection_config: &BTreeMap<String, String>,
     creation_config: Option<&WriteBufferCreationConfig>,
 ) -> Result<BTreeMap<u32, PartitionClient>> {
-    let client = ClientBuilder::new(vec![conn]).build().await?;
+    let client_config = ClientConfig::try_from(connection_config)?;
+    let mut client_builder = ClientBuilder::new(vec![conn]);
+    if let Some(max_message_size) = client_config.max_message_size {
+        client_builder = client_builder.max_message_size(max_message_size);
+    }
+    let client = client_builder.build().await?;
     let controller_client = client.controller_client().await?;
 
     loop {
@@ -267,12 +307,14 @@ async fn setup_topic(
 
         // create topic
         if let Some(creation_config) = creation_config {
+            let topic_creation_config = TopicCreationConfig::try_from(creation_config)?;
+
             match controller_client
                 .create_topic(
                     &database_name,
-                    creation_config.n_sequencers.get() as i32,
-                    1,
-                    5_000,
+                    topic_creation_config.num_partitions,
+                    topic_creation_config.replication_factor,
+                    topic_creation_config.timeout_ms,
                 )
                 .await
             {
@@ -366,6 +408,7 @@ mod tests {
             RSKafkaProducer::new(
                 self.conn.clone(),
                 self.database_name.clone(),
+                &BTreeMap::default(),
                 self.creation_config(creation_config).as_ref(),
                 Arc::clone(&self.time_provider),
                 Some(self.trace_collector() as Arc<_>),
@@ -377,6 +420,7 @@ mod tests {
             RSKafkaConsumer::new(
                 self.conn.clone(),
                 self.database_name.clone(),
+                &BTreeMap::default(),
                 self.creation_config(creation_config).as_ref(),
                 Some(self.trace_collector() as Arc<_>),
             )
@@ -410,6 +454,7 @@ mod tests {
                     setup_topic(
                         conn,
                         topic_name,
+                        &BTreeMap::default(),
                         Some(&WriteBufferCreationConfig {
                             n_sequencers: n_partitions,
                             ..Default::default()