refactor: only use struct-style `select` in InfluxQL planner (#7636)
* refactor: only use struct-style `select` in InfluxQL planner For #7533 we need to track more columns apart from `time` and `value` and having a simple variant and multiple complex ones gets overly complicated soon. The aggregator is internally identical anyways, so let's only use one and then pull out the struct fields that we need. I'll also change the InfluxRPC planner to use the struct variant next, so we have a single `select` system both in the planners and in `query_functions`. * docs: improve * docs: explain test Co-authored-by: Andrew Lamb <alamb@influxdata.com> --------- Co-authored-by: Andrew Lamb <alamb@influxdata.com>pull/24376/head
parent
e67ed1fd0b
commit
949d131e77
|
@ -68,9 +68,6 @@ use iox_query::exec::IOxSessionContext;
|
|||
use iox_query::logical_optimizer::range_predicate::find_time_range;
|
||||
use itertools::Itertools;
|
||||
use observability_deps::tracing::debug;
|
||||
use query_functions::selectors::{
|
||||
selector_first, selector_last, selector_max, selector_min, SelectorOutput,
|
||||
};
|
||||
use query_functions::{
|
||||
clean_non_meta_escapes,
|
||||
selectors::{
|
||||
|
@ -1188,36 +1185,19 @@ impl<'a> InfluxQLToLogicalPlan<'a> {
|
|||
return Ok(expr);
|
||||
}
|
||||
|
||||
Ok(
|
||||
if let ProjectionType::Selector { .. } = ctx.info.projection_type {
|
||||
// Selector queries use the `struct_selector_<name>`, as they
|
||||
// will project the value and the time fields of the struct
|
||||
Expr::GetIndexedField(GetIndexedField {
|
||||
expr: Box::new(
|
||||
match name {
|
||||
"first" => struct_selector_first(),
|
||||
"last" => struct_selector_last(),
|
||||
"max" => struct_selector_max(),
|
||||
"min" => struct_selector_min(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
.call(vec![expr, "time".as_expr()]),
|
||||
),
|
||||
key: ScalarValue::Utf8(Some("value".to_owned())),
|
||||
})
|
||||
} else {
|
||||
// All other queries only require the value of the selector
|
||||
let data_type = &expr.get_type(&schemas.df_schema)?;
|
||||
match name {
|
||||
"first" => selector_first(data_type, SelectorOutput::Value),
|
||||
"last" => selector_last(data_type, SelectorOutput::Value),
|
||||
"max" => selector_max(data_type, SelectorOutput::Value),
|
||||
"min" => selector_min(data_type, SelectorOutput::Value),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
.call(vec![expr, "time".as_expr()])
|
||||
},
|
||||
)
|
||||
let selector_udf = match name {
|
||||
"first" => struct_selector_first(),
|
||||
"last" => struct_selector_last(),
|
||||
"max" => struct_selector_max(),
|
||||
"min" => struct_selector_min(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
.call(vec![expr, "time".as_expr()]);
|
||||
|
||||
Ok(Expr::GetIndexedField(GetIndexedField {
|
||||
expr: Box::new(selector_udf),
|
||||
key: ScalarValue::Utf8(Some("value".to_owned())),
|
||||
}))
|
||||
}
|
||||
_ => error::query(format!("Invalid function '{name}'")),
|
||||
}
|
||||
|
@ -2752,32 +2732,32 @@ mod test {
|
|||
// aggregate query, as we're grouping by time
|
||||
assert_snapshot!(plan("SELECT LAST(usage_idle) FROM cpu GROUP BY TIME(5s)"), @r###"
|
||||
Sort: time ASC NULLS LAST [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, last:Float64;N]
|
||||
Projection: Dictionary(Int32, Utf8("cpu")) AS iox::measurement, time, selector_last_value(cpu.usage_idle,cpu.time) AS last [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, last:Float64;N]
|
||||
GapFill: groupBy=[[time]], aggr=[[selector_last_value(cpu.usage_idle,cpu.time)]], time_column=time, stride=IntervalMonthDayNano("5000000000"), range=Unbounded..Excluded(now()) [time:Timestamp(Nanosecond, None);N, selector_last_value(cpu.usage_idle,cpu.time):Float64;N]
|
||||
Aggregate: groupBy=[[datebin(IntervalMonthDayNano("5000000000"), cpu.time, TimestampNanosecond(0, None)) AS time]], aggr=[[selector_last_value(cpu.usage_idle, cpu.time)]] [time:Timestamp(Nanosecond, None);N, selector_last_value(cpu.usage_idle,cpu.time):Float64;N]
|
||||
Projection: Dictionary(Int32, Utf8("cpu")) AS iox::measurement, time, (selector_last(cpu.usage_idle,cpu.time))[value] AS last [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, last:Float64;N]
|
||||
GapFill: groupBy=[[time]], aggr=[[selector_last(cpu.usage_idle,cpu.time)]], time_column=time, stride=IntervalMonthDayNano("5000000000"), range=Unbounded..Excluded(now()) [time:Timestamp(Nanosecond, None);N, selector_last(cpu.usage_idle,cpu.time):Struct([Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "time", data_type: Timestamp(Nanosecond, None), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]);N]
|
||||
Aggregate: groupBy=[[datebin(IntervalMonthDayNano("5000000000"), cpu.time, TimestampNanosecond(0, None)) AS time]], aggr=[[selector_last(cpu.usage_idle, cpu.time)]] [time:Timestamp(Nanosecond, None);N, selector_last(cpu.usage_idle,cpu.time):Struct([Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "time", data_type: Timestamp(Nanosecond, None), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]);N]
|
||||
TableScan: cpu [cpu:Dictionary(Int32, Utf8);N, host:Dictionary(Int32, Utf8);N, region:Dictionary(Int32, Utf8);N, time:Timestamp(Nanosecond, None), usage_idle:Float64;N, usage_system:Float64;N, usage_user:Float64;N]
|
||||
"###);
|
||||
|
||||
// aggregate query, grouping by time with gap filling
|
||||
assert_snapshot!(plan("SELECT FIRST(usage_idle) FROM cpu GROUP BY TIME(5s) FILL(0)"), @r###"
|
||||
Sort: time ASC NULLS LAST [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, first:Float64;N]
|
||||
Projection: Dictionary(Int32, Utf8("cpu")) AS iox::measurement, time, coalesce(selector_first_value(cpu.usage_idle,cpu.time), Float64(0)) AS first [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, first:Float64;N]
|
||||
GapFill: groupBy=[[time]], aggr=[[selector_first_value(cpu.usage_idle,cpu.time)]], time_column=time, stride=IntervalMonthDayNano("5000000000"), range=Unbounded..Excluded(now()) [time:Timestamp(Nanosecond, None);N, selector_first_value(cpu.usage_idle,cpu.time):Float64;N]
|
||||
Aggregate: groupBy=[[datebin(IntervalMonthDayNano("5000000000"), cpu.time, TimestampNanosecond(0, None)) AS time]], aggr=[[selector_first_value(cpu.usage_idle, cpu.time)]] [time:Timestamp(Nanosecond, None);N, selector_first_value(cpu.usage_idle,cpu.time):Float64;N]
|
||||
Projection: Dictionary(Int32, Utf8("cpu")) AS iox::measurement, time, (coalesce_struct(selector_first(cpu.usage_idle,cpu.time), Struct({value:Float64(0),time:TimestampNanosecond(0, None)})))[value] AS first [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, first:Float64;N]
|
||||
GapFill: groupBy=[[time]], aggr=[[selector_first(cpu.usage_idle,cpu.time)]], time_column=time, stride=IntervalMonthDayNano("5000000000"), range=Unbounded..Excluded(now()) [time:Timestamp(Nanosecond, None);N, selector_first(cpu.usage_idle,cpu.time):Struct([Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "time", data_type: Timestamp(Nanosecond, None), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]);N]
|
||||
Aggregate: groupBy=[[datebin(IntervalMonthDayNano("5000000000"), cpu.time, TimestampNanosecond(0, None)) AS time]], aggr=[[selector_first(cpu.usage_idle, cpu.time)]] [time:Timestamp(Nanosecond, None);N, selector_first(cpu.usage_idle,cpu.time):Struct([Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "time", data_type: Timestamp(Nanosecond, None), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]);N]
|
||||
TableScan: cpu [cpu:Dictionary(Int32, Utf8);N, host:Dictionary(Int32, Utf8);N, region:Dictionary(Int32, Utf8);N, time:Timestamp(Nanosecond, None), usage_idle:Float64;N, usage_system:Float64;N, usage_user:Float64;N]
|
||||
"###);
|
||||
|
||||
// aggregate query, as we're specifying multiple selectors or aggregates
|
||||
assert_snapshot!(plan("SELECT LAST(usage_idle), FIRST(usage_idle) FROM cpu"), @r###"
|
||||
Sort: time ASC NULLS LAST [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None), last:Float64;N, first:Float64;N]
|
||||
Projection: Dictionary(Int32, Utf8("cpu")) AS iox::measurement, TimestampNanosecond(0, None) AS time, selector_last_value(cpu.usage_idle,cpu.time) AS last, selector_first_value(cpu.usage_idle,cpu.time) AS first [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None), last:Float64;N, first:Float64;N]
|
||||
Aggregate: groupBy=[[]], aggr=[[selector_last_value(cpu.usage_idle, cpu.time), selector_first_value(cpu.usage_idle, cpu.time)]] [selector_last_value(cpu.usage_idle,cpu.time):Float64;N, selector_first_value(cpu.usage_idle,cpu.time):Float64;N]
|
||||
Projection: Dictionary(Int32, Utf8("cpu")) AS iox::measurement, TimestampNanosecond(0, None) AS time, (selector_last(cpu.usage_idle,cpu.time))[value] AS last, (selector_first(cpu.usage_idle,cpu.time))[value] AS first [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None), last:Float64;N, first:Float64;N]
|
||||
Aggregate: groupBy=[[]], aggr=[[selector_last(cpu.usage_idle, cpu.time), selector_first(cpu.usage_idle, cpu.time)]] [selector_last(cpu.usage_idle,cpu.time):Struct([Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "time", data_type: Timestamp(Nanosecond, None), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]);N, selector_first(cpu.usage_idle,cpu.time):Struct([Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "time", data_type: Timestamp(Nanosecond, None), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]);N]
|
||||
TableScan: cpu [cpu:Dictionary(Int32, Utf8);N, host:Dictionary(Int32, Utf8);N, region:Dictionary(Int32, Utf8);N, time:Timestamp(Nanosecond, None), usage_idle:Float64;N, usage_system:Float64;N, usage_user:Float64;N]
|
||||
"###);
|
||||
assert_snapshot!(plan("SELECT LAST(usage_idle), COUNT(usage_idle) FROM cpu"), @r###"
|
||||
Sort: time ASC NULLS LAST [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None), last:Float64;N, count:Int64;N]
|
||||
Projection: Dictionary(Int32, Utf8("cpu")) AS iox::measurement, TimestampNanosecond(0, None) AS time, selector_last_value(cpu.usage_idle,cpu.time) AS last, COUNT(cpu.usage_idle) AS count [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None), last:Float64;N, count:Int64;N]
|
||||
Aggregate: groupBy=[[]], aggr=[[selector_last_value(cpu.usage_idle, cpu.time), COUNT(cpu.usage_idle)]] [selector_last_value(cpu.usage_idle,cpu.time):Float64;N, COUNT(cpu.usage_idle):Int64;N]
|
||||
Projection: Dictionary(Int32, Utf8("cpu")) AS iox::measurement, TimestampNanosecond(0, None) AS time, (selector_last(cpu.usage_idle,cpu.time))[value] AS last, COUNT(cpu.usage_idle) AS count [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None), last:Float64;N, count:Int64;N]
|
||||
Aggregate: groupBy=[[]], aggr=[[selector_last(cpu.usage_idle, cpu.time), COUNT(cpu.usage_idle)]] [selector_last(cpu.usage_idle,cpu.time):Struct([Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "time", data_type: Timestamp(Nanosecond, None), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]);N, COUNT(cpu.usage_idle):Int64;N]
|
||||
TableScan: cpu [cpu:Dictionary(Int32, Utf8);N, host:Dictionary(Int32, Utf8);N, region:Dictionary(Int32, Utf8);N, time:Timestamp(Nanosecond, None), usage_idle:Float64;N, usage_system:Float64;N, usage_user:Float64;N]
|
||||
"###);
|
||||
|
||||
|
@ -3413,7 +3393,7 @@ mod test {
|
|||
"###);
|
||||
assert_snapshot!(plan("SELECT COUNT(f64_field) FROM data GROUP BY TIME(10s) FILL(0)"), @r###"
|
||||
Sort: time ASC NULLS LAST [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count:Int64;N]
|
||||
Projection: Dictionary(Int32, Utf8("data")) AS iox::measurement, time, coalesce(COUNT(data.f64_field), Int64(0)) AS count [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count:Int64;N]
|
||||
Projection: Dictionary(Int32, Utf8("data")) AS iox::measurement, time, coalesce_struct(COUNT(data.f64_field), Int64(0)) AS count [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count:Int64;N]
|
||||
GapFill: groupBy=[[time]], aggr=[[COUNT(data.f64_field)]], time_column=time, stride=IntervalMonthDayNano("10000000000"), range=Unbounded..Excluded(now()) [time:Timestamp(Nanosecond, None);N, COUNT(data.f64_field):Int64;N]
|
||||
Aggregate: groupBy=[[datebin(IntervalMonthDayNano("10000000000"), data.time, TimestampNanosecond(0, None)) AS time]], aggr=[[COUNT(data.f64_field)]] [time:Timestamp(Nanosecond, None);N, COUNT(data.f64_field):Int64;N]
|
||||
TableScan: data [TIME:Boolean;N, bar:Dictionary(Int32, Utf8);N, bool_field:Boolean;N, f64_field:Float64;N, foo:Dictionary(Int32, Utf8);N, i64_field:Int64;N, mixedCase:Float64;N, str_field:Utf8;N, time:Timestamp(Nanosecond, None), with space:Float64;N]
|
||||
|
@ -3429,7 +3409,7 @@ mod test {
|
|||
// Coalesces the fill value, which is a float, to the matching type of a `COUNT` aggregate.
|
||||
assert_snapshot!(plan("SELECT COUNT(f64_field) FROM data GROUP BY TIME(10s) FILL(3.2)"), @r###"
|
||||
Sort: time ASC NULLS LAST [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count:Int64;N]
|
||||
Projection: Dictionary(Int32, Utf8("data")) AS iox::measurement, time, coalesce(COUNT(data.f64_field), Int64(3)) AS count [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count:Int64;N]
|
||||
Projection: Dictionary(Int32, Utf8("data")) AS iox::measurement, time, coalesce_struct(COUNT(data.f64_field), Int64(3)) AS count [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count:Int64;N]
|
||||
GapFill: groupBy=[[time]], aggr=[[COUNT(data.f64_field)]], time_column=time, stride=IntervalMonthDayNano("10000000000"), range=Unbounded..Excluded(now()) [time:Timestamp(Nanosecond, None);N, COUNT(data.f64_field):Int64;N]
|
||||
Aggregate: groupBy=[[datebin(IntervalMonthDayNano("10000000000"), data.time, TimestampNanosecond(0, None)) AS time]], aggr=[[COUNT(data.f64_field)]] [time:Timestamp(Nanosecond, None);N, COUNT(data.f64_field):Int64;N]
|
||||
TableScan: data [TIME:Boolean;N, bar:Dictionary(Int32, Utf8);N, bool_field:Boolean;N, f64_field:Float64;N, foo:Dictionary(Int32, Utf8);N, i64_field:Int64;N, mixedCase:Float64;N, str_field:Utf8;N, time:Timestamp(Nanosecond, None), with space:Float64;N]
|
||||
|
@ -3438,7 +3418,7 @@ mod test {
|
|||
// Aggregates as part of a binary expression
|
||||
assert_snapshot!(plan("SELECT COUNT(f64_field) + MEAN(f64_field) FROM data GROUP BY TIME(10s) FILL(3.2)"), @r###"
|
||||
Sort: time ASC NULLS LAST [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count_f64_field_mean_f64_field:Float64;N]
|
||||
Projection: Dictionary(Int32, Utf8("data")) AS iox::measurement, time, coalesce(COUNT(data.f64_field), Int64(3)) + coalesce(AVG(data.f64_field), Float64(3.2)) AS count_f64_field_mean_f64_field [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count_f64_field_mean_f64_field:Float64;N]
|
||||
Projection: Dictionary(Int32, Utf8("data")) AS iox::measurement, time, coalesce_struct(COUNT(data.f64_field), Int64(3)) + coalesce_struct(AVG(data.f64_field), Float64(3.2)) AS count_f64_field_mean_f64_field [iox::measurement:Dictionary(Int32, Utf8), time:Timestamp(Nanosecond, None);N, count_f64_field_mean_f64_field:Float64;N]
|
||||
GapFill: groupBy=[[time]], aggr=[[COUNT(data.f64_field), AVG(data.f64_field)]], time_column=time, stride=IntervalMonthDayNano("10000000000"), range=Unbounded..Excluded(now()) [time:Timestamp(Nanosecond, None);N, COUNT(data.f64_field):Int64;N, AVG(data.f64_field):Float64;N]
|
||||
Aggregate: groupBy=[[datebin(IntervalMonthDayNano("10000000000"), data.time, TimestampNanosecond(0, None)) AS time]], aggr=[[COUNT(data.f64_field), AVG(data.f64_field)]] [time:Timestamp(Nanosecond, None);N, COUNT(data.f64_field):Int64;N, AVG(data.f64_field):Float64;N]
|
||||
TableScan: data [TIME:Boolean;N, bar:Dictionary(Int32, Utf8);N, bool_field:Boolean;N, f64_field:Float64;N, foo:Dictionary(Int32, Utf8);N, i64_field:Int64;N, mixedCase:Float64;N, str_field:Utf8;N, time:Timestamp(Nanosecond, None), with space:Float64;N]
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
use crate::plan::{error, util_copy};
|
||||
use arrow::datatypes::DataType;
|
||||
use arrow::datatypes::{DataType, TimeUnit};
|
||||
use datafusion::common::{DFSchema, DFSchemaRef, Result};
|
||||
use datafusion::logical_expr::utils::expr_as_column_expr;
|
||||
use datafusion::logical_expr::{coalesce, lit, Expr, ExprSchemable, LogicalPlan, Operator};
|
||||
use datafusion::logical_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator};
|
||||
use datafusion::scalar::ScalarValue;
|
||||
use influxdb_influxql_parser::expression::BinaryOperator;
|
||||
use influxdb_influxql_parser::literal::Number;
|
||||
use influxdb_influxql_parser::string::Regex;
|
||||
use query_functions::clean_non_meta_escapes;
|
||||
use query_functions::coalesce_struct::coalesce_struct;
|
||||
use schema::Schema;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -55,15 +57,30 @@ pub(crate) fn parse_regex(re: &Regex) -> Result<regex::Regex> {
|
|||
.map_err(|e| error::map::query(format!("invalid regular expression '{re}': {e}")))
|
||||
}
|
||||
|
||||
/// Returns `n` as a literal expression of the specified `data_type`.
|
||||
fn number_to_expr(n: &Number, data_type: DataType) -> Result<Expr> {
|
||||
/// Returns `n` as a scalar value of the specified `data_type`.
|
||||
fn number_to_scalar(n: &Number, data_type: &DataType) -> Result<ScalarValue> {
|
||||
Ok(match (n, data_type) {
|
||||
(Number::Integer(v), DataType::Int64) => lit(*v),
|
||||
(Number::Integer(v), DataType::Float64) => lit(*v as f64),
|
||||
(Number::Integer(v), DataType::UInt64) => lit(*v as u64),
|
||||
(Number::Float(v), DataType::Int64) => lit(*v as i64),
|
||||
(Number::Float(v), DataType::Float64) => lit(*v),
|
||||
(Number::Float(v), DataType::UInt64) => lit(*v as u64),
|
||||
(Number::Integer(v), DataType::Int64) => ScalarValue::from(*v),
|
||||
(Number::Integer(v), DataType::Float64) => ScalarValue::from(*v as f64),
|
||||
(Number::Integer(v), DataType::UInt64) => ScalarValue::from(*v as u64),
|
||||
(Number::Integer(v), DataType::Timestamp(TimeUnit::Nanosecond, tz)) => {
|
||||
ScalarValue::TimestampNanosecond(Some(*v), tz.clone())
|
||||
}
|
||||
(Number::Float(v), DataType::Int64) => ScalarValue::from(*v as i64),
|
||||
(Number::Float(v), DataType::Float64) => ScalarValue::from(*v),
|
||||
(Number::Float(v), DataType::UInt64) => ScalarValue::from(*v as u64),
|
||||
(Number::Float(v), DataType::Timestamp(TimeUnit::Nanosecond, tz)) => {
|
||||
ScalarValue::TimestampNanosecond(Some(*v as i64), tz.clone())
|
||||
}
|
||||
(n, DataType::Struct(fields)) => ScalarValue::Struct(
|
||||
Some(
|
||||
fields
|
||||
.iter()
|
||||
.map(|f| number_to_scalar(n, f.data_type()))
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
),
|
||||
fields.clone(),
|
||||
),
|
||||
(n, data_type) => {
|
||||
// The only output data types expected are Int64, Float64 or UInt64
|
||||
return error::internal(format!("no conversion from {n} to {data_type}"));
|
||||
|
@ -99,7 +116,10 @@ pub(crate) fn rebase_expr(
|
|||
Ok(if base_exprs.contains(nested_expr) {
|
||||
let col_expr = expr_as_column_expr(nested_expr, plan)?;
|
||||
let data_type = col_expr.get_type(plan.schema())?;
|
||||
Some(coalesce(vec![col_expr, number_to_expr(value, data_type)?]))
|
||||
Some(coalesce_struct(vec![
|
||||
col_expr,
|
||||
lit(number_to_scalar(value, &data_type)?),
|
||||
]))
|
||||
} else {
|
||||
None
|
||||
})
|
||||
|
|
|
@ -0,0 +1,419 @@
|
|||
//! `COALESCE`, but works for structs.
|
||||
//!
|
||||
//! Candidate for upstreaming as per <https://github.com/apache/arrow-datafusion/issues/6074>.
|
||||
//!
|
||||
//! For struct types, this preforms a recursive "first none-null" filling.
|
||||
//!
|
||||
//! For non-struct types (like uint32) this works like the normal `coalesce` function.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```sql
|
||||
//! coalesce_nested(
|
||||
//! NULL,
|
||||
//! {
|
||||
//! a: 1,
|
||||
//! b: NULL,
|
||||
//! c: NULL,
|
||||
//! d: NULL,
|
||||
//! },
|
||||
//! {
|
||||
//! a: 2,
|
||||
//! b: NULL,
|
||||
//! c: {a: NULL},
|
||||
//! d: {a: 2, b: NULL},
|
||||
//! },
|
||||
//! {
|
||||
//! a: 3,
|
||||
//! b: NULL,
|
||||
//! c: NULL,
|
||||
//! d: {a: 3, b: 3},
|
||||
//! },
|
||||
//! )
|
||||
//!
|
||||
//! =
|
||||
//!
|
||||
//! {
|
||||
//! a: 1,
|
||||
//! b: NULL,
|
||||
//! c: {a: NULL},
|
||||
//! d: {a: 2, b: 3},
|
||||
//! }
|
||||
//! ```
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::{
|
||||
array::{Array, StructArray},
|
||||
compute::{is_null, kernels::zip::zip},
|
||||
datatypes::DataType,
|
||||
};
|
||||
use datafusion::{
|
||||
common::cast::as_struct_array,
|
||||
error::DataFusionError,
|
||||
logical_expr::{
|
||||
ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
|
||||
},
|
||||
physical_plan::ColumnarValue,
|
||||
prelude::Expr,
|
||||
scalar::ScalarValue,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
/// The name of the `coalesce_struct` UDF given to DataFusion.
|
||||
pub const COALESCE_STRUCT_UDF_NAME: &str = "coalesce_struct";
|
||||
|
||||
/// Implementation of `coalesce_struct`.
|
||||
///
|
||||
/// See [module-level docs](self) for more information.
|
||||
pub static COALESCE_STRUCT_UDF: Lazy<Arc<ScalarUDF>> = Lazy::new(|| {
|
||||
let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
|
||||
if arg_types.is_empty() {
|
||||
return Err(DataFusionError::Plan(format!(
|
||||
"{COALESCE_STRUCT_UDF_NAME} expects at least 1 argument"
|
||||
)));
|
||||
}
|
||||
let first_dt = &arg_types[0];
|
||||
|
||||
for (idx, dt) in arg_types.iter().enumerate() {
|
||||
if dt != first_dt {
|
||||
let idx = idx + 1;
|
||||
return Err(DataFusionError::Plan(format!(
|
||||
"{COALESCE_STRUCT_UDF_NAME} expects all arguments to have the same type, but first arg is '{first_dt}' and arg {idx} (1-based) is '{dt}'",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Arc::new(first_dt.clone()))
|
||||
});
|
||||
|
||||
let fun: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| {
|
||||
args.iter().enumerate().fold(Ok(None), |accu, (pos, arg)| {
|
||||
let Some(accu) = accu? else {return Ok(Some(arg.clone()))};
|
||||
|
||||
if accu.data_type() != arg.data_type() {
|
||||
return Err(DataFusionError::Plan(format!(
|
||||
"{} expects all arguments to have the same type, but first arg is '{}' and arg {} (1-based) is '{}'",
|
||||
COALESCE_STRUCT_UDF_NAME,
|
||||
accu.data_type(),
|
||||
pos + 1,
|
||||
arg.data_type(),
|
||||
)));
|
||||
}
|
||||
|
||||
let (array1, array2) = match (accu, arg) {
|
||||
(ColumnarValue::Scalar(scalar1), ColumnarValue::Scalar(scalar2)) => {
|
||||
return Ok(Some(ColumnarValue::Scalar(scalar_coalesce_struct(scalar1, scalar2))));
|
||||
}
|
||||
(ColumnarValue::Scalar(s), ColumnarValue::Array(array2)) => {
|
||||
let array1 = s.to_array_of_size(array2.len());
|
||||
(array1, Arc::clone(array2))
|
||||
}
|
||||
(ColumnarValue::Array(array1), ColumnarValue::Scalar(s)) => {
|
||||
let array2 = s.to_array_of_size(array1.len());
|
||||
(array1, array2)
|
||||
}
|
||||
(ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => {
|
||||
(array1, Arc::clone(array2))
|
||||
}
|
||||
};
|
||||
|
||||
let array = arrow_coalesce_struct(&array1, &array2)?;
|
||||
Ok(Some(ColumnarValue::Array(array)))
|
||||
})?.ok_or_else(|| DataFusionError::Plan(format!(
|
||||
"{COALESCE_STRUCT_UDF_NAME} expects at least 1 argument"
|
||||
)))
|
||||
});
|
||||
|
||||
Arc::new(ScalarUDF::new(
|
||||
COALESCE_STRUCT_UDF_NAME,
|
||||
&Signature::variadic_any(Volatility::Immutable),
|
||||
&return_type,
|
||||
&fun,
|
||||
))
|
||||
});
|
||||
|
||||
/// Recursively fold [`Array`]s.
|
||||
fn arrow_coalesce_struct(
|
||||
array1: &dyn Array,
|
||||
array2: &dyn Array,
|
||||
) -> Result<Arc<dyn Array>, DataFusionError> {
|
||||
if matches!(array1.data_type(), DataType::Struct(_)) {
|
||||
let array1 = as_struct_array(array1)?;
|
||||
let array2 = as_struct_array(array2)?;
|
||||
|
||||
let cols = array1
|
||||
.columns()
|
||||
.iter()
|
||||
.zip(array2.columns())
|
||||
.zip(array1.fields())
|
||||
.map(|((col1, col2), field)| {
|
||||
let out = arrow_coalesce_struct(&col1, &col2)?;
|
||||
// TODO: avoid field clone once https://github.com/apache/arrow-rs/pull/4116 is available
|
||||
Ok((field.as_ref().clone(), out)) as Result<_, DataFusionError>
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let array = StructArray::from(cols);
|
||||
Ok(Arc::new(array))
|
||||
} else {
|
||||
let array = zip(&is_null(array1)?, array2, array1)?;
|
||||
Ok(array)
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively fold [`ScalarValue`]s.
|
||||
fn scalar_coalesce_struct(scalar1: ScalarValue, scalar2: &ScalarValue) -> ScalarValue {
|
||||
match (scalar1, scalar2) {
|
||||
(ScalarValue::Struct(Some(vals1), fields1), ScalarValue::Struct(Some(vals2), _)) => {
|
||||
let vals = vals1
|
||||
.into_iter()
|
||||
.zip(vals2)
|
||||
.map(|(v1, v2)| scalar_coalesce_struct(v1, v2))
|
||||
.collect();
|
||||
ScalarValue::Struct(Some(vals), fields1)
|
||||
}
|
||||
(scalar1, scalar2) if scalar1.is_null() => scalar2.clone(),
|
||||
(scalar1, _) => scalar1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create logical `coalesce_struct` expression.
|
||||
///
|
||||
/// See [module-level docs](self) for more information.
|
||||
pub fn coalesce_struct(args: Vec<Expr>) -> Expr {
|
||||
Expr::ScalarUDF {
|
||||
fun: Arc::clone(&COALESCE_STRUCT_UDF),
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::{
|
||||
datatypes::{Field, Fields, Schema},
|
||||
record_batch::RecordBatch,
|
||||
};
|
||||
use datafusion::{
|
||||
assert_batches_eq,
|
||||
common::assert_contains,
|
||||
prelude::{col, lit},
|
||||
scalar::ScalarValue,
|
||||
};
|
||||
use datafusion_util::context_with_table;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test() {
|
||||
let fields_b = Fields::from(vec![
|
||||
Field::new("ba", DataType::UInt64, true),
|
||||
Field::new("bb", DataType::UInt64, true),
|
||||
]);
|
||||
let fields = Fields::from(vec![
|
||||
Field::new("a", DataType::UInt64, true),
|
||||
Field::new("b", DataType::Struct(fields_b.clone()), true),
|
||||
]);
|
||||
let dt = DataType::Struct(fields.clone());
|
||||
|
||||
assert_case_ok(
|
||||
[
|
||||
ColumnarValue::Array(ScalarValue::UInt64(None).to_array()),
|
||||
ColumnarValue::Array(ScalarValue::UInt64(Some(1)).to_array()),
|
||||
ColumnarValue::Array(ScalarValue::UInt64(Some(2)).to_array()),
|
||||
],
|
||||
&DataType::UInt64,
|
||||
["+-----+", "| out |", "+-----+", "| 1 |", "+-----+"],
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_case_ok(
|
||||
[ColumnarValue::Array(
|
||||
ScalarValue::Struct(None, fields.clone()).to_array(),
|
||||
)],
|
||||
&dt,
|
||||
["+-----+", "| out |", "+-----+", "| |", "+-----+"],
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_case_ok(
|
||||
[
|
||||
ColumnarValue::Array(ScalarValue::Struct(None, fields.clone()).to_array()),
|
||||
ColumnarValue::Array(
|
||||
ScalarValue::Struct(
|
||||
Some(vec![
|
||||
ScalarValue::UInt64(Some(1)),
|
||||
ScalarValue::Struct(None, fields_b.clone()),
|
||||
]),
|
||||
fields.clone(),
|
||||
)
|
||||
.to_array(),
|
||||
),
|
||||
ColumnarValue::Array(ScalarValue::Struct(None, fields.clone()).to_array()),
|
||||
ColumnarValue::Array(
|
||||
ScalarValue::Struct(
|
||||
Some(vec![
|
||||
ScalarValue::UInt64(Some(2)),
|
||||
ScalarValue::Struct(
|
||||
Some(vec![
|
||||
ScalarValue::UInt64(Some(3)),
|
||||
ScalarValue::UInt64(None),
|
||||
]),
|
||||
fields_b.clone(),
|
||||
),
|
||||
]),
|
||||
fields.clone(),
|
||||
)
|
||||
.to_array(),
|
||||
),
|
||||
],
|
||||
&dt,
|
||||
[
|
||||
"+--------------------------+",
|
||||
"| out |",
|
||||
"+--------------------------+",
|
||||
"| {a: 1, b: {ba: 3, bb: }} |",
|
||||
"+--------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
// same case as above, but with ColumnarValue::Scalar
|
||||
assert_case_ok(
|
||||
[
|
||||
ColumnarValue::Scalar(ScalarValue::Struct(None, fields.clone())),
|
||||
ColumnarValue::Scalar(ScalarValue::Struct(
|
||||
Some(vec![
|
||||
ScalarValue::UInt64(Some(1)),
|
||||
ScalarValue::Struct(None, fields_b.clone()),
|
||||
]),
|
||||
fields.clone(),
|
||||
)),
|
||||
ColumnarValue::Scalar(ScalarValue::Struct(None, fields.clone())),
|
||||
ColumnarValue::Scalar(ScalarValue::Struct(
|
||||
Some(vec![
|
||||
ScalarValue::UInt64(Some(2)),
|
||||
ScalarValue::Struct(
|
||||
Some(vec![
|
||||
ScalarValue::UInt64(Some(3)),
|
||||
ScalarValue::UInt64(None),
|
||||
]),
|
||||
fields_b.clone(),
|
||||
),
|
||||
]),
|
||||
fields.clone(),
|
||||
)),
|
||||
ColumnarValue::Array(ScalarValue::Struct(None, fields.clone()).to_array()),
|
||||
],
|
||||
&dt,
|
||||
[
|
||||
"+--------------------------+",
|
||||
"| out |",
|
||||
"+--------------------------+",
|
||||
"| {a: 1, b: {ba: 3, bb: }} |",
|
||||
"+--------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_case_err(
|
||||
[],
|
||||
&dt,
|
||||
"Error during planning: coalesce_struct expects at least 1 argument",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_case_err(
|
||||
[ColumnarValue::Array(ScalarValue::Struct(None, fields.clone()).to_array()), ColumnarValue::Array(ScalarValue::Struct(None, fields_b.clone()).to_array())],
|
||||
&dt,
|
||||
"Error during planning: coalesce_struct expects all arguments to have the same type, but first arg is"
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_case_err(
|
||||
[ColumnarValue::Array(ScalarValue::Struct(None, fields.clone()).to_array()), ColumnarValue::Scalar(ScalarValue::Struct(None, fields_b.clone()))],
|
||||
&dt,
|
||||
"Error during planning: coalesce_struct expects all arguments to have the same type, but first arg is"
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_case_err(
|
||||
[ColumnarValue::Scalar(ScalarValue::Struct(None, fields.clone())), ColumnarValue::Array(ScalarValue::Struct(None, fields_b.clone()).to_array())],
|
||||
&dt,
|
||||
"Error during planning: coalesce_struct expects all arguments to have the same type, but first arg is"
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_case_err(
|
||||
[ColumnarValue::Scalar(ScalarValue::Struct(None, fields.clone())), ColumnarValue::Scalar(ScalarValue::Struct(None, fields_b.clone()))],
|
||||
&dt,
|
||||
"Error during planning: coalesce_struct expects all arguments to have the same type, but first arg is"
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn assert_case_ok<const N: usize, const M: usize>(
|
||||
vals: [ColumnarValue; N],
|
||||
dt: &DataType,
|
||||
expected: [&'static str; M],
|
||||
) {
|
||||
let actual = run_plan(vals.to_vec(), dt).await.unwrap();
|
||||
assert_batches_eq!(expected, &actual);
|
||||
}
|
||||
|
||||
async fn assert_case_err<const N: usize>(
|
||||
vals: [ColumnarValue; N],
|
||||
dt: &DataType,
|
||||
expected: &'static str,
|
||||
) {
|
||||
let actual = run_plan(vals.to_vec(), dt).await.unwrap_err();
|
||||
assert_contains!(actual.to_string(), expected);
|
||||
}
|
||||
|
||||
async fn run_plan(
|
||||
vals: Vec<ColumnarValue>,
|
||||
dt: &DataType,
|
||||
) -> Result<Vec<RecordBatch>, DataFusionError> {
|
||||
let col_names = (0..vals.len())
|
||||
.map(|idx| format!("col{idx}"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cols = vals
|
||||
.iter()
|
||||
.zip(&col_names)
|
||||
.filter_map(|(val, col_name)| match val {
|
||||
ColumnarValue::Array(a) => Some((col_name.as_str(), Arc::clone(a))),
|
||||
ColumnarValue::Scalar(_) => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let rb = if cols.is_empty() {
|
||||
RecordBatch::new_empty(Arc::new(Schema::new([])))
|
||||
} else {
|
||||
RecordBatch::try_from_iter(cols.into_iter())?
|
||||
};
|
||||
|
||||
let ctx = context_with_table(rb);
|
||||
let df = ctx.table("t").await?;
|
||||
let df = df.select(vec![coalesce_struct(
|
||||
vals.iter()
|
||||
.zip(col_names)
|
||||
.map(|(val, col_name)| match val {
|
||||
ColumnarValue::Array(_) => col(col_name),
|
||||
ColumnarValue::Scalar(s) => lit(s.clone()),
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
.alias("out")])?;
|
||||
|
||||
// execute the query
|
||||
let batches: Vec<RecordBatch> = df.collect().await?;
|
||||
assert_eq!(batches.len(), 1);
|
||||
assert_eq!(batches[0].num_rows(), 1);
|
||||
|
||||
for batch in &batches {
|
||||
assert_eq!(batch.num_columns(), 1);
|
||||
assert_eq!(batch.column(0).data_type(), dt);
|
||||
}
|
||||
|
||||
Ok(batches)
|
||||
}
|
||||
}
|
|
@ -18,6 +18,8 @@ use datafusion::{
|
|||
use group_by::WindowDuration;
|
||||
use window::EncodedWindowDuration;
|
||||
|
||||
pub mod coalesce_struct;
|
||||
|
||||
/// Grouping by structs
|
||||
pub mod group_by;
|
||||
|
||||
|
|
Loading…
Reference in New Issue