From 2febaff24be388f141b5b511f48a796b97afe5c6 Mon Sep 17 00:00:00 2001
From: Trevor Hilton <thilton@influxdata.com>
Date: Sat, 23 Mar 2024 10:41:00 -0400
Subject: [PATCH] feat: support query parameters (#24804)

feat: support query parameters

This adds support for parameters in the /api/v3/query_sql
and /api/v3/query_influxql API

The new parameter `params` is supported in the URL query string
of a GET request, or in the JSON body of a POST request.

Two new E2E tests were added to check successful GET/POST as well
as error scenario when params are not provided for a query string
that would expect them.
---
 influxdb3/tests/server/query.rs        | 243 +++++++++++++++++++++++++
 influxdb3_server/src/http.rs           | 123 +++++++++----
 influxdb3_server/src/http/v1.rs        |   4 +-
 influxdb3_server/src/lib.rs            |   2 +
 influxdb3_server/src/query_executor.rs |   5 +-
 5 files changed, 333 insertions(+), 44 deletions(-)

diff --git a/influxdb3/tests/server/query.rs b/influxdb3/tests/server/query.rs
index 7081f1803d..116342fd98 100644
--- a/influxdb3/tests/server/query.rs
+++ b/influxdb3/tests/server/query.rs
@@ -3,6 +3,7 @@ use futures::StreamExt;
 use influxdb3_client::Precision;
 use pretty_assertions::assert_eq;
 use serde_json::{json, Value};
+use test_helpers::assert_contains;
 
 #[tokio::test]
 async fn api_v3_query_sql() {
@@ -50,6 +51,125 @@ async fn api_v3_query_sql() {
     );
 }
 
+#[tokio::test]
+async fn api_v3_query_sql_params() {
+    let server = TestServer::spawn().await;
+
+    server
+        .write_lp_to_db(
+            "foo",
+            "cpu,host=a,region=us-east usage=0.9 1
+            cpu,host=b,region=us-east usage=0.50 1
+            cpu,host=a,region=us-east usage=0.80 2
+            cpu,host=b,region=us-east usage=0.60 2
+            cpu,host=a,region=us-east usage=0.70 3
+            cpu,host=b,region=us-east usage=0.70 3
+            cpu,host=a,region=us-east usage=0.50 4
+            cpu,host=b,region=us-east usage=0.80 4",
+            Precision::Second,
+        )
+        .await
+        .unwrap();
+
+    let client = reqwest::Client::new();
+    let url = format!("{base}/api/v3/query_sql", base = server.client_addr());
+
+    // Use a POST request
+    {
+        let resp = client
+            .post(&url)
+            .json(&json!({
+                "db": "foo",
+                "q": "SELECT * FROM cpu WHERE host = $host AND usage > $usage",
+                "params": {
+                    "host": "b",
+                    "usage": 0.60,
+                },
+                "format": "pretty",
+            }))
+            .send()
+            .await
+            .unwrap()
+            .text()
+            .await
+            .unwrap();
+
+        assert_eq!(
+            "+------+---------+---------------------+-------+\n\
+            | host | region  | time                | usage |\n\
+            +------+---------+---------------------+-------+\n\
+            | b    | us-east | 1970-01-01T00:00:03 | 0.7   |\n\
+            | b    | us-east | 1970-01-01T00:00:04 | 0.8   |\n\
+            +------+---------+---------------------+-------+",
+            resp
+        );
+    }
+
+    // Use a GET request
+    {
+        let params = serde_json::to_string(&json!({
+            "host": "b",
+            "usage": 0.60,
+        }))
+        .unwrap();
+        let resp = client
+            .get(&url)
+            .query(&[
+                ("db", "foo"),
+                (
+                    "q",
+                    "SELECT * FROM cpu WHERE host = $host AND usage > $usage",
+                ),
+                ("format", "pretty"),
+                ("params", params.as_str()),
+            ])
+            .send()
+            .await
+            .unwrap()
+            .text()
+            .await
+            .unwrap();
+
+        assert_eq!(
+            "+------+---------+---------------------+-------+\n\
+            | host | region  | time                | usage |\n\
+            +------+---------+---------------------+-------+\n\
+            | b    | us-east | 1970-01-01T00:00:03 | 0.7   |\n\
+            | b    | us-east | 1970-01-01T00:00:04 | 0.8   |\n\
+            +------+---------+---------------------+-------+",
+            resp
+        );
+    }
+
+    // Check for errors
+    {
+        let resp = client
+            .post(&url)
+            .json(&json!({
+                "db": "foo",
+                "q": "SELECT * FROM cpu WHERE host = $host",
+                "params": {
+                    "not_host": "a"
+                },
+                "format": "pretty",
+            }))
+            .send()
+            .await
+            .unwrap();
+        let status = resp.status();
+        let body = resp.text().await.unwrap();
+
+        // TODO - it would be nice if this was a 4xx error, because this is really
+        //   a user error; however, the underlying error that occurs when Logical
+        //   planning is DatafusionError::Internal, and is not so convenient to deal
+        //   with. This may get addressed in:
+        //
+        //   https://github.com/apache/arrow-datafusion/issues/9738
+        assert!(status.is_server_error());
+        assert_contains!(body, "No value found for placeholder with name $host");
+    }
+}
+
 #[tokio::test]
 async fn api_v3_query_influxql() {
     let server = TestServer::spawn().await;
@@ -270,6 +390,129 @@ async fn api_v3_query_influxql() {
     }
 }
 
+#[tokio::test]
+async fn api_v3_query_influxql_params() {
+    let server = TestServer::spawn().await;
+
+    server
+        .write_lp_to_db(
+            "foo",
+            "cpu,host=a,region=us-east usage=0.9 1
+            cpu,host=b,region=us-east usage=0.50 1
+            cpu,host=a,region=us-east usage=0.80 2
+            cpu,host=b,region=us-east usage=0.60 2
+            cpu,host=a,region=us-east usage=0.70 3
+            cpu,host=b,region=us-east usage=0.70 3
+            cpu,host=a,region=us-east usage=0.50 4
+            cpu,host=b,region=us-east usage=0.80 4",
+            Precision::Second,
+        )
+        .await
+        .unwrap();
+
+    let client = reqwest::Client::new();
+    let url = format!("{base}/api/v3/query_influxql", base = server.client_addr());
+
+    // Use a POST request
+    {
+        let resp = client
+            .post(&url)
+            .json(&json!({
+                "db": "foo",
+                "q": "SELECT * FROM cpu WHERE host = $host AND usage > $usage",
+                "params": {
+                    "host": "b",
+                    "usage": 0.60,
+                },
+                "format": "pretty",
+            }))
+            .send()
+            .await
+            .unwrap()
+            .text()
+            .await
+            .unwrap();
+
+        assert_eq!(
+            "+------------------+---------------------+------+---------+-------+\n\
+            | iox::measurement | time                | host | region  | usage |\n\
+            +------------------+---------------------+------+---------+-------+\n\
+            | cpu              | 1970-01-01T00:00:03 | b    | us-east | 0.7   |\n\
+            | cpu              | 1970-01-01T00:00:04 | b    | us-east | 0.8   |\n\
+            +------------------+---------------------+------+---------+-------+",
+            resp
+        );
+    }
+
+    // Use a GET request
+    {
+        let params = serde_json::to_string(&json!({
+            "host": "b",
+            "usage": 0.60,
+        }))
+        .unwrap();
+        let resp = client
+            .get(&url)
+            .query(&[
+                ("db", "foo"),
+                (
+                    "q",
+                    "SELECT * FROM cpu WHERE host = $host AND usage > $usage",
+                ),
+                ("format", "pretty"),
+                ("params", params.as_str()),
+            ])
+            .send()
+            .await
+            .unwrap()
+            .text()
+            .await
+            .unwrap();
+
+        assert_eq!(
+            "+------------------+---------------------+------+---------+-------+\n\
+            | iox::measurement | time                | host | region  | usage |\n\
+            +------------------+---------------------+------+---------+-------+\n\
+            | cpu              | 1970-01-01T00:00:03 | b    | us-east | 0.7   |\n\
+            | cpu              | 1970-01-01T00:00:04 | b    | us-east | 0.8   |\n\
+            +------------------+---------------------+------+---------+-------+",
+            resp
+        );
+    }
+
+    // Check for errors
+    {
+        let resp = client
+            .post(&url)
+            .json(&json!({
+                "db": "foo",
+                "q": "SELECT * FROM cpu WHERE host = $host",
+                "params": {
+                    "not_host": "a"
+                },
+                "format": "pretty",
+            }))
+            .send()
+            .await
+            .unwrap();
+        let status = resp.status();
+        let body = resp.text().await.unwrap();
+
+        // TODO - it would be nice if this was a 4xx error, because this is really
+        //   a user error; however, the underlying error that occurs when Logical
+        //   planning is DatafusionError::Internal, and is not so convenient to deal
+        //   with. This may get addressed in:
+        //
+        //   https://github.com/apache/arrow-datafusion/issues/9738
+        assert!(status.is_server_error());
+        assert_contains!(
+            body,
+            "Bind parameter '$host' was referenced in the InfluxQL \
+            statement but its value is undefined"
+        );
+    }
+}
+
 #[tokio::test]
 async fn api_v1_query() {
     let server = TestServer::spawn().await;
diff --git a/influxdb3_server/src/http.rs b/influxdb3_server/src/http.rs
index 055920f98d..c88cfc809f 100644
--- a/influxdb3_server/src/http.rs
+++ b/influxdb3_server/src/http.rs
@@ -17,6 +17,7 @@ use hyper::header::AUTHORIZATION;
 use hyper::header::CONTENT_ENCODING;
 use hyper::header::CONTENT_TYPE;
 use hyper::http::HeaderValue;
+use hyper::HeaderMap;
 use hyper::{Body, Method, Request, Response, StatusCode};
 use influxdb3_write::catalog::Error as CatalogError;
 use influxdb3_write::persister::TrackedMemoryArrowWriter;
@@ -25,6 +26,7 @@ use influxdb3_write::BufferedWriteRequest;
 use influxdb3_write::Precision;
 use influxdb3_write::WriteBuffer;
 use iox_query_influxql_rewrite as rewrite;
+use iox_query_params::StatementParams;
 use iox_time::TimeProvider;
 use observability_deps::tracing::{debug, error, info};
 use serde::de::DeserializeOwned;
@@ -93,6 +95,10 @@ pub enum Error {
     #[error("access denied")]
     Forbidden,
 
+    /// The HTTP request method is not supported for this resource
+    #[error("unsupported method")]
+    UnsupportedMethod,
+
     /// PProf support is not compiled
     #[error("pprof support is not compiled")]
     PProfIsNotCompiled,
@@ -265,6 +271,18 @@ impl Error {
                     .body(body)
                     .unwrap()
             }
+            Self::UnsupportedMethod => {
+                let err: ErrorMessage<()> = ErrorMessage {
+                    error: self.to_string(),
+                    data: None,
+                };
+                let serialized = serde_json::to_string(&err).unwrap();
+                let body = Body::from(serialized);
+                Response::builder()
+                    .status(StatusCode::METHOD_NOT_ALLOWED)
+                    .body(body)
+                    .unwrap()
+            }
             _ => {
                 let body = Body::from(self.to_string());
                 Response::builder()
@@ -347,17 +365,18 @@ where
     }
 
     async fn query_sql(&self, req: Request<Body>) -> Result<Response<Body>> {
-        let QueryParams {
+        let QueryRequest {
             database,
             query_str,
             format,
-        } = QueryParams::<String, _>::from_request(&req)?;
+            params,
+        } = self.extract_query_request::<String>(req).await?;
 
         info!(%database, %query_str, ?format, "handling query_sql");
 
         let stream = self
             .query_executor
-            .query(&database, &query_str, QueryKind::Sql, None, None)
+            .query(&database, &query_str, params, QueryKind::Sql, None, None)
             .await?;
 
         Response::builder()
@@ -368,15 +387,18 @@ where
     }
 
     async fn query_influxql(&self, req: Request<Body>) -> Result<Response<Body>> {
-        let QueryParams {
+        let QueryRequest {
             database,
             query_str,
             format,
-        } = QueryParams::<Option<String>, _>::from_request(&req)?;
+            params,
+        } = self.extract_query_request::<Option<String>>(req).await?;
 
         info!(?database, %query_str, ?format, "handling query_influxql");
 
-        let stream = self.query_influxql_inner(database, &query_str).await?;
+        let stream = self
+            .query_influxql_inner(database, &query_str, params)
+            .await?;
 
         Response::builder()
             .status(StatusCode::OK)
@@ -477,6 +499,39 @@ where
         Ok(())
     }
 
+    async fn extract_query_request<D: DeserializeOwned>(
+        &self,
+        req: Request<Body>,
+    ) -> Result<QueryRequest<D, QueryFormat, StatementParams>> {
+        let header_format = QueryFormat::try_from_headers(req.headers())?;
+        let request = match *req.method() {
+            Method::GET => {
+                let query = req.uri().query().ok_or(Error::MissingQueryParams)?;
+                let r = serde_urlencoded::from_str::<QueryRequest<D, Option<QueryFormat>, String>>(
+                    query,
+                )?;
+                QueryRequest {
+                    database: r.database,
+                    query_str: r.query_str,
+                    format: r.format,
+                    params: r.params.map(|s| serde_json::from_str(&s)).transpose()?,
+                }
+            }
+            Method::POST => {
+                let body = self.read_body(req).await?;
+                serde_json::from_slice(body.as_ref())?
+            }
+            _ => return Err(Error::UnsupportedMethod),
+        };
+
+        Ok(QueryRequest {
+            database: request.database,
+            query_str: request.query_str,
+            format: request.format.unwrap_or(header_format),
+            params: request.params,
+        })
+    }
+
     /// Inner function for performing InfluxQL queries
     ///
     /// This is used by both the `/api/v3/query_influxql` and `/api/v1/query`
@@ -485,6 +540,7 @@ where
         &self,
         database: Option<String>,
         query_str: &str,
+        params: Option<StatementParams>,
     ) -> Result<SendableRecordBatchStream> {
         let mut statements = rewrite::parse_statements(query_str)?;
 
@@ -525,6 +581,7 @@ where
                     // TODO - implement an interface that takes the statement directly,
                     // so we don't need to double down on the parsing
                     &statement.to_statement().to_string(),
+                    params,
                     QueryKind::InfluxQl,
                     None,
                     None,
@@ -617,45 +674,13 @@ fn validate_db_name(name: &str) -> Result<()> {
 }
 
 #[derive(Debug, Deserialize)]
-pub(crate) struct QueryParams<D, F> {
+pub(crate) struct QueryRequest<D, F, P> {
     #[serde(rename = "db")]
     pub(crate) database: D,
     #[serde(rename = "q")]
     pub(crate) query_str: String,
     pub(crate) format: F,
-}
-
-impl<D> QueryParams<D, QueryFormat>
-where
-    D: DeserializeOwned,
-{
-    fn from_request(req: &Request<Body>) -> Result<Self> {
-        let query = req.uri().query().ok_or(Error::MissingQueryParams)?;
-        let params = serde_urlencoded::from_str::<QueryParams<D, Option<QueryFormat>>>(query)?;
-        let format = match params.format {
-            None => match req.headers().get(ACCEPT).map(HeaderValue::as_bytes) {
-                // Accept Headers use the MIME types maintained by IANA here:
-                // https://www.iana.org/assignments/media-types/media-types.xhtml
-                // Note parquet hasn't been accepted yet just Arrow, but there
-                // is the possibility it will be:
-                // https://issues.apache.org/jira/browse/PARQUET-1889
-                Some(b"application/vnd.apache.parquet") => QueryFormat::Parquet,
-                Some(b"text/csv") => QueryFormat::Csv,
-                Some(b"text/plain") => QueryFormat::Pretty,
-                Some(b"application/json" | b"*/*") | None => QueryFormat::Json,
-                Some(mime_type) => match String::from_utf8(mime_type.to_vec()) {
-                    Ok(s) => return Err(QueryParamsError::InvalidMimeType(s).into()),
-                    Err(e) => return Err(QueryParamsError::NonUtf8MimeType(e).into()),
-                },
-            },
-            Some(f) => f,
-        };
-        Ok(Self {
-            database: params.database,
-            query_str: params.query_str,
-            format,
-        })
-    }
+    pub(crate) params: Option<P>,
 }
 
 #[derive(Debug, thiserror::Error)]
@@ -687,6 +712,24 @@ impl QueryFormat {
             Self::Json => "application/json",
         }
     }
+
+    fn try_from_headers(headers: &HeaderMap) -> Result<Self> {
+        match headers.get(ACCEPT).map(HeaderValue::as_bytes) {
+            // Accept Headers use the MIME types maintained by IANA here:
+            // https://www.iana.org/assignments/media-types/media-types.xhtml
+            // Note parquet hasn't been accepted yet just Arrow, but there
+            // is the possibility it will be:
+            // https://issues.apache.org/jira/browse/PARQUET-1889
+            Some(b"application/vnd.apache.parquet") => Ok(Self::Parquet),
+            Some(b"text/csv") => Ok(Self::Csv),
+            Some(b"text/plain") => Ok(Self::Pretty),
+            Some(b"application/json" | b"*/*") | None => Ok(Self::Json),
+            Some(mime_type) => match String::from_utf8(mime_type.to_vec()) {
+                Ok(s) => Err(QueryParamsError::InvalidMimeType(s).into()),
+                Err(e) => Err(QueryParamsError::NonUtf8MimeType(e).into()),
+            },
+        }
+    }
 }
 
 async fn record_batch_stream_to_body(
diff --git a/influxdb3_server/src/http/v1.rs b/influxdb3_server/src/http/v1.rs
index 13110b5870..c6c13d69a6 100644
--- a/influxdb3_server/src/http/v1.rs
+++ b/influxdb3_server/src/http/v1.rs
@@ -58,7 +58,9 @@ where
 
         let chunk_size = chunked.then(|| chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE));
 
-        let stream = self.query_influxql_inner(database, &query).await?;
+        // TODO - Currently not supporting parameterized queries, see
+        //        https://github.com/influxdata/influxdb/issues/24805
+        let stream = self.query_influxql_inner(database, &query, None).await?;
         let stream =
             QueryResponseStream::new(0, stream, chunk_size, pretty, epoch).map_err(QueryError)?;
         let body = Body::wrap_stream(stream);
diff --git a/influxdb3_server/src/lib.rs b/influxdb3_server/src/lib.rs
index 9035b6807a..8cfeec9dca 100644
--- a/influxdb3_server/src/lib.rs
+++ b/influxdb3_server/src/lib.rs
@@ -27,6 +27,7 @@ use datafusion::execution::SendableRecordBatchStream;
 use hyper::service::service_fn;
 use influxdb3_write::{Persister, WriteBuffer};
 use iox_query::QueryNamespaceProvider;
+use iox_query_params::StatementParams;
 use iox_time::TimeProvider;
 use observability_deps::tracing::{error, info};
 use service::hybrid;
@@ -129,6 +130,7 @@ pub trait QueryExecutor: QueryNamespaceProvider + Debug + Send + Sync + 'static
         &self,
         database: &str,
         q: &str,
+        params: Option<StatementParams>,
         kind: QueryKind,
         span_ctx: Option<SpanContext>,
         external_span_ctx: Option<RequestLogContext>,
diff --git a/influxdb3_server/src/query_executor.rs b/influxdb3_server/src/query_executor.rs
index 68f113a5c8..d5a762d727 100644
--- a/influxdb3_server/src/query_executor.rs
+++ b/influxdb3_server/src/query_executor.rs
@@ -90,6 +90,7 @@ impl<W: WriteBuffer> QueryExecutor for QueryExecutorImpl<W> {
         &self,
         database: &str,
         q: &str,
+        params: Option<StatementParams>,
         kind: QueryKind,
         span_ctx: Option<SpanContext>,
         external_span_ctx: Option<RequestLogContext>,
@@ -117,9 +118,7 @@ impl<W: WriteBuffer> QueryExecutor for QueryExecutorImpl<W> {
         );
 
         info!("plan");
-        // TODO: Figure out if we want to support parameter values in SQL
-        // queries
-        let params = StatementParams::default();
+        let params = params.unwrap_or_default();
         let plan = match kind {
             QueryKind::Sql => {
                 let planner = SqlQueryPlanner::new();