diff --git a/Cargo.lock b/Cargo.lock index e67843d1ac..0a7bd9f9fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1793,6 +1793,7 @@ dependencies = [ "tonic", "tonic-health", "tonic-reflection", + "tower", "trace", "trace_exporters", "trace_http", diff --git a/Cargo.toml b/Cargo.toml index 0fe3496bc7..31827e74f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -151,6 +151,7 @@ tokio-util = { version = "0.6.3" } tonic = "0.5.0" tonic-health = "0.4.0" tonic-reflection = "0.2.0" +tower = "0.4" uuid = { version = "0.8", features = ["v4"] } # jemalloc-sys with unprefixed_malloc_on_supported_platforms feature and heappy are mutually exclusive diff --git a/src/influxdb_ioxd.rs b/src/influxdb_ioxd.rs index 097fc4a6c1..000a5b5937 100644 --- a/src/influxdb_ioxd.rs +++ b/src/influxdb_ioxd.rs @@ -242,7 +242,7 @@ async fn serve( grpc_listener, Arc::clone(&application), Arc::clone(&app_server), - trace_collector, + trace_collector.clone(), frontend_shutdown.clone(), config.initial_serving_state.into(), ) @@ -258,6 +258,7 @@ async fn serve( Arc::clone(&app_server), frontend_shutdown.clone(), max_http_request_size, + trace_collector, ) .fuse(); info!("HTTP server listening"); diff --git a/src/influxdb_ioxd/http.rs b/src/influxdb_ioxd/http.rs index b1cbe8980e..c12fff248f 100644 --- a/src/influxdb_ioxd/http.rs +++ b/src/influxdb_ioxd/http.rs @@ -16,6 +16,8 @@ mod heappy; #[cfg(feature = "pprof")] mod pprof; +mod tower; + // Influx crates use super::planner::Planner; use data_types::{ @@ -35,7 +37,7 @@ use http::header::{CONTENT_ENCODING, CONTENT_TYPE}; use hyper::{http::HeaderValue, Body, Method, Request, Response, StatusCode}; use metrics::KeyValue; use observability_deps::tracing::{self, debug, error}; -use routerify::{prelude::*, Middleware, RequestInfo, Router, RouterError, RouterService}; +use routerify::{prelude::*, Middleware, RequestInfo, Router, RouterError}; use serde::Deserialize; use snafu::{OptionExt, ResultExt, Snafu}; @@ -47,6 +49,7 @@ use std::{ sync::Arc, }; use tokio_util::sync::CancellationToken; +use trace::TraceCollector; /// Constants used in API error codes. /// @@ -950,15 +953,16 @@ pub async fn serve( server: Arc>, shutdown: CancellationToken, max_request_size: usize, + trace_collector: Option>, ) -> Result<(), hyper::Error> where M: ConnectionManager + Send + Sync + Debug + 'static, { let router = router(application, server, max_request_size); - let service = RouterService::new(router).unwrap(); + let new_service = tower::MakeService::new(router, trace_collector); hyper::Server::builder(addr) - .serve(service) + .serve(new_service) .with_graceful_shutdown(shutdown.cancelled()) .await } @@ -981,6 +985,7 @@ mod tests { use serde::de::DeserializeOwned; use server::{db::Db, rules::ProvidedDatabaseRules, ApplicationState, ConnectionManagerImpl}; use tokio_stream::wrappers::ReceiverStream; + use trace::RingBufferTraceCollector; fn make_application() -> Arc { Arc::new(ApplicationState::new( @@ -1001,7 +1006,7 @@ mod tests { async fn test_health() { let application = make_application(); let app_server = make_server(Arc::clone(&application)); - let server_url = test_server(application, Arc::clone(&app_server)); + let server_url = test_server(application, Arc::clone(&app_server), None); let client = Client::new(); let response = client.get(&format!("{}/health", server_url)).send().await; @@ -1020,7 +1025,7 @@ mod tests { .register_metric("my_metric", "description"); let app_server = make_server(Arc::clone(&application)); - let server_url = test_server(application, Arc::clone(&app_server)); + let server_url = test_server(application, Arc::clone(&app_server), None); metric.recorder(&[("tag", "value")]).inc(20); @@ -1037,6 +1042,36 @@ mod tests { assert!(data.contains(&"\nmy_metric_total{tag=\"value\"} 20\n")); } + #[tokio::test] + async fn test_tracing() { + let application = make_application(); + let app_server = make_server(Arc::clone(&application)); + let trace_collector = Arc::new(RingBufferTraceCollector::new(5)); + + let server_url = test_server( + application, + Arc::clone(&app_server), + Some(Arc::::clone(&trace_collector)), + ); + + let client = Client::new(); + let response = client + .get(&format!("{}/health", server_url)) + .header("uber-trace-id", "34f3495:36e34:0:1") + .send() + .await; + + // Print the response so if the test fails, we have a log of what went wrong + check_response("health", response, StatusCode::OK, Some("OK")).await; + + let mut spans = trace_collector.spans(); + assert_eq!(spans.len(), 1); + + let span = spans.pop().unwrap(); + assert_eq!(span.ctx.trace_id.get(), 0x34f3495); + assert_eq!(span.ctx.parent_span_id.unwrap().get(), 0x36e34); + } + #[tokio::test] async fn test_write() { let application = make_application(); @@ -1047,7 +1082,7 @@ mod tests { .create_database(make_rules("MyOrg_MyBucket")) .await .unwrap(); - let server_url = test_server(application, Arc::clone(&app_server)); + let server_url = test_server(application, Arc::clone(&app_server), None); let client = Client::new(); @@ -1096,7 +1131,7 @@ mod tests { .await .unwrap(); - let server_url = test_server(application, Arc::clone(&app_server)); + let server_url = test_server(application, Arc::clone(&app_server), None); let client = Client::new(); @@ -1184,7 +1219,7 @@ mod tests { .create_database(make_rules("MyOrg_MyBucket")) .await .unwrap(); - let server_url = test_server(application, Arc::clone(&app_server)); + let server_url = test_server(application, Arc::clone(&app_server), None); let client = Client::new(); @@ -1474,6 +1509,7 @@ mod tests { fn test_server( application: Arc, server: Arc>, + trace_collector: Option>, ) -> String { // NB: specify port 0 to let the OS pick the port. let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); @@ -1486,6 +1522,7 @@ mod tests { server, CancellationToken::new(), TEST_MAX_REQUEST_SIZE, + trace_collector, )); println!("Started server at {}", server_url); server_url @@ -1509,7 +1546,7 @@ mod tests { .create_database(make_rules("MyOrg_MyBucket")) .await .unwrap(); - let server_url = test_server(application, Arc::clone(&app_server)); + let server_url = test_server(application, Arc::clone(&app_server), None); (app_server, server_url) } diff --git a/src/influxdb_ioxd/http/tower.rs b/src/influxdb_ioxd/http/tower.rs new file mode 100644 index 0000000000..034507a193 --- /dev/null +++ b/src/influxdb_ioxd/http/tower.rs @@ -0,0 +1,71 @@ +use std::convert::Infallible; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::future::BoxFuture; +use futures::{ready, FutureExt}; +use hyper::server::conn::AddrStream; +use hyper::Body; +use routerify::{RequestService, Router, RouterService}; + +use trace::TraceCollector; +use trace_http::tower::{TraceLayer, TraceService}; + +use super::ApplicationError; +use tower::Layer; + +/// `MakeService` can be thought of as a hyper-compatible connection factory +/// +/// Specifically it implements the necessary trait to be used with `hyper::server::Builder::serve` +pub struct MakeService { + inner: RouterService, + trace_layer: trace_http::tower::TraceLayer, +} + +impl MakeService { + pub fn new( + router: Router, + collector: Option>, + ) -> Self { + Self { + inner: RouterService::new(router).unwrap(), + trace_layer: TraceLayer::new(collector), + } + } +} + +impl tower::Service<&AddrStream> for MakeService { + type Response = Service; + type Error = Infallible; + type Future = MakeServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, conn: &AddrStream) -> Self::Future { + MakeServiceFuture { + inner: self.inner.call(conn), + trace_layer: self.trace_layer.clone(), + } + } +} + +/// A future produced by `MakeService` that resolves to a `Service` +pub struct MakeServiceFuture { + inner: BoxFuture<'static, Result, Infallible>>, + trace_layer: trace_http::tower::TraceLayer, +} + +impl Future for MakeServiceFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let maybe_service = ready!(self.inner.poll_unpin(cx)); + Poll::Ready(maybe_service.map(|service| self.trace_layer.layer(service))) + } +} + +pub type Service = TraceService>;