diff --git a/Cargo.lock b/Cargo.lock index 87de480c3f..b59bff3428 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2469,6 +2469,7 @@ dependencies = [ "arrow-flight", "arrow_util", "assert_cmd", + "authz", "backtrace", "base64 0.22.0", "clap", diff --git a/influxdb3/Cargo.toml b/influxdb3/Cargo.toml index b8bead478d..e62302c45f 100644 --- a/influxdb3/Cargo.toml +++ b/influxdb3/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dependencies] # Core Crates +authz.workspace = true clap_blocks.workspace = true iox_query.workspace = true iox_time.workspace = true diff --git a/influxdb3/src/commands/serve.rs b/influxdb3/src/commands/serve.rs index 03cc78fe82..1cdc4c0e86 100644 --- a/influxdb3/src/commands/serve.rs +++ b/influxdb3/src/commands/serve.rs @@ -7,7 +7,10 @@ use clap_blocks::{ object_store::{make_object_store, ObjectStoreConfig}, socket_addr::SocketAddr, }; -use influxdb3_server::{query_executor::QueryExecutorImpl, serve, CommonServerState, Server}; +use influxdb3_server::{ + auth::AllOrNothingAuthorizer, builder::ServerBuilder, query_executor::QueryExecutorImpl, serve, + CommonServerState, +}; use influxdb3_write::persister::PersisterImpl; use influxdb3_write::wal::WalImpl; use influxdb3_write::write_buffer::WriteBufferImpl; @@ -51,6 +54,9 @@ pub enum Error { #[error("Write buffer error: {0}")] WriteBuffer(#[from] influxdb3_write::write_buffer::Error), + + #[error("invalid token: {0}")] + InvalidToken(#[from] hex::FromHexError), } pub type Result = std::result::Result; @@ -238,32 +244,38 @@ pub async fn command(config: Config) -> Result<()> { trace_exporter, trace_header_parser, *config.http_bind_address, - config.bearer_token, )?; - let persister = PersisterImpl::new(Arc::clone(&object_store)); + let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store))); let wal: Option> = config .wal_directory .map(|dir| WalImpl::new(dir).map(Arc::new)) .transpose()?; // TODO: the next segment ID should be loaded from the persister - let write_buffer = Arc::new(WriteBufferImpl::new(Arc::new(persister), wal).await?); - let query_executor = QueryExecutorImpl::new( + let write_buffer = Arc::new(WriteBufferImpl::new(Arc::clone(&persister), wal).await?); + let query_executor = Arc::new(QueryExecutorImpl::new( write_buffer.catalog(), Arc::clone(&write_buffer), Arc::clone(&exec), Arc::clone(&metrics), Arc::new(config.datafusion_config), 10, - ); + )); let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store))); - let server = Server::new( - common_state, - persister, - Arc::clone(&write_buffer), - Arc::new(query_executor), - config.max_http_request_size, - ); + + let builder = ServerBuilder::new(common_state) + .max_request_size(config.max_http_request_size) + .write_buffer(write_buffer) + .query_executor(query_executor) + .persister(persister); + + let server = if let Some(token) = config.bearer_token.map(hex::decode).transpose()? { + builder + .authorizer(Arc::new(AllOrNothingAuthorizer::new(token))) + .build() + } else { + builder.build() + }; serve(server, frontend_shutdown).await?; Ok(()) diff --git a/influxdb3/tests/server/auth.rs b/influxdb3/tests/server/auth.rs index fec2f393fc..e04eb1c4cc 100644 --- a/influxdb3/tests/server/auth.rs +++ b/influxdb3/tests/server/auth.rs @@ -1,78 +1,31 @@ -use parking_lot::Mutex; +use arrow_flight::error::FlightError; +use arrow_util::assert_batches_sorted_eq; +use influxdb3_client::Precision; use reqwest::StatusCode; -use std::env; -use std::mem; -use std::panic; -use std::process::Child; -use std::process::Command; -use std::process::Stdio; -struct DropCommand { - cmd: Option, -} - -impl DropCommand { - const fn new(cmd: Child) -> Self { - Self { cmd: Some(cmd) } - } - - fn kill(&mut self) { - let mut cmd = self.cmd.take().unwrap(); - cmd.kill().unwrap(); - mem::drop(cmd); - } -} - -static COMMAND: Mutex> = parking_lot::const_mutex(None); +use crate::{collect_stream, TestServer}; #[tokio::test] async fn auth() { const HASHED_TOKEN: &str = "5315f0c4714537843face80cca8c18e27ce88e31e9be7a5232dc4dc8444f27c0227a9bd64831d3ab58f652bd0262dd8558dd08870ac9e5c650972ce9e4259439"; const TOKEN: &str = "apiv3_mp75KQAhbqv0GeQXk8MPuZ3ztaLEaR5JzS8iifk1FwuroSVyXXyrJK1c4gEr1kHkmbgzDV-j3MvQpaIMVJBAiA"; - // The binary is made before testing so we have access to it - let bin_path = { - let mut bin_path = env::current_exe().unwrap(); - bin_path.pop(); - bin_path.pop(); - bin_path.join("influxdb3") - }; - let server = DropCommand::new( - Command::new(bin_path) - .args([ - "serve", - "--object-store", - "memory", - "--bearer-token", - HASHED_TOKEN, - ]) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .expect("Was able to spawn a server"), - ); - *COMMAND.lock() = Some(server); - - let current_hook = panic::take_hook(); - panic::set_hook(Box::new(move |info| { - COMMAND.lock().take().unwrap().kill(); - current_hook(info); - })); + let server = TestServer::configure() + .auth_token(HASHED_TOKEN, TOKEN) + .spawn() + .await; let client = reqwest::Client::new(); - - // Wait for the server to come up - while client - .get("http://127.0.0.1:8181/health") - .bearer_auth(TOKEN) - .send() - .await - .is_err() - {} + let base = server.client_addr(); + let write_lp_url = format!("{base}/api/v3/write_lp"); + let write_lp_params = [("db", "foo")]; + let query_sql_url = format!("{base}/api/v3/query_sql"); + let query_sql_params = [("db", "foo"), ("q", "select * from cpu")]; assert_eq!( client - .post("http://127.0.0.1:8181/api/v3/write_lp?db=foo") + .post(&write_lp_url) + .query(&write_lp_params) .body("cpu,host=a val=1i 123") .send() .await @@ -82,7 +35,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .send() .await .unwrap() @@ -91,7 +45,8 @@ async fn auth() { ); assert_eq!( client - .post("http://127.0.0.1:8181/api/v3/write_lp?db=foo") + .post(&write_lp_url) + .query(&write_lp_params) .body("cpu,host=a val=1i 123") .bearer_auth(TOKEN) .send() @@ -102,7 +57,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .bearer_auth(TOKEN) .send() .await @@ -114,7 +70,8 @@ async fn auth() { // Test that there is an extra string after the token foo assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .header("Authorization", format!("Bearer {TOKEN} whee")) .send() .await @@ -124,7 +81,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .header("Authorization", format!("bearer {TOKEN}")) .send() .await @@ -134,7 +92,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .header("Authorization", "Bearer") .send() .await @@ -144,7 +103,8 @@ async fn auth() { ); assert_eq!( client - .get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu") + .get(&query_sql_url) + .query(&query_sql_params) .header("auth", format!("Bearer {TOKEN}")) .send() .await @@ -152,5 +112,97 @@ async fn auth() { .status(), StatusCode::UNAUTHORIZED ); - COMMAND.lock().take().unwrap().kill(); +} + +#[tokio::test] +async fn auth_grpc() { + const HASHED_TOKEN: &str = "5315f0c4714537843face80cca8c18e27ce88e31e9be7a5232dc4dc8444f27c0227a9bd64831d3ab58f652bd0262dd8558dd08870ac9e5c650972ce9e4259439"; + const TOKEN: &str = "apiv3_mp75KQAhbqv0GeQXk8MPuZ3ztaLEaR5JzS8iifk1FwuroSVyXXyrJK1c4gEr1kHkmbgzDV-j3MvQpaIMVJBAiA"; + + let server = TestServer::configure() + .auth_token(HASHED_TOKEN, TOKEN) + .spawn() + .await; + + // Write some data to the server, this will be authorized through the HTTP API + server + .write_lp_to_db( + "foo", + "cpu,host=s1,region=us-east usage=0.9 1\n\ + cpu,host=s1,region=us-east usage=0.89 2\n\ + cpu,host=s1,region=us-east usage=0.85 3", + Precision::Nanosecond, + ) + .await + .unwrap(); + + // Check that with a valid authorization header, it succeeds: + for header in ["authorization", "Authorization"] { + // Spin up a FlightSQL client + let mut client = server.flight_sql_client("foo").await; + + // Set the authorization header on the client: + client + .add_header(header, &format!("Bearer {TOKEN}")) + .unwrap(); + + // Make the query again, this time it should work: + let response = client.query("SELECT * FROM cpu").await.unwrap(); + let batches = collect_stream(response).await; + assert_batches_sorted_eq!( + [ + "+------+---------+--------------------------------+-------+", + "| host | region | time | usage |", + "+------+---------+--------------------------------+-------+", + "| s1 | us-east | 1970-01-01T00:00:00.000000001Z | 0.9 |", + "| s1 | us-east | 1970-01-01T00:00:00.000000002Z | 0.89 |", + "| s1 | us-east | 1970-01-01T00:00:00.000000003Z | 0.85 |", + "+------+---------+--------------------------------+-------+", + ], + &batches + ); + } + + // Check that without providing an Authentication header, it gives back + // an Unauthenticated error: + { + let mut client = server.flight_sql_client("foo").await; + let error = client.query("SELECT * FROM cpu").await.unwrap_err(); + assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated)); + } + + // Create some new clients that set the authorization header incorrectly to + // ensure errors are returned: + + // Mispelled "Bearer" + { + let mut client = server.flight_sql_client("foo").await; + client + .add_header("authorization", &format!("bearer {TOKEN}")) + .unwrap(); + let error = client.query("SELECT * FROM cpu").await.unwrap_err(); + assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated)); + } + + // Invalid token, this actually gives Permission denied + { + let mut client = server.flight_sql_client("foo").await; + client + .add_header("authorization", "Bearer invalid-token") + .unwrap(); + let error = client.query("SELECT * FROM cpu").await.unwrap_err(); + assert!( + matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::PermissionDenied) + ); + } + + // Mispelled header key + { + let mut client = server.flight_sql_client("foo").await; + client + .add_header("auth", &format!("Bearer {TOKEN}")) + .unwrap(); + let error = client.query("SELECT * FROM cpu").await.unwrap_err(); + assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated)); + } } diff --git a/influxdb3/tests/server/main.rs b/influxdb3/tests/server/main.rs index a99147b3a7..77a2fb903c 100644 --- a/influxdb3/tests/server/main.rs +++ b/influxdb3/tests/server/main.rs @@ -17,8 +17,43 @@ mod flight; mod limits; mod query; +/// Configuration for a [`TestServer`] +#[derive(Debug, Default)] +pub struct TestConfig { + auth_token: Option<(String, String)>, +} + +impl TestConfig { + /// Set the auth token for this [`TestServer`] + pub fn auth_token, R: Into>( + mut self, + hashed_token: S, + raw_token: R, + ) -> Self { + self.auth_token = Some((hashed_token.into(), raw_token.into())); + self + } + + /// Spawn a new [`TestServer`] with this configuration + /// + /// This will run the `influxdb3 serve` command, and bind its HTTP + /// address to a random port on localhost. + pub async fn spawn(self) -> TestServer { + TestServer::spawn_inner(self).await + } + + fn as_args(&self) -> Vec<&str> { + let mut args = vec![]; + if let Some((token, _)) = &self.auth_token { + args.append(&mut vec!["--bearer-token", token]); + } + args + } +} + /// A running instance of the `influxdb3 serve` process pub struct TestServer { + config: TestConfig, bind_addr: SocketAddr, server_process: Child, http_client: reqwest::Client, @@ -30,19 +65,29 @@ impl TestServer { /// This will run the `influxdb3 serve` command, and bind its HTTP /// address to a random port on localhost. pub async fn spawn() -> Self { + Self::spawn_inner(Default::default()).await + } + + /// Configure a [`TestServer`] before spawning + pub fn configure() -> TestConfig { + TestConfig::default() + } + + async fn spawn_inner(config: TestConfig) -> Self { let bind_addr = get_local_bind_addr(); let mut command = Command::cargo_bin("influxdb3").expect("create the influxdb3 command"); let command = command .arg("serve") .args(["--http-bind", &bind_addr.to_string()]) .args(["--object-store", "memory"]) - // TODO - other configuration can be passed through + .args(config.as_args()) .stdout(Stdio::null()) .stderr(Stdio::null()); let server_process = command.spawn().expect("spawn the influxdb3 server process"); let server = Self { + config, bind_addr, server_process, http_client: reqwest::Client::new(), @@ -111,7 +156,10 @@ impl TestServer { lp: impl ToString, precision: Precision, ) -> Result<(), influxdb3_client::Error> { - let client = influxdb3_client::Client::new(self.client_addr()).unwrap(); + let mut client = influxdb3_client::Client::new(self.client_addr()).unwrap(); + if let Some((_, token)) = &self.config.auth_token { + client = client.with_auth_token(token); + } client .api_v3_write_lp(database) .body(lp.to_string()) diff --git a/influxdb3_server/src/auth.rs b/influxdb3_server/src/auth.rs new file mode 100644 index 0000000000..1954c6510f --- /dev/null +++ b/influxdb3_server/src/auth.rs @@ -0,0 +1,58 @@ +use async_trait::async_trait; +use authz::{Authorizer, Error, Permission}; +use observability_deps::tracing::{debug, warn}; +use sha2::{Digest, Sha512}; + +/// An [`Authorizer`] implementation that will grant access to all +/// requests that provide `token` +#[derive(Debug)] +pub struct AllOrNothingAuthorizer { + token: Vec, +} + +impl AllOrNothingAuthorizer { + pub fn new(token: Vec) -> Self { + Self { token } + } +} + +#[async_trait] +impl Authorizer for AllOrNothingAuthorizer { + async fn permissions( + &self, + token: Option>, + perms: &[Permission], + ) -> Result, Error> { + debug!(?perms, "requesting permissions"); + let provided = token.as_deref().ok_or(Error::NoToken)?; + if Sha512::digest(provided)[..] == self.token { + warn!("invalid token provided"); + Ok(perms.to_vec()) + } else { + Err(Error::InvalidToken) + } + } + + async fn probe(&self) -> Result<(), Error> { + Ok(()) + } +} + +/// The defult [`Authorizer`] implementation that will authorize all requests +#[derive(Debug)] +pub struct DefaultAuthorizer; + +#[async_trait] +impl Authorizer for DefaultAuthorizer { + async fn permissions( + &self, + _token: Option>, + perms: &[Permission], + ) -> Result, Error> { + Ok(perms.to_vec()) + } + + async fn probe(&self) -> Result<(), Error> { + Ok(()) + } +} diff --git a/influxdb3_server/src/builder.rs b/influxdb3_server/src/builder.rs new file mode 100644 index 0000000000..4fba3796ef --- /dev/null +++ b/influxdb3_server/src/builder.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use authz::Authorizer; + +use crate::{auth::DefaultAuthorizer, http::HttpApi, CommonServerState, Server}; + +#[derive(Debug)] +pub struct ServerBuilder { + common_state: CommonServerState, + max_request_size: usize, + write_buffer: W, + query_executor: Q, + persister: P, + authorizer: Arc, +} + +impl ServerBuilder { + pub fn new(common_state: CommonServerState) -> Self { + Self { + common_state, + max_request_size: usize::MAX, + write_buffer: NoWriteBuf, + query_executor: NoQueryExec, + persister: NoPersister, + authorizer: Arc::new(DefaultAuthorizer), + } + } +} + +impl ServerBuilder { + pub fn max_request_size(mut self, max_request_size: usize) -> Self { + self.max_request_size = max_request_size; + self + } + + pub fn authorizer(mut self, a: Arc) -> Self { + self.authorizer = a; + self + } +} + +#[derive(Debug)] +pub struct NoWriteBuf; +#[derive(Debug)] +pub struct WithWriteBuf(Arc); +#[derive(Debug)] +pub struct NoQueryExec; +#[derive(Debug)] +pub struct WithQueryExec(Arc); +#[derive(Debug)] +pub struct NoPersister; +#[derive(Debug)] +pub struct WithPersister

(Arc

); + +impl ServerBuilder { + pub fn write_buffer(self, wb: Arc) -> ServerBuilder, Q, P> { + ServerBuilder { + common_state: self.common_state, + max_request_size: self.max_request_size, + write_buffer: WithWriteBuf(wb), + query_executor: self.query_executor, + persister: self.persister, + authorizer: self.authorizer, + } + } +} + +impl ServerBuilder { + pub fn query_executor(self, qe: Arc) -> ServerBuilder, P> { + ServerBuilder { + common_state: self.common_state, + max_request_size: self.max_request_size, + write_buffer: self.write_buffer, + query_executor: WithQueryExec(qe), + persister: self.persister, + authorizer: self.authorizer, + } + } +} + +impl ServerBuilder { + pub fn persister

(self, p: Arc

) -> ServerBuilder> { + ServerBuilder { + common_state: self.common_state, + max_request_size: self.max_request_size, + write_buffer: self.write_buffer, + query_executor: self.query_executor, + persister: WithPersister(p), + authorizer: self.authorizer, + } + } +} + +impl ServerBuilder, WithQueryExec, WithPersister

> { + pub fn build(self) -> Server { + let persister = Arc::clone(&self.persister.0); + let authorizer = Arc::clone(&self.authorizer); + let http = Arc::new(HttpApi::new( + self.common_state.clone(), + Arc::clone(&self.write_buffer.0), + Arc::clone(&self.query_executor.0), + self.max_request_size, + Arc::clone(&authorizer), + )); + Server { + common_state: self.common_state, + http, + persister, + authorizer, + } + } +} diff --git a/influxdb3_server/src/http.rs b/influxdb3_server/src/http.rs index cd48f0cd3e..55a4559755 100644 --- a/influxdb3_server/src/http.rs +++ b/influxdb3_server/src/http.rs @@ -4,7 +4,7 @@ use crate::{query_executor, QueryKind}; use crate::{CommonServerState, QueryExecutor}; use arrow::record_batch::RecordBatch; use arrow::util::pretty; -use authz::http::AuthorizationHeaderExtension; +use authz::Authorizer; use bytes::{Bytes, BytesMut}; use data_types::NamespaceName; use datafusion::error::DataFusionError; @@ -29,8 +29,6 @@ use observability_deps::tracing::{debug, error, info}; use serde::de::DeserializeOwned; use serde::Deserialize; use serde::Serialize; -use sha2::Digest; -use sha2::Sha512; use std::convert::Infallible; use std::fmt::Debug; use std::num::NonZeroI32; @@ -195,6 +193,8 @@ pub enum AuthorizationError { Unauthorized, #[error("the request was not in the form of 'Authorization: Bearer '")] MalformedRequest, + #[error("requestor is forbidden from requested resource")] + Forbidden, #[error("to str error: {0}")] ToStr(#[from] hyper::header::ToStrError), } @@ -278,6 +278,7 @@ pub(crate) struct HttpApi { write_buffer: Arc, pub(crate) query_executor: Arc, max_request_bytes: usize, + authorizer: Arc, } impl HttpApi { @@ -286,12 +287,14 @@ impl HttpApi { write_buffer: Arc, query_executor: Arc, max_request_bytes: usize, + authorizer: Arc, ) -> Self { Self { common_state, write_buffer, query_executor, max_request_bytes, + authorizer, } } } @@ -487,47 +490,57 @@ where Ok(decoded_data.into()) } - fn authorize_request(&self, req: &mut Request) -> Result<(), AuthorizationError> { + async fn authorize_request(&self, req: &mut Request) -> Result<(), AuthorizationError> { // We won't need the authorization header anymore and we don't want to accidentally log it. // Take it out so we can use it and not log it later by accident. - let auth = req.headers_mut().remove(AUTHORIZATION); + let auth = req + .headers_mut() + .remove(AUTHORIZATION) + .map(validate_auth_header) + .transpose()?; - if let Some(bearer_token) = self.common_state.bearer_token() { - let Some(header) = &auth else { - return Err(AuthorizationError::Unauthorized); - }; + // Currently we pass an empty permissions list, but in future we may be able to derive + // the permissions based on the incoming request + let permissions = self.authorizer.permissions(auth, &[]).await?; - // Split the header value into two parts - let mut header = header.to_str()?.split(' '); + // Extend the request with the permissions, which may be useful in future + req.extensions_mut().insert(permissions); - // Check that the header is the 'Bearer' auth scheme - let bearer = header.next().ok_or(AuthorizationError::MalformedRequest)?; - if bearer != "Bearer" { - return Err(AuthorizationError::MalformedRequest); - } - - // Get the token that we want to hash to check the request is valid - let token = header.next().ok_or(AuthorizationError::MalformedRequest)?; - - // There should only be two parts the 'Bearer' scheme and the actual - // token, error otherwise - if header.next().is_some() { - return Err(AuthorizationError::MalformedRequest); - } - - // Check that the hashed token is acceptable - let authorized = &Sha512::digest(token)[..] == bearer_token; - if !authorized { - return Err(AuthorizationError::Unauthorized); - } - } - - req.extensions_mut() - .insert(AuthorizationHeaderExtension::new(auth)); Ok(()) } } +fn validate_auth_header(header: HeaderValue) -> Result, AuthorizationError> { + // Split the header value into two parts + let mut header = header.to_str()?.split(' '); + + // Check that the header is the 'Bearer' auth scheme + let bearer = header.next().ok_or(AuthorizationError::MalformedRequest)?; + if bearer != "Bearer" { + return Err(AuthorizationError::MalformedRequest); + } + + // Get the token that we want to hash to check the request is valid + let token = header.next().ok_or(AuthorizationError::MalformedRequest)?; + + // There should only be two parts the 'Bearer' scheme and the actual + // token, error otherwise + if header.next().is_some() { + return Err(AuthorizationError::MalformedRequest); + } + + Ok(token.as_bytes().to_vec()) +} + +impl From for AuthorizationError { + fn from(auth_error: authz::Error) -> Self { + match auth_error { + authz::Error::Forbidden => Self::Forbidden, + _ => Self::Unauthorized, + } + } +} + /// A valid name: /// - Starts with a letter or a number /// - Is ASCII not UTF-8 @@ -700,7 +713,7 @@ where Q: QueryExecutor, Error: From<::Error>, { - if let Err(e) = http_server.authorize_request(&mut req) { + if let Err(e) = http_server.authorize_request(&mut req).await { match e { AuthorizationError::Unauthorized => { return Ok(Response::builder() @@ -716,6 +729,12 @@ where }")) .unwrap()); } + AuthorizationError::Forbidden => { + return Ok(Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Body::empty()) + .unwrap()) + } // We don't expect this to happen, but if the header is messed up // better to handle it then not at all AuthorizationError::ToStr(_) => { diff --git a/influxdb3_server/src/lib.rs b/influxdb3_server/src/lib.rs index 097802967e..dbe8dfb1b2 100644 --- a/influxdb3_server/src/lib.rs +++ b/influxdb3_server/src/lib.rs @@ -11,6 +11,8 @@ clippy::clone_on_ref_ptr, clippy::future_not_send )] +pub mod auth; +pub mod builder; mod grpc; mod http; pub mod query_executor; @@ -20,6 +22,7 @@ use crate::grpc::make_flight_server; use crate::http::route_request; use crate::http::HttpApi; use async_trait::async_trait; +use authz::Authorizer; use datafusion::execution::SendableRecordBatchStream; use hyper::service::service_fn; use influxdb3_write::{Persister, WriteBuffer}; @@ -72,7 +75,6 @@ pub struct CommonServerState { trace_exporter: Option>, trace_header_parser: TraceHeaderParser, http_addr: SocketAddr, - bearer_token: Option>, } impl CommonServerState { @@ -81,14 +83,12 @@ impl CommonServerState { trace_exporter: Option>, trace_header_parser: TraceHeaderParser, http_addr: SocketAddr, - bearer_token: Option, ) -> Result { Ok(Self { metrics, trace_exporter, trace_header_parser, http_addr, - bearer_token: bearer_token.map(hex::decode).transpose()?, }) } @@ -109,10 +109,6 @@ impl CommonServerState { pub fn metric_registry(&self) -> Arc { Arc::::clone(&self.metrics) } - - pub fn bearer_token(&self) -> Option<&[u8]> { - self.bearer_token.as_deref() - } } #[allow(dead_code)] @@ -121,6 +117,7 @@ pub struct Server { common_state: CommonServerState, http: Arc>, persister: Arc

, + authorizer: Arc, } #[async_trait] @@ -151,30 +148,9 @@ pub enum QueryKind { InfluxQl, } -impl Server -where - Q: QueryExecutor, - P: Persister, -{ - pub fn new( - common_state: CommonServerState, - persister: Arc

, - write_buffer: Arc, - query_executor: Arc, - max_http_request_size: usize, - ) -> Self { - let http = Arc::new(HttpApi::new( - common_state.clone(), - Arc::clone(&write_buffer), - Arc::clone(&query_executor), - max_http_request_size, - )); - - Self { - common_state, - http, - persister, - } +impl Server { + pub fn authorizer(&self) -> Arc { + Arc::clone(&self.authorizer) } } @@ -204,8 +180,7 @@ where let grpc_service = trace_layer.clone().layer(make_flight_server( Arc::clone(&server.http.query_executor), - // TODO - need to configure authz here: - None, + Some(server.authorizer()), )); let rest_service = hyper::service::make_service_fn(|_| { let http_server = Arc::clone(&server.http); @@ -249,6 +224,8 @@ pub async fn wait_for_signal() { #[cfg(test)] mod tests { + use crate::auth::DefaultAuthorizer; + use crate::builder::ServerBuilder; use crate::serve; use datafusion::parquet::data_type::AsBytes; use hyper::{body, Body, Client, Request, Response, StatusCode}; @@ -271,14 +248,9 @@ mod tests { let addr = get_free_port(); let trace_header_parser = trace_http::ctx::TraceHeaderParser::new(); let metrics = Arc::new(metric::Registry::new()); - let common_state = crate::CommonServerState::new( - Arc::clone(&metrics), - None, - trace_header_parser, - addr, - None, - ) - .unwrap(); + let common_state = + crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr) + .unwrap(); let object_store: Arc = Arc::new(object_store::memory::InMemory::new()); let parquet_store = ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3")); @@ -293,33 +265,31 @@ mod tests { metric_registry: Arc::clone(&metrics), mem_pool_size: usize::MAX, })); - let persister = PersisterImpl::new(Arc::clone(&object_store)); + let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store))); let write_buffer = Arc::new( influxdb3_write::write_buffer::WriteBufferImpl::new( - Arc::new(persister), + Arc::clone(&persister), None::>, ) .await .unwrap(), ); - let query_executor = crate::query_executor::QueryExecutorImpl::new( + let query_executor = Arc::new(crate::query_executor::QueryExecutorImpl::new( write_buffer.catalog(), Arc::clone(&write_buffer), Arc::clone(&exec), Arc::clone(&metrics), Arc::new(HashMap::new()), 10, - ); - let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store))); + )); - let server = crate::Server::new( - common_state, - persister, - Arc::clone(&write_buffer), - Arc::new(query_executor), - usize::MAX, - ); + let server = ServerBuilder::new(common_state) + .write_buffer(Arc::clone(&write_buffer)) + .query_executor(Arc::clone(&query_executor)) + .persister(Arc::clone(&persister)) + .authorizer(Arc::new(DefaultAuthorizer)) + .build(); let frontend_shutdown = CancellationToken::new(); let shutdown = frontend_shutdown.clone(); @@ -410,14 +380,9 @@ mod tests { let addr = get_free_port(); let trace_header_parser = trace_http::ctx::TraceHeaderParser::new(); let metrics = Arc::new(metric::Registry::new()); - let common_state = crate::CommonServerState::new( - Arc::clone(&metrics), - None, - trace_header_parser, - addr, - None, - ) - .unwrap(); + let common_state = + crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr) + .unwrap(); let object_store: Arc = Arc::new(object_store::memory::InMemory::new()); let parquet_store = ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3")); @@ -451,13 +416,12 @@ mod tests { 10, ); - let server = crate::Server::new( - common_state, - persister, - Arc::clone(&write_buffer), - Arc::new(query_executor), - usize::MAX, - ); + let server = ServerBuilder::new(common_state) + .write_buffer(Arc::clone(&write_buffer)) + .query_executor(Arc::new(query_executor)) + .persister(persister) + .authorizer(Arc::new(DefaultAuthorizer)) + .build(); let frontend_shutdown = CancellationToken::new(); let shutdown = frontend_shutdown.clone(); @@ -584,14 +548,9 @@ mod tests { let addr = get_free_port(); let trace_header_parser = trace_http::ctx::TraceHeaderParser::new(); let metrics = Arc::new(metric::Registry::new()); - let common_state = crate::CommonServerState::new( - Arc::clone(&metrics), - None, - trace_header_parser, - addr, - None, - ) - .unwrap(); + let common_state = + crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr) + .unwrap(); let object_store: Arc = Arc::new(object_store::memory::InMemory::new()); let parquet_store = ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3")); @@ -625,13 +584,12 @@ mod tests { 10, ); - let server = crate::Server::new( - common_state, - persister, - Arc::clone(&write_buffer), - Arc::new(query_executor), - usize::MAX, - ); + let server = ServerBuilder::new(common_state) + .write_buffer(Arc::clone(&write_buffer)) + .query_executor(Arc::new(query_executor)) + .persister(persister) + .authorizer(Arc::new(DefaultAuthorizer)) + .build(); let frontend_shutdown = CancellationToken::new(); let shutdown = frontend_shutdown.clone();