diff --git a/Cargo.lock b/Cargo.lock index 28b48b7b3e..49cb080e8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3635,7 +3635,11 @@ version = "0.1.0" dependencies = [ "assert_matches", "async-trait", + "bytes", + "data_types", "dml", + "flate2", + "futures", "generated_types", "hyper", "metric", @@ -3643,7 +3647,11 @@ dependencies = [ "observability_deps", "parking_lot", "paste", + "serde", + "serde_urlencoded", "thiserror", + "time 0.1.0", + "tokio", "tonic", "trace", "workspace-hack", diff --git a/influxdb_iox/src/commands/run/router2.rs b/influxdb_iox/src/commands/run/router2.rs index 43a4a6c3b6..f7c61a54db 100644 --- a/influxdb_iox/src/commands/run/router2.rs +++ b/influxdb_iox/src/commands/run/router2.rs @@ -13,7 +13,10 @@ use crate::{ }, }; use observability_deps::tracing::*; -use router2::server::RouterServer; +use router2::{ + dml_handler::nop::NopDmlHandler, + server::{http::HttpDelegate, RouterServer}, +}; use thiserror::Error; #[derive(Debug, Error)] @@ -53,7 +56,11 @@ pub struct Config { pub async fn command(config: Config) -> Result<()> { let common_state = CommonServerState::from_config(config.run_config.clone())?; - let router_server = RouterServer::default(); + let http = HttpDelegate::new( + config.run_config.max_http_request_size, + NopDmlHandler::default(), + ); + let router_server = RouterServer::new(http, Default::default()); let server_type = Arc::new(RouterServerType::new(router_server, &common_state)); info!("starting router2"); diff --git a/influxdb_iox/src/influxdb_ioxd/server_type/router2.rs b/influxdb_iox/src/influxdb_ioxd/server_type/router2.rs index 7c37c64d82..a8562c1080 100644 --- a/influxdb_iox/src/influxdb_ioxd/server_type/router2.rs +++ b/influxdb_iox/src/influxdb_ioxd/server_type/router2.rs @@ -1,9 +1,12 @@ -use std::{fmt::Display, sync::Arc}; +use std::{ + fmt::{Debug, Display}, + sync::Arc, +}; use async_trait::async_trait; use hyper::{Body, Request, Response}; use metric::Registry; -use router2::server::RouterServer; +use router2::{dml_handler::DmlHandler, server::RouterServer}; use tokio_util::sync::CancellationToken; use trace::TraceCollector; @@ -14,14 +17,14 @@ use crate::influxdb_ioxd::{ }; #[derive(Debug)] -pub struct RouterServerType { - server: RouterServer, +pub struct RouterServerType { + server: RouterServer, shutdown: CancellationToken, trace_collector: Option>, } -impl RouterServerType { - pub fn new(server: RouterServer, common_state: &CommonServerState) -> Self { +impl RouterServerType { + pub fn new(server: RouterServer, common_state: &CommonServerState) -> Self { Self { server, shutdown: CancellationToken::new(), @@ -31,7 +34,10 @@ impl RouterServerType { } #[async_trait] -impl ServerType for RouterServerType { +impl ServerType for RouterServerType +where + D: DmlHandler + 'static, +{ type RouteError = IoxHttpErrorAdaptor; /// Return the [`metric::Registry`] used by the router. @@ -51,7 +57,11 @@ impl ServerType for RouterServerType { &self, req: Request, ) -> Result, Self::RouteError> { - self.server.http().route(req).map_err(IoxHttpErrorAdaptor) + self.server + .http() + .route(req) + .await + .map_err(IoxHttpErrorAdaptor) } /// Registers the services exposed by the router [`GrpcDelegate`] delegate. @@ -82,7 +92,7 @@ pub struct IoxHttpErrorAdaptor(router2::server::http::Error); impl Display for IoxHttpErrorAdaptor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) + Display::fmt(&self.0, f) } } diff --git a/router2/Cargo.toml b/router2/Cargo.toml index cc758a8c0b..b028e9acde 100644 --- a/router2/Cargo.toml +++ b/router2/Cargo.toml @@ -7,14 +7,22 @@ edition = "2021" [dependencies] async-trait = "0.1" +bytes = "1.1" +data_types = { path = "../data_types" } dml = { path = "../dml" } +flate2 = "1.0" +futures = "0.3.19" generated_types = { path = "../generated_types" } hyper = "0.14" metric = { path = "../metric" } mutable_batch_lp = { path = "../mutable_batch_lp" } observability_deps = { path = "../observability_deps" } parking_lot = "0.11" +serde = "1.0" +serde_urlencoded = "0.7" thiserror = "1.0" +time = { path = "../time" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } tonic = "0.6" trace = { path = "../trace/" } workspace-hack = { path = "../workspace-hack"} diff --git a/router2/src/server.rs b/router2/src/server.rs index d6b4e16a8c..376b8c3d06 100644 --- a/router2/src/server.rs +++ b/router2/src/server.rs @@ -2,6 +2,8 @@ use std::sync::Arc; +use crate::dml_handler::DmlHandler; + use self::{grpc::GrpcDelegate, http::HttpDelegate}; pub mod grpc; @@ -10,22 +12,22 @@ pub mod http; /// The [`RouterServer`] manages the lifecycle and contains all state for a /// `router2` server instance. #[derive(Debug, Default)] -pub struct RouterServer { +pub struct RouterServer { metrics: Arc, - http: HttpDelegate, + http: HttpDelegate, grpc: GrpcDelegate, } -impl RouterServer { - /// Get a reference to the router http delegate. - pub fn http(&self) -> &HttpDelegate { - &self.http - } - - /// Get a reference to the router grpc delegate. - pub fn grpc(&self) -> &GrpcDelegate { - &self.grpc +impl RouterServer { + /// Initialise a new [`RouterServer`] using the provided HTTP and gRPC + /// handlers. + pub fn new(http: HttpDelegate, grpc: GrpcDelegate) -> Self { + Self { + metrics: Default::default(), + http, + grpc, + } } /// Return the [`metric::Registry`] used by the router. @@ -33,3 +35,18 @@ impl RouterServer { Arc::clone(&self.metrics) } } + +impl RouterServer +where + D: DmlHandler, +{ + /// Get a reference to the router http delegate. + pub fn http(&self) -> &HttpDelegate { + &self.http + } + + /// Get a reference to the router grpc delegate. + pub fn grpc(&self) -> &GrpcDelegate { + &self.grpc + } +} diff --git a/router2/src/server/http.rs b/router2/src/server/http.rs index 9432f121ed..11e6b078e2 100644 --- a/router2/src/server/http.rs +++ b/router2/src/server/http.rs @@ -1,14 +1,62 @@ //! HTTP service implementations for `router2`. -use hyper::{Body, Request, Response, StatusCode}; +use std::str::Utf8Error; + +use bytes::{Bytes, BytesMut}; +use data_types::names::{org_and_bucket_to_database, OrgBucketMappingError}; +use dml::{DmlMeta, DmlWrite}; +use futures::StreamExt; +use hyper::{header::CONTENT_ENCODING, Body, Method, Request, Response, StatusCode}; +use observability_deps::tracing::*; +use serde::Deserialize; use thiserror::Error; +use time::{SystemProvider, TimeProvider}; +use trace::ctx::SpanContext; + +use crate::dml_handler::{DmlError, DmlHandler}; /// Errors returned by the `router2` HTTP request handler. #[derive(Debug, Error)] pub enum Error { /// The requested path has no registered handler. #[error("not found")] - NotFound, + NoHandler, + + /// An error with the org/bucket in the request. + #[error(transparent)] + InvalidOrgBucket(#[from] OrgBucketError), + + /// The request body content is not valid utf8. + #[error("body content is not valid utf8: {0}")] + NonUtf8Body(Utf8Error), + + /// The `Content-Encoding` header is invalid and cannot be read. + #[error("invalid content-encoding header: {0}")] + NonUtf8ContentHeader(hyper::header::ToStrError), + + /// The specified `Content-Encoding` is not acceptable. + #[error("unacceptable content-encoding: {0}")] + InvalidContentEncoding(String), + + /// The client disconnected. + #[error("client disconnected")] + ClientHangup(hyper::Error), + + /// The client sent a request body that exceeds the configured maximum. + #[error("max request size ({0} bytes) exceeded")] + RequestSizeExceeded(usize), + + /// Decoding a gzip-compressed stream of data failed. + #[error("error decoding gzip stream: {0}")] + InvalidGzip(std::io::Error), + + /// Failure to decode the provided line protocol. + #[error("failed to parse line protocol: {0}")] + ParseLineProtocol(mutable_batch_lp::Error), + + /// An error returned from the [`DmlHandler`]. + #[error("dml handler error: {0}")] + DmlHandler(#[from] DmlError), } impl Error { @@ -16,11 +64,65 @@ impl Error { /// the end user. pub fn as_status_code(&self) -> StatusCode { match self { - Error::NotFound => StatusCode::NOT_FOUND, + Error::NoHandler | Error::DmlHandler(DmlError::DatabaseNotFound(_)) => { + StatusCode::NOT_FOUND + } + Error::InvalidOrgBucket(_) => StatusCode::BAD_REQUEST, + Error::ClientHangup(_) => StatusCode::BAD_REQUEST, + Error::InvalidGzip(_) => StatusCode::BAD_REQUEST, + Error::NonUtf8ContentHeader(_) => StatusCode::BAD_REQUEST, + Error::NonUtf8Body(_) => StatusCode::BAD_REQUEST, + Error::InvalidContentEncoding(_) => { + // https://www.rfc-editor.org/rfc/rfc7231#section-6.5.13 + StatusCode::UNSUPPORTED_MEDIA_TYPE + } + Error::RequestSizeExceeded(_) => StatusCode::PAYLOAD_TOO_LARGE, + Error::ParseLineProtocol(_) => StatusCode::BAD_REQUEST, + Error::DmlHandler(DmlError::Internal(_)) => StatusCode::INTERNAL_SERVER_ERROR, } } } +/// Errors returned when decoding the organisation / bucket information from a +/// HTTP request and deriving the database name from it. +#[derive(Debug, Error)] +pub enum OrgBucketError { + /// The request contains no org/bucket destination information. + #[error("no org/bucket destination provided")] + NotSpecified, + + /// The request contains invalid parameters. + #[error("failed to deserialise org/bucket in request: {0}")] + DecodeFail(#[from] serde::de::value::Error), + + /// The provided org/bucket could not be converted into a database name. + #[error(transparent)] + MappingFail(#[from] OrgBucketMappingError), +} + +#[derive(Debug, Deserialize)] +/// Org & bucket identifiers for a DML operation. +pub struct OrgBucketInfo { + org: String, + bucket: String, +} + +impl TryFrom<&Request> for OrgBucketInfo { + type Error = OrgBucketError; + + fn try_from(req: &Request) -> Result { + let query = req.uri().query().ok_or(OrgBucketError::NotSpecified)?; + let got: OrgBucketInfo = serde_urlencoded::from_str(query)?; + + // An empty org or bucket is not acceptable. + if got.org.is_empty() || got.bucket.is_empty() { + return Err(OrgBucketError::NotSpecified); + } + + Ok(got) + } +} + /// This type is responsible for servicing requests to the `router2` HTTP /// endpoint. /// @@ -28,12 +130,377 @@ impl Error { /// server runner framework takes care of implementing the heath endpoint, /// metrics, pprof, etc. #[derive(Debug, Default)] -pub struct HttpDelegate; +pub struct HttpDelegate { + max_request_bytes: usize, + time_provider: T, + dml_handler: D, +} -impl HttpDelegate { - /// Routes `req` to the appropriate handler, if any, returning the handler - /// response. - pub fn route(&self, _req: Request) -> Result, Error> { - unimplemented!() +impl HttpDelegate { + /// Initialise a new [`HttpDelegate`] passing valid requests to the + /// specified `dml_handler`. + /// + /// HTTP request bodies are limited to `max_request_bytes` in size, + /// returning an error if exceeded. + pub fn new(max_request_bytes: usize, dml_handler: D) -> Self { + Self { + max_request_bytes, + time_provider: SystemProvider::default(), + dml_handler, + } } } + +impl HttpDelegate +where + D: DmlHandler, + T: TimeProvider, +{ + /// Routes `req` to the appropriate handler, if any, returning the handler + /// response. + pub async fn route(&self, req: Request) -> Result, Error> { + match (req.method(), req.uri().path()) { + (&Method::POST, "/api/v2/write") => self.write_handler(req).await, + (&Method::POST, "/api/v2/delete") => self.delete_handler(req).await, + _ => return Err(Error::NoHandler), + } + .map(|_| response_no_content()) + } + + async fn write_handler(&self, req: Request) -> Result<(), Error> { + let span_ctx: Option = req.extensions().get().cloned(); + + let account = OrgBucketInfo::try_from(&req)?; + let db_name = org_and_bucket_to_database(&account.org, &account.bucket) + .map_err(OrgBucketError::MappingFail)?; + + trace!(org=%account.org, bucket=%account.bucket, db_name=%db_name, "processing write request"); + + // Read the HTTP body and convert it to a str. + let body = self.read_body(req).await?; + let body = std::str::from_utf8(&body).map_err(Error::NonUtf8Body)?; + + // The time, in nanoseconds since the epoch, to assign to any points that don't + // contain a timestamp + let default_time = self.time_provider.now().timestamp_nanos(); + + let (tables, stats) = match mutable_batch_lp::lines_to_batches_stats(body, default_time) { + Ok(v) => v, + Err(mutable_batch_lp::Error::EmptyPayload) => { + debug!("nothing to write"); + return Ok(()); + } + Err(e) => return Err(Error::ParseLineProtocol(e)), + }; + + let op = DmlWrite::new(tables, DmlMeta::unsequenced(span_ctx)); + self.dml_handler + .dispatch(db_name, op, stats, body.len()) + .await?; + + Ok(()) + } + + async fn delete_handler(&self, _req: Request) -> Result<(), Error> { + unimplemented!() + } + + /// Parse the request's body into raw bytes, applying the configured size + /// limits and decoding any content encoding. + async fn read_body(&self, req: hyper::Request) -> Result { + let encoding = req + .headers() + .get(&CONTENT_ENCODING) + .map(|v| v.to_str().map_err(Error::NonUtf8ContentHeader)) + .transpose()?; + let ungzip = match encoding { + None => false, + Some("gzip") => true, + Some(v) => return Err(Error::InvalidContentEncoding(v.to_string())), + }; + + let mut payload = req.into_body(); + + let mut body = BytesMut::new(); + while let Some(chunk) = payload.next().await { + let chunk = chunk.map_err(Error::ClientHangup)?; + // limit max size of in-memory payload + if (body.len() + chunk.len()) > self.max_request_bytes { + return Err(Error::RequestSizeExceeded(self.max_request_bytes)); + } + body.extend_from_slice(&chunk); + } + let body = body.freeze(); + + // If the body is not compressed, return early. + if !ungzip { + return Ok(body); + } + + // Unzip the gzip-encoded content + use std::io::Read; + let decoder = flate2::read::GzDecoder::new(&body[..]); + + // Read at most max_request_bytes bytes to prevent a decompression bomb + // based DoS. + // + // In order to detect if the entire stream ahs been read, or truncated, + // read an extra byte beyond the limit and check the resulting data + // length - see the max_request_size_truncation test. + let mut decoder = decoder.take(self.max_request_bytes as u64 + 1); + let mut decoded_data = Vec::new(); + decoder + .read_to_end(&mut decoded_data) + .map_err(Error::InvalidGzip)?; + + // If the length is max_size+1, the body is at least max_size+1 bytes in + // length, and possibly longer, but truncated. + if decoded_data.len() > self.max_request_bytes { + return Err(Error::RequestSizeExceeded(self.max_request_bytes)); + } + + Ok(decoded_data.into()) + } +} + +fn response_no_content() -> Response { + Response::builder() + .status(StatusCode::NO_CONTENT) + .body(Body::empty()) + .unwrap() +} + +#[cfg(test)] +mod tests { + use std::{io::Write, iter, sync::Arc}; + + use assert_matches::assert_matches; + use dml::DmlOperation; + use flate2::{write::GzEncoder, Compression}; + use hyper::header::HeaderValue; + + use crate::dml_handler::mock::{MockDmlHandler, MockDmlHandlerCall}; + + use super::*; + + const MAX_BYTES: usize = 1024; + + // Generate two write handler tests - one for a plain request and one with a + // gzip-encoded body (and appropriate header), asserting the handler return + // value & write op. + macro_rules! test_write_handler { + ( + $name:ident, + query_string = $query_string:expr, // Request URI query string + body = $body:expr, // Request body content + dml_handler = $dml_handler:expr, // DML handler response (if called) + want_write_db = $want_write_db:expr, // Expected write DB name (empty if no write) + want_return = $($want_return:tt )+ // Expected HTTP response + ) => { + // Generate the two test cases by feed the same inputs, but varying + // the encoding. + test_write_handler!( + $name, + encoding=plain, + query_string = $query_string, + body = $body, + dml_handler = $dml_handler, + want_write_db = $want_write_db, + want_return = $($want_return)+ + ); + test_write_handler!( + $name, + encoding=gzip, + query_string = $query_string, + body = $body, + dml_handler = $dml_handler, + want_write_db = $want_write_db, + want_return = $($want_return)+ + ); + }; + + // Actual test body generator. + ( + $name:ident, + encoding = $encoding:tt, + query_string = $query_string:expr, + body = $body:expr, + dml_handler = $dml_handler:expr, + want_write_db = $want_write_db:expr, + want_return = $($want_return:tt )+ + ) => { + paste::paste! { + #[tokio::test] + async fn []() { + let body = $body; + let want_body_len = body.len(); + + // Optionally generate a fragment of code to encode the body + let body = test_write_handler!(encoding=$encoding, body); + + #[allow(unused_mut)] + let mut request = Request::builder() + .uri(format!("https://bananas.example/api/v2/write{}", $query_string)) + .method("POST") + .body(Body::from(body)) + .unwrap(); + + // Optionally modify request to account for the desired + // encoding + test_write_handler!(encoding_header=$encoding, request); + + let dml_handler = Arc::new(MockDmlHandler::default().with_dispatch_return($dml_handler)); + let delegate = HttpDelegate::new(MAX_BYTES, Arc::clone(&dml_handler)); + + let got = delegate.route(request).await; + assert_matches!(got, $($want_return)+); + + let calls = dml_handler.calls(); + if !$want_write_db.is_empty() { + assert_eq!(calls.len(), 1); + + // Validate the write op + let op = assert_matches!(&calls[0], MockDmlHandlerCall::Dispatch{ db_name, op, body_len, .. } => { + assert_eq!(db_name, $want_write_db); + assert_eq!(*body_len, want_body_len); + op + }); + assert_matches!(op, DmlOperation::Write(_)); + } else { + assert!(calls.is_empty()); + } + } + } + }; + (encoding=plain, $body:ident) => { + $body + }; + (encoding=gzip, $body:ident) => {{ + // Apply gzip compression to the body + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(&$body).unwrap(); + e.finish().expect("failed to compress test body") + }}; + (encoding_header=plain, $request:ident) => {}; + (encoding_header=gzip, $request:ident) => {{ + // Set the gzip content encoding + $request + .headers_mut() + .insert(CONTENT_ENCODING, HeaderValue::from_static("gzip")); + }}; + } + + test_write_handler!( + ok, + query_string = "?org=bananas&bucket=test", + body = "platanos,tag1=A,tag2=B val=42i 123456".as_bytes(), + dml_handler = [Ok(())], + want_write_db = "bananas_test", + want_return = Ok(r) => { + assert_eq!(r.status(), StatusCode::NO_CONTENT); + } + ); + + test_write_handler!( + no_query_params, + query_string = "", + body = "platanos,tag1=A,tag2=B val=42i 123456".as_bytes(), + dml_handler = [Ok(())], + want_write_db = "", // None + want_return = Err(Error::InvalidOrgBucket(OrgBucketError::NotSpecified)) + ); + + test_write_handler!( + no_org_bucket, + query_string = "?", + body = "platanos,tag1=A,tag2=B val=42i 123456".as_bytes(), + dml_handler = [Ok(())], + want_write_db = "", // None + want_return = Err(Error::InvalidOrgBucket(OrgBucketError::DecodeFail(_))) + ); + + test_write_handler!( + empty_org_bucket, + query_string = "?org=&bucket=", + body = "platanos,tag1=A,tag2=B val=42i 123456".as_bytes(), + dml_handler = [Ok(())], + want_write_db = "", // None + want_return = Err(Error::InvalidOrgBucket(OrgBucketError::NotSpecified)) + ); + + test_write_handler!( + invalid_org_bucket, + query_string = format!( + "?org=test&bucket={}", + iter::repeat("A").take(1000).collect::() + ), + body = "platanos,tag1=A,tag2=B val=42i 123456".as_bytes(), + dml_handler = [Ok(())], + want_write_db = "", // None + want_return = Err(Error::InvalidOrgBucket(OrgBucketError::MappingFail(_))) + ); + + test_write_handler!( + invalid_line_protocol, + query_string = "?org=bananas&bucket=test", + body = "not line protocol".as_bytes(), + dml_handler = [Ok(())], + want_write_db = "", // None + want_return = Err(Error::ParseLineProtocol(_)) + ); + + test_write_handler!( + non_utf8_body, + query_string = "?org=bananas&bucket=test", + body = vec![0xc3, 0x28], + dml_handler = [Ok(())], + want_write_db = "", // None + want_return = Err(Error::NonUtf8Body(_)) + ); + + test_write_handler!( + max_request_size_truncation, + query_string = "?org=bananas&bucket=test", + body = { + // Generate a LP string in the form of: + // + // bananas,A=AAAAAAAAAA(repeated)... B=42 + // ^ + // | + // MAX_BYTES boundary + // + // So that reading MAX_BYTES number of bytes produces the string: + // + // bananas,A=AAAAAAAAAA(repeated)... + // + // Effectively trimming off the " B=42" suffix. + let body = "bananas,A="; + iter::once(body) + .chain(iter::repeat("A").take(MAX_BYTES - body.len())) + .chain(iter::once(" B=42\n")) + .flat_map(|s| s.bytes()) + .collect::>() + }, + dml_handler = [Ok(())], + want_write_db = "", // None + want_return = Err(Error::RequestSizeExceeded(_)) + ); + + test_write_handler!( + db_not_found, + query_string = "?org=bananas&bucket=test", + body = "platanos,tag1=A,tag2=B val=42i 123456".as_bytes(), + dml_handler = [Err(DmlError::DatabaseNotFound("bananas_test".to_string()))], + want_write_db = "bananas_test", + want_return = Err(Error::DmlHandler(DmlError::DatabaseNotFound(_))) + ); + + test_write_handler!( + dml_handler_error, + query_string = "?org=bananas&bucket=test", + body = "platanos,tag1=A,tag2=B val=42i 123456".as_bytes(), + dml_handler = [Err(DmlError::Internal("💣".into()))], + want_write_db = "bananas_test", + want_return = Err(Error::DmlHandler(DmlError::Internal(_))) + ); +}