feat: implement `Authorizer` to authorize all HTTP requests (#24738)
* feat: add `Authorizer` impls to authz REST and gRPC This adds two new Authorizer implementations to Edge: Default and AllOrNothing, which will provide the two auth options for Edge. Both gRPC requests and HTTP REST request will be authorized by the same Authorizer implementation. The SHA512 digest action was moved into the `Authorizer` impl. * feat: add `ServerBuilder` to construct `Server A builder was added to the Server in this commit, as part of an attempt to get the server creation to be more modular. * refactor: use test server fixture in auth e2e test Refactored the `auth` integration test in `influxdb3` to use the `TestServer` fixture; part of this involved extending the fixture to be configurable, so that the `TestServer` can be spun up with an auth token. * test: add test for authorized gRPC A new end-to-end test, auth_grpc, was added to check that authorization is working with the influxdb3 Flight service.pull/24747/head
parent
fad681c06c
commit
c4d651fbd1
|
@ -2469,6 +2469,7 @@ dependencies = [
|
|||
"arrow-flight",
|
||||
"arrow_util",
|
||||
"assert_cmd",
|
||||
"authz",
|
||||
"backtrace",
|
||||
"base64 0.22.0",
|
||||
"clap",
|
||||
|
|
|
@ -7,6 +7,7 @@ license.workspace = true
|
|||
|
||||
[dependencies]
|
||||
# Core Crates
|
||||
authz.workspace = true
|
||||
clap_blocks.workspace = true
|
||||
iox_query.workspace = true
|
||||
iox_time.workspace = true
|
||||
|
|
|
@ -7,7 +7,10 @@ use clap_blocks::{
|
|||
object_store::{make_object_store, ObjectStoreConfig},
|
||||
socket_addr::SocketAddr,
|
||||
};
|
||||
use influxdb3_server::{query_executor::QueryExecutorImpl, serve, CommonServerState, Server};
|
||||
use influxdb3_server::{
|
||||
auth::AllOrNothingAuthorizer, builder::ServerBuilder, query_executor::QueryExecutorImpl, serve,
|
||||
CommonServerState,
|
||||
};
|
||||
use influxdb3_write::persister::PersisterImpl;
|
||||
use influxdb3_write::wal::WalImpl;
|
||||
use influxdb3_write::write_buffer::WriteBufferImpl;
|
||||
|
@ -51,6 +54,9 @@ pub enum Error {
|
|||
|
||||
#[error("Write buffer error: {0}")]
|
||||
WriteBuffer(#[from] influxdb3_write::write_buffer::Error),
|
||||
|
||||
#[error("invalid token: {0}")]
|
||||
InvalidToken(#[from] hex::FromHexError),
|
||||
}
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
@ -238,32 +244,38 @@ pub async fn command(config: Config) -> Result<()> {
|
|||
trace_exporter,
|
||||
trace_header_parser,
|
||||
*config.http_bind_address,
|
||||
config.bearer_token,
|
||||
)?;
|
||||
let persister = PersisterImpl::new(Arc::clone(&object_store));
|
||||
let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store)));
|
||||
let wal: Option<Arc<WalImpl>> = config
|
||||
.wal_directory
|
||||
.map(|dir| WalImpl::new(dir).map(Arc::new))
|
||||
.transpose()?;
|
||||
// TODO: the next segment ID should be loaded from the persister
|
||||
let write_buffer = Arc::new(WriteBufferImpl::new(Arc::new(persister), wal).await?);
|
||||
let query_executor = QueryExecutorImpl::new(
|
||||
let write_buffer = Arc::new(WriteBufferImpl::new(Arc::clone(&persister), wal).await?);
|
||||
let query_executor = Arc::new(QueryExecutorImpl::new(
|
||||
write_buffer.catalog(),
|
||||
Arc::clone(&write_buffer),
|
||||
Arc::clone(&exec),
|
||||
Arc::clone(&metrics),
|
||||
Arc::new(config.datafusion_config),
|
||||
10,
|
||||
);
|
||||
));
|
||||
|
||||
let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store)));
|
||||
let server = Server::new(
|
||||
common_state,
|
||||
persister,
|
||||
Arc::clone(&write_buffer),
|
||||
Arc::new(query_executor),
|
||||
config.max_http_request_size,
|
||||
);
|
||||
|
||||
let builder = ServerBuilder::new(common_state)
|
||||
.max_request_size(config.max_http_request_size)
|
||||
.write_buffer(write_buffer)
|
||||
.query_executor(query_executor)
|
||||
.persister(persister);
|
||||
|
||||
let server = if let Some(token) = config.bearer_token.map(hex::decode).transpose()? {
|
||||
builder
|
||||
.authorizer(Arc::new(AllOrNothingAuthorizer::new(token)))
|
||||
.build()
|
||||
} else {
|
||||
builder.build()
|
||||
};
|
||||
serve(server, frontend_shutdown).await?;
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -1,78 +1,31 @@
|
|||
use parking_lot::Mutex;
|
||||
use arrow_flight::error::FlightError;
|
||||
use arrow_util::assert_batches_sorted_eq;
|
||||
use influxdb3_client::Precision;
|
||||
use reqwest::StatusCode;
|
||||
use std::env;
|
||||
use std::mem;
|
||||
use std::panic;
|
||||
use std::process::Child;
|
||||
use std::process::Command;
|
||||
use std::process::Stdio;
|
||||
|
||||
struct DropCommand {
|
||||
cmd: Option<Child>,
|
||||
}
|
||||
|
||||
impl DropCommand {
|
||||
const fn new(cmd: Child) -> Self {
|
||||
Self { cmd: Some(cmd) }
|
||||
}
|
||||
|
||||
fn kill(&mut self) {
|
||||
let mut cmd = self.cmd.take().unwrap();
|
||||
cmd.kill().unwrap();
|
||||
mem::drop(cmd);
|
||||
}
|
||||
}
|
||||
|
||||
static COMMAND: Mutex<Option<DropCommand>> = parking_lot::const_mutex(None);
|
||||
use crate::{collect_stream, TestServer};
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth() {
|
||||
const HASHED_TOKEN: &str = "5315f0c4714537843face80cca8c18e27ce88e31e9be7a5232dc4dc8444f27c0227a9bd64831d3ab58f652bd0262dd8558dd08870ac9e5c650972ce9e4259439";
|
||||
const TOKEN: &str = "apiv3_mp75KQAhbqv0GeQXk8MPuZ3ztaLEaR5JzS8iifk1FwuroSVyXXyrJK1c4gEr1kHkmbgzDV-j3MvQpaIMVJBAiA";
|
||||
// The binary is made before testing so we have access to it
|
||||
let bin_path = {
|
||||
let mut bin_path = env::current_exe().unwrap();
|
||||
bin_path.pop();
|
||||
bin_path.pop();
|
||||
bin_path.join("influxdb3")
|
||||
};
|
||||
let server = DropCommand::new(
|
||||
Command::new(bin_path)
|
||||
.args([
|
||||
"serve",
|
||||
"--object-store",
|
||||
"memory",
|
||||
"--bearer-token",
|
||||
HASHED_TOKEN,
|
||||
])
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.expect("Was able to spawn a server"),
|
||||
);
|
||||
|
||||
*COMMAND.lock() = Some(server);
|
||||
|
||||
let current_hook = panic::take_hook();
|
||||
panic::set_hook(Box::new(move |info| {
|
||||
COMMAND.lock().take().unwrap().kill();
|
||||
current_hook(info);
|
||||
}));
|
||||
let server = TestServer::configure()
|
||||
.auth_token(HASHED_TOKEN, TOKEN)
|
||||
.spawn()
|
||||
.await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Wait for the server to come up
|
||||
while client
|
||||
.get("http://127.0.0.1:8181/health")
|
||||
.bearer_auth(TOKEN)
|
||||
.send()
|
||||
.await
|
||||
.is_err()
|
||||
{}
|
||||
let base = server.client_addr();
|
||||
let write_lp_url = format!("{base}/api/v3/write_lp");
|
||||
let write_lp_params = [("db", "foo")];
|
||||
let query_sql_url = format!("{base}/api/v3/query_sql");
|
||||
let query_sql_params = [("db", "foo"), ("q", "select * from cpu")];
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.post("http://127.0.0.1:8181/api/v3/write_lp?db=foo")
|
||||
.post(&write_lp_url)
|
||||
.query(&write_lp_params)
|
||||
.body("cpu,host=a val=1i 123")
|
||||
.send()
|
||||
.await
|
||||
|
@ -82,7 +35,8 @@ async fn auth() {
|
|||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
|
||||
.get(&query_sql_url)
|
||||
.query(&query_sql_params)
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
|
@ -91,7 +45,8 @@ async fn auth() {
|
|||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.post("http://127.0.0.1:8181/api/v3/write_lp?db=foo")
|
||||
.post(&write_lp_url)
|
||||
.query(&write_lp_params)
|
||||
.body("cpu,host=a val=1i 123")
|
||||
.bearer_auth(TOKEN)
|
||||
.send()
|
||||
|
@ -102,7 +57,8 @@ async fn auth() {
|
|||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
|
||||
.get(&query_sql_url)
|
||||
.query(&query_sql_params)
|
||||
.bearer_auth(TOKEN)
|
||||
.send()
|
||||
.await
|
||||
|
@ -114,7 +70,8 @@ async fn auth() {
|
|||
// Test that there is an extra string after the token foo
|
||||
assert_eq!(
|
||||
client
|
||||
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
|
||||
.get(&query_sql_url)
|
||||
.query(&query_sql_params)
|
||||
.header("Authorization", format!("Bearer {TOKEN} whee"))
|
||||
.send()
|
||||
.await
|
||||
|
@ -124,7 +81,8 @@ async fn auth() {
|
|||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
|
||||
.get(&query_sql_url)
|
||||
.query(&query_sql_params)
|
||||
.header("Authorization", format!("bearer {TOKEN}"))
|
||||
.send()
|
||||
.await
|
||||
|
@ -134,7 +92,8 @@ async fn auth() {
|
|||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
|
||||
.get(&query_sql_url)
|
||||
.query(&query_sql_params)
|
||||
.header("Authorization", "Bearer")
|
||||
.send()
|
||||
.await
|
||||
|
@ -144,7 +103,8 @@ async fn auth() {
|
|||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.get("http://127.0.0.1:8181/api/v3/query_sql?db=foo&q=select+*+from+cpu")
|
||||
.get(&query_sql_url)
|
||||
.query(&query_sql_params)
|
||||
.header("auth", format!("Bearer {TOKEN}"))
|
||||
.send()
|
||||
.await
|
||||
|
@ -152,5 +112,97 @@ async fn auth() {
|
|||
.status(),
|
||||
StatusCode::UNAUTHORIZED
|
||||
);
|
||||
COMMAND.lock().take().unwrap().kill();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_grpc() {
|
||||
const HASHED_TOKEN: &str = "5315f0c4714537843face80cca8c18e27ce88e31e9be7a5232dc4dc8444f27c0227a9bd64831d3ab58f652bd0262dd8558dd08870ac9e5c650972ce9e4259439";
|
||||
const TOKEN: &str = "apiv3_mp75KQAhbqv0GeQXk8MPuZ3ztaLEaR5JzS8iifk1FwuroSVyXXyrJK1c4gEr1kHkmbgzDV-j3MvQpaIMVJBAiA";
|
||||
|
||||
let server = TestServer::configure()
|
||||
.auth_token(HASHED_TOKEN, TOKEN)
|
||||
.spawn()
|
||||
.await;
|
||||
|
||||
// Write some data to the server, this will be authorized through the HTTP API
|
||||
server
|
||||
.write_lp_to_db(
|
||||
"foo",
|
||||
"cpu,host=s1,region=us-east usage=0.9 1\n\
|
||||
cpu,host=s1,region=us-east usage=0.89 2\n\
|
||||
cpu,host=s1,region=us-east usage=0.85 3",
|
||||
Precision::Nanosecond,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Check that with a valid authorization header, it succeeds:
|
||||
for header in ["authorization", "Authorization"] {
|
||||
// Spin up a FlightSQL client
|
||||
let mut client = server.flight_sql_client("foo").await;
|
||||
|
||||
// Set the authorization header on the client:
|
||||
client
|
||||
.add_header(header, &format!("Bearer {TOKEN}"))
|
||||
.unwrap();
|
||||
|
||||
// Make the query again, this time it should work:
|
||||
let response = client.query("SELECT * FROM cpu").await.unwrap();
|
||||
let batches = collect_stream(response).await;
|
||||
assert_batches_sorted_eq!(
|
||||
[
|
||||
"+------+---------+--------------------------------+-------+",
|
||||
"| host | region | time | usage |",
|
||||
"+------+---------+--------------------------------+-------+",
|
||||
"| s1 | us-east | 1970-01-01T00:00:00.000000001Z | 0.9 |",
|
||||
"| s1 | us-east | 1970-01-01T00:00:00.000000002Z | 0.89 |",
|
||||
"| s1 | us-east | 1970-01-01T00:00:00.000000003Z | 0.85 |",
|
||||
"+------+---------+--------------------------------+-------+",
|
||||
],
|
||||
&batches
|
||||
);
|
||||
}
|
||||
|
||||
// Check that without providing an Authentication header, it gives back
|
||||
// an Unauthenticated error:
|
||||
{
|
||||
let mut client = server.flight_sql_client("foo").await;
|
||||
let error = client.query("SELECT * FROM cpu").await.unwrap_err();
|
||||
assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated));
|
||||
}
|
||||
|
||||
// Create some new clients that set the authorization header incorrectly to
|
||||
// ensure errors are returned:
|
||||
|
||||
// Mispelled "Bearer"
|
||||
{
|
||||
let mut client = server.flight_sql_client("foo").await;
|
||||
client
|
||||
.add_header("authorization", &format!("bearer {TOKEN}"))
|
||||
.unwrap();
|
||||
let error = client.query("SELECT * FROM cpu").await.unwrap_err();
|
||||
assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated));
|
||||
}
|
||||
|
||||
// Invalid token, this actually gives Permission denied
|
||||
{
|
||||
let mut client = server.flight_sql_client("foo").await;
|
||||
client
|
||||
.add_header("authorization", "Bearer invalid-token")
|
||||
.unwrap();
|
||||
let error = client.query("SELECT * FROM cpu").await.unwrap_err();
|
||||
assert!(
|
||||
matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::PermissionDenied)
|
||||
);
|
||||
}
|
||||
|
||||
// Mispelled header key
|
||||
{
|
||||
let mut client = server.flight_sql_client("foo").await;
|
||||
client
|
||||
.add_header("auth", &format!("Bearer {TOKEN}"))
|
||||
.unwrap();
|
||||
let error = client.query("SELECT * FROM cpu").await.unwrap_err();
|
||||
assert!(matches!(error, FlightError::Tonic(s) if s.code() == tonic::Code::Unauthenticated));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,8 +17,43 @@ mod flight;
|
|||
mod limits;
|
||||
mod query;
|
||||
|
||||
/// Configuration for a [`TestServer`]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TestConfig {
|
||||
auth_token: Option<(String, String)>,
|
||||
}
|
||||
|
||||
impl TestConfig {
|
||||
/// Set the auth token for this [`TestServer`]
|
||||
pub fn auth_token<S: Into<String>, R: Into<String>>(
|
||||
mut self,
|
||||
hashed_token: S,
|
||||
raw_token: R,
|
||||
) -> Self {
|
||||
self.auth_token = Some((hashed_token.into(), raw_token.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Spawn a new [`TestServer`] with this configuration
|
||||
///
|
||||
/// This will run the `influxdb3 serve` command, and bind its HTTP
|
||||
/// address to a random port on localhost.
|
||||
pub async fn spawn(self) -> TestServer {
|
||||
TestServer::spawn_inner(self).await
|
||||
}
|
||||
|
||||
fn as_args(&self) -> Vec<&str> {
|
||||
let mut args = vec![];
|
||||
if let Some((token, _)) = &self.auth_token {
|
||||
args.append(&mut vec!["--bearer-token", token]);
|
||||
}
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
/// A running instance of the `influxdb3 serve` process
|
||||
pub struct TestServer {
|
||||
config: TestConfig,
|
||||
bind_addr: SocketAddr,
|
||||
server_process: Child,
|
||||
http_client: reqwest::Client,
|
||||
|
@ -30,19 +65,29 @@ impl TestServer {
|
|||
/// This will run the `influxdb3 serve` command, and bind its HTTP
|
||||
/// address to a random port on localhost.
|
||||
pub async fn spawn() -> Self {
|
||||
Self::spawn_inner(Default::default()).await
|
||||
}
|
||||
|
||||
/// Configure a [`TestServer`] before spawning
|
||||
pub fn configure() -> TestConfig {
|
||||
TestConfig::default()
|
||||
}
|
||||
|
||||
async fn spawn_inner(config: TestConfig) -> Self {
|
||||
let bind_addr = get_local_bind_addr();
|
||||
let mut command = Command::cargo_bin("influxdb3").expect("create the influxdb3 command");
|
||||
let command = command
|
||||
.arg("serve")
|
||||
.args(["--http-bind", &bind_addr.to_string()])
|
||||
.args(["--object-store", "memory"])
|
||||
// TODO - other configuration can be passed through
|
||||
.args(config.as_args())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
|
||||
let server_process = command.spawn().expect("spawn the influxdb3 server process");
|
||||
|
||||
let server = Self {
|
||||
config,
|
||||
bind_addr,
|
||||
server_process,
|
||||
http_client: reqwest::Client::new(),
|
||||
|
@ -111,7 +156,10 @@ impl TestServer {
|
|||
lp: impl ToString,
|
||||
precision: Precision,
|
||||
) -> Result<(), influxdb3_client::Error> {
|
||||
let client = influxdb3_client::Client::new(self.client_addr()).unwrap();
|
||||
let mut client = influxdb3_client::Client::new(self.client_addr()).unwrap();
|
||||
if let Some((_, token)) = &self.config.auth_token {
|
||||
client = client.with_auth_token(token);
|
||||
}
|
||||
client
|
||||
.api_v3_write_lp(database)
|
||||
.body(lp.to_string())
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
use async_trait::async_trait;
|
||||
use authz::{Authorizer, Error, Permission};
|
||||
use observability_deps::tracing::{debug, warn};
|
||||
use sha2::{Digest, Sha512};
|
||||
|
||||
/// An [`Authorizer`] implementation that will grant access to all
|
||||
/// requests that provide `token`
|
||||
#[derive(Debug)]
|
||||
pub struct AllOrNothingAuthorizer {
|
||||
token: Vec<u8>,
|
||||
}
|
||||
|
||||
impl AllOrNothingAuthorizer {
|
||||
pub fn new(token: Vec<u8>) -> Self {
|
||||
Self { token }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Authorizer for AllOrNothingAuthorizer {
|
||||
async fn permissions(
|
||||
&self,
|
||||
token: Option<Vec<u8>>,
|
||||
perms: &[Permission],
|
||||
) -> Result<Vec<Permission>, Error> {
|
||||
debug!(?perms, "requesting permissions");
|
||||
let provided = token.as_deref().ok_or(Error::NoToken)?;
|
||||
if Sha512::digest(provided)[..] == self.token {
|
||||
warn!("invalid token provided");
|
||||
Ok(perms.to_vec())
|
||||
} else {
|
||||
Err(Error::InvalidToken)
|
||||
}
|
||||
}
|
||||
|
||||
async fn probe(&self) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// The defult [`Authorizer`] implementation that will authorize all requests
|
||||
#[derive(Debug)]
|
||||
pub struct DefaultAuthorizer;
|
||||
|
||||
#[async_trait]
|
||||
impl Authorizer for DefaultAuthorizer {
|
||||
async fn permissions(
|
||||
&self,
|
||||
_token: Option<Vec<u8>>,
|
||||
perms: &[Permission],
|
||||
) -> Result<Vec<Permission>, Error> {
|
||||
Ok(perms.to_vec())
|
||||
}
|
||||
|
||||
async fn probe(&self) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use authz::Authorizer;
|
||||
|
||||
use crate::{auth::DefaultAuthorizer, http::HttpApi, CommonServerState, Server};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ServerBuilder<W, Q, P> {
|
||||
common_state: CommonServerState,
|
||||
max_request_size: usize,
|
||||
write_buffer: W,
|
||||
query_executor: Q,
|
||||
persister: P,
|
||||
authorizer: Arc<dyn Authorizer>,
|
||||
}
|
||||
|
||||
impl ServerBuilder<NoWriteBuf, NoQueryExec, NoPersister> {
|
||||
pub fn new(common_state: CommonServerState) -> Self {
|
||||
Self {
|
||||
common_state,
|
||||
max_request_size: usize::MAX,
|
||||
write_buffer: NoWriteBuf,
|
||||
query_executor: NoQueryExec,
|
||||
persister: NoPersister,
|
||||
authorizer: Arc::new(DefaultAuthorizer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, Q, P> ServerBuilder<W, Q, P> {
|
||||
pub fn max_request_size(mut self, max_request_size: usize) -> Self {
|
||||
self.max_request_size = max_request_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn authorizer(mut self, a: Arc<dyn Authorizer>) -> Self {
|
||||
self.authorizer = a;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NoWriteBuf;
|
||||
#[derive(Debug)]
|
||||
pub struct WithWriteBuf<W>(Arc<W>);
|
||||
#[derive(Debug)]
|
||||
pub struct NoQueryExec;
|
||||
#[derive(Debug)]
|
||||
pub struct WithQueryExec<Q>(Arc<Q>);
|
||||
#[derive(Debug)]
|
||||
pub struct NoPersister;
|
||||
#[derive(Debug)]
|
||||
pub struct WithPersister<P>(Arc<P>);
|
||||
|
||||
impl<Q, P> ServerBuilder<NoWriteBuf, Q, P> {
|
||||
pub fn write_buffer<W>(self, wb: Arc<W>) -> ServerBuilder<WithWriteBuf<W>, Q, P> {
|
||||
ServerBuilder {
|
||||
common_state: self.common_state,
|
||||
max_request_size: self.max_request_size,
|
||||
write_buffer: WithWriteBuf(wb),
|
||||
query_executor: self.query_executor,
|
||||
persister: self.persister,
|
||||
authorizer: self.authorizer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, P> ServerBuilder<W, NoQueryExec, P> {
|
||||
pub fn query_executor<Q>(self, qe: Arc<Q>) -> ServerBuilder<W, WithQueryExec<Q>, P> {
|
||||
ServerBuilder {
|
||||
common_state: self.common_state,
|
||||
max_request_size: self.max_request_size,
|
||||
write_buffer: self.write_buffer,
|
||||
query_executor: WithQueryExec(qe),
|
||||
persister: self.persister,
|
||||
authorizer: self.authorizer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, Q> ServerBuilder<W, Q, NoPersister> {
|
||||
pub fn persister<P>(self, p: Arc<P>) -> ServerBuilder<W, Q, WithPersister<P>> {
|
||||
ServerBuilder {
|
||||
common_state: self.common_state,
|
||||
max_request_size: self.max_request_size,
|
||||
write_buffer: self.write_buffer,
|
||||
query_executor: self.query_executor,
|
||||
persister: WithPersister(p),
|
||||
authorizer: self.authorizer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, Q, P> ServerBuilder<WithWriteBuf<W>, WithQueryExec<Q>, WithPersister<P>> {
|
||||
pub fn build(self) -> Server<W, Q, P> {
|
||||
let persister = Arc::clone(&self.persister.0);
|
||||
let authorizer = Arc::clone(&self.authorizer);
|
||||
let http = Arc::new(HttpApi::new(
|
||||
self.common_state.clone(),
|
||||
Arc::clone(&self.write_buffer.0),
|
||||
Arc::clone(&self.query_executor.0),
|
||||
self.max_request_size,
|
||||
Arc::clone(&authorizer),
|
||||
));
|
||||
Server {
|
||||
common_state: self.common_state,
|
||||
http,
|
||||
persister,
|
||||
authorizer,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,7 +4,7 @@ use crate::{query_executor, QueryKind};
|
|||
use crate::{CommonServerState, QueryExecutor};
|
||||
use arrow::record_batch::RecordBatch;
|
||||
use arrow::util::pretty;
|
||||
use authz::http::AuthorizationHeaderExtension;
|
||||
use authz::Authorizer;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use data_types::NamespaceName;
|
||||
use datafusion::error::DataFusionError;
|
||||
|
@ -29,8 +29,6 @@ use observability_deps::tracing::{debug, error, info};
|
|||
use serde::de::DeserializeOwned;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use sha2::Digest;
|
||||
use sha2::Sha512;
|
||||
use std::convert::Infallible;
|
||||
use std::fmt::Debug;
|
||||
use std::num::NonZeroI32;
|
||||
|
@ -195,6 +193,8 @@ pub enum AuthorizationError {
|
|||
Unauthorized,
|
||||
#[error("the request was not in the form of 'Authorization: Bearer <token>'")]
|
||||
MalformedRequest,
|
||||
#[error("requestor is forbidden from requested resource")]
|
||||
Forbidden,
|
||||
#[error("to str error: {0}")]
|
||||
ToStr(#[from] hyper::header::ToStrError),
|
||||
}
|
||||
|
@ -278,6 +278,7 @@ pub(crate) struct HttpApi<W, Q> {
|
|||
write_buffer: Arc<W>,
|
||||
pub(crate) query_executor: Arc<Q>,
|
||||
max_request_bytes: usize,
|
||||
authorizer: Arc<dyn Authorizer>,
|
||||
}
|
||||
|
||||
impl<W, Q> HttpApi<W, Q> {
|
||||
|
@ -286,12 +287,14 @@ impl<W, Q> HttpApi<W, Q> {
|
|||
write_buffer: Arc<W>,
|
||||
query_executor: Arc<Q>,
|
||||
max_request_bytes: usize,
|
||||
authorizer: Arc<dyn Authorizer>,
|
||||
) -> Self {
|
||||
Self {
|
||||
common_state,
|
||||
write_buffer,
|
||||
query_executor,
|
||||
max_request_bytes,
|
||||
authorizer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -487,47 +490,57 @@ where
|
|||
Ok(decoded_data.into())
|
||||
}
|
||||
|
||||
fn authorize_request(&self, req: &mut Request<Body>) -> Result<(), AuthorizationError> {
|
||||
async fn authorize_request(&self, req: &mut Request<Body>) -> Result<(), AuthorizationError> {
|
||||
// We won't need the authorization header anymore and we don't want to accidentally log it.
|
||||
// Take it out so we can use it and not log it later by accident.
|
||||
let auth = req.headers_mut().remove(AUTHORIZATION);
|
||||
let auth = req
|
||||
.headers_mut()
|
||||
.remove(AUTHORIZATION)
|
||||
.map(validate_auth_header)
|
||||
.transpose()?;
|
||||
|
||||
if let Some(bearer_token) = self.common_state.bearer_token() {
|
||||
let Some(header) = &auth else {
|
||||
return Err(AuthorizationError::Unauthorized);
|
||||
};
|
||||
// Currently we pass an empty permissions list, but in future we may be able to derive
|
||||
// the permissions based on the incoming request
|
||||
let permissions = self.authorizer.permissions(auth, &[]).await?;
|
||||
|
||||
// Split the header value into two parts
|
||||
let mut header = header.to_str()?.split(' ');
|
||||
// Extend the request with the permissions, which may be useful in future
|
||||
req.extensions_mut().insert(permissions);
|
||||
|
||||
// Check that the header is the 'Bearer' auth scheme
|
||||
let bearer = header.next().ok_or(AuthorizationError::MalformedRequest)?;
|
||||
if bearer != "Bearer" {
|
||||
return Err(AuthorizationError::MalformedRequest);
|
||||
}
|
||||
|
||||
// Get the token that we want to hash to check the request is valid
|
||||
let token = header.next().ok_or(AuthorizationError::MalformedRequest)?;
|
||||
|
||||
// There should only be two parts the 'Bearer' scheme and the actual
|
||||
// token, error otherwise
|
||||
if header.next().is_some() {
|
||||
return Err(AuthorizationError::MalformedRequest);
|
||||
}
|
||||
|
||||
// Check that the hashed token is acceptable
|
||||
let authorized = &Sha512::digest(token)[..] == bearer_token;
|
||||
if !authorized {
|
||||
return Err(AuthorizationError::Unauthorized);
|
||||
}
|
||||
}
|
||||
|
||||
req.extensions_mut()
|
||||
.insert(AuthorizationHeaderExtension::new(auth));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_auth_header(header: HeaderValue) -> Result<Vec<u8>, AuthorizationError> {
|
||||
// Split the header value into two parts
|
||||
let mut header = header.to_str()?.split(' ');
|
||||
|
||||
// Check that the header is the 'Bearer' auth scheme
|
||||
let bearer = header.next().ok_or(AuthorizationError::MalformedRequest)?;
|
||||
if bearer != "Bearer" {
|
||||
return Err(AuthorizationError::MalformedRequest);
|
||||
}
|
||||
|
||||
// Get the token that we want to hash to check the request is valid
|
||||
let token = header.next().ok_or(AuthorizationError::MalformedRequest)?;
|
||||
|
||||
// There should only be two parts the 'Bearer' scheme and the actual
|
||||
// token, error otherwise
|
||||
if header.next().is_some() {
|
||||
return Err(AuthorizationError::MalformedRequest);
|
||||
}
|
||||
|
||||
Ok(token.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
impl From<authz::Error> for AuthorizationError {
|
||||
fn from(auth_error: authz::Error) -> Self {
|
||||
match auth_error {
|
||||
authz::Error::Forbidden => Self::Forbidden,
|
||||
_ => Self::Unauthorized,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A valid name:
|
||||
/// - Starts with a letter or a number
|
||||
/// - Is ASCII not UTF-8
|
||||
|
@ -700,7 +713,7 @@ where
|
|||
Q: QueryExecutor,
|
||||
Error: From<<Q as QueryExecutor>::Error>,
|
||||
{
|
||||
if let Err(e) = http_server.authorize_request(&mut req) {
|
||||
if let Err(e) = http_server.authorize_request(&mut req).await {
|
||||
match e {
|
||||
AuthorizationError::Unauthorized => {
|
||||
return Ok(Response::builder()
|
||||
|
@ -716,6 +729,12 @@ where
|
|||
}"))
|
||||
.unwrap());
|
||||
}
|
||||
AuthorizationError::Forbidden => {
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.body(Body::empty())
|
||||
.unwrap())
|
||||
}
|
||||
// We don't expect this to happen, but if the header is messed up
|
||||
// better to handle it then not at all
|
||||
AuthorizationError::ToStr(_) => {
|
||||
|
|
|
@ -11,6 +11,8 @@ clippy::clone_on_ref_ptr,
|
|||
clippy::future_not_send
|
||||
)]
|
||||
|
||||
pub mod auth;
|
||||
pub mod builder;
|
||||
mod grpc;
|
||||
mod http;
|
||||
pub mod query_executor;
|
||||
|
@ -20,6 +22,7 @@ use crate::grpc::make_flight_server;
|
|||
use crate::http::route_request;
|
||||
use crate::http::HttpApi;
|
||||
use async_trait::async_trait;
|
||||
use authz::Authorizer;
|
||||
use datafusion::execution::SendableRecordBatchStream;
|
||||
use hyper::service::service_fn;
|
||||
use influxdb3_write::{Persister, WriteBuffer};
|
||||
|
@ -72,7 +75,6 @@ pub struct CommonServerState {
|
|||
trace_exporter: Option<Arc<trace_exporters::export::AsyncExporter>>,
|
||||
trace_header_parser: TraceHeaderParser,
|
||||
http_addr: SocketAddr,
|
||||
bearer_token: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl CommonServerState {
|
||||
|
@ -81,14 +83,12 @@ impl CommonServerState {
|
|||
trace_exporter: Option<Arc<trace_exporters::export::AsyncExporter>>,
|
||||
trace_header_parser: TraceHeaderParser,
|
||||
http_addr: SocketAddr,
|
||||
bearer_token: Option<String>,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
metrics,
|
||||
trace_exporter,
|
||||
trace_header_parser,
|
||||
http_addr,
|
||||
bearer_token: bearer_token.map(hex::decode).transpose()?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -109,10 +109,6 @@ impl CommonServerState {
|
|||
pub fn metric_registry(&self) -> Arc<metric::Registry> {
|
||||
Arc::<metric::Registry>::clone(&self.metrics)
|
||||
}
|
||||
|
||||
pub fn bearer_token(&self) -> Option<&[u8]> {
|
||||
self.bearer_token.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
|
@ -121,6 +117,7 @@ pub struct Server<W, Q, P> {
|
|||
common_state: CommonServerState,
|
||||
http: Arc<HttpApi<W, Q>>,
|
||||
persister: Arc<P>,
|
||||
authorizer: Arc<dyn Authorizer>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
@ -151,30 +148,9 @@ pub enum QueryKind {
|
|||
InfluxQl,
|
||||
}
|
||||
|
||||
impl<W, Q, P> Server<W, Q, P>
|
||||
where
|
||||
Q: QueryExecutor,
|
||||
P: Persister,
|
||||
{
|
||||
pub fn new(
|
||||
common_state: CommonServerState,
|
||||
persister: Arc<P>,
|
||||
write_buffer: Arc<W>,
|
||||
query_executor: Arc<Q>,
|
||||
max_http_request_size: usize,
|
||||
) -> Self {
|
||||
let http = Arc::new(HttpApi::new(
|
||||
common_state.clone(),
|
||||
Arc::clone(&write_buffer),
|
||||
Arc::clone(&query_executor),
|
||||
max_http_request_size,
|
||||
));
|
||||
|
||||
Self {
|
||||
common_state,
|
||||
http,
|
||||
persister,
|
||||
}
|
||||
impl<W, Q, P> Server<W, Q, P> {
|
||||
pub fn authorizer(&self) -> Arc<dyn Authorizer> {
|
||||
Arc::clone(&self.authorizer)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,8 +180,7 @@ where
|
|||
|
||||
let grpc_service = trace_layer.clone().layer(make_flight_server(
|
||||
Arc::clone(&server.http.query_executor),
|
||||
// TODO - need to configure authz here:
|
||||
None,
|
||||
Some(server.authorizer()),
|
||||
));
|
||||
let rest_service = hyper::service::make_service_fn(|_| {
|
||||
let http_server = Arc::clone(&server.http);
|
||||
|
@ -249,6 +224,8 @@ pub async fn wait_for_signal() {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::auth::DefaultAuthorizer;
|
||||
use crate::builder::ServerBuilder;
|
||||
use crate::serve;
|
||||
use datafusion::parquet::data_type::AsBytes;
|
||||
use hyper::{body, Body, Client, Request, Response, StatusCode};
|
||||
|
@ -271,14 +248,9 @@ mod tests {
|
|||
let addr = get_free_port();
|
||||
let trace_header_parser = trace_http::ctx::TraceHeaderParser::new();
|
||||
let metrics = Arc::new(metric::Registry::new());
|
||||
let common_state = crate::CommonServerState::new(
|
||||
Arc::clone(&metrics),
|
||||
None,
|
||||
trace_header_parser,
|
||||
addr,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
let common_state =
|
||||
crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr)
|
||||
.unwrap();
|
||||
let object_store: Arc<DynObjectStore> = Arc::new(object_store::memory::InMemory::new());
|
||||
let parquet_store =
|
||||
ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3"));
|
||||
|
@ -293,33 +265,31 @@ mod tests {
|
|||
metric_registry: Arc::clone(&metrics),
|
||||
mem_pool_size: usize::MAX,
|
||||
}));
|
||||
let persister = PersisterImpl::new(Arc::clone(&object_store));
|
||||
let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store)));
|
||||
|
||||
let write_buffer = Arc::new(
|
||||
influxdb3_write::write_buffer::WriteBufferImpl::new(
|
||||
Arc::new(persister),
|
||||
Arc::clone(&persister),
|
||||
None::<Arc<influxdb3_write::wal::WalImpl>>,
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let query_executor = crate::query_executor::QueryExecutorImpl::new(
|
||||
let query_executor = Arc::new(crate::query_executor::QueryExecutorImpl::new(
|
||||
write_buffer.catalog(),
|
||||
Arc::clone(&write_buffer),
|
||||
Arc::clone(&exec),
|
||||
Arc::clone(&metrics),
|
||||
Arc::new(HashMap::new()),
|
||||
10,
|
||||
);
|
||||
let persister = Arc::new(PersisterImpl::new(Arc::clone(&object_store)));
|
||||
));
|
||||
|
||||
let server = crate::Server::new(
|
||||
common_state,
|
||||
persister,
|
||||
Arc::clone(&write_buffer),
|
||||
Arc::new(query_executor),
|
||||
usize::MAX,
|
||||
);
|
||||
let server = ServerBuilder::new(common_state)
|
||||
.write_buffer(Arc::clone(&write_buffer))
|
||||
.query_executor(Arc::clone(&query_executor))
|
||||
.persister(Arc::clone(&persister))
|
||||
.authorizer(Arc::new(DefaultAuthorizer))
|
||||
.build();
|
||||
let frontend_shutdown = CancellationToken::new();
|
||||
let shutdown = frontend_shutdown.clone();
|
||||
|
||||
|
@ -410,14 +380,9 @@ mod tests {
|
|||
let addr = get_free_port();
|
||||
let trace_header_parser = trace_http::ctx::TraceHeaderParser::new();
|
||||
let metrics = Arc::new(metric::Registry::new());
|
||||
let common_state = crate::CommonServerState::new(
|
||||
Arc::clone(&metrics),
|
||||
None,
|
||||
trace_header_parser,
|
||||
addr,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
let common_state =
|
||||
crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr)
|
||||
.unwrap();
|
||||
let object_store: Arc<DynObjectStore> = Arc::new(object_store::memory::InMemory::new());
|
||||
let parquet_store =
|
||||
ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3"));
|
||||
|
@ -451,13 +416,12 @@ mod tests {
|
|||
10,
|
||||
);
|
||||
|
||||
let server = crate::Server::new(
|
||||
common_state,
|
||||
persister,
|
||||
Arc::clone(&write_buffer),
|
||||
Arc::new(query_executor),
|
||||
usize::MAX,
|
||||
);
|
||||
let server = ServerBuilder::new(common_state)
|
||||
.write_buffer(Arc::clone(&write_buffer))
|
||||
.query_executor(Arc::new(query_executor))
|
||||
.persister(persister)
|
||||
.authorizer(Arc::new(DefaultAuthorizer))
|
||||
.build();
|
||||
let frontend_shutdown = CancellationToken::new();
|
||||
let shutdown = frontend_shutdown.clone();
|
||||
|
||||
|
@ -584,14 +548,9 @@ mod tests {
|
|||
let addr = get_free_port();
|
||||
let trace_header_parser = trace_http::ctx::TraceHeaderParser::new();
|
||||
let metrics = Arc::new(metric::Registry::new());
|
||||
let common_state = crate::CommonServerState::new(
|
||||
Arc::clone(&metrics),
|
||||
None,
|
||||
trace_header_parser,
|
||||
addr,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
let common_state =
|
||||
crate::CommonServerState::new(Arc::clone(&metrics), None, trace_header_parser, addr)
|
||||
.unwrap();
|
||||
let object_store: Arc<DynObjectStore> = Arc::new(object_store::memory::InMemory::new());
|
||||
let parquet_store =
|
||||
ParquetStorage::new(Arc::clone(&object_store), StorageId::from("influxdb3"));
|
||||
|
@ -625,13 +584,12 @@ mod tests {
|
|||
10,
|
||||
);
|
||||
|
||||
let server = crate::Server::new(
|
||||
common_state,
|
||||
persister,
|
||||
Arc::clone(&write_buffer),
|
||||
Arc::new(query_executor),
|
||||
usize::MAX,
|
||||
);
|
||||
let server = ServerBuilder::new(common_state)
|
||||
.write_buffer(Arc::clone(&write_buffer))
|
||||
.query_executor(Arc::new(query_executor))
|
||||
.persister(persister)
|
||||
.authorizer(Arc::new(DefaultAuthorizer))
|
||||
.build();
|
||||
let frontend_shutdown = CancellationToken::new();
|
||||
let shutdown = frontend_shutdown.clone();
|
||||
|
||||
|
|
Loading…
Reference in New Issue