diff --git a/flightsql/src/planner.rs b/flightsql/src/planner.rs index e615eae9a6..0bde2ed696 100644 --- a/flightsql/src/planner.rs +++ b/flightsql/src/planner.rs @@ -4,7 +4,11 @@ use std::{string::FromUtf8Error, sync::Arc}; use arrow::{error::ArrowError, ipc::writer::IpcWriteOptions}; use arrow_flight::{ error::FlightError, - sql::{Any, CommandPreparedStatementQuery, CommandStatementQuery, ProstMessageExt}, + sql::{ + ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, Any, CommandPreparedStatementQuery, + CommandStatementQuery, + }, IpcMessage, SchemaAsIpc, }; use bytes::Bytes; @@ -23,8 +27,8 @@ pub enum Error { source: prost::DecodeError, }, - #[snafu(display("Query was not valid UTF-8: {}", source))] - InvalidUtf8 { source: FromUtf8Error }, + #[snafu(display("Invalid PreparedStatement handle (invalid UTF-8:) {}", source))] + InvalidHandle { source: FromUtf8Error }, #[snafu(display("{}", source))] Flight { source: FlightError }, @@ -37,13 +41,22 @@ pub enum Error { #[snafu(display("Unsupported FlightSQL message type: {}", description))] UnsupportedMessageType { description: String }, + + #[snafu(display("Protocol error. Method {} does not expect '{:?}'", method, cmd))] + Protocol { cmd: String, method: &'static str }, } pub type Result<T, E = Error> = std::result::Result<T, E>; impl From<FlightError> for Error { - fn from(value: FlightError) -> Self { - Self::Flight { source: value } + fn from(source: FlightError) -> Self { + Self::Flight { source } + } +} + +impl From<ArrowError> for Error { + fn from(source: ArrowError) -> Self { + Self::Arrow { source } } } @@ -75,11 +88,19 @@ impl FlightSQLPlanner { let namespace_name = namespace_name.into(); debug!(%namespace_name, type_url=%msg.type_url, "Handling flightsql get_flight_info"); - match FlightSQLCommand::try_new(&msg)? { - FlightSQLCommand::CommandStatementQuery(query) - | FlightSQLCommand::CommandPreparedStatementQuery(query) => { + let cmd = FlightSQLCommand::try_new(&msg)?; + match cmd { + FlightSQLCommand::CommandStatementQuery(query) => { Self::get_schema_for_query(&query, ctx).await } + FlightSQLCommand::CommandPreparedStatementQuery(handle) => { + Self::get_schema_for_query(&handle.query, ctx).await + } + _ => ProtocolSnafu { + cmd: format!("{cmd:?}"), + method: "GetFlightInfo", + } + .fail(), } } @@ -108,43 +129,144 @@ impl FlightSQLPlanner { ctx: &IOxSessionContext, ) -> Result<Arc<dyn ExecutionPlan>> { let namespace_name = namespace_name.into(); - debug!(%namespace_name, type_url=%msg.type_url, "Handling flightsql plan to run an actual query"); + debug!(%namespace_name, type_url=%msg.type_url, "Handling flightsql do_get"); - match FlightSQLCommand::try_new(&msg)? { + let cmd = FlightSQLCommand::try_new(&msg)?; + match cmd { FlightSQLCommand::CommandStatementQuery(query) => { debug!(%query, "Planning FlightSQL query"); ctx.prepare_sql(&query).await.context(DataFusionSnafu) } - FlightSQLCommand::CommandPreparedStatementQuery(query) => { + FlightSQLCommand::CommandPreparedStatementQuery(handle) => { + let query = &handle.query; debug!(%query, "Planning FlightSQL prepared query"); - ctx.prepare_sql(&query).await.context(DataFusionSnafu) + ctx.prepare_sql(query).await.context(DataFusionSnafu) } + _ => ProtocolSnafu { + cmd: format!("{cmd:?}"), + method: "DoGet", + } + .fail(), + } + } + + /// Handles the action specified in `msg` and returns bytes for + /// the [`arrow_flight::Result`] (not the same as a rust + /// [`Result`]!) + pub async fn do_action( + namespace_name: impl Into<String>, + _database: Arc<dyn QueryNamespace>, + msg: Any, + ctx: &IOxSessionContext, + ) -> Result<Bytes> { + let namespace_name = namespace_name.into(); + debug!(%namespace_name, type_url=%msg.type_url, "Handling flightsql do_action"); + + let cmd = FlightSQLCommand::try_new(&msg)?; + match cmd { + FlightSQLCommand::ActionCreatePreparedStatementRequest(query) => { + debug!(%query, "Creating prepared statement"); + + // todo run the planner here and actually figure out parameter schemas + // see https://github.com/apache/arrow-datafusion/pull/4701 + let parameter_schema = vec![]; + + let dataset_schema = Self::get_schema_for_query(&query, ctx).await?; + let handle = PreparedStatementHandle::new(query); + + let result = ActionCreatePreparedStatementResult { + prepared_statement_handle: Bytes::from(handle), + dataset_schema, + parameter_schema: Bytes::from(parameter_schema), + }; + + let msg = Any::pack(&result)?; + Ok(msg.encode_to_vec().into()) + } + FlightSQLCommand::ActionClosePreparedStatementRequest(handle) => { + let query = &handle.query; + debug!(%query, "Closing prepared statement"); + + // Nothing really to do + Ok(Bytes::new()) + } + _ => ProtocolSnafu { + cmd: format!("{cmd:?}"), + method: "DoAction", + } + .fail(), } } } -/// Decoded and validated FlightSQL command +/// Represents a prepared statement "handle". IOx passes all state +/// required to run the prepared statement back and forth to the +/// client so any querier instance can run it +#[derive(Debug, Clone)] +struct PreparedStatementHandle { + /// The raw SQL query text + query: String, +} + +impl PreparedStatementHandle { + fn new(query: String) -> Self { + Self { query } + } +} + +/// Decode bytes to a PreparedStatementHandle +impl TryFrom<Bytes> for PreparedStatementHandle { + type Error = Error; + + fn try_from(handle: Bytes) -> Result<Self, Self::Error> { + // Note: in IOx handles are the entire decoded query + let query = String::from_utf8(handle.to_vec()).context(InvalidHandleSnafu)?; + Ok(Self { query }) + } +} + +/// Encode a PreparedStatementHandle as Bytes +impl From<PreparedStatementHandle> for Bytes { + fn from(value: PreparedStatementHandle) -> Self { + Bytes::from(value.query.into_bytes()) + } +} + +/// Decoded / validated FlightSQL command messages #[derive(Debug, Clone)] enum FlightSQLCommand { CommandStatementQuery(String), - CommandPreparedStatementQuery(String), + /// Run a prepared statement + CommandPreparedStatementQuery(PreparedStatementHandle), + /// Create a prepared statement + ActionCreatePreparedStatementRequest(String), + /// Close a prepared statement + ActionClosePreparedStatementRequest(PreparedStatementHandle), } impl FlightSQLCommand { - /// Figure out and decode the specific FlightSQL command in `msg` + /// Figure out and decode the specific FlightSQL command in `msg` and decode it to a native IOx / Rust struct fn try_new(msg: &Any) -> Result<Self> { - if let Some(decoded_cmd) = try_unpack::<CommandStatementQuery>(msg)? { + if let Some(decoded_cmd) = Any::unpack::<CommandStatementQuery>(msg)? { let CommandStatementQuery { query } = decoded_cmd; Ok(Self::CommandStatementQuery(query)) - } else if let Some(decoded_cmd) = try_unpack::<CommandPreparedStatementQuery>(msg)? { + } else if let Some(decoded_cmd) = Any::unpack::<CommandPreparedStatementQuery>(msg)? { let CommandPreparedStatementQuery { prepared_statement_handle, } = decoded_cmd; - // handle should be a decoded query - let query = - String::from_utf8(prepared_statement_handle.to_vec()).context(InvalidUtf8Snafu)?; - Ok(Self::CommandPreparedStatementQuery(query)) + let handle = PreparedStatementHandle::try_from(prepared_statement_handle)?; + Ok(Self::CommandPreparedStatementQuery(handle)) + } else if let Some(decoded_cmd) = Any::unpack::<ActionCreatePreparedStatementRequest>(msg)? + { + let ActionCreatePreparedStatementRequest { query } = decoded_cmd; + Ok(Self::ActionCreatePreparedStatementRequest(query)) + } else if let Some(decoded_cmd) = Any::unpack::<ActionClosePreparedStatementRequest>(msg)? { + let ActionClosePreparedStatementRequest { + prepared_statement_handle, + } = decoded_cmd; + let handle = PreparedStatementHandle::try_from(prepared_statement_handle)?; + Ok(Self::ActionClosePreparedStatementRequest(handle)) } else { UnsupportedMessageTypeSnafu { description: &msg.type_url, @@ -153,17 +275,3 @@ impl FlightSQLCommand { } } } - -/// try to unpack the [`arrow_flight::sql::Any`] as type `T`, returning Ok(None) if -/// the type is wrong or Err if an error occurs -fn try_unpack<T: ProstMessageExt>(msg: &Any) -> Result<Option<T>> { - // Does the type URL match? - if T::type_url() != msg.type_url { - return Ok(None); - } - // type matched, so try and decode - let m = Message::decode(&*msg.value).context(DeserializationTypeKnownSnafu { - type_url: &msg.type_url, - })?; - Ok(Some(m)) -} diff --git a/influxdb_iox/tests/end_to_end_cases/cli.rs b/influxdb_iox/tests/end_to_end_cases/cli.rs index 2956ecea07..4492ed4dbe 100644 --- a/influxdb_iox/tests/end_to_end_cases/cli.rs +++ b/influxdb_iox/tests/end_to_end_cases/cli.rs @@ -887,7 +887,7 @@ async fn query_ingester() { // something like "wrong query protocol" or // "invalid message" as the querier requires a // different message format Ticket in the flight protocol - let expected = "Unknown namespace: "; + let expected = "Namespace '' not found"; // Validate that the error message contains a reasonable error Command::cargo_bin("influxdb_iox") diff --git a/influxdb_iox/tests/end_to_end_cases/flightsql.rs b/influxdb_iox/tests/end_to_end_cases/flightsql.rs index a169a967b0..35ab9d3f45 100644 --- a/influxdb_iox/tests/end_to_end_cases/flightsql.rs +++ b/influxdb_iox/tests/end_to_end_cases/flightsql.rs @@ -1,10 +1,11 @@ use arrow_util::assert_batches_sorted_eq; +use datafusion::common::assert_contains; use futures::{FutureExt, TryStreamExt}; use influxdb_iox_client::flightsql::FlightSqlClient; use test_helpers_end_to_end::{maybe_skip_integration, MiniCluster, Step, StepTest, StepTestState}; #[tokio::test] -async fn flightsql_query() { +async fn flightsql_adhoc_query() { test_helpers::maybe_start_logging(); let database_url = maybe_skip_integration!(); @@ -62,6 +63,109 @@ async fn flightsql_query() { .await } -// TODO other tests: -// 1. Errors -// 2. Prepared statements +#[tokio::test] +async fn flightsql_adhoc_query_error() { + test_helpers::maybe_start_logging(); + let database_url = maybe_skip_integration!(); + + // Set up the cluster ==================================== + let mut cluster = MiniCluster::create_shared(database_url).await; + + StepTest::new( + &mut cluster, + vec![ + Step::WriteLineProtocol( + "foo,tag1=A,tag2=B val=42i 123456\n\ + foo,tag1=A,tag2=C val=43i 123457" + .to_string(), + ), + Step::WaitForReadable, + Step::Custom(Box::new(move |state: &mut StepTestState| { + async move { + let sql = String::from("select * from incorrect_table"); + + let connection = state.cluster().querier().querier_grpc_connection(); + let (channel, _headers) = connection.into_grpc_connection().into_parts(); + + let mut client = FlightSqlClient::new(channel); + + // Add namespace to client headers until it is fully supported by FlightSQL + let namespace = state.cluster().namespace(); + client.add_header("iox-namespace-name", namespace).unwrap(); + + let err = client.query(sql).await.unwrap_err(); + + // namespaces are created on write + assert_contains!( + err.to_string(), + "table 'public.iox.incorrect_table' not found" + ); + } + .boxed() + })), + ], + ) + .run() + .await +} + +#[tokio::test] +async fn flightsql_prepared_query() { + test_helpers::maybe_start_logging(); + let database_url = maybe_skip_integration!(); + + let table_name = "the_table"; + + // Set up the cluster ==================================== + let mut cluster = MiniCluster::create_shared(database_url).await; + + StepTest::new( + &mut cluster, + vec![ + Step::WriteLineProtocol(format!( + "{},tag1=A,tag2=B val=42i 123456\n\ + {},tag1=A,tag2=C val=43i 123457", + table_name, table_name + )), + Step::WaitForReadable, + Step::AssertNotPersisted, + Step::Custom(Box::new(move |state: &mut StepTestState| { + async move { + let sql = format!("select * from {}", table_name); + let expected = vec![ + "+------+------+--------------------------------+-----+", + "| tag1 | tag2 | time | val |", + "+------+------+--------------------------------+-----+", + "| A | B | 1970-01-01T00:00:00.000123456Z | 42 |", + "| A | C | 1970-01-01T00:00:00.000123457Z | 43 |", + "+------+------+--------------------------------+-----+", + ]; + + let connection = state.cluster().querier().querier_grpc_connection(); + let (channel, _headers) = connection.into_grpc_connection().into_parts(); + + let mut client = FlightSqlClient::new(channel); + + // Add namespace to client headers until it is fully supported by FlightSQL + let namespace = state.cluster().namespace(); + client.add_header("iox-namespace-name", namespace).unwrap(); + + let handle = client.prepare(sql).await.unwrap(); + + let batches: Vec<_> = client + .execute(handle) + .await + .expect("ran SQL query") + .try_collect() + .await + .expect("got batches"); + + assert_batches_sorted_eq!(&expected, &batches); + } + .boxed() + })), + ], + ) + .run() + .await +} diff --git a/influxdb_iox_client/src/client/flightsql.rs b/influxdb_iox_client/src/client/flightsql.rs index 108c667a7c..8136357135 100644 --- a/influxdb_iox_client/src/client/flightsql.rs +++ b/influxdb_iox_client/src/client/flightsql.rs @@ -21,12 +21,20 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use arrow::datatypes::{Schema, SchemaRef}; use arrow_flight::{ decode::FlightRecordBatchStream, error::{FlightError, Result}, - sql::{CommandStatementQuery, ProstMessageExt}, - FlightClient, FlightDescriptor, FlightInfo, Ticket, + sql::{ + ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, + CommandPreparedStatementQuery, CommandStatementQuery, ProstMessageExt, + }, + Action, FlightClient, FlightDescriptor, FlightInfo, IpcMessage, Ticket, }; +use bytes::Bytes; +use futures_util::TryStreamExt; use prost::Message; use tonic::metadata::MetadataMap; use tonic::transport::Channel; @@ -92,15 +100,15 @@ impl FlightSqlClient { } /// Send `cmd`, encoded as protobuf, to the FlightSQL server - async fn get_flight_info_for_command<M: ProstMessageExt>( + async fn get_flight_info_for_command( &mut self, - cmd: M, + cmd: arrow_flight::sql::Any, ) -> Result<FlightInfo> { - let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let descriptor = FlightDescriptor::new_cmd(cmd.encode_to_vec()); self.inner.get_flight_info(descriptor).await } - /// Execute a SQL query on the server using `CommandStatementQuery. + /// Execute a SQL query on the server using [`CommandStatementQuery`] /// /// This involves two round trips /// @@ -112,7 +120,14 @@ impl FlightSqlClient { /// /// This implementation does not support alternate endpoints pub async fn query(&mut self, query: String) -> Result<FlightRecordBatchStream> { - let cmd = CommandStatementQuery { query }; + let msg = CommandStatementQuery { query }; + self.do_get_with_cmd(msg.as_any()).await + } + + async fn do_get_with_cmd( + &mut self, + cmd: arrow_flight::sql::Any, + ) -> Result<FlightRecordBatchStream> { let FlightInfo { schema: _, flight_descriptor: _, @@ -156,4 +171,130 @@ impl FlightSqlClient { self.inner.do_get(Ticket { ticket }).await } + + /// Create a prepared statement for execution. + /// + /// Sends a [`ActionCreatePreparedStatementRequest`] message to + /// the `DoAction` endpoint of the FlightSQL server, and returns + /// the handle from the server. + /// + /// See [`Self::execute`] to run a previously prepared statement + pub async fn prepare(&mut self, query: String) -> Result<PreparedStatement> { + let cmd = ActionCreatePreparedStatementRequest { query }; + + let request = Action { + r#type: "CreatePreparedStatement".into(), + body: cmd.as_any().encode_to_vec().into(), + }; + + let mut results: Vec<Bytes> = self.inner.do_action(request).await?.try_collect().await?; + + if results.len() != 1 { + return Err(FlightError::ProtocolError(format!( + "Expected 1 response for preparing a statement, got {}", + results.len() + ))); + } + let result = results.pop().unwrap(); + + // decode the response + let response: arrow_flight::sql::Any = Message::decode(result.as_ref()) + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + + let ActionCreatePreparedStatementResult { + prepared_statement_handle, + dataset_schema, + parameter_schema, + } = Any::unpack(&response)?.ok_or_else(|| { + FlightError::ProtocolError(format!( + "Expected ActionCreatePreparedStatementResult message but got {} instead", + response.type_url + )) + })?; + + Ok(PreparedStatement::new( + prepared_statement_handle, + schema_bytes_to_schema(dataset_schema)?, + schema_bytes_to_schema(parameter_schema)?, + )) + } + + /// Execute a SQL query on the server using [`CommandStatementQuery`] + /// + /// This involves two round trips + /// + /// Step 1: send a [`CommandStatementQuery`] message to the + /// `GetFlightInfo` endpoint of the FlightSQL server to receive a + /// FlightInfo descriptor. + /// + /// Step 2: Fetch the results described in the [`FlightInfo`] + /// + /// This implementation does not support alternate endpoints + pub async fn execute( + &mut self, + statement: PreparedStatement, + ) -> Result<FlightRecordBatchStream> { + let PreparedStatement { + prepared_statement_handle, + dataset_schema: _, + parameter_schema: _, + } = statement; + // TODO handle parameters (via DoPut) + + let cmd = CommandPreparedStatementQuery { + prepared_statement_handle, + }; + + self.do_get_with_cmd(cmd.as_any()).await + } +} + +fn schema_bytes_to_schema(schema: Bytes) -> Result<SchemaRef> { + let schema = if schema.is_empty() { + Schema::empty() + } else { + Schema::try_from(IpcMessage(schema))? + }; + + Ok(Arc::new(schema)) +} + +/// represents a "prepared statement handle" on the server +#[derive(Debug, Clone)] +pub struct PreparedStatement { + /// The handle returned from the server + prepared_statement_handle: Bytes, + + /// Schema for the result of the query + dataset_schema: SchemaRef, + + /// Schema of parameters, if any + parameter_schema: SchemaRef, +} + +impl PreparedStatement { + /// The handle returned from the server + /// Schema for the result of the query + /// Schema of parameters, if any + fn new( + prepared_statement_handle: Bytes, + dataset_schema: SchemaRef, + parameter_schema: SchemaRef, + ) -> Self { + Self { + prepared_statement_handle, + dataset_schema, + parameter_schema, + } + } + + /// Return the schema of the query + pub fn get_dataset_schema(&self) -> SchemaRef { + Arc::clone(&self.dataset_schema) + } + + /// Return the schema needed for the parameters + pub fn get_parameter_schema(&self) -> SchemaRef { + Arc::clone(&self.parameter_schema) + } } diff --git a/service_common/src/planner.rs b/service_common/src/planner.rs index cc22c4b87d..013f215699 100644 --- a/service_common/src/planner.rs +++ b/service_common/src/planner.rs @@ -87,6 +87,30 @@ impl Planner { .await } + /// Creates a plan for a `DoAction` FlightSQL message, + /// as described on [`FlightSQLPlanner::do_action`], on a + /// separate threadpool + pub async fn flight_sql_do_action<N>( + &self, + namespace_name: impl Into<String>, + namespace: Arc<N>, + msg: Any, + ) -> Result<Bytes> + where + N: QueryNamespace + 'static, + { + let namespace_name = namespace_name.into(); + let ctx = self.ctx.child_ctx("planner flight_sql_do_get"); + + self.ctx + .run(async move { + FlightSQLPlanner::do_action(namespace_name, namespace, msg, &ctx) + .await + .map_err(DataFusionError::from) + }) + .await + } + /// Creates the response for a `GetFlightInfo` FlightSQL message /// as described on [`FlightSQLPlanner::get_flight_info`], on a /// separate threadpool. diff --git a/service_grpc_flight/src/lib.rs b/service_grpc_flight/src/lib.rs index 7a9a7acca4..9c164dd56d 100644 --- a/service_grpc_flight/src/lib.rs +++ b/service_grpc_flight/src/lib.rs @@ -31,7 +31,7 @@ use request::{IoxGetRequest, RunQuery}; use service_common::{datafusion_error_to_tonic_code, planner::Planner, QueryNamespaceProvider}; use snafu::{OptionExt, ResultExt, Snafu}; use std::{fmt::Debug, pin::Pin, sync::Arc, task::Poll, time::Instant}; -use tonic::{Request, Response, Streaming}; +use tonic::{metadata::MetadataMap, Request, Response, Streaming}; use trace::{ctx::SpanContext, span::SpanExt}; use trace_http::ctx::{RequestLogContext, RequestLogContextExt}; use tracker::InstrumentedAsyncOwnedSemaphorePermit; @@ -43,6 +43,14 @@ use tracker::InstrumentedAsyncOwnedSemaphorePermit; /// for discussion on adding support to FlightSQL itself. const IOX_FLIGHT_SQL_NAMESPACE_HEADER: &str = "iox-namespace-name"; +/// Environment variable to take the FlightSQL name from +/// TODO move this to a proper CLI / Config argument +/// so it is more discoverable / documented +/// +/// Any value set in this environment variable will be overridden +/// per-request by the `iox-namespace-name` header. +const IOX_FLIGHT_SQL_NAMESPACE_ENV_NAME: &str = "INFLUXDB_IOX_DEFAULT_FLIGHT_SQL_NAMESPACE"; + #[allow(clippy::enum_variant_names)] #[derive(Debug, Snafu)] pub enum Error { @@ -55,7 +63,7 @@ pub enum Error { #[snafu(display("Invalid handshake. No payload provided"))] InvalidHandshake {}, - #[snafu(display("Namespace {} not found", namespace_name))] + #[snafu(display("Namespace '{}' not found", namespace_name))] NamespaceNotFound { namespace_name: String }, #[snafu(display( @@ -68,8 +76,10 @@ pub enum Error { source: DataFusionError, }, - #[snafu(display("no 'iox-namespace-name' header in request"))] - NoNamespaceHeader, + #[snafu(display( + "no default flightsql namespace set and no 'iox-namespace-name' header in request" + ))] + NoFlightSqlNamespace, #[snafu(display("Invalid 'iox-namespace-name' header in request: {}", source))] InvalidNamespaceHeader { @@ -113,7 +123,7 @@ impl From<Error> for tonic::Status { | Error::InvalidNamespaceName { .. } => info!(e=%err, msg), Error::Query { .. } => info!(e=%err, msg), Error::Optimize { .. } - |Error::NoNamespaceHeader + |Error::NoFlightSqlNamespace |Error::InvalidNamespaceHeader { .. } | Error::Planning { .. } | Error::Deserialization { .. } @@ -139,7 +149,7 @@ impl Error { Self::InvalidTicket { .. } | Self::InvalidHandshake { .. } | Self::Deserialization { .. } - | Self::NoNamespaceHeader + | Self::NoFlightSqlNamespace | Self::InvalidNamespaceHeader { .. } | Self::InvalidNamespaceName { .. } => tonic::Code::InvalidArgument, Self::Planning { source, .. } | Self::Query { source, .. } => { @@ -148,7 +158,8 @@ impl Error { Self::UnsupportedMessageType { .. } => tonic::Code::Unimplemented, Error::FlightSQLPlanning { source } => match source { flightsql::Error::DeserializationTypeKnown { .. } - | flightsql::Error::InvalidUtf8 { .. } + | flightsql::Error::InvalidHandle { .. } + | flightsql::Error::Protocol { .. } | flightsql::Error::UnsupportedMessageType { .. } => tonic::Code::InvalidArgument, flightsql::Error::Flight { source: e } => return tonic::Status::from(e), fs_err @ flightsql::Error::Arrow { .. } => { @@ -266,9 +277,55 @@ type TonicStream<T> = Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send /// ┃ ┃ /// ``` /// -/// ## FlightSQL Prepared Statement (NOT YET IMPLEMENTED) +/// ## FlightSQL Prepared Statement (no bind parameters like $1, etc) /// -/// TODO sequence diagram +/// To run a prepared query, via FlightSQL, the client undertakes a +/// few more steps: +/// +/// 1. Encode the query in a `ActionCreatePreparedStatementRequest` +/// request structure +/// +/// 2. Call `DoAction` method with the the request +/// +/// 3. Receive a `ActionCreatePreparedStatementResponse`, which contains +/// a prepared statement "handle". +/// +/// 4. Encode the handle in a `CommandPreparedStatementQuery` +/// FlightSQL structure in a [`FlightDescriptor`] and call the +/// `GetFlightInfo` method with the the [`FlightDescriptor`] +/// +/// 5. Steps 5,6,7 proceed the same as for a FlightSQL ad-hoc query +/// +/// ```text +/// .───────. +/// ╔═══════════╗ ( ) +/// ║ ║ │`───────'│ +/// ║ FlightSQL ║ │ IOx │ +/// ║ Client ║ │.───────.│ +/// ║ ║ ( ) +/// ╚═══════════╝ `───────' +/// ┃ Creates ┃ +/// 1 ┃ ActionCreatePreparedStatementRequest ┃ +/// ┃ ┃ +/// ┃ ┃ +/// ┃ DoAction(ActionCreatePreparedStatementRequest) ┃ +/// 2 ┃━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━▶┃ +/// ┃ ┃ +/// ┃ Result(ActionCreatePreparedStatementResponse) ┃ +/// 3 ┃◀ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ┃ +/// ┃ ┃ +/// ┃ GetFlightInfo(CommandPreparedStatementQuery) ┃ +/// 4 ┃━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━▶┃ +/// ┃ FlightInfo(..Ticket{ ┃ +/// ┃ CommandPreparedStatementQuery}) ┃ +/// 5 ┃◀ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ┃ +/// ┃ ┃ +/// ┃ DoGet(Ticket) ┃ +/// 6 ┃━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━▶┃ +/// ┃ ┃ +/// ┃ Stream of FightData ┃ +/// 7 ┃◀ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ┃ +/// ``` /// /// [Arrow Flight]: https://arrow.apache.org/docs/format/Flight.html /// [Arrow FlightSQL]: https://arrow.apache.org/docs/format/FlightSql.html @@ -303,7 +360,9 @@ where .server .db(&namespace, span_ctx.child_span("get namespace")) .await - .ok_or_else(|| tonic::Status::not_found(format!("Unknown namespace: {namespace}")))?; + .context(NamespaceNotFoundSnafu { + namespace_name: &namespace, + })?; let ctx = db.new_query_context(span_ctx); let (query_completed_token, physical_plan) = match query { @@ -444,17 +503,7 @@ where let span_ctx: Option<SpanContext> = request.extensions().get().cloned(); let trace = external_span_ctx.format_jaeger(); - // look for namespace information in headers - let namespace_name = request - .metadata() - .get(IOX_FLIGHT_SQL_NAMESPACE_HEADER) - .map(|v| { - v.to_str() - .context(InvalidNamespaceHeaderSnafu) - .map(|s| s.to_string()) - }) - .ok_or(Error::NoNamespaceHeader)??; - + let namespace_name = get_flightsql_namespace(request.metadata())?; let flight_descriptor = request.into_inner(); // extract the FlightSQL message @@ -467,8 +516,8 @@ where .server .db(&namespace_name, span_ctx.child_span("get namespace")) .await - .ok_or_else(|| { - tonic::Status::not_found(format!("Unknown namespace: {namespace_name}")) + .context(NamespaceNotFoundSnafu { + namespace_name: &namespace_name, })?; let ctx = db.new_query_context(span_ctx); @@ -521,16 +570,55 @@ where &self, _request: Request<Streaming<FlightData>>, ) -> Result<Response<Self::DoPutStream>, tonic::Status> { + info!("Handling flightsql do_put body"); + Err(tonic::Status::unimplemented("Not yet implemented: do_put")) } async fn do_action( &self, - _request: Request<Action>, + request: Request<Action>, ) -> Result<Response<Self::DoActionStream>, tonic::Status> { - Err(tonic::Status::unimplemented( - "Not yet implemented: do_action", - )) + let external_span_ctx: Option<RequestLogContext> = request.extensions().get().cloned(); + let span_ctx: Option<SpanContext> = request.extensions().get().cloned(); + let trace = external_span_ctx.format_jaeger(); + + let namespace_name = get_flightsql_namespace(request.metadata())?; + let Action { + r#type: action_type, + body, + } = request.into_inner(); + + // extract the FlightSQL message + let msg: Any = Message::decode(body).context(DeserializationSnafu)?; + + let type_url = msg.type_url.to_string(); + info!(%namespace_name, %action_type, %type_url, %trace, "DoAction request"); + + let db = self + .server + .db(&namespace_name, span_ctx.child_span("get namespace")) + .await + .context(NamespaceNotFoundSnafu { + namespace_name: &namespace_name, + })?; + + let ctx = db.new_query_context(span_ctx); + let body = Planner::new(&ctx) + .flight_sql_do_action(&namespace_name, db, msg) + .await + .context(PlanningSnafu); + + if let Err(e) = &body { + info!(%namespace_name, %type_url, %trace, %e, "Error running DoAction"); + } else { + debug!(%namespace_name, %type_url, %trace, "Completed DoAction request"); + }; + + let result = arrow_flight::Result { body: body? }; + let stream = futures::stream::iter([Ok(result)]); + + Ok(Response::new(stream.boxed())) } async fn list_actions( @@ -567,6 +655,23 @@ fn msg_from_descriptor(flight_descriptor: FlightDescriptor) -> Result<Any> { } } +/// Figure out the namespace for this request, in this order: +/// +/// 1. The [`IOX_FLIGHT_SQL_NAMESPACE_HEADER`], for example "iox-namespace-name=the_name"; +/// 2. The environment variable IOX_FLIGHT_SQL_NAMESPACE_ENV_NAME +fn get_flightsql_namespace(metadata: &MetadataMap) -> Result<String> { + if let Some(v) = metadata.get(IOX_FLIGHT_SQL_NAMESPACE_HEADER) { + let v = v.to_str().context(InvalidNamespaceHeaderSnafu)?; + return Ok(v.to_string()); + } + + if let Ok(v) = std::env::var(IOX_FLIGHT_SQL_NAMESPACE_ENV_NAME) { + return Ok(v); + } + + NoFlightSqlNamespaceSnafu.fail() +} + /// Wrapper over a FlightDataEncodeStream that adds IOx specfic /// metadata and records completion struct GetStream {