influxdb/datafusion_util/src/lib.rs

452 lines
14 KiB
Rust

#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
#![allow(clippy::clone_on_ref_ptr)]
//! This module contains various DataFusion utility functions.
//!
//! Almost everything for manipulating DataFusion `Expr`s IOx should be in DataFusion already
//! (or if not it should be upstreamed).
//!
//! For example, check out
//! [datafusion_optimizer::utils](https://docs.rs/datafusion-optimizer/13.0.0/datafusion_optimizer/utils/index.html)
//! for expression manipulation functions.
pub mod sender;
pub mod watch;
use std::sync::Arc;
use std::task::{Context, Poll};
use datafusion::arrow::array::BooleanArray;
use datafusion::arrow::compute::filter_record_batch;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::DataFusionError;
use datafusion::datasource::MemTable;
use datafusion::execution::context::TaskContext;
use datafusion::logical_expr::Operator;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::common::SizedRecordBatchStream;
use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics};
use datafusion::physical_plan::{collect, EmptyRecordBatchStream, ExecutionPlan};
use datafusion::prelude::{col, lit, Expr, SessionContext};
use datafusion::{
arrow::{
datatypes::{Schema, SchemaRef},
error::Result as ArrowResult,
record_batch::RecordBatch,
},
physical_plan::{RecordBatchStream, SendableRecordBatchStream},
scalar::ScalarValue,
};
use futures::{Stream, StreamExt};
use tokio::sync::mpsc::{Receiver, UnboundedReceiver};
use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
use watch::WatchedTask;
/// Split an Expr up into its constituent parts
///
/// ```text
/// A ==> Vec[A]
/// A AND B AND C ==> Vec[A, B, C]
/// ```
/// TODO rewrite using datafusion `split_conjunction`
pub fn disassemble_conjuct(expr: Expr) -> Vec<Expr> {
let mut exprs = vec![];
disassemble_conjuct_impl(expr, &mut exprs);
exprs
}
fn disassemble_conjuct_impl(expr: Expr, exprs: &mut Vec<Expr>) {
match expr {
Expr::BinaryExpr {
right,
op: Operator::And,
left,
} => {
disassemble_conjuct_impl(*left, exprs);
disassemble_conjuct_impl(*right, exprs);
}
other => exprs.push(other),
}
}
/// Traits to help creating DataFusion [`Expr`]s
pub trait AsExpr {
/// Creates a DataFusion expr
fn as_expr(&self) -> Expr;
/// creates a DataFusion SortExpr
fn as_sort_expr(&self) -> Expr {
Expr::Sort {
expr: Box::new(self.as_expr()),
asc: true, // Sort ASCENDING
nulls_first: true,
}
}
}
impl AsExpr for Arc<str> {
fn as_expr(&self) -> Expr {
col(self.as_ref())
}
}
impl AsExpr for str {
fn as_expr(&self) -> Expr {
col(self)
}
}
impl AsExpr for Expr {
fn as_expr(&self) -> Expr {
self.clone()
}
}
/// Creates an `Expr` that represents a Dictionary encoded string (e.g
/// the type of constant that a tag would be compared to)
pub fn lit_dict(value: &str) -> Expr {
// expr has been type coerced
lit(ScalarValue::Dictionary(
Box::new(DataType::Int32),
Box::new(ScalarValue::new_utf8(value)),
))
}
/// Creates expression like:
/// start <= time && time < end
pub fn make_range_expr(start: i64, end: i64, time: impl AsRef<str>) -> Expr {
// We need to cast the start and end values to timestamps
// the equivalent of:
let ts_start = ScalarValue::TimestampNanosecond(Some(start), None);
let ts_end = ScalarValue::TimestampNanosecond(Some(end), None);
let ts_low = lit(ts_start).lt_eq(col(time.as_ref()));
let ts_high = col(time.as_ref()).lt(lit(ts_end));
ts_low.and(ts_high)
}
/// A RecordBatchStream created from in-memory RecordBatches.
#[derive(Debug)]
pub struct MemoryStream {
schema: SchemaRef,
batches: Vec<RecordBatch>,
}
impl MemoryStream {
/// Create new stream.
///
/// Must at least pass one record batch!
pub fn new(batches: Vec<RecordBatch>) -> Self {
assert!(!batches.is_empty(), "must at least pass one record batch");
Self {
schema: batches[0].schema(),
batches,
}
}
/// Create new stream with provided schema.
pub fn new_with_schema(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
Self { schema, batches }
}
}
impl RecordBatchStream for MemoryStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
impl futures::Stream for MemoryStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.batches.is_empty() {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(self.batches.remove(0))))
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.batches.len(), Some(self.batches.len()))
}
}
#[derive(Debug)]
/// Implements a [`SendableRecordBatchStream`] to help create DataFusion outputs
/// from tokio channels.
///
/// It sends streams of RecordBatches from a tokio channel *and* crucially knows
/// up front the schema each batch will have be used.
pub struct AdapterStream<T> {
/// Schema
schema: SchemaRef,
/// channel for getting deduplicated batches
inner: T,
/// Optional join handles of underlying tasks.
#[allow(dead_code)]
task: Arc<WatchedTask>,
}
impl AdapterStream<ReceiverStream<ArrowResult<RecordBatch>>> {
/// Create a new stream which wraps the `inner` channel which produces
/// [`RecordBatch`]es that each have the specified schema
///
/// Not called `new` because it returns a pinned reference rather than the
/// object itself.
pub fn adapt(
schema: SchemaRef,
rx: Receiver<ArrowResult<RecordBatch>>,
task: Arc<WatchedTask>,
) -> SendableRecordBatchStream {
let inner = ReceiverStream::new(rx);
Box::pin(Self {
schema,
inner,
task,
})
}
}
impl AdapterStream<UnboundedReceiverStream<ArrowResult<RecordBatch>>> {
/// Create a new stream which wraps the `inner` unbounded channel which
/// produces [`RecordBatch`]es that each have the specified schema
///
/// Not called `new` because it returns a pinned reference rather than the
/// object itself.
pub fn adapt_unbounded(
schema: SchemaRef,
rx: UnboundedReceiver<ArrowResult<RecordBatch>>,
task: Arc<WatchedTask>,
) -> SendableRecordBatchStream {
let inner = UnboundedReceiverStream::new(rx);
Box::pin(Self {
schema,
inner,
task,
})
}
}
impl<T> Stream for AdapterStream<T>
where
T: Stream<Item = ArrowResult<RecordBatch>> + Unpin,
{
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}
impl<T> RecordBatchStream for AdapterStream<T>
where
T: Stream<Item = ArrowResult<RecordBatch>> + Unpin,
{
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
/// Create a SendableRecordBatchStream a RecordBatch
pub fn stream_from_batch(schema: Arc<Schema>, batch: RecordBatch) -> SendableRecordBatchStream {
stream_from_batches(schema, vec![Arc::new(batch)])
}
/// Create a SendableRecordBatchStream from Vec of RecordBatches with the same schema
pub fn stream_from_batches(
schema: Arc<Schema>,
batches: Vec<Arc<RecordBatch>>,
) -> SendableRecordBatchStream {
if batches.is_empty() {
return Box::pin(EmptyRecordBatchStream::new(schema));
}
let dummy_metrics = ExecutionPlanMetricsSet::new();
let mem_metrics = MemTrackingMetrics::new(&dummy_metrics, 0);
let stream = SizedRecordBatchStream::new(batches[0].schema(), batches, mem_metrics);
Box::pin(stream)
}
/// Create a SendableRecordBatchStream that sends back no RecordBatches with a specific schema
pub fn stream_from_schema(schema: SchemaRef) -> SendableRecordBatchStream {
let dummy_metrics = ExecutionPlanMetricsSet::new();
let mem_metrics = MemTrackingMetrics::new(&dummy_metrics, 0);
let stream = SizedRecordBatchStream::new(schema, vec![], mem_metrics);
Box::pin(stream)
}
/// Execute the [ExecutionPlan] with a default [SessionContext] and
/// collect the results in memory.
///
/// # Panics
/// If an an error occurs
pub async fn test_collect(plan: Arc<dyn ExecutionPlan>) -> Vec<RecordBatch> {
let session_ctx = SessionContext::new();
let task_ctx = Arc::new(TaskContext::from(&session_ctx));
collect(plan, task_ctx).await.unwrap()
}
/// Execute the specified partition of the [ExecutionPlan] with a
/// default [SessionContext] returning the resulting stream.
///
/// # Panics
/// If an an error occurs
pub async fn test_execute_partition(
plan: Arc<dyn ExecutionPlan>,
partition: usize,
) -> SendableRecordBatchStream {
let session_ctx = SessionContext::new();
let task_ctx = Arc::new(TaskContext::from(&session_ctx));
plan.execute(partition, task_ctx).unwrap()
}
/// Execute the specified partition of the [ExecutionPlan] with a
/// default [SessionContext] and collect the results in memory.
///
/// # Panics
/// If an an error occurs
pub async fn test_collect_partition(
plan: Arc<dyn ExecutionPlan>,
partition: usize,
) -> Vec<RecordBatch> {
let stream = test_execute_partition(plan, partition).await;
datafusion::physical_plan::common::collect(stream)
.await
.unwrap()
}
/// Filter data from RecordBatch
///
/// Borrowed from DF's <https://github.com/apache/arrow-datafusion/blob/ecd0081bde98e9031b81aa6e9ae2a4f309fcec12/datafusion/src/physical_plan/filter.rs#L186>.
// TODO: if we make DF batch_filter public, we can call that function directly
pub fn batch_filter(
batch: &RecordBatch,
predicate: &Arc<dyn PhysicalExpr>,
) -> ArrowResult<RecordBatch> {
predicate
.evaluate(batch)
.map(|v| v.into_array(batch.num_rows()))
.map_err(DataFusionError::into)
.and_then(|array| {
array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| {
DataFusionError::Internal(
"Filter predicate evaluated to non-boolean value".to_string(),
)
.into()
})
// apply filter array to record batch
.and_then(|filter_array| filter_record_batch(batch, filter_array))
})
}
/// Return a DataFusion [`SessionContext`] that has the passed RecordBatch available as a table
pub fn context_with_table(batch: RecordBatch) -> SessionContext {
let schema = batch.schema();
let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
let ctx = SessionContext::new();
ctx.register_table("t", Arc::new(provider)).unwrap();
ctx
}
/// Returns a new schema where all the fields are nullable
pub fn nullable_schema(schema: SchemaRef) -> SchemaRef {
// they are all already nullable
if schema.fields().iter().all(|f| f.is_nullable()) {
schema
} else {
// make a new schema with all nullable fields
let new_fields = schema
.fields()
.iter()
.map(|f| {
// make a copy of the field, but allow it to be nullable
f.clone().with_nullable(true)
})
.collect();
Arc::new(Schema::new_with_metadata(
new_fields,
schema.metadata().clone(),
))
}
}
#[cfg(test)]
mod tests {
use datafusion::arrow::datatypes::{DataType, Field};
use schema::builder::SchemaBuilder;
use super::*;
#[test]
fn test_make_range_expr() {
// Test that the generated predicate is correct
let ts_predicate_expr = make_range_expr(101, 202, "time");
let expected_string =
"TimestampNanosecond(101, None) <= time AND time < TimestampNanosecond(202, None)";
let actual_string = format!("{:?}", ts_predicate_expr);
assert_eq!(actual_string, expected_string);
}
#[test]
fn test_nullable_schema_nullable() {
// schema is all nullable
let schema = Arc::new(Schema::new(vec![
Field::new("foo", DataType::Int32, true),
Field::new("bar", DataType::Utf8, true),
]));
assert_eq!(schema, nullable_schema(schema.clone()))
}
#[test]
fn test_nullable_schema_non_nullable() {
// schema has one nullable column
let schema = Arc::new(Schema::new(vec![
Field::new("foo", DataType::Int32, false),
Field::new("bar", DataType::Utf8, true),
]));
let expected_schema = Arc::new(Schema::new(vec![
Field::new("foo", DataType::Int32, true),
Field::new("bar", DataType::Utf8, true),
]));
assert_eq!(expected_schema, nullable_schema(schema))
}
#[tokio::test]
async fn test_adapter_stream_panic_handling() {
let schema = SchemaBuilder::new().timestamp().build().unwrap().as_arrow();
let (tx, rx) = tokio::sync::mpsc::channel(2);
let tx_captured = tx.clone();
let fut = async move {
let _tx = tx_captured;
if true {
panic!("epic fail");
}
Ok(())
};
let join_handle = WatchedTask::new(fut, vec![tx], "test");
let stream = AdapterStream::adapt(schema, rx, join_handle);
datafusion::physical_plan::common::collect(stream)
.await
.unwrap_err();
}
}