diff --git a/Cargo.lock b/Cargo.lock index fac664f096..560ecb3a23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -161,8 +161,8 @@ dependencies = [ "prost 0.8.0", "prost-derive 0.8.0", "tokio", - "tonic 0.5.2", - "tonic-build 0.5.2", + "tonic 0.5.0", + "tonic-build 0.5.1", ] [[package]] @@ -270,7 +270,7 @@ dependencies = [ "log", "md5", "oauth2", - "paste", + "paste 1.0.5", "quick-error", "reqwest", "serde", @@ -371,9 +371,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "1.3.1" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da1976d75adbe5fbc88130ecd119529cf1cc6a93ae1546d8696ee66f0d21af1" +checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" [[package]] name = "bitvec" @@ -868,7 +868,7 @@ dependencies = [ "num_cpus", "ordered-float 2.7.0", "parquet", - "paste", + "paste 1.0.5", "pin-project-lite", "rand 0.8.4", "smallvec", @@ -1303,8 +1303,8 @@ dependencies = [ "serde", "serde_json", "thiserror", - "tonic 0.5.2", - "tonic-build 0.5.2", + "tonic 0.5.0", + "tonic-build 0.5.1", ] [[package]] @@ -1371,7 +1371,7 @@ dependencies = [ "futures", "grpc-router-test-gen", "observability_deps", - "paste", + "paste 1.0.5", "prost 0.8.0", "prost-build 0.8.0", "prost-types 0.8.0", @@ -1379,8 +1379,8 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "tonic 0.5.2", - "tonic-build 0.5.2", + "tonic 0.5.0", + "tonic-build 0.5.1", "tonic-reflection", ] @@ -1391,8 +1391,8 @@ dependencies = [ "prost 0.8.0", "prost-build 0.8.0", "prost-types 0.8.0", - "tonic 0.5.2", - "tonic-build 0.5.2", + "tonic 0.5.0", + "tonic-build 0.5.1", ] [[package]] @@ -1490,9 +1490,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.3" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "399c583b2979440c60be0821a6199eca73bc3c8dcd9d070d75ac726e2c6186e5" +checksum = "60daa14be0e0786db0f03a9e57cb404c9d756eed2b6c62b9ea98ec5743ec75a9" dependencies = [ "bytes", "http", @@ -1714,9 +1714,10 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "tonic 0.5.2", + "tonic 0.5.0", "tonic-health", "tonic-reflection", + "trace", "tracker", "trogging", "uuid", @@ -1740,7 +1741,8 @@ dependencies = [ "serde_json", "thiserror", "tokio", - "tonic 0.5.2", + "tonic 0.5.0", + "tower", ] [[package]] @@ -1899,9 +1901,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.99" +version = "0.2.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7f823d141fe0a24df1e23b4af4e3c7ba9e5966ec514ea068c93024aa7deb765" +checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" [[package]] name = "libloading" @@ -2004,9 +2006,9 @@ dependencies = [ [[package]] name = "matches" -version = "0.1.9" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" +checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" [[package]] name = "md5" @@ -2174,9 +2176,9 @@ dependencies = [ [[package]] name = "native-tls" -version = "0.2.8" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d" +checksum = "b8d96b2e1c8da3957d58100b09f102c6d9cfdfced01b7ec5a8974044bb09dbd4" dependencies = [ "lazy_static", "libc", @@ -2637,7 +2639,7 @@ dependencies = [ "cfg-if", "instant", "libc", - "redox_syscall 0.2.10", + "redox_syscall 0.2.9", "smallvec", "winapi", ] @@ -2707,12 +2709,31 @@ dependencies = [ "uuid", ] +[[package]] +name = "paste" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" +dependencies = [ + "paste-impl", + "proc-macro-hack", +] + [[package]] name = "paste" version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf547ad0c65e31259204bd90935776d1c693cec2f4ff7abb7a1bbbd40dfe58" +[[package]] +name = "paste-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" +dependencies = [ + "proc-macro-hack", +] + [[package]] name = "peeking_take_while" version = "0.1.2" @@ -2910,9 +2931,9 @@ checksum = "57e35a3326b75e49aa85f5dc6ec15b41108cf5aee58eabb1f274dd18b73c2451" [[package]] name = "predicates-tree" -version = "1.0.3" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7dd0fd014130206c9352efbdc92be592751b2b9274dff685348341082c6ea3d" +checksum = "15f553275e5721409451eb85e15fd9a860a6e5ab4496eb215987502b5f5391f2" dependencies = [ "predicates-core", "treeline", @@ -3106,9 +3127,9 @@ dependencies = [ [[package]] name = "protobuf" -version = "2.25.0" +version = "2.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "020f86b07722c5c4291f7c723eac4676b3892d47d9a7708dc2779696407f039b" +checksum = "db50e77ae196458ccd3dc58a31ea1a90b0698ab1b7928d89f644c25d72070267" [[package]] name = "query" @@ -3378,9 +3399,9 @@ checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" [[package]] name = "redox_syscall" -version = "0.2.10" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff" +checksum = "5ab49abadf3f9e1c4bc499e8845e152ad87d2ad2d30371841171169e9d75feee" dependencies = [ "bitflags", ] @@ -3403,7 +3424,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "528532f3d801c87aec9def2add9ca802fe569e44a544afe633765267840abe64" dependencies = [ "getrandom 0.2.3", - "redox_syscall 0.2.10", + "redox_syscall 0.2.9", ] [[package]] @@ -3988,9 +4009,9 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.4" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590" +checksum = "f173ac3d1a7e3b28003f40de0b5ce7fe2710f9b9dc3fc38664cebee46b3b6527" [[package]] name = "smallvec" @@ -4249,7 +4270,7 @@ dependencies = [ "cfg-if", "libc", "rand 0.8.4", - "redox_syscall 0.2.10", + "redox_syscall 0.2.9", "remove_dir_all", "winapi", ] @@ -4346,20 +4367,20 @@ dependencies = [ [[package]] name = "tikv-jemalloc-ctl" -version = "0.4.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb833c46ecbf8b6daeccb347cefcabf9c1beb5c9b0f853e1cec45632d9963e69" +checksum = "f28c80e4338857639f443169a601fafe49866aed8d7a8d565c2f5bfb1a021adf" dependencies = [ "libc", - "paste", + "paste 0.1.18", "tikv-jemalloc-sys", ] [[package]] name = "tikv-jemalloc-sys" -version = "0.4.2+5.2.1-patched.2" +version = "0.4.1+5.2.1-patched" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5844e429d797c62945a566f8da4e24c7fe3fbd5d6617fd8bf7a0b7dc1ee0f22e" +checksum = "8a26331b05179d4cb505c8d6814a7e18d298972f0a551b0e3cefccff927f86d3" dependencies = [ "cc", "fs_extra", @@ -4451,9 +4472,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.10.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cf844b23c6131f624accf65ce0e4e9956a8bb329400ea5bcc26ae3a5c20b0b" +checksum = "4b7b349f11a7047e6d1276853e612d152f5e8a352c61917887cc2169e2366b4c" dependencies = [ "autocfg", "bytes", @@ -4576,9 +4597,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.5.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "796c5e1cd49905e65dd8e700d4cb1dffcbfdb4fc9d017de08c1a537afd83627c" +checksum = "b584f064fdfc50017ec39162d5aebce49912f1eb16fd128e04b7f4ce4907c7e5" dependencies = [ "async-stream", "async-trait", @@ -4619,9 +4640,9 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.5.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12b52d07035516c2b74337d2ac7746075e7dcae7643816c1b12c5ff8a7484c08" +checksum = "d12faebbe071b06f486be82cc9318350814fdd07fcb28f3690840cd770599283" dependencies = [ "proc-macro2", "prost-build 0.8.0", @@ -4631,17 +4652,17 @@ dependencies = [ [[package]] name = "tonic-health" -version = "0.4.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "493fcae35818dffa28437b210a615119d791116c1cac80716f571f35dd55b1b9" +checksum = "14e6de0a7a1b27d9899031b01b83eb09fdc36f3fe8e6254a81840006a463c6d5" dependencies = [ "async-stream", "bytes", "prost 0.8.0", "tokio", "tokio-stream", - "tonic 0.5.2", - "tonic-build 0.5.2", + "tonic 0.5.0", + "tonic-build 0.5.1", ] [[package]] @@ -4655,8 +4676,8 @@ dependencies = [ "prost-types 0.8.0", "tokio", "tokio-stream", - "tonic 0.5.2", - "tonic-build 0.5.2", + "tonic 0.5.0", + "tonic-build 0.5.1", ] [[package]] @@ -4691,6 +4712,24 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "360dfd1d6d30e05fda32ace2c8c70e9c0a9da713275777f5a4dbb8a1893930c6" +[[package]] +name = "trace" +version = "0.1.0" +dependencies = [ + "chrono", + "futures", + "http", + "http-body", + "observability_deps", + "parking_lot", + "pin-project 1.0.8", + "rand 0.8.4", + "serde", + "serde_json", + "snafu", + "tower", +] + [[package]] name = "tracing" version = "0.1.26" diff --git a/Cargo.toml b/Cargo.toml index b1ad84bcf7..088b378923 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,7 @@ panic_logging = { path = "panic_logging" } query = { path = "query" } read_buffer = { path = "read_buffer" } server = { path = "server" } +trace = { path = "trace" } tracker = { path = "tracker" } trogging = { path = "trogging", default-features = false, features = ["structopt"] } @@ -111,7 +112,7 @@ serde_urlencoded = "0.7.0" snafu = "0.6.9" structopt = "0.3.21" thiserror = "1.0.23" -tikv-jemallocator = {version = "0.4.0", features = ["unprefixed_malloc_on_supported_platforms"] } +tikv-jemallocator = { version = "0.4.0", features = ["unprefixed_malloc_on_supported_platforms"] } tikv-jemalloc-ctl = "0.4.0" tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "parking_lot", "signal"] } tokio-stream = { version = "0.1.2", features = ["net"] } diff --git a/influxdb_iox_client/Cargo.toml b/influxdb_iox_client/Cargo.toml index fc7449c849..00c659457b 100644 --- a/influxdb_iox_client/Cargo.toml +++ b/influxdb_iox_client/Cargo.toml @@ -15,7 +15,7 @@ generated_types = { path = "../generated_types" } # Crates.io dependencies, in alphabetical order arrow = { version = "5.0", optional = true } -arrow-flight = { version = "5.0", optional = true} +arrow-flight = { version = "5.0", optional = true } futures-util = { version = "0.3.1", optional = true } http = "0.2.3" hyper = "0.14" @@ -26,6 +26,7 @@ serde_json = { version = "1.0.44", optional = true } thiserror = "1.0.23" tokio = { version = "1.0", features = ["macros"] } tonic = { version = "0.5.0" } +tower = "0.4" [dev-dependencies] # In alphabetical order serde_json = "1.0" diff --git a/influxdb_iox_client/src/client/management.rs b/influxdb_iox_client/src/client/management.rs index e275ea1af0..784b58748e 100644 --- a/influxdb_iox_client/src/client/management.rs +++ b/influxdb_iox_client/src/client/management.rs @@ -338,7 +338,7 @@ pub struct Client { impl Client { /// Creates a new client with the provided connection - pub fn new(channel: tonic::transport::Channel) -> Self { + pub fn new(channel: Connection) -> Self { Self { inner: ManagementServiceClient::new(channel), } diff --git a/influxdb_iox_client/src/client/operations.rs b/influxdb_iox_client/src/client/operations.rs index 472627ba58..dcc0fb6fb0 100644 --- a/influxdb_iox_client/src/client/operations.rs +++ b/influxdb_iox_client/src/client/operations.rs @@ -59,7 +59,7 @@ pub struct Client { impl Client { /// Creates a new client with the provided connection - pub fn new(channel: tonic::transport::Channel) -> Self { + pub fn new(channel: Connection) -> Self { Self { inner: OperationsClient::new(channel), } diff --git a/influxdb_iox_client/src/client/write.rs b/influxdb_iox_client/src/client/write.rs index df07c26091..a912896df2 100644 --- a/influxdb_iox_client/src/client/write.rs +++ b/influxdb_iox_client/src/client/write.rs @@ -47,7 +47,7 @@ pub struct Client { impl Client { /// Creates a new client with the provided connection - pub fn new(channel: tonic::transport::Channel) -> Self { + pub fn new(channel: Connection) -> Self { Self { inner: WriteServiceClient::new(channel.clone()), inner_pb: PBWriteServiceClient::new(channel), diff --git a/influxdb_iox_client/src/connection.rs b/influxdb_iox_client/src/connection.rs index 58e0f622ed..43bee07444 100644 --- a/influxdb_iox_client/src/connection.rs +++ b/influxdb_iox_client/src/connection.rs @@ -1,11 +1,13 @@ -use http::{uri::InvalidUri, Uri}; +use crate::tower::{SetRequestHeadersLayer, SetRequestHeadersService}; +use http::header::HeaderName; +use http::{uri::InvalidUri, HeaderValue, Uri}; use std::convert::TryInto; use std::time::Duration; use thiserror::Error; use tonic::transport::Endpoint; /// The connection type used for clients -pub type Connection = tonic::transport::Channel; +pub type Connection = SetRequestHeadersService; /// The default User-Agent header sent by the HTTP client. pub const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -51,6 +53,7 @@ pub type Result = std::result::Result; #[derive(Debug)] pub struct Builder { user_agent: String, + headers: Vec<(HeaderName, HeaderValue)>, connect_timeout: Duration, timeout: Duration, } @@ -61,6 +64,7 @@ impl std::default::Default for Builder { user_agent: USER_AGENT.into(), connect_timeout: DEFAULT_CONNECT_TIMEOUT, timeout: DEFAULT_TIMEOUT, + headers: Default::default(), } } } @@ -73,18 +77,17 @@ impl Builder { { let endpoint = Endpoint::from(dst.try_into()?) .user_agent(self.user_agent)? + .connect_timeout(self.connect_timeout) .timeout(self.timeout); - // Manually construct connector to workaround https://github.com/hyperium/tonic/issues/498 - let mut connector = hyper::client::HttpConnector::new(); - connector.set_connect_timeout(Some(self.connect_timeout)); + let channel = endpoint.connect().await?; - // Defaults from from tonic::channel::Endpoint - connector.enforce_http(false); - connector.set_nodelay(true); - connector.set_keepalive(None); + // Compose channel with new tower middleware stack + let channel = tower::ServiceBuilder::new() + .layer(SetRequestHeadersLayer::new(self.headers)) + .service(channel); - Ok(endpoint.connect_with_connector(connector).await?) + Ok(channel) } /// Set the `User-Agent` header sent by this client. @@ -95,6 +98,13 @@ impl Builder { } } + /// Sets a header to be included on all requests + pub fn header(self, header: impl Into, value: impl Into) -> Self { + let mut headers = self.headers; + headers.push((header.into(), value.into())); + Self { headers, ..self } + } + /// Sets the maximum duration of time the client will wait for the IOx /// server to accept the TCP connection before aborting the request. /// diff --git a/influxdb_iox_client/src/lib.rs b/influxdb_iox_client/src/lib.rs index 3edafc1449..ddff0c3f80 100644 --- a/influxdb_iox_client/src/lib.rs +++ b/influxdb_iox_client/src/lib.rs @@ -27,3 +27,5 @@ pub mod connection; pub mod format; mod client; + +mod tower; diff --git a/influxdb_iox_client/src/tower.rs b/influxdb_iox_client/src/tower.rs new file mode 100644 index 0000000000..3f5225b7a2 --- /dev/null +++ b/influxdb_iox_client/src/tower.rs @@ -0,0 +1,60 @@ +use http::header::HeaderName; +use http::{HeaderValue, Request, Response}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tower::{Layer, Service}; + +/// `SetRequestHeadersLayer` sets the provided headers on all requests flowing through it +/// unless they're already set +#[derive(Debug, Clone)] +pub(crate) struct SetRequestHeadersLayer { + headers: Arc>, +} + +impl SetRequestHeadersLayer { + pub(crate) fn new(headers: Vec<(HeaderName, HeaderValue)>) -> Self { + Self { + headers: Arc::new(headers), + } + } +} + +impl Layer for SetRequestHeadersLayer { + type Service = SetRequestHeadersService; + + fn layer(&self, service: S) -> Self::Service { + SetRequestHeadersService { + service, + headers: Arc::clone(&self.headers), + } + } +} + +/// SetRequestHeadersService wraps an inner tower::Service and sets the provided +/// headers on requests flowing through it +#[derive(Debug, Clone)] +pub struct SetRequestHeadersService { + service: S, + headers: Arc>, +} + +impl Service> for SetRequestHeadersService +where + S: Service, Response = Response>, +{ + type Response = Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, mut request: Request) -> Self::Future { + let headers = request.headers_mut(); + for (name, value) in self.headers.iter() { + headers.insert(name, value.clone()); + } + self.service.call(request) + } +} diff --git a/src/influxdb_ioxd/rpc.rs b/src/influxdb_ioxd/rpc.rs index 1cc5014be9..d8ffafb81b 100644 --- a/src/influxdb_ioxd/rpc.rs +++ b/src/influxdb_ioxd/rpc.rs @@ -76,6 +76,9 @@ pub async fn serve( where M: ConnectionManager + Send + Sync + Debug + 'static, { + // TODO: Replace this with a jaeger collector + let trace_collector = Arc::new(trace::LogTraceCollector::new()); + let stream = TcpListenerStream::new(socket); let (mut health_reporter, health_service) = tonic_health::server::health_reporter(); @@ -84,7 +87,8 @@ where .build() .context(ReflectionError)?; - let mut builder = tonic::transport::Server::builder(); + let builder = tonic::transport::Server::builder(); + let mut builder = builder.layer(trace::tower::TraceLayer::new(trace_collector)); // important that this one is NOT gated so that it can answer health requests add_service!(builder, health_reporter, health_service); diff --git a/src/main.rs b/src/main.rs index a272f1f350..4758051264 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ use observability_deps::tracing::warn; use crate::commands::tracing::TracingGuard; use influxdb_iox_client::connection::Builder; +use std::str::FromStr; use tikv_jemallocator::Jemalloc; mod commands { @@ -118,7 +119,13 @@ struct Config { env = "IOX_ADDR", default_value = "http://127.0.0.1:8082" )] - host: String, /* TODO: This must be on the root due to https://github.com/clap-rs/clap/pull/2253 */ + host: String, + + /// Additional headers to add to CLI requests + /// + /// Values should be key value pairs separated by ':' + #[structopt(long, global = true)] + header: Vec>, #[structopt(long)] /// Set the maximum number of threads to use. Defaults to the number of @@ -150,10 +157,15 @@ fn main() -> Result<(), std::io::Error> { let tokio_runtime = get_runtime(config.num_threads)?; tokio_runtime.block_on(async move { let host = config.host; + let headers = config.header; let log_verbose_count = config.log_verbose_count; let connection = || async move { - match Builder::default().build(&host).await { + let builder = headers.into_iter().fold(Builder::default(), |builder, kv| { + builder.header(kv.key, kv.value) + }); + + match builder.build(&host).await { Ok(connection) => connection, Err(e) => { eprintln!("Error connecting to {}: {}", host, e); @@ -322,3 +334,35 @@ unsafe fn set_signal_handler(signal: libc::c_int, handler: unsafe extern "C" fn( sigaction(signal, &action, std::ptr::null_mut()); } } + +/// A ':' separated key value pair +#[derive(Debug, Clone)] +struct KeyValue { + pub key: K, + pub value: V, +} + +impl std::str::FromStr for KeyValue +where + K: FromStr, + V: FromStr, + K::Err: std::fmt::Display, + V::Err: std::fmt::Display, +{ + type Err = String; + + fn from_str(s: &str) -> Result { + use itertools::Itertools; + match s.split(':').collect_tuple() { + Some((key, value)) => { + let key = K::from_str(key).map_err(|e| e.to_string())?; + let value = V::from_str(value).map_err(|e| e.to_string())?; + Ok(Self { key, value }) + } + None => Err(format!( + "Invalid key value pair - expected 'KEY:VALUE' got '{}'", + s + )), + } + } +} diff --git a/tests/common/server_fixture.rs b/tests/common/server_fixture.rs index df52839000..27f35cab54 100644 --- a/tests/common/server_fixture.rs +++ b/tests/common/server_fixture.rs @@ -16,6 +16,7 @@ use futures::prelude::*; use generated_types::influxdata::iox::management::v1::{ database_status::DatabaseState, ServerStatus, }; +use influxdb_iox_client::connection::Connection; use once_cell::sync::OnceCell; use tempfile::{NamedTempFile, TempDir}; use tokio::sync::Mutex; @@ -77,7 +78,7 @@ const TOKEN: &str = "InfluxDB IOx doesn't have authentication yet"; /// testing. pub struct ServerFixture { server: Arc, - grpc_channel: tonic::transport::Channel, + grpc_channel: Connection, } /// Specifieds should we configure a server initially @@ -162,7 +163,7 @@ impl ServerFixture { /// Return a channel connected to the gRPC API. Panics if the /// server is not yet up - pub fn grpc_channel(&self) -> tonic::transport::Channel { + pub fn grpc_channel(&self) -> Connection { self.grpc_channel.clone() } @@ -471,9 +472,7 @@ impl TestServer { } /// Create a connection channel for the gRPC endpoint - async fn grpc_channel( - &self, - ) -> influxdb_iox_client::connection::Result { + async fn grpc_channel(&self) -> influxdb_iox_client::connection::Result { grpc_channel(&self.addrs).await } @@ -485,7 +484,7 @@ impl TestServer { /// Create a connection channel for the gRPC endpoint pub async fn grpc_channel( addrs: &BindAddresses, -) -> influxdb_iox_client::connection::Result { +) -> influxdb_iox_client::connection::Result { influxdb_iox_client::connection::Builder::default() .build(&addrs.grpc_base) .await diff --git a/tests/end_to_end_cases/scenario.rs b/tests/end_to_end_cases/scenario.rs index 628b525de8..d9dfe5a14a 100644 --- a/tests/end_to_end_cases/scenario.rs +++ b/tests/end_to_end_cases/scenario.rs @@ -22,7 +22,7 @@ use generated_types::{ influxdata::iox::management::v1::{self as management, *}, ReadSource, TimestampRange, }; -use influxdb_iox_client::flight::PerformQuery; +use influxdb_iox_client::{connection::Connection, flight::PerformQuery}; use crate::common::server_fixture::{ServerFixture, DEFAULT_SERVER_ID}; @@ -361,7 +361,7 @@ impl DatabaseBuilder { } // Build a database - pub async fn build(self, channel: tonic::transport::Channel) { + pub async fn build(self, channel: Connection) { let mut management_client = influxdb_iox_client::management::Client::new(channel); let routing_rules = if self.write_buffer.is_some() { @@ -427,19 +427,13 @@ impl DatabaseBuilder { /// given a channel to talk with the management api, create a new /// database with the specified name configured with a 10MB mutable /// buffer, partitioned on table -pub async fn create_readable_database( - db_name: impl Into, - channel: tonic::transport::Channel, -) { +pub async fn create_readable_database(db_name: impl Into, channel: Connection) { DatabaseBuilder::new(db_name.into()).build(channel).await } /// given a channel to talk with the management api, create a new /// database with no mutable buffer configured, no partitioning rules -pub async fn create_unreadable_database( - db_name: impl Into, - channel: tonic::transport::Channel, -) { +pub async fn create_unreadable_database(db_name: impl Into, channel: Connection) { let mut management_client = influxdb_iox_client::management::Client::new(channel); let rules = DatabaseRules { @@ -456,10 +450,7 @@ pub async fn create_unreadable_database( /// given a channel to talk with the management api, create a new /// database with the specified name configured with a 10MB mutable /// buffer, partitioned on table, with some data written into two partitions -pub async fn create_two_partition_database( - db_name: impl Into, - channel: tonic::transport::Channel, -) { +pub async fn create_two_partition_database(db_name: impl Into, channel: Connection) { let mut write_client = influxdb_iox_client::write::Client::new(channel.clone()); let db_name = db_name.into(); diff --git a/tests/end_to_end_cases/storage_api.rs b/tests/end_to_end_cases/storage_api.rs index 196c66b914..047841648a 100644 --- a/tests/end_to_end_cases/storage_api.rs +++ b/tests/end_to_end_cases/storage_api.rs @@ -14,9 +14,9 @@ use generated_types::{ MeasurementTagValuesRequest, Node, Predicate, ReadFilterRequest, ReadGroupRequest, ReadWindowAggregateRequest, Tag, TagKeysRequest, TagValuesRequest, TimestampRange, }; +use influxdb_iox_client::connection::Connection; use std::str; use test_helpers::tag_key_bytes_to_strings; -use tonic::transport::Channel; #[tokio::test] pub async fn test() { @@ -41,7 +41,7 @@ pub async fn test() { } /// Validate that capabilities storage endpoint is hooked up -async fn capabilities_endpoint(storage_client: &mut StorageClient) { +async fn capabilities_endpoint(storage_client: &mut StorageClient) { let capabilities_response = storage_client.capabilities(Empty {}).await.unwrap(); let capabilities_response = capabilities_response.into_inner(); assert_eq!( @@ -52,7 +52,7 @@ async fn capabilities_endpoint(storage_client: &mut StorageClient) { ); } -async fn read_filter_endpoint(storage_client: &mut StorageClient, scenario: &Scenario) { +async fn read_filter_endpoint(storage_client: &mut StorageClient, scenario: &Scenario) { let read_source = scenario.read_source(); let range = scenario.timestamp_range(); @@ -100,7 +100,7 @@ async fn read_filter_endpoint(storage_client: &mut StorageClient, scena ); } -async fn tag_keys_endpoint(storage_client: &mut StorageClient, scenario: &Scenario) { +async fn tag_keys_endpoint(storage_client: &mut StorageClient, scenario: &Scenario) { let read_source = scenario.read_source(); let range = scenario.timestamp_range(); let predicate = make_tag_predicate("host", "server01"); @@ -124,7 +124,7 @@ async fn tag_keys_endpoint(storage_client: &mut StorageClient, scenario assert_eq!(keys, vec!["_m(0x00)", "host", "name", "region", "_f(0xff)"]); } -async fn tag_values_endpoint(storage_client: &mut StorageClient, scenario: &Scenario) { +async fn tag_values_endpoint(storage_client: &mut StorageClient, scenario: &Scenario) { let read_source = scenario.read_source(); let range = scenario.timestamp_range(); let predicate = make_tag_predicate("host", "server01"); @@ -154,7 +154,7 @@ async fn tag_values_endpoint(storage_client: &mut StorageClient, scenar } async fn measurement_names_endpoint( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, scenario: &Scenario, ) { let read_source = scenario.read_source(); @@ -186,7 +186,7 @@ async fn measurement_names_endpoint( } async fn measurement_tag_keys_endpoint( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, scenario: &Scenario, ) { let read_source = scenario.read_source(); @@ -222,7 +222,7 @@ async fn measurement_tag_keys_endpoint( } async fn measurement_tag_values_endpoint( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, scenario: &Scenario, ) { let read_source = scenario.read_source(); @@ -259,7 +259,7 @@ async fn measurement_tag_values_endpoint( } async fn measurement_fields_endpoint( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, scenario: &Scenario, ) { let read_source = scenario.read_source(); @@ -381,7 +381,7 @@ async fn load_read_group_data(client: &influxdb2_client::Client, scenario: &Scen // Standalone test for read_group with group keys and no aggregate // assumes that load_read_group_data has been previously run async fn test_read_group_none_agg( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, read_source: &std::option::Option, ) { // read_group(group_keys: region, agg: None) @@ -434,7 +434,7 @@ async fn test_read_group_none_agg( /// Test that predicates make it through async fn test_read_group_none_agg_with_predicate( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, read_source: &std::option::Option, ) { let read_group_request = ReadGroupRequest { @@ -480,7 +480,7 @@ async fn test_read_group_none_agg_with_predicate( // "aggregate" (not a "selector" style). assumes that // load_read_group_data has been previously run async fn test_read_group_sum_agg( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, read_source: &std::option::Option, ) { // read_group(group_keys: region, agg: Sum) @@ -535,7 +535,7 @@ async fn test_read_group_sum_agg( // "selector" function last. assumes that // load_read_group_data has been previously run async fn test_read_group_last_agg( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, read_source: &std::option::Option, ) { // read_group(group_keys: region, agg: Last) @@ -754,7 +754,7 @@ fn make_field_predicate(field_name: impl Into) -> Predicate { /// Make a read_group request and returns the results in a comparable format async fn do_read_filter_request( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, request: ReadFilterRequest, ) -> Vec { let request = tonic::Request::new(request); @@ -781,7 +781,7 @@ async fn do_read_filter_request( /// Make a read_group request and returns the results in a comparable format async fn do_read_group_request( - storage_client: &mut StorageClient, + storage_client: &mut StorageClient, request: ReadGroupRequest, ) -> Vec { let request = tonic::Request::new(request); diff --git a/trace/Cargo.toml b/trace/Cargo.toml new file mode 100644 index 0000000000..abff942b4f --- /dev/null +++ b/trace/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "trace" +version = "0.1.0" +authors = ["Raphael Taylor-Davies "] +edition = "2018" +description = "Distributed tracing support within IOx" + +[dependencies] + +chrono = { version = "0.4", features = ["serde"] } +futures = "0.3" +http = "0.2.0" +http-body = "0.4" +observability_deps = { path = "../observability_deps" } +parking_lot = "0.11" +pin-project = "1.0" +rand = "0.8.3" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0.44" +snafu = "0.6" +tower = "0.4" + +[dev-dependencies] diff --git a/trace/src/ctx.rs b/trace/src/ctx.rs new file mode 100644 index 0000000000..11646474cb --- /dev/null +++ b/trace/src/ctx.rs @@ -0,0 +1,277 @@ +use std::num::{NonZeroU128, NonZeroU64, ParseIntError}; +use std::str::FromStr; +use std::sync::Arc; + +use http::HeaderMap; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use snafu::Snafu; + +use crate::{ + span::{Span, SpanStatus}, + TraceCollector, +}; + +const B3_FLAGS: &str = "X-B3-Flags"; +const B3_SAMPLED_HEADER: &str = "X-B3-Sampled"; +const B3_TRACE_ID_HEADER: &str = "X-B3-TraceId"; +const B3_PARENT_SPAN_ID_HEADER: &str = "X-B3-ParentSpanId"; +const B3_SPAN_ID_HEADER: &str = "X-B3-SpanId"; + +/// Error decoding SpanContext from transport representation +#[derive(Debug, Snafu)] +pub enum ContextError { + #[snafu(display("header '{}' not found", header))] + Missing { header: &'static str }, + + #[snafu(display("header '{}' has non-UTF8 content: {}", header, source))] + InvalidUtf8 { + header: &'static str, + source: http::header::ToStrError, + }, + + #[snafu(display("error decoding header '{}': {}", header, source))] + HeaderDecodeError { + header: &'static str, + source: DecodeError, + }, +} + +/// Error decoding a specific header value +#[derive(Debug, Snafu)] +pub enum DecodeError { + #[snafu(display("value decode error: {}", source))] + ValueDecodeError { source: ParseIntError }, + + #[snafu(display("value cannot be 0"))] + ZeroError, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct TraceId(pub NonZeroU128); + +impl<'a> FromStr for TraceId { + type Err = DecodeError; + + fn from_str(s: &str) -> Result { + Ok(Self( + NonZeroU128::new( + s.parse() + .map_err(|source| DecodeError::ValueDecodeError { source })?, + ) + .ok_or(DecodeError::ZeroError)?, + )) + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct SpanId(pub NonZeroU64); + +impl SpanId { + pub fn gen() -> Self { + // Should this be a UUID? + Self(rand::thread_rng().gen()) + } +} + +impl<'a> FromStr for SpanId { + type Err = DecodeError; + + fn from_str(s: &str) -> Result { + Ok(Self( + NonZeroU64::new( + s.parse() + .map_err(|source| DecodeError::ValueDecodeError { source })?, + ) + .ok_or(DecodeError::ZeroError)?, + )) + } +} + +/// The immutable context of a `Span` +/// +/// Importantly this contains all the information necessary to create a child `Span` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SpanContext { + pub trace_id: TraceId, + + pub parent_span_id: Option, + + pub span_id: SpanId, + + #[serde(skip)] + pub collector: Option>, +} + +impl SpanContext { + /// Creates a new child of the Span described by this TraceContext + pub fn child<'a>(&self, name: &'a str) -> Span<'a> { + Span { + name, + ctx: Self { + trace_id: self.trace_id, + span_id: SpanId::gen(), + collector: self.collector.clone(), + parent_span_id: Some(self.span_id), + }, + start: None, + end: None, + status: SpanStatus::Unknown, + metadata: Default::default(), + events: Default::default(), + } + } + + /// Create a SpanContext for the trace described in the request's headers + /// + /// Follows the B3 multiple header encoding defined here + /// - + pub fn from_headers( + collector: &Arc, + headers: &HeaderMap, + ) -> Result, ContextError> { + let debug = decoded_header(headers, B3_FLAGS)? + .map(|header| header == "1") + .unwrap_or(false); + + let sampled = match debug { + // Debug implies an accept decision + true => true, + false => decoded_header(headers, B3_SAMPLED_HEADER)? + .map(|value| value == "1" || value == "true") + .unwrap_or(false), + }; + + if !sampled { + return Ok(None); + } + + Ok(Some(Self { + trace_id: required_header(headers, B3_TRACE_ID_HEADER)?, + parent_span_id: parsed_header(headers, B3_PARENT_SPAN_ID_HEADER)?, + span_id: required_header(headers, B3_SPAN_ID_HEADER)?, + collector: Some(Arc::clone(collector)), + })) + } +} + +/// Decodes a given header from the provided HeaderMap to a string +/// +/// - Returns Ok(None) if the header doesn't exist +/// - Returns Err if the header fails to decode to a string +/// - Returns Ok(Some(_)) otherwise +fn decoded_header<'a>( + headers: &'a HeaderMap, + header: &'static str, +) -> Result, ContextError> { + headers + .get(header) + .map(|value| { + value + .to_str() + .map_err(|source| ContextError::InvalidUtf8 { header, source }) + }) + .transpose() +} + +/// Decodes and parses a given header from the provided HeaderMap +/// +/// - Returns Ok(None) if the header doesn't exist +/// - Returns Err if the header fails to decode to a string or fails to parse +/// - Returns Ok(Some(_)) otherwise +fn parsed_header>( + headers: &HeaderMap, + header: &'static str, +) -> Result, ContextError> { + decoded_header(headers, header)? + .map(FromStr::from_str) + .transpose() + .map_err(|source| ContextError::HeaderDecodeError { source, header }) +} + +/// Decodes and parses a given required header from the provided HeaderMap +/// +/// - Returns Err if the header fails to decode to a string, fails to parse, or doesn't exist +/// - Returns Ok(str) otherwise +fn required_header>( + headers: &HeaderMap, + header: &'static str, +) -> Result { + parsed_header(headers, header)?.ok_or(ContextError::Missing { header }) +} + +#[cfg(test)] +mod tests { + use super::*; + use http::HeaderValue; + + #[test] + fn test_decode() { + let collector: Arc = Arc::new(crate::LogTraceCollector::new()); + let mut headers = HeaderMap::new(); + + // No headers should be None + assert!(SpanContext::from_headers(&collector, &headers) + .unwrap() + .is_none()); + + headers.insert(B3_SAMPLED_HEADER, HeaderValue::from_static("0")); + + // Not sampled + assert!(SpanContext::from_headers(&collector, &headers) + .unwrap() + .is_none()); + + headers.insert(B3_SAMPLED_HEADER, HeaderValue::from_static("1")); + + // Missing required headers + assert_eq!( + SpanContext::from_headers(&collector, &headers) + .unwrap_err() + .to_string(), + "header 'X-B3-TraceId' not found" + ); + + headers.insert(B3_TRACE_ID_HEADER, HeaderValue::from_static("99999999")); + headers.insert(B3_SPAN_ID_HEADER, HeaderValue::from_static("69559")); + + let span = SpanContext::from_headers(&collector, &headers) + .unwrap() + .unwrap(); + + assert_eq!(span.span_id.0.get(), 69559); + assert_eq!(span.trace_id.0.get(), 99999999); + assert!(span.parent_span_id.is_none()); + + headers.insert( + B3_PARENT_SPAN_ID_HEADER, + HeaderValue::from_static("4595945"), + ); + + let span = SpanContext::from_headers(&collector, &headers) + .unwrap() + .unwrap(); + + assert_eq!(span.span_id.0.get(), 69559); + assert_eq!(span.trace_id.0.get(), 99999999); + assert_eq!(span.parent_span_id.unwrap().0.get(), 4595945); + + headers.insert(B3_SPAN_ID_HEADER, HeaderValue::from_static("not a number")); + + assert_eq!( + SpanContext::from_headers(&collector, &headers) + .unwrap_err() + .to_string(), + "error decoding header 'X-B3-SpanId': value decode error: invalid digit found in string" + ); + + headers.insert(B3_SPAN_ID_HEADER, HeaderValue::from_static("0")); + + assert_eq!( + SpanContext::from_headers(&collector, &headers) + .unwrap_err() + .to_string(), + "error decoding header 'X-B3-SpanId': value cannot be 0" + ); + } +} diff --git a/trace/src/lib.rs b/trace/src/lib.rs new file mode 100644 index 0000000000..d46ec1a878 --- /dev/null +++ b/trace/src/lib.rs @@ -0,0 +1,76 @@ +#![deny(broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_debug_implementations, + clippy::explicit_iter_loop, + clippy::use_self, + clippy::clone_on_ref_ptr, + clippy::future_not_send +)] + +use std::collections::VecDeque; + +use parking_lot::Mutex; + +use observability_deps::tracing::info; + +use crate::span::Span; + +pub mod ctx; +pub mod span; +pub mod tower; + +/// A TraceCollector is a sink for completed `Span` +pub trait TraceCollector: std::fmt::Debug + Send + Sync { + fn export(&self, span: &span::Span<'_>); +} + +/// A basic trace collector that prints to stdout +#[derive(Debug)] +pub struct LogTraceCollector {} + +impl LogTraceCollector { + pub fn new() -> Self { + Self {} + } +} + +impl Default for LogTraceCollector { + fn default() -> Self { + Self::new() + } +} + +impl TraceCollector for LogTraceCollector { + fn export(&self, span: &Span<'_>) { + info!("completed span {}", span.json()) + } +} + +/// A trace collector that maintains a ring buffer of spans +#[derive(Debug)] +pub struct RingBufferTraceCollector { + buffer: Mutex>, +} + +impl RingBufferTraceCollector { + pub fn new(capacity: usize) -> Self { + Self { + buffer: Mutex::new(VecDeque::with_capacity(capacity)), + } + } + + pub fn spans(&self) -> Vec { + self.buffer.lock().iter().cloned().collect() + } +} + +impl TraceCollector for RingBufferTraceCollector { + fn export(&self, span: &Span<'_>) { + let serialized = span.json(); + let mut buffer = self.buffer.lock(); + if buffer.len() == buffer.capacity() { + buffer.pop_front(); + } + buffer.push_back(serialized); + } +} diff --git a/trace/src/span.rs b/trace/src/span.rs new file mode 100644 index 0000000000..bde4fc4154 --- /dev/null +++ b/trace/src/span.rs @@ -0,0 +1,252 @@ +use std::collections::HashMap; +use std::ops::{Deref, DerefMut}; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use observability_deps::tracing::error; + +use crate::ctx::SpanContext; + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +pub enum SpanStatus { + Unknown, + Ok, + Err, +} + +/// A `Span` is a representation of a an interval of time spent performing some operation +/// +/// A `Span` has a name, metadata, a start and end time and a unique ID. Additionally they +/// have relationships with other Spans that together comprise a Trace +/// +/// On Drop a `Span` is published to the registered collector +/// +#[derive(Debug, Serialize, Deserialize)] +pub struct Span<'a> { + pub name: &'a str, + + //#[serde(flatten)] - https://github.com/serde-rs/json/issues/505 + pub ctx: SpanContext, + + pub start: Option>, + + pub end: Option>, + + pub status: SpanStatus, + + #[serde(borrow)] + pub metadata: HashMap<&'a str, MetaValue<'a>>, + + #[serde(borrow)] + pub events: Vec>, +} + +impl<'a> Span<'a> { + pub fn event(&mut self, meta: impl Into>) { + let event = SpanEvent { + time: Utc::now(), + msg: meta.into(), + }; + self.events.push(event) + } + + pub fn error(&mut self, meta: impl Into>) { + self.event(meta); + self.status = SpanStatus::Err; + } + + pub fn json(&self) -> String { + match serde_json::to_string(self) { + Ok(serialized) => serialized, + Err(e) => { + error!(%e, "error serializing span to JSON"); + format!("\"Invalid span: {}\"", e) + } + } + } +} + +impl<'a> Drop for Span<'a> { + fn drop(&mut self) { + if let Some(collector) = &self.ctx.collector { + collector.export(self) + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SpanEvent<'a> { + pub time: DateTime, + + #[serde(borrow)] + pub msg: MetaValue<'a>, +} + +/// Values that can be stored in a Span's metadata and events +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MetaValue<'a> { + String(&'a str), + Float(f64), + Int(i64), +} + +impl<'a> From<&'a str> for MetaValue<'a> { + fn from(v: &'a str) -> Self { + Self::String(v) + } +} + +impl<'a> From for MetaValue<'a> { + fn from(v: f64) -> Self { + Self::Float(v) + } +} + +impl<'a> From for MetaValue<'a> { + fn from(v: i64) -> Self { + Self::Int(v) + } +} + +/// Updates the start and end times on the provided Span +#[derive(Debug)] +pub struct EnteredSpan<'a> { + span: Span<'a>, +} + +impl<'a> Deref for EnteredSpan<'a> { + type Target = Span<'a>; + + fn deref(&self) -> &Self::Target { + &self.span + } +} + +impl<'a> DerefMut for EnteredSpan<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.span + } +} + +impl<'a> EnteredSpan<'a> { + pub fn new(mut span: Span<'a>) -> Self { + span.start = Some(Utc::now()); + Self { span } + } +} + +impl<'a> Drop for EnteredSpan<'a> { + fn drop(&mut self) { + let now = Utc::now(); + + // SystemTime is not monotonic so must also check min + + self.span.start = Some(match self.span.start { + Some(a) => a.min(now), + None => now, + }); + + self.span.end = Some(match self.span.end { + Some(a) => a.max(now), + None => now, + }); + } +} + +#[cfg(test)] +mod tests { + use std::num::{NonZeroU128, NonZeroU64}; + use std::sync::Arc; + + use chrono::TimeZone; + + use crate::ctx::{SpanId, TraceId}; + use crate::{RingBufferTraceCollector, TraceCollector}; + + use super::*; + + fn make_span(collector: Arc) -> Span<'static> { + Span { + name: "foo", + ctx: SpanContext { + trace_id: TraceId(NonZeroU128::new(23948923).unwrap()), + parent_span_id: None, + span_id: SpanId(NonZeroU64::new(3498394).unwrap()), + collector: Some(collector), + }, + start: None, + end: None, + status: SpanStatus::Unknown, + metadata: Default::default(), + events: vec![], + } + } + + #[test] + fn test_span() { + let collector = Arc::new(RingBufferTraceCollector::new(5)); + + let mut span = make_span(Arc::::clone(&collector)); + + assert_eq!( + span.json(), + r#"{"name":"foo","ctx":{"trace_id":23948923,"parent_span_id":null,"span_id":3498394},"start":null,"end":null,"status":"Unknown","metadata":{},"events":[]}"# + ); + + span.events.push(SpanEvent { + time: Utc.timestamp_nanos(1000), + msg: MetaValue::String("this is a test event"), + }); + + assert_eq!( + span.json(), + r#"{"name":"foo","ctx":{"trace_id":23948923,"parent_span_id":null,"span_id":3498394},"start":null,"end":null,"status":"Unknown","metadata":{},"events":[{"time":"1970-01-01T00:00:00.000001Z","msg":"this is a test event"}]}"# + ); + + span.metadata.insert("foo", MetaValue::String("bar")); + span.start = Some(Utc.timestamp_nanos(100)); + + assert_eq!( + span.json(), + r#"{"name":"foo","ctx":{"trace_id":23948923,"parent_span_id":null,"span_id":3498394},"start":"1970-01-01T00:00:00.000000100Z","end":null,"status":"Unknown","metadata":{"foo":"bar"},"events":[{"time":"1970-01-01T00:00:00.000001Z","msg":"this is a test event"}]}"# + ); + + span.status = SpanStatus::Ok; + span.ctx.parent_span_id = Some(SpanId(NonZeroU64::new(23493).unwrap())); + + let expected = r#"{"name":"foo","ctx":{"trace_id":23948923,"parent_span_id":23493,"span_id":3498394},"start":"1970-01-01T00:00:00.000000100Z","end":null,"status":"Ok","metadata":{"foo":"bar"},"events":[{"time":"1970-01-01T00:00:00.000001Z","msg":"this is a test event"}]}"#; + assert_eq!(span.json(), expected); + + std::mem::drop(span); + + // Should publish span + let spans = collector.spans(); + assert_eq!(spans.len(), 1); + assert_eq!(spans[0], expected) + } + + #[test] + fn test_entered_span() { + let collector = Arc::new(RingBufferTraceCollector::new(5)); + + let span = make_span(Arc::::clone(&collector)); + + let entered = EnteredSpan::new(span); + + std::thread::sleep(std::time::Duration::from_millis(100)); + + std::mem::drop(entered); + + // Span should have been published on drop with set spans + let spans = collector.spans(); + assert_eq!(spans.len(), 1); + + let span: Span<'_> = serde_json::from_str(spans[0].as_str()).unwrap(); + + assert!(span.start.is_some()); + assert!(span.end.is_some()); + assert!(span.start.unwrap() < span.end.unwrap()); + } +} diff --git a/trace/src/tower.rs b/trace/src/tower.rs new file mode 100644 index 0000000000..313f173735 --- /dev/null +++ b/trace/src/tower.rs @@ -0,0 +1,187 @@ +//! +//! Tower plumbing for adding tracing instrumentation to an HTTP service stack +//! +//! This is loosely based on tower-http's trace crate but with the tokio-tracing +//! specific bits removed and less generics. +//! +//! For those not familiar with tower: +//! +//! - A Layer produces a Service +//! - A Service can then be called with a request which returns a Future +//! - This Future returns a response which contains a Body +//! - This Body contains the data payload (potentially streamed) +//! + +use crate::{ctx::SpanContext, span::EnteredSpan, TraceCollector}; +use futures::ready; +use http::{Request, Response}; +use http_body::SizeHint; +use observability_deps::tracing::error; +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tower::{Layer, Service}; + +/// `TraceLayer` implements `tower::Layer` and can be used to decorate a +/// `tower::Service` to collect information about requests flowing through it +#[derive(Debug, Clone)] +pub struct TraceLayer { + collector: Arc, +} + +impl TraceLayer { + pub fn new(collector: Arc) -> Self { + Self { collector } + } +} + +impl Layer for TraceLayer { + type Service = TraceService; + + fn layer(&self, service: S) -> Self::Service { + TraceService { + service, + collector: Arc::clone(&self.collector), + } + } +} + +/// TraceService wraps an inner tower::Service and instruments its returned futures +#[derive(Debug, Clone)] +pub struct TraceService { + service: S, + collector: Arc, +} + +impl Service> for TraceService +where + S: Service, Response = Response>, + ResBody: http_body::Body, +{ + type Response = Response>; + type Error = S::Error; + type Future = TracedFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, mut request: Request) -> Self::Future { + let span = match SpanContext::from_headers(&self.collector, request.headers()) { + Ok(Some(ctx)) => { + let span = ctx.child("IOx"); + + // Add context to request for use by service handlers + request.extensions_mut().insert(span.ctx.clone()); + + // Create Span to use to instrument request + Some(EnteredSpan::new(span)) + } + Ok(None) => None, + Err(e) => { + error!(%e, "error extracting trace context from request"); + None + } + }; + + TracedFuture { + span, + inner: self.service.call(request), + } + } +} + +/// `TracedFuture` wraps a future returned by a `tower::Service` and +/// instruments the returned body if any +#[pin_project] +#[derive(Debug)] +pub struct TracedFuture { + span: Option>, + #[pin] + inner: F, +} + +impl Future for TracedFuture +where + F: Future, Error>>, + ResBody: http_body::Body, +{ + type Output = Result>, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result = ready!(self.as_mut().project().inner.poll(cx)); + if let Some(span) = self.as_mut().project().span.as_mut() { + match &result { + Ok(_) => span.event("request processed"), + Err(_) => span.error("error processing request"), + } + } + + match result { + Ok(response) => Poll::Ready(Ok(response.map(|body| TracedBody { + span: self.as_mut().project().span.take(), + inner: body, + }))), + Err(e) => Poll::Ready(Err(e)), + } + } +} + +/// `TracedBody` wraps a `http_body::Body` and instruments it +#[pin_project] +#[derive(Debug)] +pub struct TracedBody { + span: Option>, + #[pin] + inner: B, +} + +impl http_body::Body for TracedBody { + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let maybe_result = ready!(self.as_mut().project().inner.poll_data(cx)); + let result = match maybe_result { + Some(result) => result, + None => return Poll::Ready(None), + }; + + if let Some(span) = self.as_mut().project().span.as_mut() { + match &result { + Ok(_) => span.event("returned body data"), + Err(_) => span.error("eos getting body"), + } + } + Poll::Ready(Some(result)) + } + + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + // TODO: Classify response status and set SpanStatus + + let result = ready!(self.as_mut().project().inner.poll_trailers(cx)); + if let Some(span) = self.as_mut().project().span.as_mut() { + match &result { + Ok(_) => span.event("returned trailers"), + Err(_) => span.error("eos getting trailers"), + } + } + Poll::Ready(result) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } +}