From 298055e9fb0f27c332844ca3142c8bac50893a89 Mon Sep 17 00:00:00 2001 From: Trevor Hilton Date: Mon, 26 Feb 2024 15:07:48 -0500 Subject: [PATCH] feat: support FlightSQL in 3.0 (#24678) * feat: support FlightSQL by serving gRPC requests on same port as HTTP This commit adds support for FlightSQL queries via gRPC to the influxdb3 service. It does so by ensuring the QueryExecutor implements the QueryNamespaceProvider trait, and the underlying QueryDatabase implements QueryNamespace. Satisfying those requirements allows the construction of a FlightServiceServer from the service_grpc_flight crate. The FlightServiceServer is a gRPC server that can be served via tonic at the API surface; however, enabling this required some tower::Service wrangling. The influxdb3_server/src/server.rs module was introduced to house this code. The objective is to serve both gRPC (via the newly introduced tonic server) and standard REST HTTP requests (via the existing HTTP server) on the same port. This is accomplished by the HybridService which can handle either gRPC or non-gRPC HTTP requests. The HybridService is wrapped in a HybridMakeService which allows us to serve it via hyper::Server on a single bind address. End-to-end tests were added in influxdb3/tests/flight.rs. These cover some basic FlightSQL cases. A common.rs module was added that introduces some fixtures to aid in end-to-end tests in influxdb3. --- Cargo.lock | 16 +- influxdb3/Cargo.toml | 12 ++ influxdb3/tests/common.rs | 99 ++++++++++++ influxdb3/tests/flight.rs | 151 ++++++++++++++++++ influxdb3_server/Cargo.toml | 18 ++- influxdb3_server/src/grpc.rs | 14 ++ influxdb3_server/src/http.rs | 47 +----- influxdb3_server/src/lib.rs | 74 +++++++-- influxdb3_server/src/query_executor.rs | 42 +++-- influxdb3_server/src/service.rs | 213 +++++++++++++++++++++++++ 10 files changed, 605 insertions(+), 81 deletions(-) create mode 100644 influxdb3/tests/common.rs create mode 100644 influxdb3/tests/flight.rs create mode 100644 influxdb3_server/src/grpc.rs create mode 100644 influxdb3_server/src/service.rs diff --git a/Cargo.lock b/Cargo.lock index cec919e1f2..a384adc017 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -435,9 +435,9 @@ dependencies = [ [[package]] name = "assert_cmd" -version = "2.0.13" +version = "2.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00ad3f3a942eee60335ab4342358c161ee296829e0d16ff42fc1d6cb07815467" +checksum = "ed72493ac66d5804837f480ab3766c72bdfab91a65e565fc54fa9e42db0073a8" dependencies = [ "anstyle", "bstr", @@ -2594,15 +2594,23 @@ dependencies = [ name = "influxdb3" version = "0.1.0" dependencies = [ + "arrow", + "arrow-array", + "arrow-flight", + "arrow_util", + "assert_cmd", "backtrace", "clap", "clap_blocks", "console-subscriber", "dotenvy", + "futures", "hex", + "hyper", "influxdb3_client", "influxdb3_server", "influxdb3_write", + "influxdb_iox_client", "iox_query", "iox_time", "ioxd_common", @@ -2624,6 +2632,8 @@ dependencies = [ "tokio", "tokio-util", "tokio_metrics_bridge", + "tonic 0.10.2", + "tower", "trace", "trace_exporters", "trogging", @@ -2653,6 +2663,7 @@ version = "0.1.0" dependencies = [ "arrow", "arrow-csv", + "arrow-flight", "arrow-json", "arrow-schema", "async-trait", @@ -2679,6 +2690,7 @@ dependencies = [ "parking_lot 0.11.2", "parquet", "parquet_file", + "pin-project-lite", "schema", "serde", "serde_json", diff --git a/influxdb3/Cargo.toml b/influxdb3/Cargo.toml index ff952ab4b2..40cfb1293d 100644 --- a/influxdb3/Cargo.toml +++ b/influxdb3/Cargo.toml @@ -71,4 +71,16 @@ jemalloc_replacing_malloc = ["tikv-jemalloc-sys", "tikv-jemalloc-ctl"] clippy = [] [dev-dependencies] +arrow_util = { path = "../arrow_util" } +influxdb_iox_client = { path = "../influxdb_iox_client" } + +# Crates.io dependencies in alphabetical order: +arrow = { workspace = true } +arrow-array = "49.0.0" +arrow-flight = "49.0.0" +assert_cmd = "2.0.14" +futures = "0.3.28" +hyper = "0.14" reqwest = { version = "0.11.24", default-features = false, features = ["rustls-tls"] } +tonic.workspace = true +tower = "0.4.13" diff --git a/influxdb3/tests/common.rs b/influxdb3/tests/common.rs new file mode 100644 index 0000000000..3b663b978d --- /dev/null +++ b/influxdb3/tests/common.rs @@ -0,0 +1,99 @@ +use std::{ + net::{SocketAddr, SocketAddrV4, TcpListener}, + process::{Child, Command, Stdio}, + time::Duration, +}; + +use assert_cmd::cargo::CommandCargoExt; +use influxdb_iox_client::flightsql::FlightSqlClient; + +/// A running instance of the `influxdb3 serve` process +pub struct TestServer { + bind_addr: SocketAddr, + server_process: Child, + http_client: reqwest::Client, +} + +impl TestServer { + /// Spawn a new [`TestServer`] + /// + /// This will run the `influxdb3 serve` command, and bind its HTTP + /// address to a random port on localhost. + pub async fn spawn() -> Self { + let bind_addr = get_local_bind_addr(); + let mut command = Command::cargo_bin("influxdb3").expect("create the influxdb3 command"); + let command = command + .arg("serve") + .args(["--http-bind", &bind_addr.to_string()]) + // TODO - other configuration can be passed through + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + let server_process = command.spawn().expect("spawn the influxdb3 server process"); + + let server = Self { + bind_addr, + server_process, + http_client: reqwest::Client::new(), + }; + + server.wait_until_ready().await; + server + } + + /// Get the URL of the running service for use with an HTTP client + pub fn client_addr(&self) -> String { + format!("http://{addr}", addr = self.bind_addr) + } + + /// Get a [`FlightSqlClient`] for making requests to the running service over gRPC + pub async fn flight_client(&self, database: &str) -> FlightSqlClient { + let channel = tonic::transport::Channel::from_shared(self.client_addr()) + .expect("create tonic channel") + .connect() + .await + .expect("connect to gRPC client"); + let mut client = FlightSqlClient::new(channel); + client.add_header("database", database).unwrap(); + client.add_header("iox-debug", "true").unwrap(); + client + } + + fn kill(&mut self) { + self.server_process.kill().expect("kill the server process"); + } + + async fn wait_until_ready(&self) { + while self + .http_client + .get(format!("{base}/health", base = self.client_addr())) + .send() + .await + .is_err() + { + tokio::time::sleep(Duration::from_millis(10)).await; + } + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.kill(); + } +} + +/// Get an available bind address on localhost +/// +/// This binds a [`TcpListener`] to 127.0.0.1:0, which will randomly +/// select an available port, and produces the resulting local address. +/// The [`TcpListener`] is dropped at the end of the function, thus +/// freeing the port for use by the caller. +fn get_local_bind_addr() -> SocketAddr { + let ip = std::net::Ipv4Addr::new(127, 0, 0, 1); + let port = 0; + let addr = SocketAddrV4::new(ip, port); + TcpListener::bind(addr) + .expect("bind to a socket address") + .local_addr() + .expect("get local address") +} diff --git a/influxdb3/tests/flight.rs b/influxdb3/tests/flight.rs new file mode 100644 index 0000000000..9bad90f0fe --- /dev/null +++ b/influxdb3/tests/flight.rs @@ -0,0 +1,151 @@ +use arrow::record_batch::RecordBatch; +use arrow_flight::{decode::FlightRecordBatchStream, sql::SqlInfo}; +use arrow_util::assert_batches_sorted_eq; +use futures::TryStreamExt; + +use crate::common::TestServer; + +mod common; + +#[tokio::test] +async fn flight() { + let server = TestServer::spawn().await; + + // use the influxdb3_client to write in some data + write_lp_to_db( + &server, + "foo", + "cpu,host=s1,region=us-east usage=0.9 1\n\ + cpu,host=s1,region=us-east usage=0.89 2\n\ + cpu,host=s1,region=us-east usage=0.85 3", + ) + .await; + + let mut client = server.flight_client("foo").await; + + // Ad-hoc Query: + { + let response = client.query("SELECT * FROM cpu").await.unwrap(); + + let batches = collect_stream(response).await; + assert_batches_sorted_eq!( + [ + "+------+---------+--------------------------------+-------+", + "| host | region | time | usage |", + "+------+---------+--------------------------------+-------+", + "| s1 | us-east | 1970-01-01T00:00:00.000000001Z | 0.9 |", + "| s1 | us-east | 1970-01-01T00:00:00.000000002Z | 0.89 |", + "| s1 | us-east | 1970-01-01T00:00:00.000000003Z | 0.85 |", + "+------+---------+--------------------------------+-------+", + ], + &batches + ); + } + + // Ad-hoc Query error: + { + let error = client + .query("SELECT * FROM invalid_table") + .await + .unwrap_err(); + + assert!(error + .to_string() + .contains("table 'public.iox.invalid_table' not found")); + } + + // Prepared query: + { + let handle = client.prepare("SELECT * FROM cpu".into()).await.unwrap(); + let stream = client.execute(handle).await.unwrap(); + + let batches = collect_stream(stream).await; + assert_batches_sorted_eq!( + [ + "+------+---------+--------------------------------+-------+", + "| host | region | time | usage |", + "+------+---------+--------------------------------+-------+", + "| s1 | us-east | 1970-01-01T00:00:00.000000001Z | 0.9 |", + "| s1 | us-east | 1970-01-01T00:00:00.000000002Z | 0.89 |", + "| s1 | us-east | 1970-01-01T00:00:00.000000003Z | 0.85 |", + "+------+---------+--------------------------------+-------+", + ], + &batches + ); + } + + // Get SQL Infos: + { + let infos = vec![SqlInfo::FlightSqlServerName as u32]; + let stream = client.get_sql_info(infos).await.unwrap(); + let batches = collect_stream(stream).await; + assert_batches_sorted_eq!( + [ + "+-----------+-----------------------------+", + "| info_name | value |", + "+-----------+-----------------------------+", + "| 0 | {string_value=InfluxDB IOx} |", + "+-----------+-----------------------------+", + ], + &batches + ); + } + + // Get Tables + { + type OptStr = std::option::Option<&'static str>; + let stream = client + .get_tables(OptStr::None, OptStr::None, OptStr::None, vec![], false) + .await + .unwrap(); + let batches = collect_stream(stream).await; + + assert_batches_sorted_eq!( + [ + "+--------------+--------------------+-------------+------------+", + "| catalog_name | db_schema_name | table_name | table_type |", + "+--------------+--------------------+-------------+------------+", + "| public | information_schema | columns | VIEW |", + "| public | information_schema | df_settings | VIEW |", + "| public | information_schema | tables | VIEW |", + "| public | information_schema | views | VIEW |", + "| public | iox | cpu | BASE TABLE |", + "+--------------+--------------------+-------------+------------+", + ], + &batches + ); + } + + // Get Catalogs + { + let stream = client.get_catalogs().await.unwrap(); + let batches = collect_stream(stream).await; + assert_batches_sorted_eq!( + [ + "+--------------+", + "| catalog_name |", + "+--------------+", + "| public |", + "+--------------+", + ], + &batches + ); + } +} + +async fn write_lp_to_db(server: &TestServer, database: &str, lp: &'static str) { + let client = influxdb3_client::Client::new(server.client_addr()).unwrap(); + client + .api_v3_write_lp(database) + .body(lp) + .send() + .await + .unwrap(); +} + +async fn collect_stream(stream: FlightRecordBatchStream) -> Vec { + stream + .try_collect() + .await + .expect("gather record batch stream") +} diff --git a/influxdb3_server/Cargo.toml b/influxdb3_server/Cargo.toml index ea14d35046..5c7aae45f2 100644 --- a/influxdb3_server/Cargo.toml +++ b/influxdb3_server/Cargo.toml @@ -30,12 +30,19 @@ trace_http = { path = "../trace_http" } tracker = { path = "../tracker" } arrow = { workspace = true, features = ["prettyprint"] } +arrow-flight.workspace = true +arrow-json = "49.0.0" +arrow-schema = "49.0.0" +arrow-csv = "49.0.0" +async-trait = "0.1" chrono = "0.4" datafusion = { workspace = true } -async-trait = "0.1" +flate2 = "1.0.27" futures = "0.3.28" +hex = "0.4.3" hyper = "0.14" parking_lot = "0.11.1" +pin-project-lite = "0.2" thiserror = "1.0" tokio = { version = "1", features = ["rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.7.9" } @@ -43,14 +50,9 @@ tonic = { workspace = true } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" serde_urlencoded = "0.7.0" -tower = "0.4.13" -flate2 = "1.0.27" -workspace-hack = { version = "0.1", path = "../workspace-hack" } -arrow-json = "49.0.0" -arrow-schema = "49.0.0" -arrow-csv = "49.0.0" sha2 = "0.10.8" -hex = "0.4.3" +tower = "0.4.13" +workspace-hack = { version = "0.1", path = "../workspace-hack" } [dev-dependencies] parquet.workspace = true diff --git a/influxdb3_server/src/grpc.rs b/influxdb3_server/src/grpc.rs new file mode 100644 index 0000000000..30dbd28a14 --- /dev/null +++ b/influxdb3_server/src/grpc.rs @@ -0,0 +1,14 @@ +use std::sync::Arc; + +use arrow_flight::flight_service_server::{ + FlightService as Flight, FlightServiceServer as FlightServer, +}; +use authz::Authorizer; +use iox_query::QueryNamespaceProvider; + +pub(crate) fn make_flight_server( + server: Arc, + authz: Option>, +) -> FlightServer { + service_grpc_flight::make_server(server, authz) +} diff --git a/influxdb3_server/src/http.rs b/influxdb3_server/src/http.rs index 2df7abfac7..1de02da081 100644 --- a/influxdb3_server/src/http.rs +++ b/influxdb3_server/src/http.rs @@ -12,7 +12,6 @@ use hyper::header::ACCEPT; use hyper::header::AUTHORIZATION; use hyper::header::CONTENT_ENCODING; use hyper::http::HeaderValue; -use hyper::server::conn::{AddrIncoming, AddrStream}; use hyper::{Body, Method, Request, Response, StatusCode}; use influxdb3_write::persister::TrackedMemoryArrowWriter; use influxdb3_write::WriteBuffer; @@ -27,11 +26,6 @@ use std::num::NonZeroI32; use std::str::Utf8Error; use std::sync::Arc; use thiserror::Error; -use tokio_util::sync::CancellationToken; -use tower::Layer; -use trace_http::metrics::MetricFamily; -use trace_http::metrics::RequestMetrics; -use trace_http::tower::TraceLayer; #[derive(Debug, Error)] pub enum Error { @@ -159,13 +153,11 @@ impl Error { pub type Result = std::result::Result; -const TRACE_SERVER_NAME: &str = "http_api"; - #[derive(Debug)] pub(crate) struct HttpApi { common_state: CommonServerState, write_buffer: Arc, - query_executor: Arc, + pub(crate) query_executor: Arc, max_request_bytes: usize, } @@ -449,42 +441,7 @@ pub(crate) struct WriteParams { pub(crate) db: String, } -pub(crate) async fn serve( - http_server: Arc>, - shutdown: CancellationToken, -) -> Result<()> { - let listener = AddrIncoming::bind(&http_server.common_state.http_addr)?; - println!("binding listener"); - info!(bind_addr=%listener.local_addr(), "bound HTTP listener"); - - let req_metrics = RequestMetrics::new( - Arc::clone(&http_server.common_state.metrics), - MetricFamily::HttpServer, - ); - let trace_layer = TraceLayer::new( - http_server.common_state.trace_header_parser.clone(), - Arc::new(req_metrics), - http_server.common_state.trace_collector().clone(), - TRACE_SERVER_NAME, - ); - - hyper::Server::builder(listener) - .serve(hyper::service::make_service_fn(|_conn: &AddrStream| { - let http_server = Arc::clone(&http_server); - let service = hyper::service::service_fn(move |request: Request<_>| { - route_request(Arc::clone(&http_server), request) - }); - - let service = trace_layer.layer(service); - futures::future::ready(Ok::<_, Infallible>(service)) - })) - .with_graceful_shutdown(shutdown.cancelled()) - .await?; - - Ok(()) -} - -async fn route_request( +pub(crate) async fn route_request( http_server: Arc>, mut req: Request, ) -> Result, Infallible> { diff --git a/influxdb3_server/src/lib.rs b/influxdb3_server/src/lib.rs index 46fc1ed564..302c977c7b 100644 --- a/influxdb3_server/src/lib.rs +++ b/influxdb3_server/src/lib.rs @@ -11,26 +11,43 @@ clippy::clone_on_ref_ptr, clippy::future_not_send )] +mod grpc; mod http; pub mod query_executor; +mod service; +use crate::grpc::make_flight_server; +use crate::http::route_request; use crate::http::HttpApi; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; +use hyper::service::service_fn; use influxdb3_write::{Persister, WriteBuffer}; -use observability_deps::tracing::info; +use iox_query::QueryNamespaceProvider; +use observability_deps::tracing::{error, info}; +use service::hybrid; +use std::convert::Infallible; use std::fmt::Debug; use std::net::SocketAddr; use std::sync::Arc; use thiserror::Error; use tokio_util::sync::CancellationToken; +use tower::Layer; use trace::ctx::SpanContext; use trace::TraceCollector; use trace_http::ctx::RequestLogContext; use trace_http::ctx::TraceHeaderParser; +use trace_http::metrics::MetricFamily; +use trace_http::metrics::RequestMetrics; +use trace_http::tower::TraceLayer; + +const TRACE_SERVER_NAME: &str = "influxdb3_http"; #[derive(Debug, Error)] pub enum Error { + #[error("hyper error: {0}")] + Hyper(#[from] hyper::Error), + #[error("http error: {0}")] Http(#[from] http::Error), @@ -100,11 +117,12 @@ impl CommonServerState { #[derive(Debug)] pub struct Server { + common_state: CommonServerState, http: Arc>, } #[async_trait] -pub trait QueryExecutor: Debug + Send + Sync + 'static { +pub trait QueryExecutor: QueryNamespaceProvider + Debug + Send + Sync + 'static { async fn query( &self, database: &str, @@ -114,7 +132,10 @@ pub trait QueryExecutor: Debug + Send + Sync + 'static { ) -> Result; } -impl Server { +impl Server +where + Q: QueryExecutor, +{ pub fn new( common_state: CommonServerState, _persister: Arc, @@ -124,26 +145,57 @@ impl Server { ) -> Self { let http = Arc::new(HttpApi::new( common_state.clone(), - Arc::::clone(&write_buffer), - Arc::::clone(&query_executor), + Arc::clone(&write_buffer), + Arc::clone(&query_executor), max_http_request_size, )); - Self { http } + Self { common_state, http } } } -pub async fn serve( - server: Server, - shutdown: CancellationToken, -) -> Result<()> { +pub async fn serve(server: Server, shutdown: CancellationToken) -> Result<()> +where + W: WriteBuffer, + Q: QueryExecutor, +{ // TODO: // 1. load the persisted catalog and segments from the persister // 2. load semgments into the buffer // 3. persist any segments from the buffer that are closed and haven't yet been persisted // 4. start serving - http::serve(Arc::clone(&server.http), shutdown).await?; + let req_metrics = RequestMetrics::new( + Arc::clone(&server.common_state.metrics), + MetricFamily::HttpServer, + ); + let trace_layer = TraceLayer::new( + server.common_state.trace_header_parser.clone(), + Arc::new(req_metrics), + server.common_state.trace_collector().clone(), + TRACE_SERVER_NAME, + ); + + let grpc_service = trace_layer.clone().layer(make_flight_server( + Arc::clone(&server.http.query_executor), + // TODO - need to configure authz here: + None, + )); + let rest_service = hyper::service::make_service_fn(|_| { + let http_server = Arc::clone(&server.http); + let service = service_fn(move |req: hyper::Request| { + route_request(Arc::clone(&http_server), req) + }); + let service = trace_layer.layer(service); + futures::future::ready(Ok::<_, Infallible>(service)) + }); + + let hybrid_make_service = hybrid(rest_service, grpc_service); + + hyper::Server::bind(&server.common_state.http_addr) + .serve(hybrid_make_service) + .with_graceful_shutdown(shutdown.cancelled()) + .await?; Ok(()) } diff --git a/influxdb3_server/src/query_executor.rs b/influxdb3_server/src/query_executor.rs index fe5eddc158..88871489f1 100644 --- a/influxdb3_server/src/query_executor.rs +++ b/influxdb3_server/src/query_executor.rs @@ -29,7 +29,7 @@ use iox_query::query_log::StateReceived; use iox_query::QueryNamespaceProvider; use iox_query::{QueryChunk, QueryChunkData, QueryNamespace}; use metric::Registry; -use observability_deps::tracing::info; +use observability_deps::tracing::{debug, info, trace}; use schema::sort::SortKey; use schema::Schema; use std::any::Any; @@ -187,19 +187,37 @@ impl QueryDatabase { query_log, } } + + async fn query_table(&self, table_name: &str) -> Option>> { + self.db_schema.get_table_schema(table_name).map(|schema| { + Arc::new(QueryTable { + db_schema: Arc::clone(&self.db_schema), + name: table_name.into(), + schema: schema.clone(), + write_buffer: Arc::clone(&self.write_buffer), + }) + }) + } } #[async_trait] impl QueryNamespace for QueryDatabase { async fn chunks( &self, - _table_name: &str, - _filters: &[Expr], - _projection: Option<&Vec>, - _ctx: IOxSessionContext, + table_name: &str, + filters: &[Expr], + projection: Option<&Vec>, + ctx: IOxSessionContext, ) -> Result>, DataFusionError> { - info!("called chunks on querydatabase"); - todo!() + let _span_recorder = SpanRecorder::new(ctx.child_span("QueryDatabase::chunks")); + debug!(%table_name, ?filters, "Finding chunks for table"); + + let Some(table) = self.query_table(table_name).await else { + trace!(%table_name, "No entry for table"); + return Ok(vec![]); + }; + + table.chunks(&ctx.inner().state(), projection, filters, None) } fn retention_time_ns(&self) -> Option { @@ -282,14 +300,7 @@ impl SchemaProvider for QueryDatabase { } async fn table(&self, name: &str) -> Option> { - self.db_schema.get_table_schema(name).map(|schema| { - Arc::new(QueryTable { - db_schema: Arc::clone(&self.db_schema), - name: name.into(), - schema: schema.clone(), - write_buffer: Arc::clone(&self.write_buffer), - }) as Arc - }) + self.query_table(name).await.map(|qt| qt as _) } fn table_exist(&self, name: &str) -> bool { @@ -313,6 +324,7 @@ impl QueryTable { filters: &[Expr], _limit: Option, ) -> Result>, DataFusionError> { + // TODO - this is only pulling from write buffer, and not parquet? self.write_buffer.get_table_chunks( &self.db_schema.name, self.name.as_ref(), diff --git a/influxdb3_server/src/service.rs b/influxdb3_server/src/service.rs new file mode 100644 index 0000000000..873354657e --- /dev/null +++ b/influxdb3_server/src/service.rs @@ -0,0 +1,213 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::Future; +use hyper::HeaderMap; +use hyper::{body::HttpBody, Body, Request, Response}; +use pin_project_lite::pin_project; +use tower::Service; + +type BoxError = Box; + +pub(crate) fn hybrid( + make_rest: MakeRest, + grpc: Grpc, +) -> HybridMakeService { + HybridMakeService { make_rest, grpc } +} + +pub struct HybridMakeService { + make_rest: MakeRest, + grpc: Grpc, +} + +impl Service for HybridMakeService +where + MakeRest: Service, + Grpc: Clone, +{ + type Response = HybridService; + type Error = MakeRest::Error; + type Future = HybridMakeServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.make_rest.poll_ready(cx) + } + + fn call(&mut self, conn_info: ConnInfo) -> Self::Future { + HybridMakeServiceFuture { + rest_future: self.make_rest.call(conn_info), + grpc: Some(self.grpc.clone()), + } + } +} + +pin_project! { + pub struct HybridMakeServiceFuture { + #[pin] + rest_future: RestFuture, + grpc: Option, + } +} + +impl Future for HybridMakeServiceFuture +where + RestFuture: Future>, +{ + type Output = Result, RestError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.rest_future.poll(cx) { + Poll::Ready(Ok(rest)) => Poll::Ready(Ok(HybridService { + rest, + grpc: this.grpc.take().expect("future polled after execution"), + })), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +pub struct HybridService { + rest: Rest, + grpc: Grpc, +} + +impl Service> for HybridService +where + Rest: Service, Response = Response>, + Grpc: Service, Response = Response>, + Rest::Error: Into, + Grpc::Error: Into, +{ + type Response = Response>; + type Error = BoxError; + type Future = HybridFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.rest.poll_ready(cx) { + Poll::Ready(Ok(())) => match self.grpc.poll_ready(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Pending => Poll::Pending, + }, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Pending => Poll::Pending, + } + } + + fn call(&mut self, req: Request) -> Self::Future { + match ( + req.version(), + req.headers().get(hyper::header::CONTENT_TYPE), + ) { + (hyper::Version::HTTP_2, Some(hv)) + if hv.as_bytes().starts_with(b"application/grpc") => + { + HybridFuture::Grpc { + grpc_future: self.grpc.call(req), + } + } + _ => HybridFuture::Rest { + rest_future: self.rest.call(req), + }, + } + } +} + +pin_project! { + #[project = HybridBodyProj] + pub enum HybridBody { + Rest { + #[pin] + rest_body: RestBody + }, + Grpc { + #[pin] + grpc_body: GrpcBody + }, + } +} + +impl HttpBody for HybridBody +where + RestBody: HttpBody + Send + Unpin, + GrpcBody: HttpBody + Send + Unpin, + RestBody::Error: Into, + GrpcBody::Error: Into, +{ + type Data = RestBody::Data; + type Error = BoxError; + + fn is_end_stream(&self) -> bool { + match self { + Self::Rest { rest_body } => rest_body.is_end_stream(), + Self::Grpc { grpc_body } => grpc_body.is_end_stream(), + } + } + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project() { + HybridBodyProj::Rest { rest_body } => rest_body.poll_data(cx).map_err(Into::into), + HybridBodyProj::Grpc { grpc_body } => grpc_body.poll_data(cx).map_err(Into::into), + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + match self.project() { + HybridBodyProj::Rest { rest_body } => rest_body.poll_trailers(cx).map_err(Into::into), + HybridBodyProj::Grpc { grpc_body } => grpc_body.poll_trailers(cx).map_err(Into::into), + } + } +} + +pin_project! { + #[project = HybridFutureProj] + pub enum HybridFuture { + Rest { + #[pin] + rest_future: RestFuture, + }, + Grpc { + #[pin] + grpc_future: GrpcFuture, + } + } +} + +impl Future + for HybridFuture +where + RestFuture: Future, RestError>>, + GrpcFuture: Future, GrpcError>>, + RestError: Into, + GrpcError: Into, +{ + type Output = Result>, BoxError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + HybridFutureProj::Rest { rest_future } => match rest_future.poll(cx) { + Poll::Ready(Ok(res)) => { + Poll::Ready(Ok(res.map(|rest_body| HybridBody::Rest { rest_body }))) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())), + Poll::Pending => Poll::Pending, + }, + HybridFutureProj::Grpc { grpc_future } => match grpc_future.poll(cx) { + Poll::Ready(Ok(res)) => { + Poll::Ready(Ok(res.map(|grpc_body| HybridBody::Grpc { grpc_body }))) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())), + Poll::Pending => Poll::Pending, + }, + } + } +}