fix: unambigious bucket/org to DB mappings

Previosuly the $ORG and $BUCKET was joined as:

	$ORG + "_" + $BUCKET

Which is fine unless either $ORG or $BUCKET includes a "_", such as:

	$ORG = "org_a"
	$BUCKET = "bucket"

	and

	$ORG = "org"
	$BUCKET = "a_bucket"

This change continues to join $ORG and $BUCKET with an underscore, but
disallows underscores in either $ORG or $BUCKET. It appears these values
are non-zero u64s in the gRPC protocol converted to their base-10 string
representations for the DB name, so this seems safe to enforce.

In addition, this change introduces a `DatabaseName` type to avoid
passing bare strings around, and allow consuming code to ensure only
valid database names are provided at compile type. This type works with
both owned & borrowed content so doesn't force a string copy where we
can avoid it, and derefs to `str` to make it easier to use with existing
code.

I've been minimally invasive in pushing the `DatabaseName` through the
existing code and figured I'd see what the sentement is first.
Candidates for conversion from `str` to `DatabaseName` that seem to make
sense to me include:

	- `DatabaseStore` trait
	- `RemoteServer` trait
	- Others? Basically anywhere other than the "edge" API inputs

Fixes #436 (thanks @zeebo)
pull/24376/head
Dom 2020-12-02 17:12:39 +00:00
parent 8c0e14e039
commit f90a95fd80
7 changed files with 275 additions and 93 deletions

View File

@ -0,0 +1,147 @@
use snafu::Snafu;
use std::{borrow::Cow, ops::Range};
/// Length constraints for a database name.
///
/// A `Range` is half open covering [1, 64)
const LENGTH_CONSTRAINT: Range<usize> = 1..64;
/// Database name validation errors.
#[derive(Debug, Snafu)]
pub enum DatabaseNameError {
#[snafu(display(
"Database name {} length must be between {} and {} characters",
name,
LENGTH_CONSTRAINT.start,
LENGTH_CONSTRAINT.end
))]
LengthConstraint { name: String },
#[snafu(display(
"Database name {} contains invalid characters (allowed: alphanumeric, _ and -)",
name
))]
BadChars { name: String },
}
/// A correctly formed database name.
///
/// Using this wrapper type allows the consuming code to enforce the invariant
/// that only valid names are provided.
///
/// This type derefs to a `str` and therefore can be used in place of anything
/// that is expecting a `str`:
///
/// ```rust
/// # use data_types::DatabaseName;
/// fn print_database(s: &str) {
/// println!("database name: {}", s);
/// }
///
/// let db = DatabaseName::new("data").unwrap();
/// print_database(&db);
/// ```
///
/// But this is not reciprocal - functions that wish to accept only
/// pre-validated names can use `DatabaseName` as a parameter.
#[derive(
Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
)]
pub struct DatabaseName<'a>(Cow<'a, str>);
impl<'a> DatabaseName<'a> {
pub fn new<T: Into<Cow<'a, str>>>(name: T) -> Result<Self, DatabaseNameError> {
let name: Cow<'a, str> = name.into();
if !LENGTH_CONSTRAINT.contains(&name.len()) {
return Err(DatabaseNameError::LengthConstraint {
name: name.to_string(),
});
}
// Validate the name contains only valid characters.
//
// NOTE: If changing these characters, please update the error message
// above.
if !name
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
{
return Err(DatabaseNameError::BadChars {
name: name.to_string(),
});
}
Ok(Self(name))
}
}
impl<'a> std::convert::TryFrom<&'a str> for DatabaseName<'a> {
type Error = DatabaseNameError;
fn try_from(v: &'a str) -> Result<Self, Self::Error> {
Self::new(v)
}
}
impl<'a> std::convert::TryFrom<String> for DatabaseName<'a> {
type Error = DatabaseNameError;
fn try_from(v: String) -> Result<Self, Self::Error> {
Self::new(v)
}
}
impl<'a> std::ops::Deref for DatabaseName<'a> {
type Target = str;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl<'a> std::fmt::Display for DatabaseName<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::TryFrom;
#[test]
fn test_deref() {
let db = DatabaseName::new("my_example_name").unwrap();
assert_eq!(&*db, "my_example_name");
}
#[test]
fn test_too_short() {
let name = "".to_string();
let got = DatabaseName::try_from(name).unwrap_err();
assert!(matches!(
got,
DatabaseNameError::LengthConstraint { name: _n }
));
}
#[test]
fn test_too_long() {
let name = "my_example_name_that_is_quite_a_bit_longer_than_allowed_even_though_database_names_can_be_quite_long_bananas".to_string();
let got = DatabaseName::try_from(name).unwrap_err();
assert!(matches!(
got,
DatabaseNameError::LengthConstraint { name: _n }
));
}
#[test]
fn test_bad_chars() {
let got = DatabaseName::new("example!").unwrap_err();
assert!(matches!(got, DatabaseNameError::BadChars { name: _n }));
}
}

View File

@ -16,3 +16,6 @@ pub mod database_rules;
pub mod error;
pub mod partition_metadata;
pub mod table_schema;
mod database_name;
pub use database_name::*;

View File

@ -129,14 +129,6 @@ pub trait DatabaseStore: Debug + Send + Sync {
async fn db_or_create(&self, name: &str) -> Result<Arc<Self::Database>, Self::Error>;
}
/// Compatibility: return the database name to use for the specified
/// org and bucket name.
///
/// TODO move to somewhere else / change the traits to take the database name directly
pub fn org_and_bucket_to_database(org: impl Into<String>, bucket: &str) -> String {
org.into() + "_" + bucket
}
// Note: I would like to compile this module only in the 'test' cfg,
// but when I do so then other modules can not find them. For example:
//

View File

@ -13,6 +13,7 @@ use arrow_deps::arrow::record_batch::RecordBatch;
use data_types::{
data::{lines_to_replicated_write, ReplicatedWrite},
database_rules::{DatabaseRules, HostGroup, HostGroupId, MatchTables},
{DatabaseName, DatabaseNameError},
};
use influxdb_line_protocol::ParsedLine;
use object_store::ObjectStore;
@ -23,7 +24,7 @@ use async_trait::async_trait;
use bytes::Bytes;
use futures::stream::TryStreamExt;
use serde::{Deserialize, Serialize};
use snafu::{ensure, OptionExt, ResultExt, Snafu};
use snafu::{OptionExt, ResultExt, Snafu};
type DatabaseError = Box<dyn std::error::Error + Send + Sync + 'static>;
@ -33,11 +34,8 @@ pub enum Error {
ServerError { source: std::io::Error },
#[snafu(display("database not found: {}", db_name))]
DatabaseNotFound { db_name: String },
#[snafu(display(
"invalid database name {} can only contain alphanumeric, _ and - characters",
db_name
))]
InvalidDatabaseName { db_name: String },
#[snafu(display("invalid database: {}", source))]
InvalidDatabaseName { source: DatabaseNameError },
#[snafu(display("database error: {}", source))]
UnknownDatabaseError { source: DatabaseError },
#[snafu(display("no local buffer for database: {}", db))]
@ -79,7 +77,7 @@ pub struct Server<M: ConnectionManager> {
struct Config {
// id is optional because this may not be set on startup. It might be set via an API call
id: Option<u32>,
databases: BTreeMap<String, Db>,
databases: BTreeMap<DatabaseName<'static>, Db>,
host_groups: BTreeMap<HostGroupId, HostGroup>,
}
@ -111,16 +109,10 @@ impl<M: ConnectionManager> Server<M> {
// Return an error if this server hasn't yet been setup with an id
self.require_id()?;
let db_name = db_name.into();
ensure!(
db_name
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-'),
InvalidDatabaseName { db_name }
);
let db_name = DatabaseName::new(db_name.into()).context(InvalidDatabaseName)?;
let buffer = if rules.store_locally {
Some(WriteBufferDb::new(&db_name))
Some(WriteBufferDb::new(db_name.to_string()))
} else {
None
};
@ -201,41 +193,47 @@ impl<M: ConnectionManager> Server<M> {
pub async fn write_lines(&self, db_name: &str, lines: &[ParsedLine<'_>]) -> Result<()> {
let id = self.require_id()?;
let db_name = DatabaseName::new(db_name).context(InvalidDatabaseName)?;
let db = self
.config
.databases
.get(db_name)
.context(DatabaseNotFound { db_name })?;
.get(&db_name)
.context(DatabaseNotFound {
db_name: db_name.to_string(),
})?;
let sequence = db.next_sequence();
let write = lines_to_replicated_write(id, sequence, lines, &db.rules);
self.handle_replicated_write(db_name, db, write).await?;
self.handle_replicated_write(&db_name, db, write).await?;
Ok(())
}
/// Executes a query against the local write buffer database, if one exists.
pub async fn query_local(&self, db_name: &str, query: &str) -> Result<Vec<RecordBatch>> {
let db_name = DatabaseName::new(db_name).context(InvalidDatabaseName)?;
let db = self
.config
.databases
.get(db_name)
.context(DatabaseNotFound { db_name })?;
.get(&db_name)
.context(DatabaseNotFound {
db_name: db_name.to_string(),
})?;
let buff = db
.local_store
.as_ref()
.context(NoLocalBuffer { db: db_name })?;
let buff = db.local_store.as_ref().context(NoLocalBuffer {
db: db_name.to_string(),
})?;
buff.query(query)
.await
.map_err(|e| Box::new(e) as DatabaseError)
.context(UnknownDatabaseError {})
}
pub async fn handle_replicated_write(
pub async fn handle_replicated_write<'a>(
&self,
db_name: &str,
db_name: &DatabaseName<'a>,
db: &Db,
write: ReplicatedWrite,
) -> Result<()> {
@ -268,10 +266,10 @@ impl<M: ConnectionManager> Server<M> {
// replicates to a single host in the group based on hashing rules. If that host is unavailable
// an error will be returned. The request may still succeed if enough of the other host groups
// have returned a success.
async fn replicate_to_host_group(
async fn replicate_to_host_group<'a>(
&self,
host_group_id: &str,
db_name: &str,
db_name: &DatabaseName<'a>,
write: &ReplicatedWrite,
) -> Result<()> {
let group = self
@ -420,12 +418,8 @@ mod tests {
..Default::default()
};
let got = server.create_database(name, rules).await.unwrap_err();
if let Error::InvalidDatabaseName { db_name: got } = got {
if got != name {
panic!(format!("expected name {} got {}", name, got));
}
} else {
panic!("unexpected error");
if !matches!(got, Error::InvalidDatabaseName { source: _s }) {
panic!("expected invalid name error");
}
}

View File

@ -1,2 +1,45 @@
pub mod http_routes;
pub mod rpc;
use data_types::{DatabaseName, DatabaseNameError};
use snafu::{ResultExt, Snafu};
#[derive(Debug, Snafu)]
pub enum OrgBucketMappingError {
#[snafu(display(
"Internal error accessing org {}, bucket {}: the '_' character is reserved",
org,
bucket_name,
))]
InvalidBucketOrgName { org: String, bucket_name: String },
#[snafu(display("Invalid database name: {}", source))]
InvalidDatabaseName { source: DatabaseNameError },
}
/// Map an InfluxDB 2.X org & bucket into an IOx DatabaseName.
///
/// This function ensures the mapping is unambiguous by requiring both `org` and
/// `bucket` to not contain the `_` character in addition to the [`DatabaseName`]
/// validation.
pub(crate) fn org_and_bucket_to_database<'a, O: AsRef<str>, B: AsRef<str>>(
org: O,
bucket: B,
) -> Result<DatabaseName<'a>, OrgBucketMappingError> {
const SEPARATOR: char = '_';
// Ensure neither the org, nor the bucket contain the separator character.
if org.as_ref().chars().any(|c| c == SEPARATOR)
|| bucket.as_ref().chars().any(|c| c == SEPARATOR)
{
return InvalidBucketOrgName {
bucket_name: bucket.as_ref().to_string(),
org: org.as_ref().to_string(),
}
.fail();
}
let db_name = format!("{}{}{}", org.as_ref(), SEPARATOR, bucket.as_ref());
DatabaseName::new(db_name).context(InvalidDatabaseName)
}

View File

@ -17,8 +17,9 @@ use tracing::{debug, error, info};
use arrow_deps::arrow;
use influxdb_line_protocol::parse_lines;
use query::{org_and_bucket_to_database, DatabaseStore, SQLDatabase, TSDatabase};
use query::{DatabaseStore, SQLDatabase, TSDatabase};
use super::{org_and_bucket_to_database, OrgBucketMappingError};
use bytes::{Bytes, BytesMut};
use futures::{self, StreamExt};
use hyper::{Body, Method, StatusCode};
@ -41,6 +42,8 @@ pub enum ApplicationError {
bucket_name: String,
source: Box<dyn std::error::Error + Send + Sync>,
},
#[snafu(display("Internal error mapping org & bucket: {}", source))]
BucketMappingError { source: OrgBucketMappingError },
#[snafu(display(
"Internal error writing points into org {}, bucket {}: {}",
@ -122,6 +125,7 @@ impl ApplicationError {
pub fn status_code(&self) -> StatusCode {
match self {
Self::BucketByName { .. } => StatusCode::INTERNAL_SERVER_ERROR,
Self::BucketMappingError { .. } => StatusCode::INTERNAL_SERVER_ERROR,
Self::WritingPoints { .. } => StatusCode::INTERNAL_SERVER_ERROR,
Self::Query { .. } => StatusCode::INTERNAL_SERVER_ERROR,
Self::QueryError { .. } => StatusCode::BAD_REQUEST,
@ -213,7 +217,8 @@ async fn write<T: DatabaseStore>(
query_string: String::from(query),
})?;
let db_name = org_and_bucket_to_database(&write_info.org, &write_info.bucket);
let db_name = org_and_bucket_to_database(&write_info.org, &write_info.bucket)
.context(BucketMappingError)?;
let db = storage
.db_or_create(&db_name)
@ -273,7 +278,8 @@ async fn read<T: DatabaseStore>(
query_string: query,
})?;
let db_name = org_and_bucket_to_database(&read_info.org, &read_info.bucket);
let db_name = org_and_bucket_to_database(&read_info.org, &read_info.bucket)
.context(BucketMappingError)?;
let db = storage.db(&db_name).await.context(BucketNotFound {
org: read_info.org.clone(),

View File

@ -23,15 +23,16 @@ use generated_types::{node, Node};
use query::exec::fieldlist::FieldList;
use query::group_by::GroupByAndAggregate;
use crate::server::org_and_bucket_to_database;
use crate::server::rpc::expr::{self, AddRPCNode, Loggable, SpecialTagKeys};
use crate::server::rpc::input::GrpcInputs;
use data_types::DatabaseName;
use query::{
exec::{
seriesset::{Error as SeriesSetError, SeriesSetItem},
Executor as QueryExecutor,
},
org_and_bucket_to_database,
predicate::PredicateBuilder,
DatabaseStore, TSDatabase,
};
@ -785,11 +786,9 @@ impl SetRange for PredicateBuilder {
}
}
fn get_database_name(input: &impl GrpcInputs) -> Result<String, Status> {
Ok(org_and_bucket_to_database(
input.org_id()?,
&input.bucket_name()?,
))
fn get_database_name<'a>(input: &impl GrpcInputs) -> Result<DatabaseName<'a>, Status> {
org_and_bucket_to_database(input.org_id()?.to_string(), &input.bucket_name()?)
.map_err(|e| Status::internal(e.to_string()))
}
// The following code implements the business logic of the requests as
@ -802,7 +801,7 @@ fn get_database_name(input: &impl GrpcInputs) -> Result<String, Status> {
async fn measurement_name_impl<T>(
db_store: Arc<T>,
executor: Arc<QueryExecutor>,
db_name: String,
db_name: DatabaseName<'static>,
range: Option<TimestampRange>,
) -> Result<StringValuesResponse>
where
@ -813,11 +812,13 @@ where
let plan = db_store
.db(&db_name)
.await
.context(DatabaseNotFound { db_name: &db_name })?
.context(DatabaseNotFound {
db_name: db_name.to_string(),
})?
.table_names(predicate)
.await
.map_err(|e| Error::ListingTables {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})?;
@ -825,7 +826,7 @@ where
.to_string_set(plan)
.await
.map_err(|e| Error::ListingTables {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})?;
@ -842,7 +843,7 @@ where
async fn tag_keys_impl<T>(
db_store: Arc<T>,
executor: Arc<QueryExecutor>,
db_name: String,
db_name: DatabaseName<'static>,
measurement: Option<String>,
range: Option<TimestampRange>,
rpc_predicate: Option<Predicate>,
@ -861,16 +862,15 @@ where
})?
.build();
let db = db_store
.db(&db_name)
.await
.context(DatabaseNotFound { db_name: &db_name })?;
let db = db_store.db(&db_name).await.context(DatabaseNotFound {
db_name: db_name.to_string(),
})?;
let tag_key_plan = db
.tag_column_names(predicate)
.await
.map_err(|e| Error::ListingColumns {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})?;
@ -879,7 +879,7 @@ where
.to_string_set(tag_key_plan)
.await
.map_err(|e| Error::ListingColumns {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})?;
@ -896,7 +896,7 @@ where
async fn tag_values_impl<T>(
db_store: Arc<T>,
executor: Arc<QueryExecutor>,
db_name: String,
db_name: DatabaseName<'static>,
tag_name: String,
measurement: Option<String>,
range: Option<TimestampRange>,
@ -916,16 +916,15 @@ where
})?
.build();
let db = db_store
.db(&db_name)
.await
.context(DatabaseNotFound { db_name: &db_name })?;
let db = db_store.db(&db_name).await.context(DatabaseNotFound {
db_name: db_name.to_string(),
})?;
let tag_value_plan =
db.column_values(&tag_name, predicate)
.await
.map_err(|e| Error::ListingTagValues {
db_name: db_name.clone(),
db_name: db_name.to_string(),
tag_name: tag_name.clone(),
source: Box::new(e),
})?;
@ -935,7 +934,7 @@ where
.to_string_set(tag_value_plan)
.await
.map_err(|e| Error::ListingTagValues {
db_name: db_name.clone(),
db_name: db_name.to_string(),
tag_name: tag_name.clone(),
source: Box::new(e),
})?;
@ -954,11 +953,11 @@ where
}
/// Launch async tasks that send the result of executing read_filter to `tx`
async fn read_filter_impl<T>(
async fn read_filter_impl<'a, T>(
tx: mpsc::Sender<Result<ReadResponse, Status>>,
db_store: Arc<T>,
executor: Arc<QueryExecutor>,
db_name: String,
db_name: DatabaseName<'static>,
range: Option<TimestampRange>,
rpc_predicate: Option<Predicate>,
) -> Result<()>
@ -975,16 +974,15 @@ where
})?
.build();
let db = db_store
.db(&db_name)
.await
.context(DatabaseNotFound { db_name: &db_name })?;
let db = db_store.db(&db_name).await.context(DatabaseNotFound {
db_name: db_name.to_string(),
})?;
let series_plan =
db.query_series(predicate)
.await
.map_err(|e| Error::PlanningFilteringSeries {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})?;
@ -1004,7 +1002,7 @@ where
.to_series_set(series_plan, tx_series)
.await
.map_err(|e| Error::FilteringSeries {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})
.log_if_error("Running series set plan")
@ -1040,7 +1038,7 @@ async fn query_group_impl<T>(
tx: mpsc::Sender<Result<ReadResponse, Status>>,
db_store: Arc<T>,
executor: Arc<QueryExecutor>,
db_name: String,
db_name: DatabaseName<'static>,
range: Option<TimestampRange>,
rpc_predicate: Option<Predicate>,
gby_agg: GroupByAndAggregate,
@ -1058,16 +1056,15 @@ where
})?
.build();
let db = db_store
.db(&db_name)
.await
.context(DatabaseNotFound { db_name: &db_name })?;
let db = db_store.db(&db_name).await.context(DatabaseNotFound {
db_name: db_name.to_string(),
})?;
let grouped_series_set_plan =
db.query_groups(predicate, gby_agg)
.await
.map_err(|e| Error::PlanningFilteringSeries {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})?;
@ -1087,7 +1084,7 @@ where
.to_series_set(grouped_series_set_plan, tx_series)
.await
.map_err(|e| Error::GroupingSeries {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})
.log_if_error("Running Grouped SeriesSet Plan")
@ -1100,7 +1097,7 @@ where
async fn field_names_impl<T>(
db_store: Arc<T>,
executor: Arc<QueryExecutor>,
db_name: String,
db_name: DatabaseName<'static>,
measurement: Option<String>,
range: Option<TimestampRange>,
rpc_predicate: Option<Predicate>,
@ -1119,16 +1116,15 @@ where
})?
.build();
let db = db_store
.db(&db_name)
.await
.context(DatabaseNotFound { db_name: &db_name })?;
let db = db_store.db(&db_name).await.context(DatabaseNotFound {
db_name: db_name.to_string(),
})?;
let fieldlist_plan =
db.field_column_names(predicate)
.await
.map_err(|e| Error::ListingFields {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})?;
@ -1137,7 +1133,7 @@ where
.to_fieldlist(fieldlist_plan)
.await
.map_err(|e| Error::ListingFields {
db_name: db_name.clone(),
db_name: db_name.to_string(),
source: Box::new(e),
})?;
@ -2222,7 +2218,7 @@ mod tests {
org_id: u64,
bucket_id: u64,
/// The influxdb_iox database name corresponding to `org_id` and `bucket_id`
db_name: String,
db_name: DatabaseName<'static>,
}
impl OrgAndBucket {
@ -2233,7 +2229,8 @@ mod tests {
.expect("bucket_id was valid")
.to_string();
let db_name = org_and_bucket_to_database(&org_id_str, &bucket_id_str);
let db_name = org_and_bucket_to_database(&org_id_str, &bucket_id_str)
.expect("mock database name construction failed");
Self {
org_id,