feat: implement `Authorizer` to authorize all HTTP requests (#24738)

* feat: add `Authorizer` impls to authz REST and gRPC

This adds two new Authorizer implementations to Edge: Default and
AllOrNothing, which will provide the two auth options for Edge.

Both gRPC requests and HTTP REST request will be authorized by
the same Authorizer implementation.

The SHA512 digest action was moved into the `Authorizer` impl.

* feat: add `ServerBuilder` to construct `Server

A builder was added to the Server in this commit, as part of an
attempt to get the server creation to be more modular.

* refactor: use test server fixture in auth e2e test

Refactored the `auth` integration test in `influxdb3` to use the
`TestServer` fixture; part of this involved extending the fixture
to be configurable, so that the `TestServer` can be spun up with
an auth token.

* test: add test for authorized gRPC

A new end-to-end test, auth_grpc, was added to check that
authorization is working with the influxdb3 Flight service.
pull/24747/head
Trevor Hilton 2024-03-08 14:18:17 -05:00 committed by GitHub
parent fad681c06c
commit c4d651fbd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 465 additions and 204 deletions

1
Cargo.lock generated
View File

@ -2469,6 +2469,7 @@ dependencies = [
"arrow-flight",
"arrow_util",
"assert_cmd",
"authz",
"backtrace",
"base64 0.22.0",
"clap",

View File

@ -7,6 +7,7 @@ license.workspace = true
[dependencies]
# Core Crates
authz.workspace = true
clap_blocks.workspace = true
iox_query.workspace = true
iox_time.workspace = true

View File

@ -7,7 +7,10 @@ use clap_blocks::{
object_store::{make_object_store, ObjectStoreConfig},
socket_addr::SocketAddr,
};
use influxdb3_server::{query_executor::QueryExecutorImpl, serve, CommonServerState, Server};
use influxdb3_server::{
auth::AllOrNothingAuthorizer, builder::ServerBuilder, query_executor::QueryExecutorImpl, serve,
CommonServerState,
};
use influxdb3_write::persister::PersisterImpl;
use influxdb3_write::wal::WalImpl;
use influxdb3_write::write_buffer::WriteBufferImpl;
@ -51,6 +54,9 @@ pub enum Error {
#[error("Write buffer error: {0}")]
WriteBuffer(#[from] influxdb3_write::write_buffer::Error),
#[error("invalid token: {0}")]
InvalidToken(#[from] hex::FromHexError),
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
@ -238,32 +244,38 @@ pub async fn command(config: Config) -> Result<()> {
trace_exporter,
trace_header_parser,
*config.http_bind_address,
config.bearer_token,
)?;
let persister = PersisterImpl::new(Arc::clone(&object_store));
let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store)));
let wal: Option<Arc<WalImpl>> = config
.wal_directory
.map(|dir| WalImpl::new(dir).map(Arc::new))
.transpose()?;
// TODO: the next segment ID should be loaded from the persister
let write_buffer = Arc::new(WriteBufferImpl::new(Arc::new(persister), wal).await?);
let query_executor = QueryExecutorImpl::new(
let write_buffer = Arc::new(WriteBufferImpl::new(Arc::clone(&persister), wal).await?);
let query_executor = Arc::new(QueryExecutorImpl::new(
write_buffer.catalog(),
Arc::clone(&write_buffer),
Arc::clone(&exec),
Arc::clone(&metrics),
Arc::new(config.datafusion_config),
10,
);
));
let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store)));
let server = Server::new(
common_state,
persister,
Arc::clone(&write_buffer),
Arc::new(query_executor),
config.max_http_request_size,
);
let builder = ServerBuilder::new(common_state)
.max_request_size(config.max_http_request_size)
.write_buffer(write_buffer)
.query_executor(query_executor)
.persister(persister);
let server = if let Some(token) = config.bearer_token.map(hex::decode).transpose()? {
builder
.authorizer(Arc::new(AllOrNothingAuthorizer::new(token)))
.build()
} else {
builder.build()
};
serve(server, frontend_shutdown).await?;
Ok(())

View File

@ -1,78 +1,31 @@
use parking_lot::Mutex;
use arrow_flight::error::FlightError;
use arrow_util::assert_batches_sorted_eq;
use influxdb3_client::Precision;
use reqwest::StatusCode;
use std::env;
use std::mem;
use std::panic;
use std::process::Child;
use std::process::Command;
use std::process::Stdio;
struct DropCommand {
cmd: Option<Child>,
}
impl DropCommand {
const fn new(cmd: Child) -> Self {
Self { cmd: Some(cmd) }
}
fn kill(&mut self) {
let mut cmd = self.cmd.take().unwrap();
cmd.kill().unwrap();
mem::drop(cmd);
}
}
static COMMAND: Mutex<Option<DropCommand>> = parking_lot::const_mutex(None);
use crate::{collect_stream, TestServer};
#[tokio::test]
async fn auth() {
const HASHED_TOKEN: &str = "5315f0c4714537843face80cca8c18e27ce88e31e9be7a5232dc4dc8444f27c0227a9bd64831d3ab58f652bd0262dd8558dd08870ac9e5c650972ce9e4259439";
const TOKEN: &str = "apiv3_mp75KQAhbqv0GeQXk8MPuZ3ztaLEaR5JzS8iifk1FwuroSVyXXyrJK1c4gEr1kHkmbgzDV-j3MvQpaIMVJBAiA";
// The binary is made before testing so we have access to it
let bin_path = {
let mut bin_path = env::current_exe().unwrap();
bin_path.pop();
bin_path.pop();
bin_path.join("influxdb3")
};
let server = DropCommand::new(
Command::new(bin_path)
.args([
"serve",
"--object-store",
"memory",
"--bearer-token",
HASHED_TOKEN,
])
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("Was able to spawn a server"),
);
*COMMAND.lock() = Some(server);
let current_hook = panic::take_hook();
panic::set_hook(Box::new(move |info| {
COMMAND.lock().take().unwrap().kill();
current_hook(info);
}));
let server = TestServer::configure()
.auth_token(HASHED_TOKEN, TOKEN)
.spawn()
.await;
let client = reqwest::Client::new();
// Wait for the server to come up
while client
.get("http://127.0.0.1:8181/health")
.bearer_auth(TOKEN)
.send()
.await
.is_err()
{}
let base = server.client_addr();
let write_lp_url = format!("{base}/api/v3/write_lp");
let write_lp_params = [("db", "foo")];
let query_sql_url = format!("{base}/api/v3/query_sql");
let query_sql_params = [("db", "foo"), ("q", "select * from cpu")];
assert_eq!(
client
.post("http://127.0.0.1:8181/api/v3/write_lp?db=foo")
.post(&write_lp_url)
.query(&write_lp_params)
.body("cpu,host=a val=1i 123")
.send()
.await
@ -82,7 +35,8 @@ async fn auth() {
);
assert_eq!(
client
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
.get(&query_sql_url)
.query(&query_sql_params)
.send()
.await
.unwrap()
@ -91,7 +45,8 @@ async fn auth() {
);
assert_eq!(
client
.post("http://127.0.0.1:8181/api/v3/write_lp?db=foo")
.post(&write_lp_url)
.query(&write_lp_params)
.body("cpu,host=a val=1i 123")
.bearer_auth(TOKEN)
.send()
@ -102,7 +57,8 @@ async fn auth() {
);
assert_eq!(
client
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
.get(&query_sql_url)
.query(&query_sql_params)
.bearer_auth(TOKEN)
.send()
.await
@ -114,7 +70,8 @@ async fn auth() {
// Test that there is an extra string after the token foo
assert_eq!(
client
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
.get(&query_sql_url)
.query(&query_sql_params)
.header("Authorization", format!("Bearer {TOKEN} whee"))
.send()
.await
@ -124,7 +81,8 @@ async fn auth() {
);
assert_eq!(
client
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
.get(&query_sql_url)
.query(&query_sql_params)
.header("Authorization", format!("bearer {TOKEN}"))
.send()
.await
@ -134,7 +92,8 @@ async fn auth() {
);
assert_eq!(
client
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
.get(&query_sql_url)
.query(&query_sql_params)
.header("Authorization", "Bearer")
.send()
.await
@ -144,7 +103,8 @@ async fn auth() {
);
assert_eq!(
client
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
.get(&query_sql_url)
.query(&query_sql_params)
.header("auth", format!("Bearer {TOKEN}"))
.send()
.await
@ -152,5 +112,97 @@ async fn auth() {
.status(),
StatusCode::UNAUTHORIZED
);
COMMAND.lock().take().unwrap().kill();
}
#[tokio::test]
async fn auth_grpc() {
const HASHED_TOKEN: &str = "5315f0c4714537843face80cca8c18e27ce88e31e9be7a5232dc4dc8444f27c0227a9bd64831d3ab58f652bd0262dd8558dd08870ac9e5c650972ce9e4259439";
const TOKEN: &str = "apiv3_mp75KQAhbqv0GeQXk8MPuZ3ztaLEaR5JzS8iifk1FwuroSVyXXyrJK1c4gEr1kHkmbgzDV-j3MvQpaIMVJBAiA";
let server = TestServer::configure()
.auth_token(HASHED_TOKEN, TOKEN)
.spawn()
.await;
// Write some data to the server, this will be authorized through the HTTP API
server
.write_lp_to_db(
"foo",
"cpu,host=s1,region=us-east usage=0.9 1\n\
cpu,host=s1,region=us-east usage=0.89 2\n\
cpu,host=s1,region=us-east usage=0.85 3",
Precision::Nanosecond,
)
.await
.unwrap();
// Check that with a valid authorization header, it succeeds:
for header in ["authorization", "Authorization"] {
// Spin up a FlightSQL client
let mut client = server.flight_sql_client("foo").await;
// Set the authorization header on the client:
client
.add_header(header, &format!("Bearer {TOKEN}"))
.unwrap();
// Make the query again, this time it should work:
let response = client.query("SELECT * FROM cpu").await.unwrap();
let batches = collect_stream(response).await;
assert_batches_sorted_eq!(
[
"+------+---------+--------------------------------+-------+",
"| host | region | time | usage |",
"+------+---------+--------------------------------+-------+",
"| s1 | us-east | 1970-01-01T00:00:00.000000001Z | 0.9 |",
"| s1 | us-east | 1970-01-01T00:00:00.000000002Z | 0.89 |",
"| s1 | us-east | 1970-01-01T00:00:00.000000003Z | 0.85 |",
"+------+---------+--------------------------------+-------+",
],
&batches
);
}
// Check that without providing an Authentication header, it gives back
// an Unauthenticated error:
{
let mut client = server.flight_sql_client("foo").await;
let error = client.query("SELECT * FROM cpu").await.unwrap_err();
assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated));
}
// Create some new clients that set the authorization header incorrectly to
// ensure errors are returned:
// Mispelled "Bearer"
{
let mut client = server.flight_sql_client("foo").await;
client
.add_header("authorization", &format!("bearer {TOKEN}"))
.unwrap();
let error = client.query("SELECT * FROM cpu").await.unwrap_err();
assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated));
}
// Invalid token, this actually gives Permission denied
{
let mut client = server.flight_sql_client("foo").await;
client
.add_header("authorization", "Bearer invalid-token")
.unwrap();
let error = client.query("SELECT * FROM cpu").await.unwrap_err();
assert!(
matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::PermissionDenied)
);
}
// Mispelled header key
{
let mut client = server.flight_sql_client("foo").await;
client
.add_header("auth", &format!("Bearer {TOKEN}"))
.unwrap();
let error = client.query("SELECT * FROM cpu").await.unwrap_err();
assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated));
}
}

View File

@ -17,8 +17,43 @@ mod flight;
mod limits;
mod query;
/// Configuration for a [`TestServer`]
#[derive(Debug, Default)]
pub struct TestConfig {
auth_token: Option<(String, String)>,
}
impl TestConfig {
/// Set the auth token for this [`TestServer`]
pub fn auth_token<S: Into<String>, R: Into<String>>(
mut self,
hashed_token: S,
raw_token: R,
) -> Self {
self.auth_token = Some((hashed_token.into(), raw_token.into()));
self
}
/// Spawn a new [`TestServer`] with this configuration
///
/// This will run the `influxdb3 serve` command, and bind its HTTP
/// address to a random port on localhost.
pub async fn spawn(self) -> TestServer {
TestServer::spawn_inner(self).await
}
fn as_args(&self) -> Vec<&str> {
let mut args = vec![];
if let Some((token, _)) = &self.auth_token {
args.append(&mut vec!["--bearer-token", token]);
}
args
}
}
/// A running instance of the `influxdb3 serve` process
pub struct TestServer {
config: TestConfig,
bind_addr: SocketAddr,
server_process: Child,
http_client: reqwest::Client,
@ -30,19 +65,29 @@ impl TestServer {
/// This will run the `influxdb3 serve` command, and bind its HTTP
/// address to a random port on localhost.
pub async fn spawn() -> Self {
Self::spawn_inner(Default::default()).await
}
/// Configure a [`TestServer`] before spawning
pub fn configure() -> TestConfig {
TestConfig::default()
}
async fn spawn_inner(config: TestConfig) -> Self {
let bind_addr = get_local_bind_addr();
let mut command = Command::cargo_bin("influxdb3").expect("create the influxdb3 command");
let command = command
.arg("serve")
.args(["--http-bind", &bind_addr.to_string()])
.args(["--object-store", "memory"])
// TODO - other configuration can be passed through
.args(config.as_args())
.stdout(Stdio::null())
.stderr(Stdio::null());
let server_process = command.spawn().expect("spawn the influxdb3 server process");
let server = Self {
config,
bind_addr,
server_process,
http_client: reqwest::Client::new(),
@ -111,7 +156,10 @@ impl TestServer {
lp: impl ToString,
precision: Precision,
) -> Result<(), influxdb3_client::Error> {
let client = influxdb3_client::Client::new(self.client_addr()).unwrap();
let mut client = influxdb3_client::Client::new(self.client_addr()).unwrap();
if let Some((_, token)) = &self.config.auth_token {
client = client.with_auth_token(token);
}
client
.api_v3_write_lp(database)
.body(lp.to_string())

View File

@ -0,0 +1,58 @@
use async_trait::async_trait;
use authz::{Authorizer, Error, Permission};
use observability_deps::tracing::{debug, warn};
use sha2::{Digest, Sha512};
/// An [`Authorizer`] implementation that will grant access to all
/// requests that provide `token`
#[derive(Debug)]
pub struct AllOrNothingAuthorizer {
token: Vec<u8>,
}
impl AllOrNothingAuthorizer {
pub fn new(token: Vec<u8>) -> Self {
Self { token }
}
}
#[async_trait]
impl Authorizer for AllOrNothingAuthorizer {
async fn permissions(
&self,
token: Option<Vec<u8>>,
perms: &[Permission],
) -> Result<Vec<Permission>, Error> {
debug!(?perms, "requesting permissions");
let provided = token.as_deref().ok_or(Error::NoToken)?;
if Sha512::digest(provided)[..] == self.token {
warn!("invalid token provided");
Ok(perms.to_vec())
} else {
Err(Error::InvalidToken)
}
}
async fn probe(&self) -> Result<(), Error> {
Ok(())
}
}
/// The defult [`Authorizer`] implementation that will authorize all requests
#[derive(Debug)]
pub struct DefaultAuthorizer;
#[async_trait]
impl Authorizer for DefaultAuthorizer {
async fn permissions(
&self,
_token: Option<Vec<u8>>,
perms: &[Permission],
) -> Result<Vec<Permission>, Error> {
Ok(perms.to_vec())
}
async fn probe(&self) -> Result<(), Error> {
Ok(())
}
}

View File

@ -0,0 +1,112 @@
use std::sync::Arc;
use authz::Authorizer;
use crate::{auth::DefaultAuthorizer, http::HttpApi, CommonServerState, Server};
#[derive(Debug)]
pub struct ServerBuilder<W, Q, P> {
common_state: CommonServerState,
max_request_size: usize,
write_buffer: W,
query_executor: Q,
persister: P,
authorizer: Arc<dyn Authorizer>,
}
impl ServerBuilder<NoWriteBuf, NoQueryExec, NoPersister> {
pub fn new(common_state: CommonServerState) -> Self {
Self {
common_state,
max_request_size: usize::MAX,
write_buffer: NoWriteBuf,
query_executor: NoQueryExec,
persister: NoPersister,
authorizer: Arc::new(DefaultAuthorizer),
}
}
}
impl<W, Q, P> ServerBuilder<W, Q, P> {
pub fn max_request_size(mut self, max_request_size: usize) -> Self {
self.max_request_size = max_request_size;
self
}
pub fn authorizer(mut self, a: Arc<dyn Authorizer>) -> Self {
self.authorizer = a;
self
}
}
#[derive(Debug)]
pub struct NoWriteBuf;
#[derive(Debug)]
pub struct WithWriteBuf<W>(Arc<W>);
#[derive(Debug)]
pub struct NoQueryExec;
#[derive(Debug)]
pub struct WithQueryExec<Q>(Arc<Q>);
#[derive(Debug)]
pub struct NoPersister;
#[derive(Debug)]
pub struct WithPersister<P>(Arc<P>);
impl<Q, P> ServerBuilder<NoWriteBuf, Q, P> {
pub fn write_buffer<W>(self, wb: Arc<W>) -> ServerBuilder<WithWriteBuf<W>, Q, P> {
ServerBuilder {
common_state: self.common_state,
max_request_size: self.max_request_size,
write_buffer: WithWriteBuf(wb),
query_executor: self.query_executor,
persister: self.persister,
authorizer: self.authorizer,
}
}
}
impl<W, P> ServerBuilder<W, NoQueryExec, P> {
pub fn query_executor<Q>(self, qe: Arc<Q>) -> ServerBuilder<W, WithQueryExec<Q>, P> {
ServerBuilder {
common_state: self.common_state,
max_request_size: self.max_request_size,
write_buffer: self.write_buffer,
query_executor: WithQueryExec(qe),
persister: self.persister,
authorizer: self.authorizer,
}
}
}
impl<W, Q> ServerBuilder<W, Q, NoPersister> {
pub fn persister<P>(self, p: Arc<P>) -> ServerBuilder<W, Q, WithPersister<P>> {
ServerBuilder {
common_state: self.common_state,
max_request_size: self.max_request_size,
write_buffer: self.write_buffer,
query_executor: self.query_executor,
persister: WithPersister(p),
authorizer: self.authorizer,
}
}
}
impl<W, Q, P> ServerBuilder<WithWriteBuf<W>, WithQueryExec<Q>, WithPersister<P>> {
pub fn build(self) -> Server<W, Q, P> {
let persister = Arc::clone(&self.persister.0);
let authorizer = Arc::clone(&self.authorizer);
let http = Arc::new(HttpApi::new(
self.common_state.clone(),
Arc::clone(&self.write_buffer.0),
Arc::clone(&self.query_executor.0),
self.max_request_size,
Arc::clone(&authorizer),
));
Server {
common_state: self.common_state,
http,
persister,
authorizer,
}
}
}

View File

@ -4,7 +4,7 @@ use crate::{query_executor, QueryKind};
use crate::{CommonServerState, QueryExecutor};
use arrow::record_batch::RecordBatch;
use arrow::util::pretty;
use authz::http::AuthorizationHeaderExtension;
use authz::Authorizer;
use bytes::{Bytes, BytesMut};
use data_types::NamespaceName;
use datafusion::error::DataFusionError;
@ -29,8 +29,6 @@ use observability_deps::tracing::{debug, error, info};
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde::Serialize;
use sha2::Digest;
use sha2::Sha512;
use std::convert::Infallible;
use std::fmt::Debug;
use std::num::NonZeroI32;
@ -195,6 +193,8 @@ pub enum AuthorizationError {
Unauthorized,
#[error("the request was not in the form of 'Authorization: Bearer <token>'")]
MalformedRequest,
#[error("requestor is forbidden from requested resource")]
Forbidden,
#[error("to str error: {0}")]
ToStr(#[from] hyper::header::ToStrError),
}
@ -278,6 +278,7 @@ pub(crate) struct HttpApi<W, Q> {
write_buffer: Arc<W>,
pub(crate) query_executor: Arc<Q>,
max_request_bytes: usize,
authorizer: Arc<dyn Authorizer>,
}
impl<W, Q> HttpApi<W, Q> {
@ -286,12 +287,14 @@ impl<W, Q> HttpApi<W, Q> {
write_buffer: Arc<W>,
query_executor: Arc<Q>,
max_request_bytes: usize,
authorizer: Arc<dyn Authorizer>,
) -> Self {
Self {
common_state,
write_buffer,
query_executor,
max_request_bytes,
authorizer,
}
}
}
@ -487,47 +490,57 @@ where
Ok(decoded_data.into())
}
fn authorize_request(&self, req: &mut Request<Body>) -> Result<(), AuthorizationError> {
async fn authorize_request(&self, req: &mut Request<Body>) -> Result<(), AuthorizationError> {
// We won't need the authorization header anymore and we don't want to accidentally log it.
// Take it out so we can use it and not log it later by accident.
let auth = req.headers_mut().remove(AUTHORIZATION);
let auth = req
.headers_mut()
.remove(AUTHORIZATION)
.map(validate_auth_header)
.transpose()?;
if let Some(bearer_token) = self.common_state.bearer_token() {
let Some(header) = &auth else {
return Err(AuthorizationError::Unauthorized);
};
// Currently we pass an empty permissions list, but in future we may be able to derive
// the permissions based on the incoming request
let permissions = self.authorizer.permissions(auth, &[]).await?;
// Split the header value into two parts
let mut header = header.to_str()?.split(' ');
// Extend the request with the permissions, which may be useful in future
req.extensions_mut().insert(permissions);
// Check that the header is the 'Bearer' auth scheme
let bearer = header.next().ok_or(AuthorizationError::MalformedRequest)?;
if bearer != "Bearer" {
return Err(AuthorizationError::MalformedRequest);
}
// Get the token that we want to hash to check the request is valid
let token = header.next().ok_or(AuthorizationError::MalformedRequest)?;
// There should only be two parts the 'Bearer' scheme and the actual
// token, error otherwise
if header.next().is_some() {
return Err(AuthorizationError::MalformedRequest);
}
// Check that the hashed token is acceptable
let authorized = &Sha512::digest(token)[..] == bearer_token;
if !authorized {
return Err(AuthorizationError::Unauthorized);
}
}
req.extensions_mut()
.insert(AuthorizationHeaderExtension::new(auth));
Ok(())
}
}
fn validate_auth_header(header: HeaderValue) -> Result<Vec<u8>, AuthorizationError> {
// Split the header value into two parts
let mut header = header.to_str()?.split(' ');
// Check that the header is the 'Bearer' auth scheme
let bearer = header.next().ok_or(AuthorizationError::MalformedRequest)?;
if bearer != "Bearer" {
return Err(AuthorizationError::MalformedRequest);
}
// Get the token that we want to hash to check the request is valid
let token = header.next().ok_or(AuthorizationError::MalformedRequest)?;
// There should only be two parts the 'Bearer' scheme and the actual
// token, error otherwise
if header.next().is_some() {
return Err(AuthorizationError::MalformedRequest);
}
Ok(token.as_bytes().to_vec())
}
impl From<authz::Error> for AuthorizationError {
fn from(auth_error: authz::Error) -> Self {
match auth_error {
authz::Error::Forbidden => Self::Forbidden,
_ => Self::Unauthorized,
}
}
}
/// A valid name:
/// - Starts with a letter or a number
/// - Is ASCII not UTF-8
@ -700,7 +713,7 @@ where
Q: QueryExecutor,
Error: From<<Q as QueryExecutor>::Error>,
{
if let Err(e) = http_server.authorize_request(&mut req) {
if let Err(e) = http_server.authorize_request(&mut req).await {
match e {
AuthorizationError::Unauthorized => {
return Ok(Response::builder()
@ -716,6 +729,12 @@ where
}"))
.unwrap());
}
AuthorizationError::Forbidden => {
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::empty())
.unwrap())
}
// We don't expect this to happen, but if the header is messed up
// better to handle it then not at all
AuthorizationError::ToStr(_) => {

View File

@ -11,6 +11,8 @@ clippy::clone_on_ref_ptr,
clippy::future_not_send
)]
pub mod auth;
pub mod builder;
mod grpc;
mod http;
pub mod query_executor;
@ -20,6 +22,7 @@ use crate::grpc::make_flight_server;
use crate::http::route_request;
use crate::http::HttpApi;
use async_trait::async_trait;
use authz::Authorizer;
use datafusion::execution::SendableRecordBatchStream;
use hyper::service::service_fn;
use influxdb3_write::{Persister, WriteBuffer};
@ -72,7 +75,6 @@ pub struct CommonServerState {
trace_exporter: Option<Arc<trace_exporters::export::AsyncExporter>>,
trace_header_parser: TraceHeaderParser,
http_addr: SocketAddr,
bearer_token: Option<Vec<u8>>,
}
impl CommonServerState {
@ -81,14 +83,12 @@ impl CommonServerState {
trace_exporter: Option<Arc<trace_exporters::export::AsyncExporter>>,
trace_header_parser: TraceHeaderParser,
http_addr: SocketAddr,
bearer_token: Option<String>,
) -> Result<Self> {
Ok(Self {
metrics,
trace_exporter,
trace_header_parser,
http_addr,
bearer_token: bearer_token.map(hex::decode).transpose()?,
})
}
@ -109,10 +109,6 @@ impl CommonServerState {
pub fn metric_registry(&self) -> Arc<metric::Registry> {
Arc::<metric::Registry>::clone(&self.metrics)
}
pub fn bearer_token(&self) -> Option<&[u8]> {
self.bearer_token.as_deref()
}
}
#[allow(dead_code)]
@ -121,6 +117,7 @@ pub struct Server<W, Q, P> {
common_state: CommonServerState,
http: Arc<HttpApi<W, Q>>,
persister: Arc<P>,
authorizer: Arc<dyn Authorizer>,
}
#[async_trait]
@ -151,30 +148,9 @@ pub enum QueryKind {
InfluxQl,
}
impl<W, Q, P> Server<W, Q, P>
where
Q: QueryExecutor,
P: Persister,
{
pub fn new(
common_state: CommonServerState,
persister: Arc<P>,
write_buffer: Arc<W>,
query_executor: Arc<Q>,
max_http_request_size: usize,
) -> Self {
let http = Arc::new(HttpApi::new(
common_state.clone(),
Arc::clone(&write_buffer),
Arc::clone(&query_executor),
max_http_request_size,
));
Self {
common_state,
http,
persister,
}
impl<W, Q, P> Server<W, Q, P> {
pub fn authorizer(&self) -> Arc<dyn Authorizer> {
Arc::clone(&self.authorizer)
}
}
@ -204,8 +180,7 @@ where
let grpc_service = trace_layer.clone().layer(make_flight_server(
Arc::clone(&server.http.query_executor),
// TODO - need to configure authz here:
None,
Some(server.authorizer()),
));
let rest_service = hyper::service::make_service_fn(|_| {
let http_server = Arc::clone(&server.http);
@ -249,6 +224,8 @@ pub async fn wait_for_signal() {
#[cfg(test)]
mod tests {
use crate::auth::DefaultAuthorizer;
use crate::builder::ServerBuilder;
use crate::serve;
use datafusion::parquet::data_type::AsBytes;
use hyper::{body, Body, Client, Request, Response, StatusCode};
@ -271,14 +248,9 @@ mod tests {
let addr = get_free_port();
let trace_header_parser = trace_http::ctx::TraceHeaderParser::new();
let metrics = Arc::new(metric::Registry::new());
let common_state = crate::CommonServerState::new(
Arc::clone(&metrics),
None,
trace_header_parser,
addr,
None,
)
.unwrap();
let common_state =
crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr)
.unwrap();
let object_store: Arc<DynObjectStore> = Arc::new(object_store::memory::InMemory::new());
let parquet_store =
ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3"));
@ -293,33 +265,31 @@ mod tests {
metric_registry: Arc::clone(&metrics),
mem_pool_size: usize::MAX,
}));
let persister = PersisterImpl::new(Arc::clone(&object_store));
let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store)));
let write_buffer = Arc::new(
influxdb3_write::write_buffer::WriteBufferImpl::new(
Arc::new(persister),
Arc::clone(&persister),
None::<Arc<influxdb3_write::wal::WalImpl>>,
)
.await
.unwrap(),
);
let query_executor = crate::query_executor::QueryExecutorImpl::new(
let query_executor = Arc::new(crate::query_executor::QueryExecutorImpl::new(
write_buffer.catalog(),
Arc::clone(&write_buffer),
Arc::clone(&exec),
Arc::clone(&metrics),
Arc::new(HashMap::new()),
10,
);
let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store)));
));
let server = crate::Server::new(
common_state,
persister,
Arc::clone(&write_buffer),
Arc::new(query_executor),
usize::MAX,
);
let server = ServerBuilder::new(common_state)
.write_buffer(Arc::clone(&write_buffer))
.query_executor(Arc::clone(&query_executor))
.persister(Arc::clone(&persister))
.authorizer(Arc::new(DefaultAuthorizer))
.build();
let frontend_shutdown = CancellationToken::new();
let shutdown = frontend_shutdown.clone();
@ -410,14 +380,9 @@ mod tests {
let addr = get_free_port();
let trace_header_parser = trace_http::ctx::TraceHeaderParser::new();
let metrics = Arc::new(metric::Registry::new());
let common_state = crate::CommonServerState::new(
Arc::clone(&metrics),
None,
trace_header_parser,
addr,
None,
)
.unwrap();
let common_state =
crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr)
.unwrap();
let object_store: Arc<DynObjectStore> = Arc::new(object_store::memory::InMemory::new());
let parquet_store =
ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3"));
@ -451,13 +416,12 @@ mod tests {
10,
);
let server = crate::Server::new(
common_state,
persister,
Arc::clone(&write_buffer),
Arc::new(query_executor),
usize::MAX,
);
let server = ServerBuilder::new(common_state)
.write_buffer(Arc::clone(&write_buffer))
.query_executor(Arc::new(query_executor))
.persister(persister)
.authorizer(Arc::new(DefaultAuthorizer))
.build();
let frontend_shutdown = CancellationToken::new();
let shutdown = frontend_shutdown.clone();
@ -584,14 +548,9 @@ mod tests {
let addr = get_free_port();
let trace_header_parser = trace_http::ctx::TraceHeaderParser::new();
let metrics = Arc::new(metric::Registry::new());
let common_state = crate::CommonServerState::new(
Arc::clone(&metrics),
None,
trace_header_parser,
addr,
None,
)
.unwrap();
let common_state =
crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr)
.unwrap();
let object_store: Arc<DynObjectStore> = Arc::new(object_store::memory::InMemory::new());
let parquet_store =
ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3"));
@ -625,13 +584,12 @@ mod tests {
10,
);
let server = crate::Server::new(
common_state,
persister,
Arc::clone(&write_buffer),
Arc::new(query_executor),
usize::MAX,
);
let server = ServerBuilder::new(common_state)
.write_buffer(Arc::clone(&write_buffer))
.query_executor(Arc::new(query_executor))
.persister(persister)
.authorizer(Arc::new(DefaultAuthorizer))
.build();
let frontend_shutdown = CancellationToken::new();
let shutdown = frontend_shutdown.clone();