feat: make flight responses streaming (#2876)
Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>pull/24376/head
parent
70555ab33d
commit
157b556d4c
|
@ -1680,6 +1680,7 @@ dependencies = [
|
|||
"parquet",
|
||||
"parquet_catalog",
|
||||
"parquet_file",
|
||||
"pin-project",
|
||||
"pprof",
|
||||
"predicate",
|
||||
"predicates",
|
||||
|
|
|
@ -140,6 +140,7 @@ once_cell = { version = "1.4.0", features = ["parking_lot"] }
|
|||
parking_lot = "0.11.2"
|
||||
itertools = "0.10.1"
|
||||
parquet = "5.5"
|
||||
pin-project = "1.0"
|
||||
# used by arrow/datafusion anyway
|
||||
prettytable-rs = "0.8"
|
||||
pprof = { version = "^0.5", default-features = false, features = ["flamegraph", "protobuf"], optional = true }
|
||||
|
|
|
@ -774,8 +774,7 @@ mod tests {
|
|||
let prepare_sql_span = child(sql_span, "prepare_sql").unwrap();
|
||||
child(prepare_sql_span, "prepare_plan").unwrap();
|
||||
|
||||
let collect_span = child(ctx_span, "collect").unwrap();
|
||||
let execute_span = child(collect_span, "execute_stream_partitioned").unwrap();
|
||||
let execute_span = child(ctx_span, "execute_stream_partitioned").unwrap();
|
||||
let coalesce_span = child(execute_span, "CoalescePartitionsEx").unwrap();
|
||||
|
||||
// validate spans from DataFusion ExecutionPlan are present
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
//! Implements the native gRPC IOx query API using Arrow Flight
|
||||
use std::fmt::Debug;
|
||||
use std::task::Poll;
|
||||
use std::{pin::Pin, sync::Arc};
|
||||
|
||||
use arrow::{
|
||||
|
@ -13,14 +14,17 @@ use arrow_flight::{
|
|||
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
|
||||
HandshakeRequest, HandshakeResponse, PutResult, SchemaAsIpc, SchemaResult, Ticket,
|
||||
};
|
||||
use futures::Stream;
|
||||
use datafusion::physical_plan::ExecutionPlan;
|
||||
use futures::{SinkExt, Stream, StreamExt};
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use serde::Deserialize;
|
||||
use snafu::{ResultExt, Snafu};
|
||||
use tokio::task::JoinHandle;
|
||||
use tonic::{Request, Response, Streaming};
|
||||
|
||||
use data_types::{DatabaseName, DatabaseNameError};
|
||||
use observability_deps::tracing::{info, warn};
|
||||
use query::exec::ExecutionContextProvider;
|
||||
use query::exec::{ExecutionContextProvider, IOxExecutionContext};
|
||||
use server::{connection::ConnectionManager, Server};
|
||||
|
||||
use crate::influxdb_ioxd::rpc::error::default_server_error_handler;
|
||||
|
@ -153,8 +157,6 @@ where
|
|||
Err(tonic::Status::unimplemented("Not yet implemented"))
|
||||
}
|
||||
|
||||
// TODO: Stream results back directly by using `execute` instead of `collect`
|
||||
// https://docs.rs/datafusion/3.0.0/datafusion/physical_plan/trait.ExecutionPlan.html#tymethod.execute
|
||||
async fn do_get(
|
||||
&self,
|
||||
request: Request<Ticket>,
|
||||
|
@ -182,32 +184,7 @@ where
|
|||
.await
|
||||
.context(Planning)?;
|
||||
|
||||
// execute the query
|
||||
let results = ctx
|
||||
.collect(Arc::clone(&physical_plan))
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as _)
|
||||
.context(Query {
|
||||
database_name: &read_info.database_name,
|
||||
})?;
|
||||
|
||||
let options = arrow::ipc::writer::IpcWriteOptions::default();
|
||||
let schema = Arc::new(optimize_schema(&physical_plan.schema()));
|
||||
let schema_flight_data = SchemaAsIpc::new(&schema, &options).into();
|
||||
|
||||
let mut flights = vec![schema_flight_data];
|
||||
|
||||
for batch in results {
|
||||
let batch = optimize_record_batch(&batch, Arc::clone(&schema))?;
|
||||
|
||||
let (flight_dictionaries, flight_batch) =
|
||||
arrow_flight::utils::flight_data_from_arrow_batch(&batch, &options);
|
||||
|
||||
flights.extend(flight_dictionaries);
|
||||
flights.push(flight_batch);
|
||||
}
|
||||
|
||||
let output = futures::stream::iter(flights.into_iter().map(Ok));
|
||||
let output = GetStream::new(ctx, physical_plan, read_info.database_name).await?;
|
||||
|
||||
Ok(Response::new(Box::pin(output) as Self::DoGetStream))
|
||||
}
|
||||
|
@ -268,6 +245,132 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[pin_project(PinnedDrop)]
|
||||
struct GetStream {
|
||||
#[pin]
|
||||
rx: futures::channel::mpsc::Receiver<Result<FlightData, tonic::Status>>,
|
||||
join_handle: JoinHandle<()>,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl GetStream {
|
||||
async fn new(
|
||||
ctx: IOxExecutionContext,
|
||||
physical_plan: Arc<dyn ExecutionPlan>,
|
||||
database_name: String,
|
||||
) -> Result<Self, tonic::Status> {
|
||||
// setup channel
|
||||
let (mut tx, rx) = futures::channel::mpsc::channel::<Result<FlightData, tonic::Status>>(1);
|
||||
|
||||
// get schema
|
||||
let schema = Arc::new(optimize_schema(&physical_plan.schema()));
|
||||
|
||||
// setup stream
|
||||
let options = arrow::ipc::writer::IpcWriteOptions::default();
|
||||
let schema_flight_data = SchemaAsIpc::new(&schema, &options).into();
|
||||
let mut stream_record_batches = ctx
|
||||
.execute_stream(Arc::clone(&physical_plan))
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as _)
|
||||
.context(Query {
|
||||
database_name: &database_name,
|
||||
})?;
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
if tx.send(Ok(schema_flight_data)).await.is_err() {
|
||||
// receiver gone
|
||||
return;
|
||||
}
|
||||
|
||||
while let Some(batch_or_err) = stream_record_batches.next().await {
|
||||
match batch_or_err {
|
||||
Ok(batch) => {
|
||||
match optimize_record_batch(&batch, Arc::clone(&schema)) {
|
||||
Ok(batch) => {
|
||||
let (flight_dictionaries, flight_batch) =
|
||||
arrow_flight::utils::flight_data_from_arrow_batch(
|
||||
&batch, &options,
|
||||
);
|
||||
|
||||
for dict in flight_dictionaries {
|
||||
if tx.send(Ok(dict)).await.is_err() {
|
||||
// receiver is gone
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if tx.send(Ok(flight_batch)).await.is_err() {
|
||||
// receiver is gone
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// failure sending here is OK because we're cutting the stream anyways
|
||||
tx.send(Err(e.into())).await.ok();
|
||||
|
||||
// end stream
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// failure sending here is OK because we're cutting the stream anyways
|
||||
tx.send(Err(Error::Query {
|
||||
database_name: database_name.clone(),
|
||||
source: Box::new(e),
|
||||
}
|
||||
.into()))
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// end stream
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
rx,
|
||||
join_handle,
|
||||
done: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pinned_drop]
|
||||
impl PinnedDrop for GetStream {
|
||||
fn drop(self: Pin<&mut Self>) {
|
||||
self.join_handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for GetStream {
|
||||
type Item = Result<FlightData, tonic::Status>;
|
||||
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
if *this.done {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
match this.rx.poll_next(cx) {
|
||||
Poll::Ready(None) => {
|
||||
*this.done = true;
|
||||
Poll::Ready(None)
|
||||
}
|
||||
e @ Poll::Ready(Some(Err(_))) => {
|
||||
*this.done = true;
|
||||
e
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Some batches are small slices of the underlying arrays.
|
||||
/// At this stage we only know the number of rows in the record batch
|
||||
/// and the sizes in bytes of the backing buffers of the column arrays.
|
||||
|
|
Loading…
Reference in New Issue