diff --git a/iox_query/src/exec/context.rs b/iox_query/src/exec/context.rs index d78a593723..b7dcc2bd2c 100644 --- a/iox_query/src/exec/context.rs +++ b/iox_query/src/exec/context.rs @@ -32,6 +32,7 @@ use datafusion::{ execution::{ context::{QueryPlanner, SessionState, TaskContext}, runtime_env::RuntimeEnv, + MemoryManager, }, logical_expr::{LogicalPlan, UserDefinedLogicalNode}, physical_plan::{ @@ -411,6 +412,7 @@ impl IOxSessionContext { pub async fn to_series_and_groups( &self, series_set_plans: SeriesSetPlans, + memory_manager: Arc, ) -> Result>> { let SeriesSetPlans { mut plans, @@ -471,7 +473,7 @@ impl IOxSessionContext { // If we have group columns, sort the results, and create the // appropriate groups if let Some(group_columns) = group_columns { - let grouper = GroupGenerator::new(group_columns); + let grouper = GroupGenerator::new(group_columns, memory_manager); Ok(grouper.group(data).await?.boxed()) } else { Ok(data.map_ok(|series| series.into()).boxed()) diff --git a/iox_query/src/exec/seriesset/converter.rs b/iox_query/src/exec/seriesset/converter.rs index 81a1c7a9a2..b9bea417ec 100644 --- a/iox_query/src/exec/seriesset/converter.rs +++ b/iox_query/src/exec/seriesset/converter.rs @@ -9,13 +9,21 @@ use arrow::{ datatypes::{DataType, Int32Type, SchemaRef}, record_batch::RecordBatch, }; -use datafusion::{error::DataFusionError, physical_plan::SendableRecordBatchStream}; +use datafusion::{ + error::DataFusionError, + execution::{ + memory_manager::proxy::{MemoryConsumerProxy, VecAllocExt}, + MemoryConsumerId, MemoryManager, + }, + physical_plan::SendableRecordBatchStream, +}; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use futures::{future::BoxFuture, ready, FutureExt, Stream, StreamExt}; use predicate::rpc_predicate::{GROUP_KEY_SPECIAL_START, GROUP_KEY_SPECIAL_STOP}; use snafu::{OptionExt, Snafu}; use std::{ collections::VecDeque, + future::Future, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -199,7 +207,7 @@ impl SeriesSetConverter { ) -> Vec<(Arc, Arc)> { assert_eq!(tag_column_names.len(), tag_indexes.len()); - tag_column_names + let mut out = tag_column_names .iter() .zip(tag_indexes) .filter_map(|(column_name, column_index)| { @@ -246,7 +254,10 @@ impl SeriesSetConverter { tag_value.map(|tag_value| (Arc::clone(column_name), Arc::from(tag_value.as_str()))) }) - .collect() + .collect::>(); + + out.shrink_to_fit(); + out } } @@ -491,32 +502,29 @@ impl Stream for SeriesSetConverterStream { #[derive(Debug)] pub struct GroupGenerator { group_columns: Vec>, + memory_manager: Arc, } impl GroupGenerator { - pub fn new(group_columns: Vec>) -> Self { - Self { group_columns } + pub fn new(group_columns: Vec>, memory_manager: Arc) -> Self { + Self { + group_columns, + memory_manager, + } } /// groups the set of `series` into SeriesOrGroups /// - /// TODO: make this truly stream-based + /// TODO: make this truly stream-based, see . pub async fn group( - &self, + self, series: S, ) -> Result>, DataFusionError> where S: Stream> + Send, { - let mut series = series - .map(|res| { - res.and_then(|series| { - SortableSeries::try_new(series, &self.group_columns) - .map_err(|e| DataFusionError::External(Box::new(e))) - }) - }) - .try_collect::>() - .await?; + let series = Box::pin(series); + let mut series = Collector::new(series, self.group_columns, self.memory_manager).await?; // Potential optimization is to skip this sort if we are // grouping by a prefix of the tags for a single measurement @@ -658,12 +666,175 @@ impl SortableSeries { use_tag.then(|| Arc::clone(&tag.value)) })); + // safe memory + tag_vals.shrink_to_fit(); + Ok(Self { series, tag_vals, num_partition_keys: group_columns.len(), }) } + + /// Memory usage in bytes, including `self`. + fn size(&self) -> usize { + std::mem::size_of_val(self) + self.series.size() - std::mem::size_of_val(&self.series) + + (std::mem::size_of::>() * self.tag_vals.capacity()) + + self.tag_vals.iter().map(|s| s.len()).sum::() + } +} + +/// [`Future`] that collects [`Series`] objects into a [`SortableSeries`] vector while registering/checking memory +/// allocations with a [`MemoryManager`]. +/// +/// This avoids unbounded memory growth when merging multiple `Series` in memory +struct Collector +where + S: Stream> + Send + Unpin, +{ + /// The inner stream was fully drained. + inner_done: bool, + + /// This very future finished. + outer_done: bool, + + /// Inner stream. + inner: S, + + /// Group columns. + /// + /// These are required for [`SortableSeries::try_new`]. + group_columns: Vec>, + + /// Already collected objects. + collected: Vec, + + /// Buffered but not-yet-registered allocated size. + /// + /// We use an additional buffer here because in contrast to the normal DataFusion processing, the input stream is + /// NOT batched and we want to avoid costly memory allocations checks with the [`MemoryManager`] for every single element. + buffered_size: usize, + + /// Our memory consumer. + /// + /// This is optional because for [`MemoryConsumerProxy::alloc`], we need to move this into + /// [`mem_proxy_alloc_fut`](Self::mem_proxy_alloc_fut) to avoid self-borrowing. + mem_proxy: Option, + + /// A potential running [`MemoryConsumerProxy::alloc`]. + /// + /// This owns [`mem_proxy`](Self::mem_proxy) to avoid self-borrowing. + mem_proxy_alloc_fut: + Option)>>, +} + +impl Collector +where + S: Stream> + Send + Unpin, +{ + /// Maximum [buffered size](Self::buffered_size). + const ALLOCATION_BUFFER_SIZE: usize = 1024 * 1024; + + fn new(inner: S, group_columns: Vec>, memory_manager: Arc) -> Self { + let mem_proxy = + MemoryConsumerProxy::new("Collector stream", MemoryConsumerId::new(0), memory_manager); + Self { + inner_done: false, + outer_done: false, + inner, + group_columns, + collected: Vec::with_capacity(0), + buffered_size: 0, + mem_proxy: Some(mem_proxy), + mem_proxy_alloc_fut: None, + } + } + + /// Start a [`MemoryConsumerProxy::alloc`] future. + /// + /// # Panic + /// Panics if a future is already running. + fn alloc(&mut self) { + assert!(self.mem_proxy_alloc_fut.is_none()); + let mut mem_proxy = + std::mem::take(&mut self.mem_proxy).expect("no mem proxy future running"); + let bytes = std::mem::take(&mut self.buffered_size); + self.mem_proxy_alloc_fut = Some( + async move { + let res = mem_proxy.alloc(bytes).await; + (mem_proxy, res) + } + .boxed(), + ); + } +} + +impl Future for Collector +where + S: Stream> + Send + Unpin, +{ + type Output = Result, DataFusionError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + + loop { + assert!(!this.outer_done); + + // Drive `MemoryConsumerProxy::alloc` to completion. + if let Some(fut) = this.mem_proxy_alloc_fut.as_mut() { + let (mem_proxy, res) = ready!(fut.poll_unpin(cx)); + assert!(this.mem_proxy.is_none()); + this.mem_proxy = Some(mem_proxy); + if let Err(e) = res { + // poison this future + this.outer_done = true; + return Poll::Ready(Err(e)); + } + } + + // if the underlying stream is drained and the allocation future is ready (see above), we can finalize this future + if this.inner_done { + this.outer_done = true; + return Poll::Ready(Ok(std::mem::take(&mut this.collected))); + } + + match ready!(this.inner.poll_next_unpin(cx)) { + Some(Ok(series)) => match SortableSeries::try_new(series, &this.group_columns) { + Ok(series) => { + // Note: the size of `SortableSeries` itself is already included in the vector allocation + this.buffered_size += series.size() - std::mem::size_of_val(&series); + this.collected + .push_accounted(series, &mut this.buffered_size); + + // should we clear our allocation buffer? + if this.buffered_size > Self::ALLOCATION_BUFFER_SIZE { + this.alloc(); + continue; + } + } + Err(e) => { + // poison this future + this.outer_done = true; + return Poll::Ready(Err(DataFusionError::External(Box::new(e)))); + } + }, + Some(Err(e)) => { + // poison this future + this.outer_done = true; + return Poll::Ready(Err(e)); + } + None => { + // underlying stream drained. now register the final allocation and then we're done + this.inner_done = true; + if this.buffered_size > 0 { + this.alloc(); + } + continue; + } + } + } + } } #[cfg(test)] @@ -677,10 +848,15 @@ mod tests { record_batch::RecordBatch, }; use arrow_util::assert_batches_eq; + use assert_matches::assert_matches; + use datafusion::execution::memory_manager::MemoryManagerConfig; use datafusion_util::{stream_from_batch, stream_from_batches, stream_from_schema}; + use futures::TryStreamExt; use itertools::Itertools; use test_helpers::str_vec_to_arc_vec; + use crate::exec::seriesset::series::{Data, Tag}; + use super::*; #[tokio::test] @@ -1431,6 +1607,28 @@ mod tests { ); } + #[tokio::test] + async fn test_group_generator_mem_limit() { + let memory_manager = + MemoryManager::new(MemoryManagerConfig::try_new_limit(1, 1.0).unwrap()); + let ggen = GroupGenerator::new(vec![Arc::from("g")], memory_manager); + let input = futures::stream::iter([Ok(Series { + tags: vec![Tag { + key: Arc::from("g"), + value: Arc::from("x"), + }], + data: Data::FloatPoints { + timestamps: vec![], + values: vec![], + }, + })]); + let err = match ggen.group(input).await { + Ok(stream) => stream.try_collect::>().await.unwrap_err(), + Err(e) => e, + }; + assert_matches!(err, DataFusionError::ResourcesExhausted(_)); + } + fn assert_series_set( set: &SeriesSet, table_name: &'static str, diff --git a/iox_query/src/exec/seriesset/series.rs b/iox_query/src/exec/seriesset/series.rs index f467bfa470..a858703b89 100644 --- a/iox_query/src/exec/seriesset/series.rs +++ b/iox_query/src/exec/seriesset/series.rs @@ -34,6 +34,13 @@ pub struct Tag { pub value: Arc, } +impl Tag { + /// Memory usage in bytes, including `self`. + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + self.key.len() + self.value.len() + } +} + impl fmt::Display for Tag { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}={}", self.key, self.value) @@ -51,6 +58,21 @@ pub struct Series { pub data: Data, } +impl Series { + /// Memory usage in bytes, including `self`. + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.tags.capacity()) + + self + .tags + .iter() + .map(|tag| tag.size() - std::mem::size_of_val(tag)) + .sum::() + + self.data.size() + - std::mem::size_of_val(&self.data) + } +} + impl fmt::Display for Series { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Series tags={{")?; @@ -97,6 +119,35 @@ pub enum Data { }, } +impl Data { + /// Memory usage in bytes, including `self`. + pub fn size(&self) -> usize { + std::mem::size_of_val(self) + + match self { + Self::FloatPoints { timestamps, values } => { + primitive_vec_size(timestamps) + primitive_vec_size(values) + } + Self::IntegerPoints { timestamps, values } => { + primitive_vec_size(timestamps) + primitive_vec_size(values) + } + Self::UnsignedPoints { timestamps, values } => { + primitive_vec_size(timestamps) + primitive_vec_size(values) + } + Self::BooleanPoints { timestamps, values } => { + primitive_vec_size(timestamps) + primitive_vec_size(values) + } + Self::StringPoints { timestamps, values } => { + primitive_vec_size(timestamps) + primitive_vec_size(values) + } + } + } +} + +/// Returns size of given vector of primitive types in bytes, EXCLUDING `vec` itself. +fn primitive_vec_size(vec: &Vec) -> usize { + std::mem::size_of::() * vec.capacity() +} + impl fmt::Display for Data { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -174,7 +225,7 @@ impl SeriesSet { let tags = self.create_frame_tags(schema.field(index.value_index).name()); - let timestamps = compute::nullif( + let mut timestamps = compute::nullif( batch.column(index.timestamp_index), &compute::is_null(array).expect("is_null"), ) @@ -183,47 +234,57 @@ impl SeriesSet { .downcast_ref::() .unwrap() .extract_values(); + timestamps.shrink_to_fit(); let data = match array.data_type() { ArrowDataType::Utf8 => { - let values = array + let mut values = array .as_any() .downcast_ref::() .unwrap() .extract_values(); + values.shrink_to_fit(); + Data::StringPoints { timestamps, values } } ArrowDataType::Float64 => { - let values = array + let mut values = array .as_any() .downcast_ref::() .unwrap() .extract_values(); + values.shrink_to_fit(); Data::FloatPoints { timestamps, values } } ArrowDataType::Int64 => { - let values = array + let mut values = array .as_any() .downcast_ref::() .unwrap() .extract_values(); + values.shrink_to_fit(); + Data::IntegerPoints { timestamps, values } } ArrowDataType::UInt64 => { - let values = array + let mut values = array .as_any() .downcast_ref::() .unwrap() .extract_values(); + values.shrink_to_fit(); + Data::UnsignedPoints { timestamps, values } } ArrowDataType::Boolean => { - let values = array + let mut values = array .as_any() .downcast_ref::() .unwrap() .extract_values(); + values.shrink_to_fit(); + Data::BooleanPoints { timestamps, values } } _ => { diff --git a/query_tests/src/influxrpc/util.rs b/query_tests/src/influxrpc/util.rs index 80e726dbaf..965cf1f6a0 100644 --- a/query_tests/src/influxrpc/util.rs +++ b/query_tests/src/influxrpc/util.rs @@ -22,9 +22,11 @@ pub async fn run_series_set_plan_maybe_error( ctx: &IOxSessionContext, plans: SeriesSetPlans, ) -> Result, DataFusionError> { + use std::sync::Arc; + use futures::TryStreamExt; - ctx.to_series_and_groups(plans) + ctx.to_series_and_groups(plans, Arc::clone(&ctx.inner().runtime_env().memory_manager)) .await? .map_ok(|series_or_group| series_or_group.to_string()) .try_collect() diff --git a/service_grpc_influxrpc/src/service.rs b/service_grpc_influxrpc/src/service.rs index 0c2448698f..f86cf7abd6 100644 --- a/service_grpc_influxrpc/src/service.rs +++ b/service_grpc_influxrpc/src/service.rs @@ -1348,7 +1348,10 @@ where // Execute the plans. let db_name = db_name.to_owned(); let series_or_groups = ctx - .to_series_and_groups(series_plan) + .to_series_and_groups( + series_plan, + Arc::clone(&ctx.inner().runtime_env().memory_manager), + ) .await .context(FilteringSeriesSnafu { db_name: db_name.clone(), @@ -1413,7 +1416,10 @@ where // Execute the plans let db_name = db_name.to_owned(); let series_or_groups = ctx - .to_series_and_groups(grouped_series_set_plan) + .to_series_and_groups( + grouped_series_set_plan, + Arc::clone(&ctx.inner().runtime_env().memory_manager), + ) .await .context(GroupingSeriesSnafu { db_name: db_name.clone(),