feat: support FlightSQL in 3.0 (#24678)

* feat: support FlightSQL by serving gRPC requests on same port as HTTP

This commit adds support for FlightSQL queries via gRPC to the influxdb3 service. It does so by ensuring the QueryExecutor implements the QueryNamespaceProvider trait, and the underlying QueryDatabase implements QueryNamespace. Satisfying those requirements allows the construction of a FlightServiceServer from the service_grpc_flight crate.

The FlightServiceServer is a gRPC server that can be served via tonic at the API surface; however, enabling this required some tower::Service wrangling. The influxdb3_server/src/server.rs module was introduced to house this code. The objective is to serve both gRPC (via the newly introduced tonic server) and standard REST HTTP requests (via the existing HTTP server) on the same port.

This is accomplished by the HybridService which can handle either gRPC or non-gRPC HTTP requests. The HybridService is wrapped in a HybridMakeService which allows us to serve it via hyper::Server on a single bind address.

End-to-end tests were added in influxdb3/tests/flight.rs. These cover some basic FlightSQL cases. A common.rs module was added that introduces some fixtures to aid in end-to-end tests in influxdb3.
pull/24677/head
Trevor Hilton 2024-02-26 15:07:48 -05:00 committed by GitHub
parent 75afbbd20e
commit 298055e9fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 605 additions and 81 deletions

16
Cargo.lock generated
View File

@ -435,9 +435,9 @@ dependencies = [
[[package]] [[package]]
name = "assert_cmd" name = "assert_cmd"
version = "2.0.13" version = "2.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00ad3f3a942eee60335ab4342358c161ee296829e0d16ff42fc1d6cb07815467" checksum = "ed72493ac66d5804837f480ab3766c72bdfab91a65e565fc54fa9e42db0073a8"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"bstr", "bstr",
@ -2594,15 +2594,23 @@ dependencies = [
name = "influxdb3" name = "influxdb3"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"arrow",
"arrow-array",
"arrow-flight",
"arrow_util",
"assert_cmd",
"backtrace", "backtrace",
"clap", "clap",
"clap_blocks", "clap_blocks",
"console-subscriber", "console-subscriber",
"dotenvy", "dotenvy",
"futures",
"hex", "hex",
"hyper",
"influxdb3_client", "influxdb3_client",
"influxdb3_server", "influxdb3_server",
"influxdb3_write", "influxdb3_write",
"influxdb_iox_client",
"iox_query", "iox_query",
"iox_time", "iox_time",
"ioxd_common", "ioxd_common",
@ -2624,6 +2632,8 @@ dependencies = [
"tokio", "tokio",
"tokio-util", "tokio-util",
"tokio_metrics_bridge", "tokio_metrics_bridge",
"tonic 0.10.2",
"tower",
"trace", "trace",
"trace_exporters", "trace_exporters",
"trogging", "trogging",
@ -2653,6 +2663,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"arrow", "arrow",
"arrow-csv", "arrow-csv",
"arrow-flight",
"arrow-json", "arrow-json",
"arrow-schema", "arrow-schema",
"async-trait", "async-trait",
@ -2679,6 +2690,7 @@ dependencies = [
"parking_lot 0.11.2", "parking_lot 0.11.2",
"parquet", "parquet",
"parquet_file", "parquet_file",
"pin-project-lite",
"schema", "schema",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -71,4 +71,16 @@ jemalloc_replacing_malloc = ["tikv-jemalloc-sys", "tikv-jemalloc-ctl"]
clippy = [] clippy = []
[dev-dependencies] [dev-dependencies]
arrow_util = { path = "../arrow_util" }
influxdb_iox_client = { path = "../influxdb_iox_client" }
# Crates.io dependencies in alphabetical order:
arrow = { workspace = true }
arrow-array = "49.0.0"
arrow-flight = "49.0.0"
assert_cmd = "2.0.14"
futures = "0.3.28"
hyper = "0.14"
reqwest = { version = "0.11.24", default-features = false, features = ["rustls-tls"] } reqwest = { version = "0.11.24", default-features = false, features = ["rustls-tls"] }
tonic.workspace = true
tower = "0.4.13"

99
influxdb3/tests/common.rs Normal file
View File

@ -0,0 +1,99 @@
use std::{
net::{SocketAddr, SocketAddrV4, TcpListener},
process::{Child, Command, Stdio},
time::Duration,
};
use assert_cmd::cargo::CommandCargoExt;
use influxdb_iox_client::flightsql::FlightSqlClient;
/// A running instance of the `influxdb3 serve` process
pub struct TestServer {
bind_addr: SocketAddr,
server_process: Child,
http_client: reqwest::Client,
}
impl TestServer {
/// Spawn a new [`TestServer`]
///
/// This will run the `influxdb3 serve` command, and bind its HTTP
/// address to a random port on localhost.
pub async fn spawn() -> Self {
let bind_addr = get_local_bind_addr();
let mut command = Command::cargo_bin("influxdb3").expect("create the influxdb3 command");
let command = command
.arg("serve")
.args(["--http-bind", &bind_addr.to_string()])
// TODO - other configuration can be passed through
.stdout(Stdio::null())
.stderr(Stdio::null());
let server_process = command.spawn().expect("spawn the influxdb3 server process");
let server = Self {
bind_addr,
server_process,
http_client: reqwest::Client::new(),
};
server.wait_until_ready().await;
server
}
/// Get the URL of the running service for use with an HTTP client
pub fn client_addr(&self) -> String {
format!("http://{addr}", addr = self.bind_addr)
}
/// Get a [`FlightSqlClient`] for making requests to the running service over gRPC
pub async fn flight_client(&self, database: &str) -> FlightSqlClient {
let channel = tonic::transport::Channel::from_shared(self.client_addr())
.expect("create tonic channel")
.connect()
.await
.expect("connect to gRPC client");
let mut client = FlightSqlClient::new(channel);
client.add_header("database", database).unwrap();
client.add_header("iox-debug", "true").unwrap();
client
}
fn kill(&mut self) {
self.server_process.kill().expect("kill the server process");
}
async fn wait_until_ready(&self) {
while self
.http_client
.get(format!("{base}/health", base = self.client_addr()))
.send()
.await
.is_err()
{
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.kill();
}
}
/// Get an available bind address on localhost
///
/// This binds a [`TcpListener`] to 127.0.0.1:0, which will randomly
/// select an available port, and produces the resulting local address.
/// The [`TcpListener`] is dropped at the end of the function, thus
/// freeing the port for use by the caller.
fn get_local_bind_addr() -> SocketAddr {
let ip = std::net::Ipv4Addr::new(127, 0, 0, 1);
let port = 0;
let addr = SocketAddrV4::new(ip, port);
TcpListener::bind(addr)
.expect("bind to a socket address")
.local_addr()
.expect("get local address")
}

151
influxdb3/tests/flight.rs Normal file
View File

@ -0,0 +1,151 @@
use arrow::record_batch::RecordBatch;
use arrow_flight::{decode::FlightRecordBatchStream, sql::SqlInfo};
use arrow_util::assert_batches_sorted_eq;
use futures::TryStreamExt;
use crate::common::TestServer;
mod common;
#[tokio::test]
async fn flight() {
let server = TestServer::spawn().await;
// use the influxdb3_client to write in some data
write_lp_to_db(
&server,
"foo",
"cpu,host=s1,region=us-east usage=0.9 1\n\
cpu,host=s1,region=us-east usage=0.89 2\n\
cpu,host=s1,region=us-east usage=0.85 3",
)
.await;
let mut client = server.flight_client("foo").await;
// Ad-hoc Query:
{
let response = client.query("SELECT * FROM cpu").await.unwrap();
let batches = collect_stream(response).await;
assert_batches_sorted_eq!(
[
"+------+---------+--------------------------------+-------+",
"| host | region | time | usage |",
"+------+---------+--------------------------------+-------+",
"| s1 | us-east | 1970-01-01T00:00:00.000000001Z | 0.9 |",
"| s1 | us-east | 1970-01-01T00:00:00.000000002Z | 0.89 |",
"| s1 | us-east | 1970-01-01T00:00:00.000000003Z | 0.85 |",
"+------+---------+--------------------------------+-------+",
],
&batches
);
}
// Ad-hoc Query error:
{
let error = client
.query("SELECT * FROM invalid_table")
.await
.unwrap_err();
assert!(error
.to_string()
.contains("table 'public.iox.invalid_table' not found"));
}
// Prepared query:
{
let handle = client.prepare("SELECT * FROM cpu".into()).await.unwrap();
let stream = client.execute(handle).await.unwrap();
let batches = collect_stream(stream).await;
assert_batches_sorted_eq!(
[
"+------+---------+--------------------------------+-------+",
"| host | region | time | usage |",
"+------+---------+--------------------------------+-------+",
"| s1 | us-east | 1970-01-01T00:00:00.000000001Z | 0.9 |",
"| s1 | us-east | 1970-01-01T00:00:00.000000002Z | 0.89 |",
"| s1 | us-east | 1970-01-01T00:00:00.000000003Z | 0.85 |",
"+------+---------+--------------------------------+-------+",
],
&batches
);
}
// Get SQL Infos:
{
let infos = vec![SqlInfo::FlightSqlServerName as u32];
let stream = client.get_sql_info(infos).await.unwrap();
let batches = collect_stream(stream).await;
assert_batches_sorted_eq!(
[
"+-----------+-----------------------------+",
"| info_name | value |",
"+-----------+-----------------------------+",
"| 0 | {string_value=InfluxDB IOx} |",
"+-----------+-----------------------------+",
],
&batches
);
}
// Get Tables
{
type OptStr = std::option::Option<&'static str>;
let stream = client
.get_tables(OptStr::None, OptStr::None, OptStr::None, vec![], false)
.await
.unwrap();
let batches = collect_stream(stream).await;
assert_batches_sorted_eq!(
[
"+--------------+--------------------+-------------+------------+",
"| catalog_name | db_schema_name | table_name | table_type |",
"+--------------+--------------------+-------------+------------+",
"| public | information_schema | columns | VIEW |",
"| public | information_schema | df_settings | VIEW |",
"| public | information_schema | tables | VIEW |",
"| public | information_schema | views | VIEW |",
"| public | iox | cpu | BASE TABLE |",
"+--------------+--------------------+-------------+------------+",
],
&batches
);
}
// Get Catalogs
{
let stream = client.get_catalogs().await.unwrap();
let batches = collect_stream(stream).await;
assert_batches_sorted_eq!(
[
"+--------------+",
"| catalog_name |",
"+--------------+",
"| public |",
"+--------------+",
],
&batches
);
}
}
async fn write_lp_to_db(server: &TestServer, database: &str, lp: &'static str) {
let client = influxdb3_client::Client::new(server.client_addr()).unwrap();
client
.api_v3_write_lp(database)
.body(lp)
.send()
.await
.unwrap();
}
async fn collect_stream(stream: FlightRecordBatchStream) -> Vec<RecordBatch> {
stream
.try_collect()
.await
.expect("gather record batch stream")
}

View File

@ -30,12 +30,19 @@ trace_http = { path = "../trace_http" }
tracker = { path = "../tracker" } tracker = { path = "../tracker" }
arrow = { workspace = true, features = ["prettyprint"] } arrow = { workspace = true, features = ["prettyprint"] }
arrow-flight.workspace = true
arrow-json = "49.0.0"
arrow-schema = "49.0.0"
arrow-csv = "49.0.0"
async-trait = "0.1"
chrono = "0.4" chrono = "0.4"
datafusion = { workspace = true } datafusion = { workspace = true }
async-trait = "0.1" flate2 = "1.0.27"
futures = "0.3.28" futures = "0.3.28"
hex = "0.4.3"
hyper = "0.14" hyper = "0.14"
parking_lot = "0.11.1" parking_lot = "0.11.1"
pin-project-lite = "0.2"
thiserror = "1.0" thiserror = "1.0"
tokio = { version = "1", features = ["rt-multi-thread", "macros", "time"] } tokio = { version = "1", features = ["rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7.9" } tokio-util = { version = "0.7.9" }
@ -43,14 +50,9 @@ tonic = { workspace = true }
serde = { version = "1.0.197", features = ["derive"] } serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114" serde_json = "1.0.114"
serde_urlencoded = "0.7.0" serde_urlencoded = "0.7.0"
tower = "0.4.13"
flate2 = "1.0.27"
workspace-hack = { version = "0.1", path = "../workspace-hack" }
arrow-json = "49.0.0"
arrow-schema = "49.0.0"
arrow-csv = "49.0.0"
sha2 = "0.10.8" sha2 = "0.10.8"
hex = "0.4.3" tower = "0.4.13"
workspace-hack = { version = "0.1", path = "../workspace-hack" }
[dev-dependencies] [dev-dependencies]
parquet.workspace = true parquet.workspace = true

View File

@ -0,0 +1,14 @@
use std::sync::Arc;
use arrow_flight::flight_service_server::{
FlightService as Flight, FlightServiceServer as FlightServer,
};
use authz::Authorizer;
use iox_query::QueryNamespaceProvider;
pub(crate) fn make_flight_server<Q: QueryNamespaceProvider>(
server: Arc<Q>,
authz: Option<Arc<dyn Authorizer>>,
) -> FlightServer<impl Flight> {
service_grpc_flight::make_server(server, authz)
}

View File

@ -12,7 +12,6 @@ use hyper::header::ACCEPT;
use hyper::header::AUTHORIZATION; use hyper::header::AUTHORIZATION;
use hyper::header::CONTENT_ENCODING; use hyper::header::CONTENT_ENCODING;
use hyper::http::HeaderValue; use hyper::http::HeaderValue;
use hyper::server::conn::{AddrIncoming, AddrStream};
use hyper::{Body, Method, Request, Response, StatusCode}; use hyper::{Body, Method, Request, Response, StatusCode};
use influxdb3_write::persister::TrackedMemoryArrowWriter; use influxdb3_write::persister::TrackedMemoryArrowWriter;
use influxdb3_write::WriteBuffer; use influxdb3_write::WriteBuffer;
@ -27,11 +26,6 @@ use std::num::NonZeroI32;
use std::str::Utf8Error; use std::str::Utf8Error;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio_util::sync::CancellationToken;
use tower::Layer;
use trace_http::metrics::MetricFamily;
use trace_http::metrics::RequestMetrics;
use trace_http::tower::TraceLayer;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum Error {
@ -159,13 +153,11 @@ impl Error {
pub type Result<T, E = Error> = std::result::Result<T, E>; pub type Result<T, E = Error> = std::result::Result<T, E>;
const TRACE_SERVER_NAME: &str = "http_api";
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct HttpApi<W, Q> { pub(crate) struct HttpApi<W, Q> {
common_state: CommonServerState, common_state: CommonServerState,
write_buffer: Arc<W>, write_buffer: Arc<W>,
query_executor: Arc<Q>, pub(crate) query_executor: Arc<Q>,
max_request_bytes: usize, max_request_bytes: usize,
} }
@ -449,42 +441,7 @@ pub(crate) struct WriteParams {
pub(crate) db: String, pub(crate) db: String,
} }
pub(crate) async fn serve<W: WriteBuffer, Q: QueryExecutor>( pub(crate) async fn route_request<W: WriteBuffer, Q: QueryExecutor>(
http_server: Arc<HttpApi<W, Q>>,
shutdown: CancellationToken,
) -> Result<()> {
let listener = AddrIncoming::bind(&http_server.common_state.http_addr)?;
println!("binding listener");
info!(bind_addr=%listener.local_addr(), "bound HTTP listener");
let req_metrics = RequestMetrics::new(
Arc::clone(&http_server.common_state.metrics),
MetricFamily::HttpServer,
);
let trace_layer = TraceLayer::new(
http_server.common_state.trace_header_parser.clone(),
Arc::new(req_metrics),
http_server.common_state.trace_collector().clone(),
TRACE_SERVER_NAME,
);
hyper::Server::builder(listener)
.serve(hyper::service::make_service_fn(|_conn: &AddrStream| {
let http_server = Arc::clone(&http_server);
let service = hyper::service::service_fn(move |request: Request<_>| {
route_request(Arc::clone(&http_server), request)
});
let service = trace_layer.layer(service);
futures::future::ready(Ok::<_, Infallible>(service))
}))
.with_graceful_shutdown(shutdown.cancelled())
.await?;
Ok(())
}
async fn route_request<W: WriteBuffer, Q: QueryExecutor>(
http_server: Arc<HttpApi<W, Q>>, http_server: Arc<HttpApi<W, Q>>,
mut req: Request<Body>, mut req: Request<Body>,
) -> Result<Response<Body>, Infallible> { ) -> Result<Response<Body>, Infallible> {

View File

@ -11,26 +11,43 @@ clippy::clone_on_ref_ptr,
clippy::future_not_send clippy::future_not_send
)] )]
mod grpc;
mod http; mod http;
pub mod query_executor; pub mod query_executor;
mod service;
use crate::grpc::make_flight_server;
use crate::http::route_request;
use crate::http::HttpApi; use crate::http::HttpApi;
use async_trait::async_trait; use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream; use datafusion::execution::SendableRecordBatchStream;
use hyper::service::service_fn;
use influxdb3_write::{Persister, WriteBuffer}; use influxdb3_write::{Persister, WriteBuffer};
use observability_deps::tracing::info; use iox_query::QueryNamespaceProvider;
use observability_deps::tracing::{error, info};
use service::hybrid;
use std::convert::Infallible;
use std::fmt::Debug; use std::fmt::Debug;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tower::Layer;
use trace::ctx::SpanContext; use trace::ctx::SpanContext;
use trace::TraceCollector; use trace::TraceCollector;
use trace_http::ctx::RequestLogContext; use trace_http::ctx::RequestLogContext;
use trace_http::ctx::TraceHeaderParser; use trace_http::ctx::TraceHeaderParser;
use trace_http::metrics::MetricFamily;
use trace_http::metrics::RequestMetrics;
use trace_http::tower::TraceLayer;
const TRACE_SERVER_NAME: &str = "influxdb3_http";
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum Error {
#[error("hyper error: {0}")]
Hyper(#[from] hyper::Error),
#[error("http error: {0}")] #[error("http error: {0}")]
Http(#[from] http::Error), Http(#[from] http::Error),
@ -100,11 +117,12 @@ impl CommonServerState {
#[derive(Debug)] #[derive(Debug)]
pub struct Server<W, Q> { pub struct Server<W, Q> {
common_state: CommonServerState,
http: Arc<HttpApi<W, Q>>, http: Arc<HttpApi<W, Q>>,
} }
#[async_trait] #[async_trait]
pub trait QueryExecutor: Debug + Send + Sync + 'static { pub trait QueryExecutor: QueryNamespaceProvider + Debug + Send + Sync + 'static {
async fn query( async fn query(
&self, &self,
database: &str, database: &str,
@ -114,7 +132,10 @@ pub trait QueryExecutor: Debug + Send + Sync + 'static {
) -> Result<SendableRecordBatchStream>; ) -> Result<SendableRecordBatchStream>;
} }
impl<W, Q> Server<W, Q> { impl<W, Q> Server<W, Q>
where
Q: QueryExecutor,
{
pub fn new( pub fn new(
common_state: CommonServerState, common_state: CommonServerState,
_persister: Arc<dyn Persister>, _persister: Arc<dyn Persister>,
@ -124,26 +145,57 @@ impl<W, Q> Server<W, Q> {
) -> Self { ) -> Self {
let http = Arc::new(HttpApi::new( let http = Arc::new(HttpApi::new(
common_state.clone(), common_state.clone(),
Arc::<W>::clone(&write_buffer), Arc::clone(&write_buffer),
Arc::<Q>::clone(&query_executor), Arc::clone(&query_executor),
max_http_request_size, max_http_request_size,
)); ));
Self { http } Self { common_state, http }
} }
} }
pub async fn serve<W: WriteBuffer, Q: QueryExecutor>( pub async fn serve<W, Q>(server: Server<W, Q>, shutdown: CancellationToken) -> Result<()>
server: Server<W, Q>, where
shutdown: CancellationToken, W: WriteBuffer,
) -> Result<()> { Q: QueryExecutor,
{
// TODO: // TODO:
// 1. load the persisted catalog and segments from the persister // 1. load the persisted catalog and segments from the persister
// 2. load semgments into the buffer // 2. load semgments into the buffer
// 3. persist any segments from the buffer that are closed and haven't yet been persisted // 3. persist any segments from the buffer that are closed and haven't yet been persisted
// 4. start serving // 4. start serving
http::serve(Arc::clone(&server.http), shutdown).await?; let req_metrics = RequestMetrics::new(
Arc::clone(&server.common_state.metrics),
MetricFamily::HttpServer,
);
let trace_layer = TraceLayer::new(
server.common_state.trace_header_parser.clone(),
Arc::new(req_metrics),
server.common_state.trace_collector().clone(),
TRACE_SERVER_NAME,
);
let grpc_service = trace_layer.clone().layer(make_flight_server(
Arc::clone(&server.http.query_executor),
// TODO - need to configure authz here:
None,
));
let rest_service = hyper::service::make_service_fn(|_| {
let http_server = Arc::clone(&server.http);
let service = service_fn(move |req: hyper::Request<hyper::Body>| {
route_request(Arc::clone(&http_server), req)
});
let service = trace_layer.layer(service);
futures::future::ready(Ok::<_, Infallible>(service))
});
let hybrid_make_service = hybrid(rest_service, grpc_service);
hyper::Server::bind(&server.common_state.http_addr)
.serve(hybrid_make_service)
.with_graceful_shutdown(shutdown.cancelled())
.await?;
Ok(()) Ok(())
} }

View File

@ -29,7 +29,7 @@ use iox_query::query_log::StateReceived;
use iox_query::QueryNamespaceProvider; use iox_query::QueryNamespaceProvider;
use iox_query::{QueryChunk, QueryChunkData, QueryNamespace}; use iox_query::{QueryChunk, QueryChunkData, QueryNamespace};
use metric::Registry; use metric::Registry;
use observability_deps::tracing::info; use observability_deps::tracing::{debug, info, trace};
use schema::sort::SortKey; use schema::sort::SortKey;
use schema::Schema; use schema::Schema;
use std::any::Any; use std::any::Any;
@ -187,19 +187,37 @@ impl<B: WriteBuffer> QueryDatabase<B> {
query_log, query_log,
} }
} }
async fn query_table(&self, table_name: &str) -> Option<Arc<QueryTable<B>>> {
self.db_schema.get_table_schema(table_name).map(|schema| {
Arc::new(QueryTable {
db_schema: Arc::clone(&self.db_schema),
name: table_name.into(),
schema: schema.clone(),
write_buffer: Arc::clone(&self.write_buffer),
})
})
}
} }
#[async_trait] #[async_trait]
impl<B: WriteBuffer> QueryNamespace for QueryDatabase<B> { impl<B: WriteBuffer> QueryNamespace for QueryDatabase<B> {
async fn chunks( async fn chunks(
&self, &self,
_table_name: &str, table_name: &str,
_filters: &[Expr], filters: &[Expr],
_projection: Option<&Vec<usize>>, projection: Option<&Vec<usize>>,
_ctx: IOxSessionContext, ctx: IOxSessionContext,
) -> Result<Vec<Arc<dyn QueryChunk>>, DataFusionError> { ) -> Result<Vec<Arc<dyn QueryChunk>>, DataFusionError> {
info!("called chunks on querydatabase"); let _span_recorder = SpanRecorder::new(ctx.child_span("QueryDatabase::chunks"));
todo!() debug!(%table_name, ?filters, "Finding chunks for table");
let Some(table) = self.query_table(table_name).await else {
trace!(%table_name, "No entry for table");
return Ok(vec![]);
};
table.chunks(&ctx.inner().state(), projection, filters, None)
} }
fn retention_time_ns(&self) -> Option<i64> { fn retention_time_ns(&self) -> Option<i64> {
@ -282,14 +300,7 @@ impl<B: WriteBuffer> SchemaProvider for QueryDatabase<B> {
} }
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> { async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
self.db_schema.get_table_schema(name).map(|schema| { self.query_table(name).await.map(|qt| qt as _)
Arc::new(QueryTable {
db_schema: Arc::clone(&self.db_schema),
name: name.into(),
schema: schema.clone(),
write_buffer: Arc::clone(&self.write_buffer),
}) as Arc<dyn TableProvider>
})
} }
fn table_exist(&self, name: &str) -> bool { fn table_exist(&self, name: &str) -> bool {
@ -313,6 +324,7 @@ impl<B: WriteBuffer> QueryTable<B> {
filters: &[Expr], filters: &[Expr],
_limit: Option<usize>, _limit: Option<usize>,
) -> Result<Vec<Arc<dyn QueryChunk>>, DataFusionError> { ) -> Result<Vec<Arc<dyn QueryChunk>>, DataFusionError> {
// TODO - this is only pulling from write buffer, and not parquet?
self.write_buffer.get_table_chunks( self.write_buffer.get_table_chunks(
&self.db_schema.name, &self.db_schema.name,
self.name.as_ref(), self.name.as_ref(),

View File

@ -0,0 +1,213 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::Future;
use hyper::HeaderMap;
use hyper::{body::HttpBody, Body, Request, Response};
use pin_project_lite::pin_project;
use tower::Service;
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub(crate) fn hybrid<MakeRest, Grpc>(
make_rest: MakeRest,
grpc: Grpc,
) -> HybridMakeService<MakeRest, Grpc> {
HybridMakeService { make_rest, grpc }
}
pub struct HybridMakeService<MakeRest, Grpc> {
make_rest: MakeRest,
grpc: Grpc,
}
impl<ConnInfo, MakeRest, Grpc> Service<ConnInfo> for HybridMakeService<MakeRest, Grpc>
where
MakeRest: Service<ConnInfo>,
Grpc: Clone,
{
type Response = HybridService<MakeRest::Response, Grpc>;
type Error = MakeRest::Error;
type Future = HybridMakeServiceFuture<MakeRest::Future, Grpc>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.make_rest.poll_ready(cx)
}
fn call(&mut self, conn_info: ConnInfo) -> Self::Future {
HybridMakeServiceFuture {
rest_future: self.make_rest.call(conn_info),
grpc: Some(self.grpc.clone()),
}
}
}
pin_project! {
pub struct HybridMakeServiceFuture<RestFuture, Grpc> {
#[pin]
rest_future: RestFuture,
grpc: Option<Grpc>,
}
}
impl<RestFuture, Rest, RestError, Grpc> Future for HybridMakeServiceFuture<RestFuture, Grpc>
where
RestFuture: Future<Output = Result<Rest, RestError>>,
{
type Output = Result<HybridService<Rest, Grpc>, RestError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.rest_future.poll(cx) {
Poll::Ready(Ok(rest)) => Poll::Ready(Ok(HybridService {
rest,
grpc: this.grpc.take().expect("future polled after execution"),
})),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
pub struct HybridService<Rest, Grpc> {
rest: Rest,
grpc: Grpc,
}
impl<Rest, Grpc, RestBody, GrpcBody> Service<Request<Body>> for HybridService<Rest, Grpc>
where
Rest: Service<Request<Body>, Response = Response<RestBody>>,
Grpc: Service<Request<Body>, Response = Response<GrpcBody>>,
Rest::Error: Into<BoxError>,
Grpc::Error: Into<BoxError>,
{
type Response = Response<HybridBody<RestBody, GrpcBody>>;
type Error = BoxError;
type Future = HybridFuture<Rest::Future, Grpc::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.rest.poll_ready(cx) {
Poll::Ready(Ok(())) => match self.grpc.poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
Poll::Pending => Poll::Pending,
},
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
Poll::Pending => Poll::Pending,
}
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
match (
req.version(),
req.headers().get(hyper::header::CONTENT_TYPE),
) {
(hyper::Version::HTTP_2, Some(hv))
if hv.as_bytes().starts_with(b"application/grpc") =>
{
HybridFuture::Grpc {
grpc_future: self.grpc.call(req),
}
}
_ => HybridFuture::Rest {
rest_future: self.rest.call(req),
},
}
}
}
pin_project! {
#[project = HybridBodyProj]
pub enum HybridBody<RestBody, GrpcBody> {
Rest {
#[pin]
rest_body: RestBody
},
Grpc {
#[pin]
grpc_body: GrpcBody
},
}
}
impl<RestBody, GrpcBody> HttpBody for HybridBody<RestBody, GrpcBody>
where
RestBody: HttpBody + Send + Unpin,
GrpcBody: HttpBody<Data = RestBody::Data> + Send + Unpin,
RestBody::Error: Into<BoxError>,
GrpcBody::Error: Into<BoxError>,
{
type Data = RestBody::Data;
type Error = BoxError;
fn is_end_stream(&self) -> bool {
match self {
Self::Rest { rest_body } => rest_body.is_end_stream(),
Self::Grpc { grpc_body } => grpc_body.is_end_stream(),
}
}
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
match self.project() {
HybridBodyProj::Rest { rest_body } => rest_body.poll_data(cx).map_err(Into::into),
HybridBodyProj::Grpc { grpc_body } => grpc_body.poll_data(cx).map_err(Into::into),
}
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
match self.project() {
HybridBodyProj::Rest { rest_body } => rest_body.poll_trailers(cx).map_err(Into::into),
HybridBodyProj::Grpc { grpc_body } => grpc_body.poll_trailers(cx).map_err(Into::into),
}
}
}
pin_project! {
#[project = HybridFutureProj]
pub enum HybridFuture<RestFuture, GrpcFuture> {
Rest {
#[pin]
rest_future: RestFuture,
},
Grpc {
#[pin]
grpc_future: GrpcFuture,
}
}
}
impl<RestFuture, GrpcFuture, RestBody, GrpcBody, RestError, GrpcError> Future
for HybridFuture<RestFuture, GrpcFuture>
where
RestFuture: Future<Output = Result<Response<RestBody>, RestError>>,
GrpcFuture: Future<Output = Result<Response<GrpcBody>, GrpcError>>,
RestError: Into<BoxError>,
GrpcError: Into<BoxError>,
{
type Output = Result<Response<HybridBody<RestBody, GrpcBody>>, BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
HybridFutureProj::Rest { rest_future } => match rest_future.poll(cx) {
Poll::Ready(Ok(res)) => {
Poll::Ready(Ok(res.map(|rest_body| HybridBody::Rest { rest_body })))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())),
Poll::Pending => Poll::Pending,
},
HybridFutureProj::Grpc { grpc_future } => match grpc_future.poll(cx) {
Poll::Ready(Ok(res)) => {
Poll::Ready(Ok(res.map(|grpc_body| HybridBody::Grpc { grpc_body })))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())),
Poll::Pending => Poll::Pending,
},
}
}
}