From c9f8151302afcedcabd28e1fd33c769a6008ec6d Mon Sep 17 00:00:00 2001 From: Stuart Carnie Date: Wed, 17 May 2023 12:31:37 +1000 Subject: [PATCH] feat: Resolve data types for `VarRef` nodes in the `WHERE` condition --- iox_query_influxql/src/plan/rewriter.rs | 98 +++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 7 deletions(-) diff --git a/iox_query_influxql/src/plan/rewriter.rs b/iox_query_influxql/src/plan/rewriter.rs index 5cd0d9b653..9c261c163e 100644 --- a/iox_query_influxql/src/plan/rewriter.rs +++ b/iox_query_influxql/src/plan/rewriter.rs @@ -5,8 +5,10 @@ use crate::plan::ir::{DataSource, Field, Select, SelectQuery, TagSet}; use crate::plan::var_ref::{influx_type_to_var_ref_data_type, var_ref_data_type_to_influx_type}; use crate::plan::{error, util, SchemaProvider}; use datafusion::common::{DataFusionError, Result}; -use influxdb_influxql_parser::common::{MeasurementName, QualifiedMeasurementName}; -use influxdb_influxql_parser::expression::walk::{walk_expr, walk_expr_mut}; +use influxdb_influxql_parser::common::{MeasurementName, QualifiedMeasurementName, WhereClause}; +use influxdb_influxql_parser::expression::walk::{ + walk_expr, walk_expr_mut, walk_expression_mut, ExpressionMut, +}; use influxdb_influxql_parser::expression::{ AsVarRefExpr, Call, Expr, VarRef, VarRefDataType, WildcardType, }; @@ -21,7 +23,7 @@ use itertools::Itertools; use schema::InfluxColumnType; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; -use std::ops::{ControlFlow, Deref}; +use std::ops::{ControlFlow, Deref, DerefMut}; /// Recursively rewrite the specified [`SelectStatement`] by performing a series of passes /// to validate and normalize the statement. @@ -100,6 +102,7 @@ impl RewriteSelect { let from = self.expand_from(s, stmt)?; let (fields, group_by) = self.expand_projection(s, stmt, &from)?; + let condition = self.rewrite_condition(s, stmt, &from)?; let tag_set = select_tag_set(s, &from); let SelectStatementInfo { projection_type } = @@ -123,7 +126,7 @@ impl RewriteSelect { projection_type, fields, from, - condition: stmt.condition.clone(), + condition, group_by, tag_set, fill, @@ -332,6 +335,40 @@ impl RewriteSelect { } Ok(new_from) } + + /// Resolve the data types of any [`VarRef`] expressions in the `WHERE` condition. + fn rewrite_condition( + &self, + s: &dyn SchemaProvider, + stmt: &SelectStatement, + from: &Vec, + ) -> Result> { + let Some(mut where_clause) = stmt.condition.clone() else { return Ok(None) }; + + let tv = TypeEvaluator::new(s, from); + + if let ControlFlow::Break(err) = walk_expression_mut(where_clause.deref_mut(), &mut |e| { + match e { + ExpressionMut::Arithmetic(e) => walk_expr_mut(e, &mut |e| match e { + // Attempt to rewrite all variable (column) references with their concrete types, + // if one hasn't been specified. + Expr::VarRef(ref mut v) => { + v.data_type = match tv.eval_var_ref(v) { + Ok(v) => v, + Err(e) => ControlFlow::Break(e)?, + }; + ControlFlow::Continue(()) + } + _ => ControlFlow::Continue(()), + }), + ExpressionMut::Conditional(_) => ControlFlow::::Continue(()), + } + }) { + Err(err) + } else { + Ok(Some(where_clause)) + } + } } /// Ensure the time field is added to all projections, @@ -2027,6 +2064,53 @@ mod test { ); } + /// Validate type resolution of [`VarRef`] nodes in the `WHERE` clause. + #[test] + fn condition() { + let namespace = MockSchemaProvider::default(); + + // resolves float field + let stmt = parse_select("SELECT usage_idle FROM cpu WHERE usage_user > 0"); + let stmt = rewrite_select_statement(&namespace, &stmt).unwrap(); + assert_eq!( + stmt.to_string(), + "SELECT time::timestamp AS time, usage_idle::float AS usage_idle FROM cpu WHERE usage_user::float > 0" + ); + + // resolves tag field + let stmt = parse_select("SELECT usage_idle FROM cpu WHERE cpu =~ /foo/"); + let stmt = rewrite_select_statement(&namespace, &stmt).unwrap(); + assert_eq!( + stmt.to_string(), + "SELECT time::timestamp AS time, usage_idle::float AS usage_idle FROM cpu WHERE cpu::tag =~ /foo/" + ); + + // Does not resolve an unknown field + let stmt = parse_select("SELECT usage_idle FROM cpu WHERE non_existent = 'bar'"); + let stmt = rewrite_select_statement(&namespace, &stmt).unwrap(); + assert_eq!( + stmt.to_string(), + "SELECT time::timestamp AS time, usage_idle::float AS usage_idle FROM cpu WHERE non_existent = 'bar'" + ); + + // Handles multiple measurements; `bytes_free` is from the `disk` measurement + let stmt = + parse_select("SELECT usage_idle, bytes_free FROM cpu, disk WHERE bytes_free = 3"); + let stmt = rewrite_select_statement(&namespace, &stmt).unwrap(); + assert_eq!( + stmt.to_string(), + "SELECT time::timestamp AS time, usage_idle::float AS usage_idle, bytes_free::integer AS bytes_free FROM cpu, disk WHERE bytes_free::integer = 3" + ); + + // Resolves recursively through subqueries and aliases + let stmt = parse_select("SELECT bytes FROM (SELECT bytes_free AS bytes FROM disk WHERE bytes_free = 3) WHERE bytes > 0"); + let stmt = rewrite_select_statement(&namespace, &stmt).unwrap(); + assert_eq!( + stmt.to_string(), + "SELECT time::timestamp AS time, bytes::integer AS bytes FROM (SELECT time::timestamp AS time, bytes_free::integer AS bytes FROM disk WHERE bytes_free::integer = 3) WHERE bytes::integer > 0" + ); + } + #[test] fn group_by() { let namespace = MockSchemaProvider::default(); @@ -2042,7 +2126,7 @@ mod test { let stmt = rewrite_select_statement(&namespace, &stmt).unwrap(); assert_eq!( stmt.to_string(), - "SELECT time::timestamp AS time, usage_idle::float AS usage_idle FROM cpu GROUP BY cpu, host, region" + "SELECT time::timestamp AS time, usage_idle::float AS usage_idle FROM cpu GROUP BY cpu::tag, host::tag, region::tag" ); // Does not include tags in projection when expanded in GROUP BY @@ -2050,7 +2134,7 @@ mod test { let stmt = rewrite_select_statement(&namespace, &stmt).unwrap(); assert_eq!( stmt.to_string(), - "SELECT time::timestamp AS time, usage_idle::float AS usage_idle, usage_system::float AS usage_system, usage_user::float AS usage_user FROM cpu GROUP BY cpu, host, region" + "SELECT time::timestamp AS time, usage_idle::float AS usage_idle, usage_system::float AS usage_system, usage_user::float AS usage_user FROM cpu GROUP BY cpu::tag, host::tag, region::tag" ); // Does include explicitly listed tags in projection @@ -2058,7 +2142,7 @@ mod test { let stmt = rewrite_select_statement(&namespace, &stmt).unwrap(); assert_eq!( stmt.to_string(), - "SELECT time::timestamp AS time, host::tag AS host, usage_idle::float AS usage_idle, usage_system::float AS usage_system, usage_user::float AS usage_user FROM cpu GROUP BY cpu, host, region" + "SELECT time::timestamp AS time, host::tag AS host, usage_idle::float AS usage_idle, usage_system::float AS usage_system, usage_user::float AS usage_user FROM cpu GROUP BY cpu::tag, host::tag, region::tag" ); }