From c4d651fbd16258a1aab50335b6ebdba5973d5b8e Mon Sep 17 00:00:00 2001 From: Trevor Hilton Date: Fri, 8 Mar 2024 14:18:17 -0500 Subject: [PATCH] feat: implement `Authorizer` to authorize all HTTP requests (#24738) * feat: add `Authorizer` impls to authz REST and gRPC This adds two new Authorizer implementations to Edge: Default and AllOrNothing, which will provide the two auth options for Edge. Both gRPC requests and HTTP REST request will be authorized by the same Authorizer implementation. The SHA512 digest action was moved into the `Authorizer` impl. * feat: add `ServerBuilder` to construct `Server A builder was added to the Server in this commit, as part of an attempt to get the server creation to be more modular. * refactor: use test server fixture in auth e2e test Refactored the `auth` integration test in `influxdb3` to use the `TestServer` fixture; part of this involved extending the fixture to be configurable, so that the `TestServer` can be spun up with an auth token. * test: add test for authorized gRPC A new end-to-end test, auth_grpc, was added to check that authorization is working with the influxdb3 Flight service. --- Cargo.lock | 1 + influxdb3/Cargo.toml | 1 + influxdb3/src/commands/serve.rs | 38 ++++--- influxdb3/tests/server/auth.rs | 192 ++++++++++++++++++++------------ influxdb3/tests/server/main.rs | 52 ++++++++- influxdb3_server/src/auth.rs | 58 ++++++++++ influxdb3_server/src/builder.rs | 112 +++++++++++++++++++ influxdb3_server/src/http.rs | 91 +++++++++------ influxdb3_server/src/lib.rs | 124 +++++++-------------- 9 files changed, 465 insertions(+), 204 deletions(-) create mode 100644 influxdb3_server/src/auth.rs create mode 100644 influxdb3_server/src/builder.rs diff --git a/Cargo.lock b/Cargo.lock index 87de480c3f..b59bff3428 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2469,6 +2469,7 @@ dependencies = [ "arrow-flight", "arrow_util", "assert_cmd", + "authz", "backtrace", "base64 0.22.0", "clap", diff --git a/influxdb3/Cargo.toml b/influxdb3/Cargo.toml index b8bead478d..e62302c45f 100644 --- a/influxdb3/Cargo.toml +++ b/influxdb3/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dependencies] # Core Crates +authz.workspace = true clap_blocks.workspace = true iox_query.workspace = true iox_time.workspace = true diff --git a/influxdb3/src/commands/serve.rs b/influxdb3/src/commands/serve.rs index 03cc78fe82..1cdc4c0e86 100644 --- a/influxdb3/src/commands/serve.rs +++ b/influxdb3/src/commands/serve.rs @@ -7,7 +7,10 @@ use clap_blocks::{ object_store::{make_object_store, ObjectStoreConfig}, socket_addr::SocketAddr, }; -use influxdb3_server::{query_executor::QueryExecutorImpl, serve, CommonServerState, Server}; +use influxdb3_server::{ + auth::AllOrNothingAuthorizer, builder::ServerBuilder, query_executor::QueryExecutorImpl, serve, + CommonServerState, +}; use influxdb3_write::persister::PersisterImpl; use influxdb3_write::wal::WalImpl; use influxdb3_write::write_buffer::WriteBufferImpl; @@ -51,6 +54,9 @@ pub enum Error { #[error("Write buffer error: {0}")] WriteBuffer(#[from] influxdb3_write::write_buffer::Error), + + #[error("invalid token: {0}")] + InvalidToken(#[from] hex::FromHexError), } pub type Result = std::result::Result; @@ -238,32 +244,38 @@ pub async fn command(config: Config) -> Result<()> { trace_exporter, trace_header_parser, *config.http_bind_address, - config.bearer_token, )?; - let persister = PersisterImpl::new(Arc::clone(&object_store)); + let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store))); let wal: Option> = config .wal_directory .map(|dir| WalImpl::new(dir).map(Arc::new)) .transpose()?; // TODO: the next segment ID should be loaded from the persister - let write_buffer = Arc::new(WriteBufferImpl::new(Arc::new(persister), wal).await?); - let query_executor = QueryExecutorImpl::new( + let write_buffer = Arc::new(WriteBufferImpl::new(Arc::clone(&persister), wal).await?); + let query_executor = Arc::new(QueryExecutorImpl::new( write_buffer.catalog(), Arc::clone(&write_buffer), Arc::clone(&exec), Arc::clone(&metrics), Arc::new(config.datafusion_config), 10, - ); + )); let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store))); - let server = Server::new( - common_state, - persister, - Arc::clone(&write_buffer), - Arc::new(query_executor), - config.max_http_request_size, - ); + + let builder = ServerBuilder::new(common_state) + .max_request_size(config.max_http_request_size) + .write_buffer(write_buffer) + .query_executor(query_executor) + .persister(persister); + + let server = if let Some(token) = config.bearer_token.map(hex::decode).transpose()? { + builder + .authorizer(Arc::new(AllOrNothingAuthorizer::new(token))) + .build() + } else { + builder.build() + }; serve(server, frontend_shutdown).await?; Ok(()) diff --git a/influxdb3/tests/server/auth.rs b/influxdb3/tests/server/auth.rs index fec2f393fc..e04eb1c4cc 100644 --- a/influxdb3/tests/server/auth.rs +++ b/influxdb3/tests/server/auth.rs @@ -1,78 +1,31 @@ -use parking_lot::Mutex; +use arrow_flight::error::FlightError; +use arrow_util::assert_batches_sorted_eq; +use influxdb3_client::Precision; use reqwest::StatusCode; -use std::env; -use std::mem; -use std::panic; -use std::process::Child; -use std::process::Command; -use std::process::Stdio; -struct DropCommand { - cmd: Option, -} - -impl DropCommand { - const fn new(cmd: Child) -> Self { - Self { cmd: Some(cmd) } - } - - fn kill(&mut self) { - let mut cmd = self.cmd.take().unwrap(); - cmd.kill().unwrap(); - mem::drop(cmd); - } -} - -static COMMAND: Mutex> = parking_lot::const_mutex(None); +use crate::{collect_stream, TestServer}; #[tokio::test] async fn auth() { const HASHED_TOKEN: &str = "5315f0c4714537843face80cca8c18e27ce88e31e9be7a5232dc4dc8444f27c0227a9bd64831d3ab58f652bd0262dd8558dd08870ac9e5c650972ce9e4259439"; const TOKEN: &str = "apiv3_mp75KQAhbqv0GeQXk8MPuZ3ztaLEaR5JzS8iifk1FwuroSVyXXyrJK1c4gEr1kHkmbgzDV-j3MvQpaIMVJBAiA"; - // The binary is made before testing so we have access to it - let bin_path = { - let mut bin_path = env::current_exe().unwrap(); - bin_path.pop(); - bin_path.pop(); - bin_path.join("influxdb3") - }; - let server = DropCommand::new( - Command::new(bin_path) - .args([ - "serve", - "--object-store", - "memory", - "--bearer-token", - HASHED_TOKEN, - ]) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .expect("Was able to spawn a server"), - ); - *COMMAND.lock() = Some(server); - - let current_hook = panic::take_hook(); - panic::set_hook(Box::new(move |info| { - COMMAND.lock().take().unwrap().kill(); - current_hook(info); - })); + let server = TestServer::configure() + .auth_token(HASHED_TOKEN, TOKEN) + .spawn() + .await; let client = reqwest::Client::new(); - - // Wait for the server to come up - while client - .get("http://127.0.0.1:8181/health") - .bearer_auth(TOKEN) - .send() - .await - .is_err() - {} + let base = server.client_addr(); + let write_lp_url = format!("{base}/api/v3/write_lp"); + let write_lp_params = [("db", "foo")]; + let query_sql_url = format!("{base}/api/v3/query_sql"); + let query_sql_params = [("db", "foo"), ("q", "select * from cpu")]; assert_eq!( client - .post("http://127.0.0.1:8181/api/v3/write_lp?db=foo") + .post(&write_lp_url) + .query(&write_lp_params) .body("cpu,host=a val=1i 123") .send() .await @@ -82,7 +35,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .send() .await .unwrap() @@ -91,7 +45,8 @@ async fn auth() { ); assert_eq!( client - .post("http://127.0.0.1:8181/api/v3/write_lp?db=foo") + .post(&write_lp_url) + .query(&write_lp_params) .body("cpu,host=a val=1i 123") .bearer_auth(TOKEN) .send() @@ -102,7 +57,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .bearer_auth(TOKEN) .send() .await @@ -114,7 +70,8 @@ async fn auth() { // Test that there is an extra string after the token foo assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .header("Authorization", format!("Bearer {TOKEN} whee")) .send() .await @@ -124,7 +81,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .header("Authorization", format!("bearer {TOKEN}")) .send() .await @@ -134,7 +92,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .header("Authorization", "Bearer") .send() .await @@ -144,7 +103,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .header("auth", format!("Bearer {TOKEN}")) .send() .await @@ -152,5 +112,97 @@ async fn auth() { .status(), StatusCode::UNAUTHORIZED ); - COMMAND.lock().take().unwrap().kill(); +} + +#[tokio::test] +async fn auth_grpc() { + const HASHED_TOKEN: &str = "5315f0c4714537843face80cca8c18e27ce88e31e9be7a5232dc4dc8444f27c0227a9bd64831d3ab58f652bd0262dd8558dd08870ac9e5c650972ce9e4259439"; + const TOKEN: &str = "apiv3_mp75KQAhbqv0GeQXk8MPuZ3ztaLEaR5JzS8iifk1FwuroSVyXXyrJK1c4gEr1kHkmbgzDV-j3MvQpaIMVJBAiA"; + + let server = TestServer::configure() + .auth_token(HASHED_TOKEN, TOKEN) + .spawn() + .await; + + // Write some data to the server, this will be authorized through the HTTP API + server + .write_lp_to_db( + "foo", + "cpu,host=s1,region=us-east usage=0.9 1\n\ + cpu,host=s1,region=us-east usage=0.89 2\n\ + cpu,host=s1,region=us-east usage=0.85 3", + Precision::Nanosecond, + ) + .await + .unwrap(); + + // Check that with a valid authorization header, it succeeds: + for header in ["authorization", "Authorization"] { + // Spin up a FlightSQL client + let mut client = server.flight_sql_client("foo").await; + + // Set the authorization header on the client: + client + .add_header(header, &format!("Bearer {TOKEN}")) + .unwrap(); + + // Make the query again, this time it should work: + let response = client.query("SELECT * FROM cpu").await.unwrap(); + let batches = collect_stream(response).await; + assert_batches_sorted_eq!( + [ + "+------+---------+--------------------------------+-------+", + "| host | region | time | usage |", + "+------+---------+--------------------------------+-------+", + "| s1 | us-east | 1970-01-01T00:00:00.000000001Z | 0.9 |", + "| s1 | us-east | 1970-01-01T00:00:00.000000002Z | 0.89 |", + "| s1 | us-east | 1970-01-01T00:00:00.000000003Z | 0.85 |", + "+------+---------+--------------------------------+-------+", + ], + &batches + ); + } + + // Check that without providing an Authentication header, it gives back + // an Unauthenticated error: + { + let mut client = server.flight_sql_client("foo").await; + let error = client.query("SELECT * FROM cpu").await.unwrap_err(); + assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated)); + } + + // Create some new clients that set the authorization header incorrectly to + // ensure errors are returned: + + // Mispelled "Bearer" + { + let mut client = server.flight_sql_client("foo").await; + client + .add_header("authorization", &format!("bearer {TOKEN}")) + .unwrap(); + let error = client.query("SELECT * FROM cpu").await.unwrap_err(); + assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated)); + } + + // Invalid token, this actually gives Permission denied + { + let mut client = server.flight_sql_client("foo").await; + client + .add_header("authorization", "Bearer invalid-token") + .unwrap(); + let error = client.query("SELECT * FROM cpu").await.unwrap_err(); + assert!( + matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::PermissionDenied) + ); + } + + // Mispelled header key + { + let mut client = server.flight_sql_client("foo").await; + client + .add_header("auth", &format!("Bearer {TOKEN}")) + .unwrap(); + let error = client.query("SELECT * FROM cpu").await.unwrap_err(); + assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated)); + } } diff --git a/influxdb3/tests/server/main.rs b/influxdb3/tests/server/main.rs index a99147b3a7..77a2fb903c 100644 --- a/influxdb3/tests/server/main.rs +++ b/influxdb3/tests/server/main.rs @@ -17,8 +17,43 @@ mod flight; mod limits; mod query; +/// Configuration for a [`TestServer`] +#[derive(Debug, Default)] +pub struct TestConfig { + auth_token: Option<(String, String)>, +} + +impl TestConfig { + /// Set the auth token for this [`TestServer`] + pub fn auth_token, R: Into>( + mut self, + hashed_token: S, + raw_token: R, + ) -> Self { + self.auth_token = Some((hashed_token.into(), raw_token.into())); + self + } + + /// Spawn a new [`TestServer`] with this configuration + /// + /// This will run the `influxdb3 serve` command, and bind its HTTP + /// address to a random port on localhost. + pub async fn spawn(self) -> TestServer { + TestServer::spawn_inner(self).await + } + + fn as_args(&self) -> Vec<&str> { + let mut args = vec![]; + if let Some((token, _)) = &self.auth_token { + args.append(&mut vec!["--bearer-token", token]); + } + args + } +} + /// A running instance of the `influxdb3 serve` process pub struct TestServer { + config: TestConfig, bind_addr: SocketAddr, server_process: Child, http_client: reqwest::Client, @@ -30,19 +65,29 @@ impl TestServer { /// This will run the `influxdb3 serve` command, and bind its HTTP /// address to a random port on localhost. pub async fn spawn() -> Self { + Self::spawn_inner(Default::default()).await + } + + /// Configure a [`TestServer`] before spawning + pub fn configure() -> TestConfig { + TestConfig::default() + } + + async fn spawn_inner(config: TestConfig) -> Self { let bind_addr = get_local_bind_addr(); let mut command = Command::cargo_bin("influxdb3").expect("create the influxdb3 command"); let command = command .arg("serve") .args(["--http-bind", &bind_addr.to_string()]) .args(["--object-store", "memory"]) - // TODO - other configuration can be passed through + .args(config.as_args()) .stdout(Stdio::null()) .stderr(Stdio::null()); let server_process = command.spawn().expect("spawn the influxdb3 server process"); let server = Self { + config, bind_addr, server_process, http_client: reqwest::Client::new(), @@ -111,7 +156,10 @@ impl TestServer { lp: impl ToString, precision: Precision, ) -> Result<(), influxdb3_client::Error> { - let client = influxdb3_client::Client::new(self.client_addr()).unwrap(); + let mut client = influxdb3_client::Client::new(self.client_addr()).unwrap(); + if let Some((_, token)) = &self.config.auth_token { + client = client.with_auth_token(token); + } client .api_v3_write_lp(database) .body(lp.to_string()) diff --git a/influxdb3_server/src/auth.rs b/influxdb3_server/src/auth.rs new file mode 100644 index 0000000000..1954c6510f --- /dev/null +++ b/influxdb3_server/src/auth.rs @@ -0,0 +1,58 @@ +use async_trait::async_trait; +use authz::{Authorizer, Error, Permission}; +use observability_deps::tracing::{debug, warn}; +use sha2::{Digest, Sha512}; + +/// An [`Authorizer`] implementation that will grant access to all +/// requests that provide `token` +#[derive(Debug)] +pub struct AllOrNothingAuthorizer { + token: Vec, +} + +impl AllOrNothingAuthorizer { + pub fn new(token: Vec) -> Self { + Self { token } + } +} + +#[async_trait] +impl Authorizer for AllOrNothingAuthorizer { + async fn permissions( + &self, + token: Option>, + perms: &[Permission], + ) -> Result, Error> { + debug!(?perms, "requesting permissions"); + let provided = token.as_deref().ok_or(Error::NoToken)?; + if Sha512::digest(provided)[..] == self.token { + warn!("invalid token provided"); + Ok(perms.to_vec()) + } else { + Err(Error::InvalidToken) + } + } + + async fn probe(&self) -> Result<(), Error> { + Ok(()) + } +} + +/// The defult [`Authorizer`] implementation that will authorize all requests +#[derive(Debug)] +pub struct DefaultAuthorizer; + +#[async_trait] +impl Authorizer for DefaultAuthorizer { + async fn permissions( + &self, + _token: Option>, + perms: &[Permission], + ) -> Result, Error> { + Ok(perms.to_vec()) + } + + async fn probe(&self) -> Result<(), Error> { + Ok(()) + } +} diff --git a/influxdb3_server/src/builder.rs b/influxdb3_server/src/builder.rs new file mode 100644 index 0000000000..4fba3796ef --- /dev/null +++ b/influxdb3_server/src/builder.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use authz::Authorizer; + +use crate::{auth::DefaultAuthorizer, http::HttpApi, CommonServerState, Server}; + +#[derive(Debug)] +pub struct ServerBuilder { + common_state: CommonServerState, + max_request_size: usize, + write_buffer: W, + query_executor: Q, + persister: P, + authorizer: Arc, +} + +impl ServerBuilder { + pub fn new(common_state: CommonServerState) -> Self { + Self { + common_state, + max_request_size: usize::MAX, + write_buffer: NoWriteBuf, + query_executor: NoQueryExec, + persister: NoPersister, + authorizer: Arc::new(DefaultAuthorizer), + } + } +} + +impl ServerBuilder { + pub fn max_request_size(mut self, max_request_size: usize) -> Self { + self.max_request_size = max_request_size; + self + } + + pub fn authorizer(mut self, a: Arc) -> Self { + self.authorizer = a; + self + } +} + +#[derive(Debug)] +pub struct NoWriteBuf; +#[derive(Debug)] +pub struct WithWriteBuf(Arc); +#[derive(Debug)] +pub struct NoQueryExec; +#[derive(Debug)] +pub struct WithQueryExec(Arc); +#[derive(Debug)] +pub struct NoPersister; +#[derive(Debug)] +pub struct WithPersister

(Arc

); + +impl ServerBuilder { + pub fn write_buffer(self, wb: Arc) -> ServerBuilder, Q, P> { + ServerBuilder { + common_state: self.common_state, + max_request_size: self.max_request_size, + write_buffer: WithWriteBuf(wb), + query_executor: self.query_executor, + persister: self.persister, + authorizer: self.authorizer, + } + } +} + +impl ServerBuilder { + pub fn query_executor(self, qe: Arc) -> ServerBuilder, P> { + ServerBuilder { + common_state: self.common_state, + max_request_size: self.max_request_size, + write_buffer: self.write_buffer, + query_executor: WithQueryExec(qe), + persister: self.persister, + authorizer: self.authorizer, + } + } +} + +impl ServerBuilder { + pub fn persister

(self, p: Arc

) -> ServerBuilder> { + ServerBuilder { + common_state: self.common_state, + max_request_size: self.max_request_size, + write_buffer: self.write_buffer, + query_executor: self.query_executor, + persister: WithPersister(p), + authorizer: self.authorizer, + } + } +} + +impl ServerBuilder, WithQueryExec, WithPersister

> { + pub fn build(self) -> Server { + let persister = Arc::clone(&self.persister.0); + let authorizer = Arc::clone(&self.authorizer); + let http = Arc::new(HttpApi::new( + self.common_state.clone(), + Arc::clone(&self.write_buffer.0), + Arc::clone(&self.query_executor.0), + self.max_request_size, + Arc::clone(&authorizer), + )); + Server { + common_state: self.common_state, + http, + persister, + authorizer, + } + } +} diff --git a/influxdb3_server/src/http.rs b/influxdb3_server/src/http.rs index cd48f0cd3e..55a4559755 100644 --- a/influxdb3_server/src/http.rs +++ b/influxdb3_server/src/http.rs @@ -4,7 +4,7 @@ use crate::{query_executor, QueryKind}; use crate::{CommonServerState, QueryExecutor}; use arrow::record_batch::RecordBatch; use arrow::util::pretty; -use authz::http::AuthorizationHeaderExtension; +use authz::Authorizer; use bytes::{Bytes, BytesMut}; use data_types::NamespaceName; use datafusion::error::DataFusionError; @@ -29,8 +29,6 @@ use observability_deps::tracing::{debug, error, info}; use serde::de::DeserializeOwned; use serde::Deserialize; use serde::Serialize; -use sha2::Digest; -use sha2::Sha512; use std::convert::Infallible; use std::fmt::Debug; use std::num::NonZeroI32; @@ -195,6 +193,8 @@ pub enum AuthorizationError { Unauthorized, #[error("the request was not in the form of 'Authorization: Bearer '")] MalformedRequest, + #[error("requestor is forbidden from requested resource")] + Forbidden, #[error("to str error: {0}")] ToStr(#[from] hyper::header::ToStrError), } @@ -278,6 +278,7 @@ pub(crate) struct HttpApi { write_buffer: Arc, pub(crate) query_executor: Arc, max_request_bytes: usize, + authorizer: Arc, } impl HttpApi { @@ -286,12 +287,14 @@ impl HttpApi { write_buffer: Arc, query_executor: Arc, max_request_bytes: usize, + authorizer: Arc, ) -> Self { Self { common_state, write_buffer, query_executor, max_request_bytes, + authorizer, } } } @@ -487,47 +490,57 @@ where Ok(decoded_data.into()) } - fn authorize_request(&self, req: &mut Request) -> Result<(), AuthorizationError> { + async fn authorize_request(&self, req: &mut Request) -> Result<(), AuthorizationError> { // We won't need the authorization header anymore and we don't want to accidentally log it. // Take it out so we can use it and not log it later by accident. - let auth = req.headers_mut().remove(AUTHORIZATION); + let auth = req + .headers_mut() + .remove(AUTHORIZATION) + .map(validate_auth_header) + .transpose()?; - if let Some(bearer_token) = self.common_state.bearer_token() { - let Some(header) = &auth else { - return Err(AuthorizationError::Unauthorized); - }; + // Currently we pass an empty permissions list, but in future we may be able to derive + // the permissions based on the incoming request + let permissions = self.authorizer.permissions(auth, &[]).await?; - // Split the header value into two parts - let mut header = header.to_str()?.split(' '); + // Extend the request with the permissions, which may be useful in future + req.extensions_mut().insert(permissions); - // Check that the header is the 'Bearer' auth scheme - let bearer = header.next().ok_or(AuthorizationError::MalformedRequest)?; - if bearer != "Bearer" { - return Err(AuthorizationError::MalformedRequest); - } - - // Get the token that we want to hash to check the request is valid - let token = header.next().ok_or(AuthorizationError::MalformedRequest)?; - - // There should only be two parts the 'Bearer' scheme and the actual - // token, error otherwise - if header.next().is_some() { - return Err(AuthorizationError::MalformedRequest); - } - - // Check that the hashed token is acceptable - let authorized = &Sha512::digest(token)[..] == bearer_token; - if !authorized { - return Err(AuthorizationError::Unauthorized); - } - } - - req.extensions_mut() - .insert(AuthorizationHeaderExtension::new(auth)); Ok(()) } } +fn validate_auth_header(header: HeaderValue) -> Result, AuthorizationError> { + // Split the header value into two parts + let mut header = header.to_str()?.split(' '); + + // Check that the header is the 'Bearer' auth scheme + let bearer = header.next().ok_or(AuthorizationError::MalformedRequest)?; + if bearer != "Bearer" { + return Err(AuthorizationError::MalformedRequest); + } + + // Get the token that we want to hash to check the request is valid + let token = header.next().ok_or(AuthorizationError::MalformedRequest)?; + + // There should only be two parts the 'Bearer' scheme and the actual + // token, error otherwise + if header.next().is_some() { + return Err(AuthorizationError::MalformedRequest); + } + + Ok(token.as_bytes().to_vec()) +} + +impl From for AuthorizationError { + fn from(auth_error: authz::Error) -> Self { + match auth_error { + authz::Error::Forbidden => Self::Forbidden, + _ => Self::Unauthorized, + } + } +} + /// A valid name: /// - Starts with a letter or a number /// - Is ASCII not UTF-8 @@ -700,7 +713,7 @@ where Q: QueryExecutor, Error: From<::Error>, { - if let Err(e) = http_server.authorize_request(&mut req) { + if let Err(e) = http_server.authorize_request(&mut req).await { match e { AuthorizationError::Unauthorized => { return Ok(Response::builder() @@ -716,6 +729,12 @@ where }")) .unwrap()); } + AuthorizationError::Forbidden => { + return Ok(Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Body::empty()) + .unwrap()) + } // We don't expect this to happen, but if the header is messed up // better to handle it then not at all AuthorizationError::ToStr(_) => { diff --git a/influxdb3_server/src/lib.rs b/influxdb3_server/src/lib.rs index 097802967e..dbe8dfb1b2 100644 --- a/influxdb3_server/src/lib.rs +++ b/influxdb3_server/src/lib.rs @@ -11,6 +11,8 @@ clippy::clone_on_ref_ptr, clippy::future_not_send )] +pub mod auth; +pub mod builder; mod grpc; mod http; pub mod query_executor; @@ -20,6 +22,7 @@ use crate::grpc::make_flight_server; use crate::http::route_request; use crate::http::HttpApi; use async_trait::async_trait; +use authz::Authorizer; use datafusion::execution::SendableRecordBatchStream; use hyper::service::service_fn; use influxdb3_write::{Persister, WriteBuffer}; @@ -72,7 +75,6 @@ pub struct CommonServerState { trace_exporter: Option>, trace_header_parser: TraceHeaderParser, http_addr: SocketAddr, - bearer_token: Option>, } impl CommonServerState { @@ -81,14 +83,12 @@ impl CommonServerState { trace_exporter: Option>, trace_header_parser: TraceHeaderParser, http_addr: SocketAddr, - bearer_token: Option, ) -> Result { Ok(Self { metrics, trace_exporter, trace_header_parser, http_addr, - bearer_token: bearer_token.map(hex::decode).transpose()?, }) } @@ -109,10 +109,6 @@ impl CommonServerState { pub fn metric_registry(&self) -> Arc { Arc::::clone(&self.metrics) } - - pub fn bearer_token(&self) -> Option<&[u8]> { - self.bearer_token.as_deref() - } } #[allow(dead_code)] @@ -121,6 +117,7 @@ pub struct Server { common_state: CommonServerState, http: Arc>, persister: Arc

, + authorizer: Arc, } #[async_trait] @@ -151,30 +148,9 @@ pub enum QueryKind { InfluxQl, } -impl Server -where - Q: QueryExecutor, - P: Persister, -{ - pub fn new( - common_state: CommonServerState, - persister: Arc

, - write_buffer: Arc, - query_executor: Arc, - max_http_request_size: usize, - ) -> Self { - let http = Arc::new(HttpApi::new( - common_state.clone(), - Arc::clone(&write_buffer), - Arc::clone(&query_executor), - max_http_request_size, - )); - - Self { - common_state, - http, - persister, - } +impl Server { + pub fn authorizer(&self) -> Arc { + Arc::clone(&self.authorizer) } } @@ -204,8 +180,7 @@ where let grpc_service = trace_layer.clone().layer(make_flight_server( Arc::clone(&server.http.query_executor), - // TODO - need to configure authz here: - None, + Some(server.authorizer()), )); let rest_service = hyper::service::make_service_fn(|_| { let http_server = Arc::clone(&server.http); @@ -249,6 +224,8 @@ pub async fn wait_for_signal() { #[cfg(test)] mod tests { + use crate::auth::DefaultAuthorizer; + use crate::builder::ServerBuilder; use crate::serve; use datafusion::parquet::data_type::AsBytes; use hyper::{body, Body, Client, Request, Response, StatusCode}; @@ -271,14 +248,9 @@ mod tests { let addr = get_free_port(); let trace_header_parser = trace_http::ctx::TraceHeaderParser::new(); let metrics = Arc::new(metric::Registry::new()); - let common_state = crate::CommonServerState::new( - Arc::clone(&metrics), - None, - trace_header_parser, - addr, - None, - ) - .unwrap(); + let common_state = + crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr) + .unwrap(); let object_store: Arc = Arc::new(object_store::memory::InMemory::new()); let parquet_store = ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3")); @@ -293,33 +265,31 @@ mod tests { metric_registry: Arc::clone(&metrics), mem_pool_size: usize::MAX, })); - let persister = PersisterImpl::new(Arc::clone(&object_store)); + let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store))); let write_buffer = Arc::new( influxdb3_write::write_buffer::WriteBufferImpl::new( - Arc::new(persister), + Arc::clone(&persister), None::>, ) .await .unwrap(), ); - let query_executor = crate::query_executor::QueryExecutorImpl::new( + let query_executor = Arc::new(crate::query_executor::QueryExecutorImpl::new( write_buffer.catalog(), Arc::clone(&write_buffer), Arc::clone(&exec), Arc::clone(&metrics), Arc::new(HashMap::new()), 10, - ); - let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store))); + )); - let server = crate::Server::new( - common_state, - persister, - Arc::clone(&write_buffer), - Arc::new(query_executor), - usize::MAX, - ); + let server = ServerBuilder::new(common_state) + .write_buffer(Arc::clone(&write_buffer)) + .query_executor(Arc::clone(&query_executor)) + .persister(Arc::clone(&persister)) + .authorizer(Arc::new(DefaultAuthorizer)) + .build(); let frontend_shutdown = CancellationToken::new(); let shutdown = frontend_shutdown.clone(); @@ -410,14 +380,9 @@ mod tests { let addr = get_free_port(); let trace_header_parser = trace_http::ctx::TraceHeaderParser::new(); let metrics = Arc::new(metric::Registry::new()); - let common_state = crate::CommonServerState::new( - Arc::clone(&metrics), - None, - trace_header_parser, - addr, - None, - ) - .unwrap(); + let common_state = + crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr) + .unwrap(); let object_store: Arc = Arc::new(object_store::memory::InMemory::new()); let parquet_store = ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3")); @@ -451,13 +416,12 @@ mod tests { 10, ); - let server = crate::Server::new( - common_state, - persister, - Arc::clone(&write_buffer), - Arc::new(query_executor), - usize::MAX, - ); + let server = ServerBuilder::new(common_state) + .write_buffer(Arc::clone(&write_buffer)) + .query_executor(Arc::new(query_executor)) + .persister(persister) + .authorizer(Arc::new(DefaultAuthorizer)) + .build(); let frontend_shutdown = CancellationToken::new(); let shutdown = frontend_shutdown.clone(); @@ -584,14 +548,9 @@ mod tests { let addr = get_free_port(); let trace_header_parser = trace_http::ctx::TraceHeaderParser::new(); let metrics = Arc::new(metric::Registry::new()); - let common_state = crate::CommonServerState::new( - Arc::clone(&metrics), - None, - trace_header_parser, - addr, - None, - ) - .unwrap(); + let common_state = + crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr) + .unwrap(); let object_store: Arc = Arc::new(object_store::memory::InMemory::new()); let parquet_store = ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3")); @@ -625,13 +584,12 @@ mod tests { 10, ); - let server = crate::Server::new( - common_state, - persister, - Arc::clone(&write_buffer), - Arc::new(query_executor), - usize::MAX, - ); + let server = ServerBuilder::new(common_state) + .write_buffer(Arc::clone(&write_buffer)) + .query_executor(Arc::new(query_executor)) + .persister(persister) + .authorizer(Arc::new(DefaultAuthorizer)) + .build(); let frontend_shutdown = CancellationToken::new(); let shutdown = frontend_shutdown.clone();