diff --git a/Cargo.lock b/Cargo.lock index aa7f573227..9b88063e10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3326,6 +3326,7 @@ version = "0.1.0" dependencies = [ "arrow-flight", "async-trait", + "authz", "clap_blocks", "data_types", "datafusion_util", @@ -5288,6 +5289,8 @@ dependencies = [ "arrow-flight", "arrow_util", "assert_matches", + "async-trait", + "authz", "bytes", "data_types", "datafusion", diff --git a/authz/src/lib.rs b/authz/src/lib.rs index 8942d0451b..65a81ac1df 100644 --- a/authz/src/lib.rs +++ b/authz/src/lib.rs @@ -38,14 +38,14 @@ pub trait Authorizer: std::fmt::Debug + Send + Sync { /// empty permission sets. async fn permissions( &self, - token: &[u8], + token: Option<&[u8]>, perms: &[Permission], ) -> Result, Error>; /// Make a test request that determines if end-to-end communication /// with the service is working. async fn probe(&self) -> Result<(), Error> { - self.permissions(b"", &[]).await?; + self.permissions(Some(b""), &[]).await?; Ok(()) } @@ -55,7 +55,7 @@ pub trait Authorizer: std::fmt::Debug + Send + Sync { /// error is returned. async fn require_any_permission( &self, - token: &[u8], + token: Option<&[u8]>, perms: &[Permission], ) -> Result<(), Error> { if self.permissions(token, perms).await?.is_empty() { @@ -66,6 +66,31 @@ pub trait Authorizer: std::fmt::Debug + Send + Sync { } } +#[async_trait] +impl Authorizer for Option { + async fn permissions( + &self, + token: Option<&[u8]>, + perms: &[Permission], + ) -> Result, Error> { + match self { + Some(authz) => authz.permissions(token, perms).await, + None => Ok(perms.to_vec()), + } + } +} + +#[async_trait] +impl + std::fmt::Debug + Send + Sync> Authorizer for T { + async fn permissions( + &self, + token: Option<&[u8]>, + perms: &[Permission], + ) -> Result, Error> { + self.as_ref().permissions(token, perms).await + } +} + /// Authorizer implementation using influxdata.iox.authz.v1 protocol. #[derive(Clone, Debug)] pub struct IoxAuthorizer { @@ -92,11 +117,11 @@ impl IoxAuthorizer { impl Authorizer for IoxAuthorizer { async fn permissions( &self, - token: &[u8], + token: Option<&[u8]>, perms: &[Permission], ) -> Result, Error> { let req = proto::AuthorizeRequest { - token: token.to_vec(), + token: token.ok_or(Error::NoToken)?.to_vec(), permissions: perms .iter() .filter_map(|p| p.clone().try_into().ok()) @@ -131,9 +156,13 @@ pub enum Error { source: Box, }, - /// The token's permissions do not allow the operation.. + /// The token's permissions do not allow the operation. #[snafu(display("forbidden"))] Forbidden, + + /// No token has been supplied, but is required. + #[snafu(display("no token"))] + NoToken, } impl Error { diff --git a/influxdb_iox/src/commands/run/all_in_one.rs b/influxdb_iox/src/commands/run/all_in_one.rs index c2ca46d383..9cc74534ac 100644 --- a/influxdb_iox/src/commands/run/all_in_one.rs +++ b/influxdb_iox/src/commands/run/all_in_one.rs @@ -3,6 +3,7 @@ use crate::process_info::setup_metric_registry; use super::main; +use authz::Authorizer; use clap_blocks::{ authz::AuthzConfig, catalog_dsn::CatalogDsnConfig, @@ -596,10 +597,8 @@ pub async fn command(config: Config) -> Result<()> { let time_provider: Arc = Arc::new(SystemProvider::new()); let authz = authz_config.authorizer()?; - if let Some(authz) = &authz { - // Verify the connection to the authorizer, if configured. - authz.probe().await?; - } + // Verify the connection to the authorizer, if configured. + authz.probe().await?; // create common state from the router and use it below let common_state = CommonServerState::from_config(router_run_config.clone())?; @@ -627,7 +626,7 @@ pub async fn command(config: Config) -> Result<()> { Arc::clone(&metrics), Arc::clone(&catalog), Arc::clone(&object_store), - authz.map(|a| Arc::clone(&a)), + authz.as_ref().map(Arc::clone), &router_config, ) .await?; @@ -682,6 +681,7 @@ pub async fn command(config: Config) -> Result<()> { ingester_addresses, querier_config, rpc_write: true, + authz: authz.as_ref().map(Arc::clone), }) .await?; diff --git a/influxdb_iox/src/commands/run/querier.rs b/influxdb_iox/src/commands/run/querier.rs index 2586593bfd..170748d150 100644 --- a/influxdb_iox/src/commands/run/querier.rs +++ b/influxdb_iox/src/commands/run/querier.rs @@ -3,9 +3,10 @@ use crate::process_info::setup_metric_registry; use super::main; +use authz::Authorizer; use clap_blocks::{ - catalog_dsn::CatalogDsnConfig, object_store::make_object_store, querier::QuerierConfig, - run_config::RunConfig, + authz::AuthzConfig, catalog_dsn::CatalogDsnConfig, object_store::make_object_store, + querier::QuerierConfig, run_config::RunConfig, }; use iox_query::exec::Executor; use iox_time::{SystemProvider, TimeProvider}; @@ -42,6 +43,12 @@ pub enum Error { #[error("Querier error: {0}")] Querier(#[from] ioxd_querier::Error), + + #[error("Authz configuration error: {0}")] + AuthzConfig(#[from] clap_blocks::authz::Error), + + #[error("Authz service error: {0}")] + AuthzService(#[from] authz::Error), } #[derive(Debug, clap::Parser)] @@ -60,6 +67,10 @@ Configuration is loaded from the following sources (highest precedence first): - pre-configured default values" )] pub struct Config { + /// Authorizer options. + #[clap(flatten)] + pub(crate) authz_config: AuthzConfig, + #[clap(flatten)] pub(crate) run_config: RunConfig, @@ -92,6 +103,10 @@ pub async fn command(config: Config) -> Result<(), Error> { let time_provider = Arc::new(SystemProvider::new()); + let authz = config.authz_config.authorizer()?; + // Verify the connection to the authorizer, if configured. + authz.probe().await?; + let num_query_threads = config.querier_config.num_query_threads(); let num_threads = num_query_threads.unwrap_or_else(|| { NonZeroUsize::new(num_cpus::get()).unwrap_or_else(|| NonZeroUsize::new(1).unwrap()) @@ -123,6 +138,7 @@ pub async fn command(config: Config) -> Result<(), Error> { ingester_addresses, querier_config: config.querier_config, rpc_write, + authz: authz.as_ref().map(Arc::clone), }) .await?; diff --git a/influxdb_iox/src/commands/run/router2.rs b/influxdb_iox/src/commands/run/router2.rs index f4d13a7e48..38fcc83cc6 100644 --- a/influxdb_iox/src/commands/run/router2.rs +++ b/influxdb_iox/src/commands/run/router2.rs @@ -1,6 +1,7 @@ //! Command line options for running a router2 that uses the RPC write path. use super::main; use crate::process_info::setup_metric_registry; +use authz::Authorizer; use clap_blocks::{ authz::AuthzConfig, catalog_dsn::CatalogDsnConfig, object_store::make_object_store, router2::Router2Config, run_config::RunConfig, @@ -105,10 +106,8 @@ pub async fn command(config: Config) -> Result<()> { &metrics, )); let authz = config.authz_config.authorizer()?; - if let Some(authz) = &authz { - // Verify the connection to the authorizer, if configured. - authz.probe().await?; - } + // Verify the connection to the authorizer, if configured. + authz.probe().await?; let server_type = create_router2_server_type( &common_state, diff --git a/ioxd_querier/Cargo.toml b/ioxd_querier/Cargo.toml index 802e7ff47c..e5fb6ff788 100644 --- a/ioxd_querier/Cargo.toml +++ b/ioxd_querier/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dependencies] # Workspace dependencies, in alphabetical order +authz = { path = "../authz" } clap_blocks = { path = "../clap_blocks" } data_types = { path = "../data_types" } datafusion_util = { path = "../datafusion_util"} diff --git a/ioxd_querier/src/lib.rs b/ioxd_querier/src/lib.rs index d7ceb8d792..1d27810743 100644 --- a/ioxd_querier/src/lib.rs +++ b/ioxd_querier/src/lib.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use authz::Authorizer; use clap_blocks::querier::{IngesterAddresses, QuerierConfig}; use datafusion_util::config::register_iox_object_store; use hyper::{Body, Request, Response}; @@ -34,6 +35,7 @@ pub struct QuerierServerType { database: Arc, server: QuerierServer, trace_collector: Option>, + authz: Option>, } impl std::fmt::Debug for QuerierServerType { @@ -47,11 +49,13 @@ impl QuerierServerType { server: QuerierServer, database: Arc, common_state: &CommonServerState, + authz: Option>, ) -> Self { Self { server, database, trace_collector: common_state.trace_collector(), + authz, } } } @@ -81,7 +85,10 @@ impl ServerType for QuerierServer let builder = setup_builder!(builder_input, self); add_service!( builder, - rpc::query::make_flight_server(Arc::clone(&self.database)) + rpc::query::make_flight_server( + Arc::clone(&self.database), + self.authz.as_ref().map(Arc::clone) + ) ); add_service!( builder, @@ -154,6 +161,7 @@ pub struct QuerierServerTypeArgs<'a> { pub ingester_addresses: IngesterAddresses, pub querier_config: QuerierConfig, pub rpc_write: bool, + pub authz: Option>, } #[derive(Debug, Error)] @@ -244,5 +252,6 @@ pub async fn create_querier_server_type( querier, database, args.common_state, + args.authz.as_ref().map(Arc::clone), ))) } diff --git a/ioxd_querier/src/rpc/query.rs b/ioxd_querier/src/rpc/query.rs index 9a55b82af3..5ac56fa493 100644 --- a/ioxd_querier/src/rpc/query.rs +++ b/ioxd_querier/src/rpc/query.rs @@ -1,3 +1,4 @@ +use authz::Authorizer; use std::sync::Arc; use arrow_flight::flight_service_server::{ @@ -6,8 +7,11 @@ use arrow_flight::flight_service_server::{ use generated_types::storage_server::{Storage, StorageServer}; use querier::QuerierDatabase; -pub fn make_flight_server(server: Arc) -> FlightServer { - service_grpc_flight::make_server(server) +pub fn make_flight_server( + server: Arc, + authz: Option>, +) -> FlightServer { + service_grpc_flight::make_server(server, authz) } pub fn make_storage_server(server: Arc) -> StorageServer { diff --git a/router/src/server/http.rs b/router/src/server/http.rs index 9ffdfe9b83..4b25f3fa06 100644 --- a/router/src/server/http.rs +++ b/router/src/server/http.rs @@ -145,6 +145,7 @@ impl From for Error { fn from(value: authz::Error) -> Self { match value { authz::Error::Forbidden => Self::Forbidden, + authz::Error::NoToken => Self::Unauthenticated, e => Self::Authorizer(e), } } @@ -417,36 +418,26 @@ where let namespace = org_and_bucket_to_namespace(&write_info.org, &write_info.bucket) .map_err(OrgBucketError::MappingFail)?; - if let Some(authz) = &self.authz { - let token = req - .extensions() - .get::() - .and_then(|v| v.as_ref()) - .and_then(|v| { - let s = v.as_ref(); - if s.len() < b"Token ".len() { - None - } else { - match s.split_at(b"Token ".len()) { - (b"Token ", token) => Some(token), - _ => None, - } + let token = req + .extensions() + .get::() + .and_then(|v| v.as_ref()) + .and_then(|v| { + let s = v.as_ref(); + if s.len() < b"Token ".len() { + None + } else { + match s.split_at(b"Token ".len()) { + (b"Token ", token) => Some(token), + _ => None, } - }) - .ok_or(Error::Unauthenticated)?; - - let perms = [Permission::ResourceAction( - Resource::Namespace(namespace.to_string()), - Action::Write, - )]; - authz - .require_any_permission(token, &perms) - .await - .map_err(|e| match e { - authz::Error::Forbidden => Error::Forbidden, - e => e.into(), - })?; - } + } + }); + let perms = [Permission::ResourceAction( + Resource::Namespace(namespace.to_string()), + Action::Write, + )]; + self.authz.require_any_permission(token, &perms).await?; trace!( org=%write_info.org, @@ -1409,13 +1400,14 @@ mod tests { impl Authorizer for MockAuthorizer { async fn permissions( &self, - token: &[u8], + token: Option<&[u8]>, perms: &[Permission], ) -> Result, authz::Error> { match token { - b"GOOD" => Ok(perms.to_vec()), - b"UGLY" => Err(authz::Error::verification("test", "test error")), - _ => Ok(vec![]), + Some(b"GOOD") => Ok(perms.to_vec()), + Some(b"UGLY") => Err(authz::Error::verification("test", "test error")), + Some(_) => Ok(vec![]), + None => Err(authz::Error::NoToken), } } } diff --git a/service_grpc_flight/Cargo.toml b/service_grpc_flight/Cargo.toml index b030549151..e9428cb84c 100644 --- a/service_grpc_flight/Cargo.toml +++ b/service_grpc_flight/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dependencies] # Workspace dependencies, in alphabetical order +authz = { path = "../authz" } arrow_util = { path = "../arrow_util" } data_types = { path = "../data_types" } datafusion = { workspace = true } @@ -22,6 +23,7 @@ tracker = { path = "../tracker" } # Crates.io dependencies, in alphabetical order arrow = { workspace = true, features = ["prettyprint"] } arrow-flight = { workspace = true, features=["flight-sql-experimental"] } +async-trait = "0.1" bytes = "1.4" futures = "0.3" prost = "0.11" diff --git a/service_grpc_flight/src/lib.rs b/service_grpc_flight/src/lib.rs index f69e814310..819bbff5d5 100644 --- a/service_grpc_flight/src/lib.rs +++ b/service_grpc_flight/src/lib.rs @@ -15,6 +15,7 @@ use arrow_flight::{ HandshakeRequest, HandshakeResponse, PutResult, SchemaAsIpc, SchemaResult, Ticket, }; use arrow_util::flight::prepare_schema_for_flight; +use authz::Authorizer; use bytes::Bytes; use data_types::NamespaceNameError; use datafusion::{error::DataFusionError, physical_plan::ExecutionPlan}; @@ -95,6 +96,15 @@ pub enum Error { #[snafu(display("Unsupported message type: {}", description))] UnsupportedMessageType { description: String }, + + #[snafu(display("Unauthenticated"))] + Unauthenticated, + + #[snafu(display("Permission denied"))] + PermissionDenied, + + #[snafu(display("Authz error: {}", source))] + Authz { source: authz::Error }, } pub type Result = std::result::Result; @@ -109,17 +119,20 @@ impl From for tonic::Status { Error::NamespaceNotFound { .. } | Error::InvalidTicket { .. } | Error::InvalidHandshake { .. } + | Error::Unauthenticated { .. } + | Error::PermissionDenied { .. } // TODO(edd): this should be `debug`. Keeping at info while IOx in early development | Error::InvalidNamespaceName { .. } => info!(e=%err, msg), Error::Query { .. } => info!(e=%err, msg), Error::Optimize { .. } - |Error::NoFlightSQLNamespace - |Error::InvalidNamespaceHeader { .. } + | Error::NoFlightSQLNamespace + | Error::InvalidNamespaceHeader { .. } | Error::Planning { .. } | Error::Deserialization { .. } | Error::InternalCreatingTicket { .. } - | Error::UnsupportedMessageType { .. } - | Error::FlightSQL { .. } + | Error::UnsupportedMessageType { .. } + | Error::FlightSQL { .. } + | Error::Authz { .. } => { warn!(e=%err, msg) } @@ -146,7 +159,7 @@ impl Error { datafusion_error_to_tonic_code(&source) } Self::UnsupportedMessageType { .. } => tonic::Code::Unimplemented, - Error::FlightSQL { source } => match source { + Self::FlightSQL { source } => match source { flightsql::Error::InvalidHandle { .. } | flightsql::Error::Decode { .. } | flightsql::Error::Protocol { .. } @@ -159,7 +172,11 @@ impl Error { } flightsql::Error::DataFusion { source } => datafusion_error_to_tonic_code(&source), }, - Self::InternalCreatingTicket { .. } | Self::Optimize { .. } => tonic::Code::Internal, + Self::InternalCreatingTicket { .. } | Self::Optimize { .. } | Self::Authz { .. } => { + tonic::Code::Internal + } + Self::Unauthenticated => tonic::Code::Unauthenticated, + Self::PermissionDenied => tonic::Code::PermissionDenied, }; tonic::Status::new(code, msg) @@ -178,6 +195,16 @@ impl From for Error { } } +impl From for Error { + fn from(source: authz::Error) -> Self { + match source { + authz::Error::Forbidden => Self::PermissionDenied, + authz::Error::NoToken => Self::Unauthenticated, + source => Self::Authz { source }, + } + } +} + type TonicStream = Pin> + Send + 'static>>; /// Concrete implementation of the IOx client protocol, implemented as @@ -331,13 +358,17 @@ where S: QueryNamespaceProvider, { server: Arc, + authz: Option>, } -pub fn make_server(server: Arc) -> FlightServer +pub fn make_server( + server: Arc, + authz: Option>, +) -> FlightServer where S: QueryNamespaceProvider, { - FlightServer::new(FlightService { server }) + FlightServer::new(FlightService { server, authz }) } impl FlightService @@ -424,6 +455,7 @@ where let external_span_ctx: Option = request.extensions().get().cloned(); let trace = external_span_ctx.format_jaeger(); let span_ctx: Option = request.extensions().get().cloned(); + let authz_token = get_flight_authz(request.metadata()); let ticket = request.into_inner(); // attempt to decode ticket @@ -437,6 +469,18 @@ where let namespace_name = request.namespace_name(); let query = request.query(); + let perms = match query { + RunQuery::FlightSQL(cmd) => flightsql_permissions(namespace_name, cmd), + RunQuery::Sql(_) | RunQuery::InfluxQL(_) => vec![authz::Permission::ResourceAction( + authz::Resource::Namespace(namespace_name.to_string()), + authz::Action::Read, + )], + }; + self.authz + .require_any_permission(authz_token.as_deref(), &perms) + .await + .map_err(Error::from)?; + let permit = self .server .acquire_semaphore(span_ctx.child_span("query rate limit semaphore")) @@ -506,13 +550,19 @@ where let trace = external_span_ctx.format_jaeger(); let namespace_name = get_flightsql_namespace(request.metadata())?; + let authz_token = get_flight_authz(request.metadata()); let flight_descriptor = request.into_inner(); // extract the FlightSQL message let cmd = cmd_from_descriptor(flight_descriptor.clone())?; - info!(%namespace_name, %cmd, %trace, "GetFlightInfo request"); + let perms = flightsql_permissions(&namespace_name, &cmd); + self.authz + .require_any_permission(authz_token.as_deref(), &perms) + .await + .map_err(Error::from)?; + let db = self .server .db(&namespace_name, span_ctx.child_span("get namespace")) @@ -585,6 +635,7 @@ where let trace = external_span_ctx.format_jaeger(); let namespace_name = get_flightsql_namespace(request.metadata())?; + let authz_token = get_flight_authz(request.metadata()); let Action { r#type: action_type, body, @@ -595,6 +646,12 @@ where info!(%namespace_name, %action_type, %cmd, %trace, "DoAction request"); + let perms = flightsql_permissions(&namespace_name, &cmd); + self.authz + .require_any_permission(authz_token.as_deref(), &perms) + .await + .map_err(Error::from)?; + let db = self .server .db(&namespace_name, span_ctx.child_span("get namespace")) @@ -663,6 +720,34 @@ fn get_flightsql_namespace(metadata: &MetadataMap) -> Result { NoFlightSQLNamespaceSnafu.fail() } +/// Retrieve the authorization token associated with the request. +fn get_flight_authz(metadata: &MetadataMap) -> Option> { + let val = metadata.get("authorization")?.as_ref(); + if val.len() < b"Bearer ".len() { + return None; + } + match val.split_at(b"Bearer ".len()) { + (b"Bearer ", token) => Some(token.to_vec()), + _ => None, + } +} + +fn flightsql_permissions(namespace_name: &str, cmd: &FlightSQLCommand) -> Vec { + let resource = authz::Resource::Namespace(namespace_name.to_string()); + let action = match cmd { + FlightSQLCommand::CommandStatementQuery(_) => authz::Action::Read, + FlightSQLCommand::CommandPreparedStatementQuery(_) => authz::Action::Read, + FlightSQLCommand::CommandGetSqlInfo(_) => authz::Action::ReadSchema, + FlightSQLCommand::CommandGetCatalogs(_) => authz::Action::ReadSchema, + FlightSQLCommand::CommandGetDbSchemas(_) => authz::Action::ReadSchema, + FlightSQLCommand::CommandGetTables(_) => authz::Action::ReadSchema, + FlightSQLCommand::CommandGetTableTypes(_) => authz::Action::ReadSchema, + FlightSQLCommand::ActionCreatePreparedStatementRequest(_) => authz::Action::Read, + FlightSQLCommand::ActionClosePreparedStatementRequest(_) => authz::Action::Read, + }; + vec![authz::Permission::ResourceAction(resource, action)] +} + /// Wrapper over a FlightDataEncodeStream that adds IOx specfic /// metadata and records completion struct GetStream { @@ -831,10 +916,14 @@ impl Stream for GetStream { #[cfg(test)] mod tests { + use arrow_flight::sql::ProstMessageExt; + use async_trait::async_trait; + use authz::Permission; use futures::Future; use metric::{Attributes, Metric, U64Gauge}; use service_common::test_util::TestDatabaseStore; use tokio::pin; + use tonic::metadata::{MetadataKey, MetadataValue}; use super::*; @@ -864,6 +953,7 @@ mod tests { let service = FlightService { server: Arc::clone(&test_storage), + authz: Option::>::None, }; let ticket = Ticket { ticket: br#"{"namespace_name": "my_db", "sql_query": "SELECT 1;"}"# @@ -995,4 +1085,171 @@ mod tests { .fetch(); assert_eq!(actual, expected); } + + #[derive(Debug)] + struct MockAuthorizer {} + + #[async_trait] + impl Authorizer for MockAuthorizer { + async fn permissions( + &self, + token: Option<&[u8]>, + perms: &[Permission], + ) -> Result, authz::Error> { + match token { + Some(b"GOOD") => Ok(perms.to_vec()), + Some(b"BAD") => Ok(vec![]), + Some(b"UGLY") => Err(authz::Error::verification("test", "test error")), + Some(_) => panic!("unexpected token"), + None => Err(authz::Error::NoToken), + } + } + } + + #[tokio::test] + async fn do_get_authz() { + let test_storage = Arc::new(TestDatabaseStore::default()); + test_storage.clone().db_or_create("bananas").await; + + let svc = FlightService { + server: Arc::clone(&test_storage), + authz: Some(Arc::new(MockAuthorizer {})), + }; + + async fn assert_code( + svc: &FlightService, + want: tonic::Code, + request: tonic::Request, + ) { + let got = match svc.do_get(request).await { + Ok(_) => tonic::Code::Ok, + Err(e) => e.code(), + }; + assert_eq!(want, got); + } + + fn request( + query: RunQuery, + authorization: &'static str, + ) -> tonic::Request { + let mut req = tonic::Request::new( + IoxGetRequest::new("bananas".to_string(), query) + .try_encode() + .unwrap(), + ); + if !authorization.is_empty() { + req.metadata_mut().insert( + MetadataKey::from_static("authorization"), + MetadataValue::from_static(authorization), + ); + } + req + } + + fn sql_request(authorization: &'static str) -> tonic::Request { + request(RunQuery::Sql("SELECT 1".to_string()), authorization) + } + + fn influxql_request(authorization: &'static str) -> tonic::Request { + request( + RunQuery::InfluxQL("SHOW DATABASES".to_string()), + authorization, + ) + } + + fn flightsql_request(authorization: &'static str) -> tonic::Request { + request( + RunQuery::FlightSQL(FlightSQLCommand::CommandGetCatalogs( + arrow_flight::sql::CommandGetCatalogs {}, + )), + authorization, + ) + } + + assert_code(&svc, tonic::Code::Unauthenticated, sql_request("")).await; + assert_code(&svc, tonic::Code::Ok, sql_request("Bearer GOOD")).await; + assert_code( + &svc, + tonic::Code::PermissionDenied, + sql_request("Bearer BAD"), + ) + .await; + assert_code(&svc, tonic::Code::Internal, sql_request("Bearer UGLY")).await; + + assert_code(&svc, tonic::Code::Unauthenticated, influxql_request("")).await; + + assert_code( + &svc, + tonic::Code::InvalidArgument, // SHOW DATABASE has not been implemented yet. + influxql_request("Bearer GOOD"), + ) + .await; + assert_code( + &svc, + tonic::Code::PermissionDenied, + influxql_request("Bearer BAD"), + ) + .await; + assert_code(&svc, tonic::Code::Internal, influxql_request("Bearer UGLY")).await; + + assert_code(&svc, tonic::Code::Unauthenticated, flightsql_request("")).await; + assert_code(&svc, tonic::Code::Ok, flightsql_request("Bearer GOOD")).await; + assert_code( + &svc, + tonic::Code::PermissionDenied, + flightsql_request("Bearer BAD"), + ) + .await; + assert_code( + &svc, + tonic::Code::Internal, + flightsql_request("Bearer UGLY"), + ) + .await; + } + + #[tokio::test] + async fn get_flight_info_authz() { + let test_storage = Arc::new(TestDatabaseStore::default()); + test_storage.clone().db_or_create("bananas").await; + + let svc = FlightService { + server: Arc::clone(&test_storage), + authz: Some(Arc::new(MockAuthorizer {})), + }; + + async fn assert_code( + svc: &FlightService, + want: tonic::Code, + request: tonic::Request, + ) { + let got = match svc.get_flight_info(request).await { + Ok(_) => tonic::Code::Ok, + Err(e) => e.code(), + }; + assert_eq!(want, got); + } + + fn request(authorization: &'static str) -> tonic::Request { + let cmd = arrow_flight::sql::CommandGetCatalogs {}; + let mut req = + tonic::Request::new(FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec())); + req.metadata_mut().insert( + MetadataKey::from_static("iox-namespace-name"), + MetadataValue::from_static("bananas"), + ); + if !authorization.is_empty() { + req.metadata_mut().insert( + MetadataKey::from_static("authorization"), + MetadataValue::from_static(authorization), + ); + } + req + } + + assert_code(&svc, tonic::Code::Unauthenticated, request("")).await; + assert_code(&svc, tonic::Code::Ok, request("Bearer GOOD")).await; + assert_code(&svc, tonic::Code::PermissionDenied, request("Bearer BAD")).await; + assert_code(&svc, tonic::Code::Internal, request("Bearer UGLY")).await; + } }