feat: make flight responses streaming (#2876)

Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
pull/24376/head
Marco Neumann 2021-10-18 17:01:50 +02:00 committed by GitHub
parent 70555ab33d
commit 157b556d4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 32 deletions

1
Cargo.lock generated
View File

@ -1680,6 +1680,7 @@ dependencies = [
"parquet",
"parquet_catalog",
"parquet_file",
"pin-project",
"pprof",
"predicate",
"predicates",

View File

@ -140,6 +140,7 @@ once_cell = { version = "1.4.0", features = ["parking_lot"] }
parking_lot = "0.11.2"
itertools = "0.10.1"
parquet = "5.5"
pin-project = "1.0"
# used by arrow/datafusion anyway
prettytable-rs = "0.8"
pprof = { version = "^0.5", default-features = false, features = ["flamegraph", "protobuf"], optional = true }

View File

@ -774,8 +774,7 @@ mod tests {
let prepare_sql_span = child(sql_span, "prepare_sql").unwrap();
child(prepare_sql_span, "prepare_plan").unwrap();
let collect_span = child(ctx_span, "collect").unwrap();
let execute_span = child(collect_span, "execute_stream_partitioned").unwrap();
let execute_span = child(ctx_span, "execute_stream_partitioned").unwrap();
let coalesce_span = child(execute_span, "CoalescePartitionsEx").unwrap();
// validate spans from DataFusion ExecutionPlan are present

View File

@ -1,5 +1,6 @@
//! Implements the native gRPC IOx query API using Arrow Flight
use std::fmt::Debug;
use std::task::Poll;
use std::{pin::Pin, sync::Arc};
use arrow::{
@ -13,14 +14,17 @@ use arrow_flight::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PutResult, SchemaAsIpc, SchemaResult, Ticket,
};
use futures::Stream;
use datafusion::physical_plan::ExecutionPlan;
use futures::{SinkExt, Stream, StreamExt};
use pin_project::{pin_project, pinned_drop};
use serde::Deserialize;
use snafu::{ResultExt, Snafu};
use tokio::task::JoinHandle;
use tonic::{Request, Response, Streaming};
use data_types::{DatabaseName, DatabaseNameError};
use observability_deps::tracing::{info, warn};
use query::exec::ExecutionContextProvider;
use query::exec::{ExecutionContextProvider, IOxExecutionContext};
use server::{connection::ConnectionManager, Server};
use crate::influxdb_ioxd::rpc::error::default_server_error_handler;
@ -153,8 +157,6 @@ where
Err(tonic::Status::unimplemented("Not yet implemented"))
}
// TODO: Stream results back directly by using `execute` instead of `collect`
// https://docs.rs/datafusion/3.0.0/datafusion/physical_plan/trait.ExecutionPlan.html#tymethod.execute
async fn do_get(
&self,
request: Request<Ticket>,
@ -182,32 +184,7 @@ where
.await
.context(Planning)?;
// execute the query
let results = ctx
.collect(Arc::clone(&physical_plan))
.await
.map_err(|e| Box::new(e) as _)
.context(Query {
database_name: &read_info.database_name,
})?;
let options = arrow::ipc::writer::IpcWriteOptions::default();
let schema = Arc::new(optimize_schema(&physical_plan.schema()));
let schema_flight_data = SchemaAsIpc::new(&schema, &options).into();
let mut flights = vec![schema_flight_data];
for batch in results {
let batch = optimize_record_batch(&batch, Arc::clone(&schema))?;
let (flight_dictionaries, flight_batch) =
arrow_flight::utils::flight_data_from_arrow_batch(&batch, &options);
flights.extend(flight_dictionaries);
flights.push(flight_batch);
}
let output = futures::stream::iter(flights.into_iter().map(Ok));
let output = GetStream::new(ctx, physical_plan, read_info.database_name).await?;
Ok(Response::new(Box::pin(output) as Self::DoGetStream))
}
@ -268,6 +245,132 @@ where
}
}
#[pin_project(PinnedDrop)]
struct GetStream {
#[pin]
rx: futures::channel::mpsc::Receiver<Result<FlightData, tonic::Status>>,
join_handle: JoinHandle<()>,
done: bool,
}
impl GetStream {
async fn new(
ctx: IOxExecutionContext,
physical_plan: Arc<dyn ExecutionPlan>,
database_name: String,
) -> Result<Self, tonic::Status> {
// setup channel
let (mut tx, rx) = futures::channel::mpsc::channel::<Result<FlightData, tonic::Status>>(1);
// get schema
let schema = Arc::new(optimize_schema(&physical_plan.schema()));
// setup stream
let options = arrow::ipc::writer::IpcWriteOptions::default();
let schema_flight_data = SchemaAsIpc::new(&schema, &options).into();
let mut stream_record_batches = ctx
.execute_stream(Arc::clone(&physical_plan))
.await
.map_err(|e| Box::new(e) as _)
.context(Query {
database_name: &database_name,
})?;
let join_handle = tokio::spawn(async move {
if tx.send(Ok(schema_flight_data)).await.is_err() {
// receiver gone
return;
}
while let Some(batch_or_err) = stream_record_batches.next().await {
match batch_or_err {
Ok(batch) => {
match optimize_record_batch(&batch, Arc::clone(&schema)) {
Ok(batch) => {
let (flight_dictionaries, flight_batch) =
arrow_flight::utils::flight_data_from_arrow_batch(
&batch, &options,
);
for dict in flight_dictionaries {
if tx.send(Ok(dict)).await.is_err() {
// receiver is gone
return;
}
}
if tx.send(Ok(flight_batch)).await.is_err() {
// receiver is gone
return;
}
}
Err(e) => {
// failure sending here is OK because we're cutting the stream anyways
tx.send(Err(e.into())).await.ok();
// end stream
return;
}
}
}
Err(e) => {
// failure sending here is OK because we're cutting the stream anyways
tx.send(Err(Error::Query {
database_name: database_name.clone(),
source: Box::new(e),
}
.into()))
.await
.ok();
// end stream
return;
}
}
}
});
Ok(Self {
rx,
join_handle,
done: false,
})
}
}
#[pinned_drop]
impl PinnedDrop for GetStream {
fn drop(self: Pin<&mut Self>) {
self.join_handle.abort();
}
}
impl Stream for GetStream {
type Item = Result<FlightData, tonic::Status>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
if *this.done {
Poll::Ready(None)
} else {
match this.rx.poll_next(cx) {
Poll::Ready(None) => {
*this.done = true;
Poll::Ready(None)
}
e @ Poll::Ready(Some(Err(_))) => {
*this.done = true;
e
}
other => other,
}
}
}
}
/// Some batches are small slices of the underlying arrays.
/// At this stage we only know the number of rows in the record batch
/// and the sizes in bytes of the backing buffers of the column arrays.