feat: selector last/min/max w/ other values (#7977)
* fix: size calculation for `selector_first` * test: extract common error tests * feat: additional args for `selector_last` * refactor: de-dup code * fix: break tie for first/last selector * feat: additional args for `selector_min` * feat: additional args for `selector_max` * fix: use same tie-breaker * refactor: de-dup code * refactor: simplify code --------- Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>pull/24376/head
parent
7fef809b2a
commit
1e1488aad0
|
@ -97,17 +97,16 @@
|
|||
//! [selector functions]: https://docs.influxdata.com/influxdb/v1.8/query_language/functions/#selectors
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use arrow::{array::ArrayRef, datatypes::DataType};
|
||||
use arrow::datatypes::DataType;
|
||||
use datafusion::{
|
||||
error::{DataFusionError, Result as DataFusionResult},
|
||||
error::Result as DataFusionResult,
|
||||
logical_expr::{AccumulatorFunctionImplementation, Signature, Volatility},
|
||||
physical_plan::{udaf::AggregateUDF, Accumulator},
|
||||
prelude::SessionContext,
|
||||
scalar::ScalarValue,
|
||||
};
|
||||
|
||||
mod internal;
|
||||
use internal::{FirstSelector, LastSelector, MaxSelector, MinSelector, Selector};
|
||||
use internal::{Comparison, Selector, Target};
|
||||
|
||||
mod type_handling;
|
||||
use type_handling::AggType;
|
||||
|
@ -228,34 +227,30 @@ impl FactoryBuilder {
|
|||
let other_types = agg_type.other_types;
|
||||
|
||||
let accumulator: Box<dyn Accumulator> = match selector_type {
|
||||
SelectorType::First => Box::new(SelectorAccumulator::new(FirstSelector::new(
|
||||
SelectorType::First => Box::new(Selector::new(
|
||||
Comparison::Min,
|
||||
Target::Time,
|
||||
value_type,
|
||||
other_types.iter().cloned(),
|
||||
)?)),
|
||||
SelectorType::Last => {
|
||||
if !other_types.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector last w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
Box::new(SelectorAccumulator::new(LastSelector::new(value_type)?))
|
||||
}
|
||||
SelectorType::Min => {
|
||||
if !other_types.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector min w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
Box::new(SelectorAccumulator::new(MinSelector::new(value_type)?))
|
||||
}
|
||||
SelectorType::Max => {
|
||||
if !other_types.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector max w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
Box::new(SelectorAccumulator::new(MaxSelector::new(value_type)?))
|
||||
}
|
||||
)?),
|
||||
SelectorType::Last => Box::new(Selector::new(
|
||||
Comparison::Max,
|
||||
Target::Time,
|
||||
value_type,
|
||||
other_types.iter().cloned(),
|
||||
)?),
|
||||
SelectorType::Min => Box::new(Selector::new(
|
||||
Comparison::Min,
|
||||
Target::Value,
|
||||
value_type,
|
||||
other_types.iter().cloned(),
|
||||
)?),
|
||||
SelectorType::Max => Box::new(Selector::new(
|
||||
Comparison::Max,
|
||||
Target::Value,
|
||||
value_type,
|
||||
other_types.iter().cloned(),
|
||||
)?),
|
||||
};
|
||||
Ok(accumulator)
|
||||
})
|
||||
|
@ -293,79 +288,6 @@ fn make_uda(name: &str, factory_builder: FactoryBuilder) -> AggregateUDF {
|
|||
)
|
||||
}
|
||||
|
||||
/// Structure that implements the Accumulator trait for DataFusion
|
||||
/// and processes (value, timestamp) pair and computes values
|
||||
#[derive(Debug)]
|
||||
struct SelectorAccumulator<SELECTOR>
|
||||
where
|
||||
SELECTOR: Selector,
|
||||
{
|
||||
// The underlying implementation for the selector
|
||||
selector: SELECTOR,
|
||||
}
|
||||
|
||||
impl<SELECTOR> SelectorAccumulator<SELECTOR>
|
||||
where
|
||||
SELECTOR: Selector,
|
||||
{
|
||||
pub fn new(selector: SELECTOR) -> Self {
|
||||
Self { selector }
|
||||
}
|
||||
}
|
||||
|
||||
impl<SELECTOR> Accumulator for SelectorAccumulator<SELECTOR>
|
||||
where
|
||||
SELECTOR: Selector + 'static,
|
||||
{
|
||||
// this function serializes our state to a vector of
|
||||
// `ScalarValue`s, which DataFusion uses to pass this state
|
||||
// between execution stages.
|
||||
fn state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
self.selector.datafusion_state()
|
||||
}
|
||||
|
||||
/// Allocated size required for this accumulator, in bytes,
|
||||
/// including `Self`. Allocated means that for internal
|
||||
/// containers such as `Vec`, the `capacity` should be used not
|
||||
/// the `len`
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.selector) + self.selector.size()
|
||||
}
|
||||
|
||||
// Return the final value of this aggregator.
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
self.selector.evaluate()
|
||||
}
|
||||
|
||||
// This function receives one entry per argument of this
|
||||
// accumulator and updates the selector state function appropriately
|
||||
fn update_batch(&mut self, values: &[ArrayRef]) -> DataFusionResult<()> {
|
||||
if values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if values.len() < 2 {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"Internal error: Expected at least 2 arguments passed to selector function but got {}",
|
||||
values.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// invoke the actual worker function.
|
||||
self.selector
|
||||
.update_batch(&values[0], &values[1], &values[2..])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The input values and accumulator state are the same types for
|
||||
// selectors, and thus we can merge intermediate states with the
|
||||
// same function as inputs
|
||||
fn merge_batch(&mut self, states: &[ArrayRef]) -> DataFusionResult<()> {
|
||||
// merge is the same operation as update for these selectors
|
||||
self.update_batch(states)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use arrow::{
|
||||
|
@ -380,7 +302,7 @@ mod test {
|
|||
use datafusion::{datasource::MemTable, prelude::*};
|
||||
|
||||
use super::*;
|
||||
use utils::{run_case, run_case_err};
|
||||
use utils::{run_case, run_cases_err};
|
||||
|
||||
mod first {
|
||||
use super::*;
|
||||
|
@ -547,37 +469,24 @@ mod test {
|
|||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_tie_breaker() {
|
||||
run_case(
|
||||
selector_first().call(vec![col("f64_value"), col("time_dup")]),
|
||||
vec![
|
||||
"+------------------------------------------------+",
|
||||
"| selector_first(t.f64_value,t.time_dup) |",
|
||||
"+------------------------------------------------+",
|
||||
"| {value: 2.0, time: 1970-01-01T00:00:00.000001} |",
|
||||
"+------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_err() {
|
||||
run_case_err(
|
||||
selector_first().call(vec![]),
|
||||
"Error during planning: selector_first requires at least 2 arguments, got 0",
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector_first().call(vec![col("f64_value")]),
|
||||
"Error during planning: selector_first requires at least 2 arguments, got 1",
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector_first().call(vec![col("time"), col("f64_value")]),
|
||||
"Error during planning: selector_first second argument must be a timestamp, but got Float64",
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector_first().call(vec![col("time"), col("f64_value"), col("bool_value")]),
|
||||
"Error during planning: selector_first second argument must be a timestamp, but got Float64",
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector_first().call(vec![col("f64_value"), col("bool_value"), col("time")]),
|
||||
"Error during planning: selector_first second argument must be a timestamp, but got Boolean",
|
||||
)
|
||||
.await;
|
||||
run_cases_err(selector_first(), "selector_first").await;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -718,6 +627,53 @@ mod test {
|
|||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_other() {
|
||||
run_case(
|
||||
selector_last().call(vec![col("f64_value"), col("time"), col("bool_value"), col("f64_not_normal_3_value"), col("i64_2_value")]),
|
||||
vec![
|
||||
"+-------------------------------------------------------------------------------------------+",
|
||||
"| selector_last(t.f64_value,t.time,t.bool_value,t.f64_not_normal_3_value,t.i64_2_value) |",
|
||||
"+-------------------------------------------------------------------------------------------+",
|
||||
"| {value: 3.0, time: 1970-01-01T00:00:00.000006, other_1: false, other_2: NaN, other_3: 30} |",
|
||||
"+-------------------------------------------------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case(
|
||||
selector_last().call(vec![col("u64_2_value"), col("time"), col("bool_value"), col("f64_not_normal_4_value"), col("i64_2_value")]),
|
||||
vec![
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
"| selector_last(t.u64_2_value,t.time,t.bool_value,t.f64_not_normal_4_value,t.i64_2_value) |",
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
"| {value: 50, time: 1970-01-01T00:00:00.000005, other_1: false, other_2: inf, other_3: 50} |",
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_tie_breaker() {
|
||||
run_case(
|
||||
selector_last().call(vec![col("f64_value"), col("time_dup")]),
|
||||
vec![
|
||||
"+------------------------------------------------+",
|
||||
"| selector_last(t.f64_value,t.time_dup) |",
|
||||
"+------------------------------------------------+",
|
||||
"| {value: 5.0, time: 1970-01-01T00:00:00.000003} |",
|
||||
"+------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_err() {
|
||||
run_cases_err(selector_last(), "selector_last").await;
|
||||
}
|
||||
}
|
||||
|
||||
mod min {
|
||||
|
@ -845,6 +801,53 @@ mod test {
|
|||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_other() {
|
||||
run_case(
|
||||
selector_min().call(vec![col("u64_value"), col("time"), col("bool_value"), col("f64_not_normal_1_value"), col("i64_2_value")]),
|
||||
vec![
|
||||
"+---------------------------------------------------------------------------------------+",
|
||||
"| selector_min(t.u64_value,t.time,t.bool_value,t.f64_not_normal_1_value,t.i64_2_value) |",
|
||||
"+---------------------------------------------------------------------------------------+",
|
||||
"| {value: 10, time: 1970-01-01T00:00:00.000004, other_1: true, other_2: NaN, other_3: } |",
|
||||
"+---------------------------------------------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_tie_breaker() {
|
||||
run_case(
|
||||
selector_min().call(vec![col("f64_not_normal_2_value"), col("time_dup")]),
|
||||
vec![
|
||||
"+---------------------------------------------------+",
|
||||
"| selector_min(t.f64_not_normal_2_value,t.time_dup) |",
|
||||
"+---------------------------------------------------+",
|
||||
"| {value: -inf, time: 1970-01-01T00:00:00.000001} |",
|
||||
"+---------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case(
|
||||
selector_min().call(vec![col("bool_const"), col("time_dup")]),
|
||||
vec![
|
||||
"+-------------------------------------------------+",
|
||||
"| selector_min(t.bool_const,t.time_dup) |",
|
||||
"+-------------------------------------------------+",
|
||||
"| {value: true, time: 1970-01-01T00:00:00.000001} |",
|
||||
"+-------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_err() {
|
||||
run_cases_err(selector_min(), "selector_min").await;
|
||||
}
|
||||
}
|
||||
|
||||
mod max {
|
||||
|
@ -972,6 +975,53 @@ mod test {
|
|||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_other() {
|
||||
run_case(
|
||||
selector_max().call(vec![col("u64_value"), col("time"), col("bool_value"), col("f64_not_normal_1_value"), col("i64_2_value")]),
|
||||
vec![
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
"| selector_max(t.u64_value,t.time,t.bool_value,t.f64_not_normal_1_value,t.i64_2_value) |",
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
"| {value: 50, time: 1970-01-01T00:00:00.000005, other_1: false, other_2: inf, other_3: 50} |",
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_tie_breaker() {
|
||||
run_case(
|
||||
selector_max().call(vec![col("f64_not_normal_2_value"), col("time_dup")]),
|
||||
vec![
|
||||
"+---------------------------------------------------+",
|
||||
"| selector_max(t.f64_not_normal_2_value,t.time_dup) |",
|
||||
"+---------------------------------------------------+",
|
||||
"| {value: inf, time: 1970-01-01T00:00:00.000002} |",
|
||||
"+---------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case(
|
||||
selector_max().call(vec![col("bool_const"), col("time_dup")]),
|
||||
vec![
|
||||
"+-------------------------------------------------+",
|
||||
"| selector_max(t.bool_const,t.time_dup) |",
|
||||
"+-------------------------------------------------+",
|
||||
"| {value: true, time: 1970-01-01T00:00:00.000001} |",
|
||||
"+-------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_err() {
|
||||
run_cases_err(selector_max(), "selector_max").await;
|
||||
}
|
||||
}
|
||||
|
||||
mod utils {
|
||||
|
@ -991,7 +1041,7 @@ mod test {
|
|||
);
|
||||
}
|
||||
|
||||
pub async fn run_case_err(expr: Expr, expected: &'static str) {
|
||||
pub async fn run_case_err(expr: Expr, expected: &str) {
|
||||
println!("Running error case for {expr}");
|
||||
|
||||
let (schema, input) = input();
|
||||
|
@ -1006,6 +1056,38 @@ mod test {
|
|||
);
|
||||
}
|
||||
|
||||
pub async fn run_cases_err(selector: AggregateUDF, name: &str) {
|
||||
run_case_err(
|
||||
selector.call(vec![]),
|
||||
&format!("Error during planning: {name} requires at least 2 arguments, got 0"),
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector.call(vec![col("f64_value")]),
|
||||
&format!("Error during planning: {name} requires at least 2 arguments, got 1"),
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector.call(vec![col("time"), col("f64_value")]),
|
||||
&format!("Error during planning: {name} second argument must be a timestamp, but got Float64"),
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector.call(vec![col("time"), col("f64_value"), col("bool_value")]),
|
||||
&format!("Error during planning: {name} second argument must be a timestamp, but got Float64"),
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector.call(vec![col("f64_value"), col("bool_value"), col("time")]),
|
||||
&format!("Error during planning: {name} second argument must be a timestamp, but got Boolean"),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn input() -> (SchemaRef, Vec<RecordBatch>) {
|
||||
// define a schema for input
|
||||
// (value) and timestamp
|
||||
|
@ -1019,9 +1101,12 @@ mod test {
|
|||
Field::new("i64_value", DataType::Int64, true),
|
||||
Field::new("i64_2_value", DataType::Int64, true),
|
||||
Field::new("u64_value", DataType::UInt64, true),
|
||||
Field::new("u64_2_value", DataType::UInt64, true),
|
||||
Field::new("string_value", DataType::Utf8, true),
|
||||
Field::new("bool_value", DataType::Boolean, true),
|
||||
Field::new("bool_const", DataType::Boolean, true),
|
||||
Field::new("time", TIME_DATA_TYPE(), true),
|
||||
Field::new("time_dup", TIME_DATA_TYPE(), true),
|
||||
]));
|
||||
|
||||
// define data in two partitions
|
||||
|
@ -1057,9 +1142,12 @@ mod test {
|
|||
Arc::new(Int64Array::from(vec![Some(20), Some(40), None])),
|
||||
Arc::new(Int64Array::from(vec![None, None, None])),
|
||||
Arc::new(UInt64Array::from(vec![Some(20), Some(40), None])),
|
||||
Arc::new(UInt64Array::from(vec![Some(20), Some(40), None])),
|
||||
Arc::new(StringArray::from(vec![Some("two"), Some("four"), None])),
|
||||
Arc::new(BooleanArray::from(vec![Some(true), Some(false), None])),
|
||||
Arc::new(BooleanArray::from(vec![Some(true), Some(true), Some(true)])),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![1000, 2000, 3000])),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![1000, 1000, 2000])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
@ -1077,8 +1165,11 @@ mod test {
|
|||
Arc::new(Int64Array::from(vec![] as Vec<Option<i64>>)),
|
||||
Arc::new(Int64Array::from(vec![] as Vec<Option<i64>>)),
|
||||
Arc::new(UInt64Array::from(vec![] as Vec<Option<u64>>)),
|
||||
Arc::new(UInt64Array::from(vec![] as Vec<Option<u64>>)),
|
||||
Arc::new(StringArray::from(vec![] as Vec<Option<&str>>)),
|
||||
Arc::new(BooleanArray::from(vec![] as Vec<Option<bool>>)),
|
||||
Arc::new(BooleanArray::from(vec![] as Vec<Option<bool>>)),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![] as Vec<i64>)),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![] as Vec<i64>)),
|
||||
],
|
||||
) {
|
||||
|
@ -1118,6 +1209,7 @@ mod test {
|
|||
Arc::new(Int64Array::from(vec![Some(10), Some(50), Some(30)])),
|
||||
Arc::new(Int64Array::from(vec![None, Some(50), Some(30)])),
|
||||
Arc::new(UInt64Array::from(vec![Some(10), Some(50), Some(30)])),
|
||||
Arc::new(UInt64Array::from(vec![Some(10), Some(50), None])),
|
||||
Arc::new(StringArray::from(vec![
|
||||
Some("a_one"),
|
||||
Some("z_five"),
|
||||
|
@ -1128,7 +1220,9 @@ mod test {
|
|||
Some(false),
|
||||
Some(false),
|
||||
])),
|
||||
Arc::new(BooleanArray::from(vec![Some(true), Some(true), Some(true)])),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![4000, 5000, 6000])),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![2000, 3000, 3000])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
|
|
@ -24,234 +24,30 @@ use datafusion::{
|
|||
|
||||
use super::type_handling::make_struct_scalar;
|
||||
|
||||
/// Implements the logic of the specific selector function (this is a
|
||||
/// cutdown version of the Accumulator DataFusion trait, to allow
|
||||
/// sharing between implementations)
|
||||
pub trait Selector: Debug + Send + Sync {
|
||||
/// return state in a form that DataFusion can store during execution
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>>;
|
||||
|
||||
/// produces the final value of this selector for the specified output type
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue>;
|
||||
|
||||
/// Update this selector's state based on values in value_arr and time_arr
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()>;
|
||||
|
||||
/// Allocated size required for this selector, in bytes,
|
||||
/// including `Self`. Allocated means that for internal
|
||||
/// containers such as `Vec`, the `capacity` should be used not
|
||||
/// the `len`
|
||||
fn size(&self) -> usize;
|
||||
/// How to compare values/time.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Comparison {
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FirstSelector {
|
||||
value: ScalarValue,
|
||||
time: Option<i64>,
|
||||
other: Box<[ScalarValue]>,
|
||||
}
|
||||
|
||||
impl FirstSelector {
|
||||
pub fn new<'a>(
|
||||
data_type: &'a DataType,
|
||||
other_types: impl IntoIterator<Item = &'a DataType>,
|
||||
) -> DataFusionResult<Self> {
|
||||
Ok(Self {
|
||||
value: ScalarValue::try_from(data_type)?,
|
||||
time: None,
|
||||
other: other_types
|
||||
.into_iter()
|
||||
.map(ScalarValue::try_from)
|
||||
.collect::<DataFusionResult<_>>()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for FirstSelector {
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok([
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
]
|
||||
.into_iter()
|
||||
.chain(self.other.iter().cloned())
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
self.other.iter(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
// Only look for times where the array also has a non
|
||||
// null value (the time array should have no nulls itself)
|
||||
//
|
||||
// For example, for the following input, the correct
|
||||
// current min time is 200 (not 100)
|
||||
//
|
||||
// value | time
|
||||
// --------------
|
||||
// NULL | 100
|
||||
// A | 200
|
||||
// B | 300
|
||||
//
|
||||
let time_arr = arrow::compute::nullif(time_arr, &arrow::compute::is_null(&value_arr)?)?;
|
||||
|
||||
let time_arr = time_arr
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampNanosecondArray>()
|
||||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
let cur_min_time = array_min(time_arr);
|
||||
|
||||
let need_update = match (&self.time, &cur_min_time) {
|
||||
(Some(time), Some(cur_min_time)) => cur_min_time < time,
|
||||
// No existing minimum, so update needed
|
||||
(None, Some(_)) => true,
|
||||
// No actual minimum time found, so no update needed
|
||||
(_, None) => false,
|
||||
};
|
||||
|
||||
if need_update {
|
||||
let index = time_arr
|
||||
.iter()
|
||||
// arrow doesn't tell us what index had the
|
||||
// minimum, so need to find it ourselves see also
|
||||
// https://github.com/apache/arrow-datafusion/issues/600
|
||||
.enumerate()
|
||||
.find(|(_, time)| cur_min_time == *time)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap(); // value always exists
|
||||
|
||||
// update all or nothing in case of an error
|
||||
let value_new = ScalarValue::try_from_array(&value_arr, index)?;
|
||||
let other_new = other_arrs
|
||||
.iter()
|
||||
.map(|arr| ScalarValue::try_from_array(arr, index))
|
||||
.collect::<DataFusionResult<_>>()?;
|
||||
|
||||
self.time = cur_min_time;
|
||||
self.value = value_new;
|
||||
self.other = other_new;
|
||||
impl Comparison {
|
||||
fn is_update<T>(&self, old: &T, new: &T) -> bool
|
||||
where
|
||||
T: PartialOrd,
|
||||
{
|
||||
match self {
|
||||
Self::Min => new < old,
|
||||
Self::Max => old < new,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value) + self.value.size()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LastSelector {
|
||||
value: ScalarValue,
|
||||
time: Option<i64>,
|
||||
}
|
||||
|
||||
impl LastSelector {
|
||||
pub fn new(data_type: &DataType) -> DataFusionResult<Self> {
|
||||
Ok(Self {
|
||||
value: ScalarValue::try_from(data_type)?,
|
||||
time: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for LastSelector {
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok(vec![
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
])
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
[],
|
||||
))
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
if !other_arrs.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector last w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Only look for times where the array also has a non
|
||||
// null value (the time array should have no nulls itself)
|
||||
//
|
||||
// For example, for the following input, the correct
|
||||
// current max time is 200 (not 300)
|
||||
//
|
||||
// value | time
|
||||
// --------------
|
||||
// A | 100
|
||||
// B | 200
|
||||
// NULL | 300
|
||||
//
|
||||
let time_arr = arrow::compute::nullif(time_arr, &arrow::compute::is_null(&value_arr)?)?;
|
||||
|
||||
let time_arr = time_arr
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampNanosecondArray>()
|
||||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
let cur_max_time = array_max(time_arr);
|
||||
|
||||
let need_update = match (&self.time, &cur_max_time) {
|
||||
(Some(time), Some(cur_max_time)) => time < cur_max_time,
|
||||
// No existing maximum, so update needed
|
||||
(None, Some(_)) => true,
|
||||
// No actual maximum value found, so no update needed
|
||||
(_, None) => false,
|
||||
};
|
||||
|
||||
if need_update {
|
||||
let index = time_arr
|
||||
.iter()
|
||||
// arrow doesn't tell us what index had the
|
||||
// maximum, so need to find it ourselves
|
||||
.enumerate()
|
||||
.find(|(_, time)| cur_max_time == *time)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap(); // value always exists
|
||||
|
||||
// update all or nothing in case of an error
|
||||
let value_new = ScalarValue::try_from_array(&value_arr, index)?;
|
||||
|
||||
self.time = cur_max_time;
|
||||
self.value = value_new;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value) + self.value.size()
|
||||
}
|
||||
/// What to compare?
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Target {
|
||||
Time,
|
||||
Value,
|
||||
}
|
||||
|
||||
/// Did we find a new min/max
|
||||
|
@ -270,6 +66,7 @@ impl ActionNeeded {
|
|||
Self::Nothing => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn update_time(&self) -> bool {
|
||||
match self {
|
||||
Self::UpdateValueAndTime => true,
|
||||
|
@ -279,184 +76,129 @@ impl ActionNeeded {
|
|||
}
|
||||
}
|
||||
|
||||
/// Common state implementation for different selectors.
|
||||
#[derive(Debug)]
|
||||
pub struct MinSelector {
|
||||
pub struct Selector {
|
||||
comp: Comparison,
|
||||
target: Target,
|
||||
value: ScalarValue,
|
||||
time: Option<i64>,
|
||||
other: Box<[ScalarValue]>,
|
||||
}
|
||||
|
||||
impl MinSelector {
|
||||
pub fn new(data_type: &DataType) -> DataFusionResult<Self> {
|
||||
impl Selector {
|
||||
pub fn new<'a>(
|
||||
comp: Comparison,
|
||||
target: Target,
|
||||
data_type: &'a DataType,
|
||||
other_types: impl IntoIterator<Item = &'a DataType>,
|
||||
) -> DataFusionResult<Self> {
|
||||
Ok(Self {
|
||||
comp,
|
||||
target,
|
||||
value: ScalarValue::try_from(data_type)?,
|
||||
time: None,
|
||||
other: other_types
|
||||
.into_iter()
|
||||
.map(ScalarValue::try_from)
|
||||
.collect::<DataFusionResult<_>>()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for MinSelector {
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok(vec![
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
])
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
[],
|
||||
))
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
fn update_time_based(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
use ActionNeeded::*;
|
||||
|
||||
if !other_arrs.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector min w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut min_accu = MinAccumulator::try_new(value_arr.data_type())?;
|
||||
min_accu.update_batch(&[Arc::clone(value_arr)])?;
|
||||
let cur_min_value = min_accu.evaluate()?;
|
||||
|
||||
let action_needed = match (self.value.is_null(), cur_min_value.is_null()) {
|
||||
(false, false) => {
|
||||
if cur_min_value < self.value {
|
||||
// new minimim found
|
||||
UpdateValueAndTime
|
||||
} else if cur_min_value == self.value {
|
||||
// same minimum found, time might need update
|
||||
UpdateTime
|
||||
} else {
|
||||
Nothing
|
||||
}
|
||||
}
|
||||
// No existing minimum time, so update needed
|
||||
(true, false) => UpdateValueAndTime,
|
||||
// No actual minimum time found, so no update needed
|
||||
(_, true) => Nothing,
|
||||
};
|
||||
|
||||
if action_needed.update_value() {
|
||||
self.value = cur_min_value;
|
||||
self.time = None; // ignore time associated with old value
|
||||
}
|
||||
|
||||
if action_needed.update_time() {
|
||||
// only keep values where we've found our current value.
|
||||
// Note: We MUST also mask-out NULLs in `value_arr`, otherwise we may easily select that!
|
||||
let time_arr = arrow::compute::nullif(
|
||||
time_arr,
|
||||
&arrow::compute::neq_dyn(&self.value.to_array_of_size(time_arr.len()), &value_arr)?,
|
||||
)?;
|
||||
let time_arr =
|
||||
arrow::compute::nullif(&time_arr, &arrow::compute::is_null(&value_arr)?)?;
|
||||
|
||||
let time_arr = time_arr
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampNanosecondArray>()
|
||||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
self.time = match (array_min(time_arr), self.time) {
|
||||
(Some(x), Some(y)) if x < y => Some(x),
|
||||
(Some(_), Some(x)) => Some(x),
|
||||
(None, Some(x)) => Some(x),
|
||||
(Some(x), None) => Some(x),
|
||||
(None, None) => None,
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value) + self.value.size()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MaxSelector {
|
||||
value: ScalarValue,
|
||||
time: Option<i64>,
|
||||
}
|
||||
|
||||
impl MaxSelector {
|
||||
pub fn new(data_type: &DataType) -> DataFusionResult<Self> {
|
||||
Ok(Self {
|
||||
value: ScalarValue::try_from(data_type)?,
|
||||
time: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for MaxSelector {
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok(vec![
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
])
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
[],
|
||||
))
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
use ActionNeeded::*;
|
||||
|
||||
if !other_arrs.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector max w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
let time_arr = arrow::compute::nullif(time_arr, &arrow::compute::is_null(&value_arr)?)?;
|
||||
|
||||
let time_arr = time_arr
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampNanosecondArray>()
|
||||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
let cur_time = match self.comp {
|
||||
Comparison::Min => array_min(time_arr),
|
||||
Comparison::Max => array_max(time_arr),
|
||||
};
|
||||
|
||||
let mut max_accu = MaxAccumulator::try_new(value_arr.data_type())?;
|
||||
max_accu.update_batch(&[Arc::clone(value_arr)])?;
|
||||
let cur_max_value = max_accu.evaluate()?;
|
||||
let need_update = match (&self.time, &cur_time) {
|
||||
(Some(time), Some(cur_time)) => self.comp.is_update(time, cur_time),
|
||||
// No existing min/max, so update needed
|
||||
(None, Some(_)) => true,
|
||||
// No actual min/max time found, so no update needed
|
||||
(_, None) => false,
|
||||
};
|
||||
|
||||
let action_needed = match (&self.value.is_null(), &cur_max_value.is_null()) {
|
||||
if need_update {
|
||||
let index = time_arr
|
||||
.iter()
|
||||
// arrow doesn't tell us what index had the
|
||||
// min/max, so need to find it ourselves
|
||||
.enumerate()
|
||||
.filter(|(_, time)| cur_time == *time)
|
||||
.map(|(idx, _)| idx)
|
||||
// break tie: favor first value
|
||||
.next()
|
||||
.unwrap(); // value always exists
|
||||
|
||||
// update all or nothing in case of an error
|
||||
let value_new = ScalarValue::try_from_array(&value_arr, index)?;
|
||||
let other_new = other_arrs
|
||||
.iter()
|
||||
.map(|arr| ScalarValue::try_from_array(arr, index))
|
||||
.collect::<DataFusionResult<_>>()?;
|
||||
|
||||
self.time = cur_time;
|
||||
self.value = value_new;
|
||||
self.other = other_new;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn update_value_based(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
use ActionNeeded::*;
|
||||
|
||||
let cur_value = match self.comp {
|
||||
Comparison::Min => {
|
||||
let mut min_accu = MinAccumulator::try_new(value_arr.data_type())?;
|
||||
min_accu.update_batch(&[Arc::clone(value_arr)])?;
|
||||
min_accu.evaluate()?
|
||||
}
|
||||
Comparison::Max => {
|
||||
let mut max_accu = MaxAccumulator::try_new(value_arr.data_type())?;
|
||||
max_accu.update_batch(&[Arc::clone(value_arr)])?;
|
||||
max_accu.evaluate()?
|
||||
}
|
||||
};
|
||||
|
||||
let action_needed = match (&self.value.is_null(), &cur_value.is_null()) {
|
||||
(false, false) => {
|
||||
if self.value < cur_max_value {
|
||||
// new maximum found
|
||||
if self.comp.is_update(&self.value, &cur_value) {
|
||||
// new min/max found
|
||||
UpdateValueAndTime
|
||||
} else if cur_max_value == self.value {
|
||||
} else if cur_value == self.value {
|
||||
// same maximum found, time might need update
|
||||
UpdateTime
|
||||
} else {
|
||||
Nothing
|
||||
}
|
||||
}
|
||||
// No existing maxmimum value, so update needed
|
||||
// No existing min/max value, so update needed
|
||||
(true, false) => UpdateValueAndTime,
|
||||
// No actual maximum value found, so no update needed
|
||||
// No actual min/max value found, so no update needed
|
||||
(_, true) => Nothing,
|
||||
};
|
||||
|
||||
if action_needed.update_value() {
|
||||
self.value = cur_max_value;
|
||||
self.value = cur_value;
|
||||
self.time = None; // ignore time associated with old value
|
||||
}
|
||||
|
||||
|
@ -479,7 +221,7 @@ impl Selector for MaxSelector {
|
|||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
|
||||
// Note: we still use the MINIMUM timestamp here even though this is the max VALUE aggregator.
|
||||
// Note: we still use the MINIMUM timestamp here even if this is the max VALUE aggregator.
|
||||
self.time = match (array_min(time_arr), self.time) {
|
||||
(Some(x), Some(y)) if x < y => Some(x),
|
||||
(Some(_), Some(x)) => Some(x),
|
||||
|
@ -487,11 +229,82 @@ impl Selector for MaxSelector {
|
|||
(Some(x), None) => Some(x),
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
// update other if required
|
||||
if !self.other.is_empty() {
|
||||
let index = time_arr
|
||||
.iter()
|
||||
// arrow doesn't tell us what index had the
|
||||
// minimum, so need to find it ourselves
|
||||
.enumerate()
|
||||
.filter(|(_, time)| self.time == *time)
|
||||
.map(|(idx, _)| idx)
|
||||
// break tie: favor first value
|
||||
.next()
|
||||
.unwrap(); // value always exists
|
||||
|
||||
self.other = other_arrs
|
||||
.iter()
|
||||
.map(|arr| ScalarValue::try_from_array(arr, index))
|
||||
.collect::<DataFusionResult<_>>()?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Accumulator for Selector {
|
||||
fn state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok([
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
]
|
||||
.into_iter()
|
||||
.chain(self.other.iter().cloned())
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn update_batch(&mut self, values: &[ArrayRef]) -> DataFusionResult<()> {
|
||||
if values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if values.len() < 2 {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"Internal error: Expected at least 2 arguments passed to selector function but got {}",
|
||||
values.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let value_arr = &values[0];
|
||||
let time_arr = &values[1];
|
||||
let other_arrs = &values[2..];
|
||||
|
||||
match self.target {
|
||||
Target::Time => self.update_time_based(value_arr, time_arr, other_arrs)?,
|
||||
Target::Value => self.update_value_based(value_arr, time_arr, other_arrs)?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn merge_batch(&mut self, states: &[ArrayRef]) -> DataFusionResult<()> {
|
||||
// merge is the same operation as update for these selectors
|
||||
self.update_batch(states)
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
self.other.iter(),
|
||||
))
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value) + self.value.size()
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value)
|
||||
+ self.value.size()
|
||||
+ self.other.iter().map(|s| s.size()).sum::<usize>()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue