Merge branch 'main' into savage/replace-dml-operation-for-ingester-rpc-write
commit
12d5b5e529
|
@ -3650,6 +3650,7 @@ dependencies = [
|
|||
"schema",
|
||||
"snafu",
|
||||
"thiserror",
|
||||
"unicode-segmentation",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
|
|
|
@ -47,14 +47,68 @@
|
|||
//! [`build_column_values()`] function can be used to obtain the set of
|
||||
//! [`TemplatePart::TagValue`] the key was constructed from.
|
||||
//!
|
||||
//! ### Value Truncation
|
||||
//!
|
||||
//! Partition key parts are limited to, at most, 200 bytes in length
|
||||
//! ([`PARTITION_KEY_MAX_PART_LEN`]). If any single partition key part exceeds
|
||||
//! this length limit, it is truncated and the truncation marker `#`
|
||||
//! ([`PARTITION_KEY_PART_TRUNCATED`]) is appended.
|
||||
//!
|
||||
//! When rebuilding column values using [`build_column_values()`], a truncated
|
||||
//! key part yields [`ColumnValue::Prefix`], which can only be used for prefix
|
||||
//! matching - equality matching against a string always returns false.
|
||||
//!
|
||||
//! Two considerations must be made when truncating the generated key:
|
||||
//!
|
||||
//! * The string may contain encoded sequences in the form %XX, and the string
|
||||
//! should not be split within an encoded sequence, or decoding the string
|
||||
//! will fail.
|
||||
//!
|
||||
//! * This may be a unicode string - what the user might consider a "character"
|
||||
//! may in fact be multiple unicode code-points, each of which may span
|
||||
//! multiple bytes.
|
||||
//!
|
||||
//! Slicing a unicode code-point in half may lead to an invalid UTF-8 string,
|
||||
//! which will prevent it from being used in Rust (and likely many other
|
||||
//! languages/systems). Because partition keys are represented as strings and
|
||||
//! not bytes, splitting a code-point in half MUST be avoided.
|
||||
//!
|
||||
//! Further to this, a sequence of multiple code-points can represent a single
|
||||
//! "character" - this is called a grapheme. For example, the representation of
|
||||
//! the Tamil "ni" character "நி" is composed of two multi-byte code-points; the
|
||||
//! Tamil letter "na" which renders as "ந" and the vowel sign "ி", each 3 bytes
|
||||
//! long. If split after the first 3 bytes, the compound "ni" character will be
|
||||
//! incorrectly rendered as the single "na"/"ந" character.
|
||||
//!
|
||||
//! Depending on what the consumer of the split string considers a character,
|
||||
//! prefix/equality matching may produce differing results if a grapheme is
|
||||
//! split. If the caller performs a byte-wise comparison, everything is fine -
|
||||
//! if they perform a "character" comparison, then the equality may be lost
|
||||
//! depending on what they consider a character.
|
||||
//!
|
||||
//! Therefore this implementation takes the conservative approach of never
|
||||
//! splitting code-points (for UTF-8 correctness) nor graphemes for simplicity
|
||||
//! and compatibility for the consumer. This may be relaxed in the future to
|
||||
//! allow splitting graphemes, but by being conservative we give ourselves this
|
||||
//! option - we can't easily do the reverse!
|
||||
//!
|
||||
//! ## Part Limit & Maximum Key Size
|
||||
//!
|
||||
//! The number of parts in a partition template is limited to 8
|
||||
//! ([`MAXIMUM_NUMBER_OF_TEMPLATE_PARTS`]), validated at creation time.
|
||||
//!
|
||||
//! Together with the above value truncation, this bounds the maximum length of
|
||||
//! a partition key to 1,607 bytes (1.57 KiB).
|
||||
//!
|
||||
//! ### Reserved Characters
|
||||
//!
|
||||
//! Reserved characters that are percent encoded (in addition to non-printable
|
||||
//! Reserved characters that are percent encoded (in addition to non-ASCII
|
||||
//! characters), and their meaning:
|
||||
//!
|
||||
//! * `|` - partition key part delimiter ([`PARTITION_KEY_DELIMITER`])
|
||||
//! * `!` - NULL/missing partition key part ([`PARTITION_KEY_VALUE_NULL`])
|
||||
//! * `^` - empty string partition key part ([`PARTITION_KEY_VALUE_EMPTY`])
|
||||
//! * `#` - key part truncation marker ([`PARTITION_KEY_PART_TRUNCATED`])
|
||||
//! * `%` - required for unambiguous reversal of percent encoding
|
||||
//!
|
||||
//! These characters are defined in [`ENCODED_PARTITION_KEY_CHARS`] and chosen
|
||||
|
@ -74,13 +128,14 @@
|
|||
//!
|
||||
//! The following partition keys are derived:
|
||||
//!
|
||||
//! * `time=2023-01-01, a=bananas, b=plátanos` -> `2023|bananas|plátanos
|
||||
//! * `time=2023-01-01, a=bananas, b=plátanos` -> `2023|bananas|plátanos`
|
||||
//! * `time=2023-01-01, b=plátanos` -> `2023|!|plátanos`
|
||||
//! * `time=2023-01-01, another=cat, b=plátanos` -> `2023|!|plátanos`
|
||||
//! * `time=2023-01-01` -> `2023|!|!`
|
||||
//! * `time=2023-01-01, a=cat|dog, b=!` -> `2023|cat%7Cdog|%21`
|
||||
//! * `time=2023-01-01, a=%50` -> `2023|%2550|!`
|
||||
//! * `time=2023-01-01, a=` -> `2023|^|!`
|
||||
//! * `time=2023-01-01, a=<long string>` -> `2023|<long string>#|!`
|
||||
//!
|
||||
//! When using the default partitioning template (YYYY-MM-DD) there is no
|
||||
//! encoding necessary, as the derived partition key contains a single part, and
|
||||
|
@ -144,6 +199,17 @@ pub const PARTITION_KEY_VALUE_NULL: char = '!';
|
|||
/// The `str` form of the [`PARTITION_KEY_VALUE_NULL`] character.
|
||||
pub const PARTITION_KEY_VALUE_NULL_STR: &str = "!";
|
||||
|
||||
/// The maximum permissible length of a partition key part, after encoding
|
||||
/// reserved & non-ASCII characters.
|
||||
pub const PARTITION_KEY_MAX_PART_LEN: usize = 200;
|
||||
|
||||
/// The truncation sentinel character, used to explicitly identify a partition
|
||||
/// key as having been truncated.
|
||||
///
|
||||
/// Truncated partition key parts can only be used for prefix matching, and
|
||||
/// yield a [`ColumnValue::Prefix`] from [`build_column_values()`].
|
||||
pub const PARTITION_KEY_PART_TRUNCATED: char = '#';
|
||||
|
||||
/// The minimal set of characters that must be encoded during partition key
|
||||
/// generation when they form part of a partition key part, in order to be
|
||||
/// unambiguously reversible.
|
||||
|
@ -153,6 +219,7 @@ pub const ENCODED_PARTITION_KEY_CHARS: AsciiSet = CONTROLS
|
|||
.add(PARTITION_KEY_DELIMITER as u8)
|
||||
.add(PARTITION_KEY_VALUE_NULL as u8)
|
||||
.add(PARTITION_KEY_VALUE_EMPTY as u8)
|
||||
.add(PARTITION_KEY_PART_TRUNCATED as u8)
|
||||
.add(b'%'); // Required for reversible unambiguous encoding
|
||||
|
||||
/// Allocationless and protobufless access to the parts of a template needed to
|
||||
|
@ -218,6 +285,12 @@ impl TablePartitionTemplateOverride {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the number of parts in this template.
|
||||
#[allow(clippy::len_without_is_empty)] // Senseless - there must always be >0 parts.
|
||||
pub fn len(&self) -> usize {
|
||||
self.parts().count()
|
||||
}
|
||||
|
||||
/// Iterate through the protobuf parts and lend out what the `mutable_batch` crate needs to
|
||||
/// build `PartitionKey`s. If this table doesn't have a custom template, use the application
|
||||
/// default of partitioning by day.
|
||||
|
@ -366,6 +439,54 @@ mod serialization {
|
|||
}
|
||||
}
|
||||
|
||||
/// The value of a column, reversed from a partition key.
|
||||
///
|
||||
/// See [`build_column_values()`].
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ColumnValue<'a> {
|
||||
/// The inner value is the exact, unmodified input column value.
|
||||
Identity(Cow<'a, str>),
|
||||
|
||||
/// The inner value is a variable length prefix of the input column value.
|
||||
///
|
||||
/// The string value is always guaranteed to be valid UTF-8.
|
||||
///
|
||||
/// Attempting to equality match this variant against a string will always
|
||||
/// be false - use [`ColumnValue::is_prefix_match_of()`] to prefix match
|
||||
/// instead.
|
||||
Prefix(Cow<'a, str>),
|
||||
}
|
||||
|
||||
impl<'a> ColumnValue<'a> {
|
||||
/// Returns true if `other` is a byte-wise prefix match of `self`.
|
||||
///
|
||||
/// This method can be called for both [`ColumnValue::Identity`] and
|
||||
/// [`ColumnValue::Prefix`].
|
||||
pub fn is_prefix_match_of<T>(&self, other: T) -> bool
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
{
|
||||
let this = match self {
|
||||
ColumnValue::Identity(v) => v.as_bytes(),
|
||||
ColumnValue::Prefix(v) => v.as_bytes(),
|
||||
};
|
||||
|
||||
other.as_ref().starts_with(this)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> PartialEq<T> for ColumnValue<'a>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
fn eq(&self, other: &T) -> bool {
|
||||
match self {
|
||||
ColumnValue::Identity(v) => other.as_ref().eq(v.as_ref()),
|
||||
ColumnValue::Prefix(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reverse a `partition_key` generated from the given partition key `template`,
|
||||
/// reconstructing the set of tag values in the form of `(column name, column
|
||||
/// value)` tuples that the `partition_key` was generated from.
|
||||
|
@ -381,7 +502,7 @@ mod serialization {
|
|||
pub fn build_column_values<'a>(
|
||||
template: &'a TablePartitionTemplateOverride,
|
||||
partition_key: &'a str,
|
||||
) -> impl Iterator<Item = (&'a str, Cow<'a, str>)> {
|
||||
) -> impl Iterator<Item = (&'a str, ColumnValue<'a>)> {
|
||||
// Exploded parts of the generated key on the "/" character.
|
||||
//
|
||||
// Any uses of the "/" character within the partition key's user-provided
|
||||
|
@ -424,12 +545,30 @@ pub fn build_column_values<'a>(
|
|||
})
|
||||
// Reverse the urlencoding of all value parts
|
||||
.map(|(name, value)| {
|
||||
(
|
||||
name,
|
||||
percent_decode_str(value)
|
||||
.decode_utf8()
|
||||
.expect("invalid partition key part encoding"),
|
||||
)
|
||||
let decoded = percent_decode_str(value)
|
||||
.decode_utf8()
|
||||
.expect("invalid partition key part encoding");
|
||||
|
||||
// Inspect the final character in the string, pre-decoding, to
|
||||
// determine if it has been truncated.
|
||||
if value
|
||||
.as_bytes()
|
||||
.last()
|
||||
.map(|v| *v == PARTITION_KEY_PART_TRUNCATED as u8)
|
||||
.unwrap_or_default()
|
||||
{
|
||||
// Remove the truncation marker.
|
||||
let len = decoded.len() - 1;
|
||||
|
||||
// Only allocate if needed; re-borrow a subslice of `Cow::Borrowed` if not.
|
||||
let column_cow = match decoded {
|
||||
Cow::Borrowed(s) => Cow::Borrowed(&s[..len]),
|
||||
Cow::Owned(s) => Cow::Owned(s[..len].to_string()),
|
||||
};
|
||||
return (name, ColumnValue::Prefix(column_cow));
|
||||
}
|
||||
|
||||
(name, ColumnValue::Identity(decoded))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -467,6 +606,21 @@ mod tests {
|
|||
use sqlx::Encode;
|
||||
use test_helpers::assert_error;
|
||||
|
||||
#[test]
|
||||
fn test_max_partition_key_len() {
|
||||
let max_len: usize =
|
||||
// 8 parts, at most 200 bytes long.
|
||||
(MAXIMUM_NUMBER_OF_TEMPLATE_PARTS * PARTITION_KEY_MAX_PART_LEN)
|
||||
// 7 delimiting characters between parts.
|
||||
+ (MAXIMUM_NUMBER_OF_TEMPLATE_PARTS - 1);
|
||||
|
||||
// If this changes, the module documentation should be changed too.
|
||||
//
|
||||
// This shouldn't change without consideration of primary key overlap as
|
||||
// a result.
|
||||
assert_eq!(max_len, 1_607, "update module docs please");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_parts_is_invalid() {
|
||||
let err = serialization::Wrapper::try_from(proto::PartitionTemplate { parts: vec![] });
|
||||
|
@ -555,6 +709,17 @@ mod tests {
|
|||
assert_error!(err, ValidationError::InvalidStrftime(ref format) if format.is_empty());
|
||||
}
|
||||
|
||||
fn identity(s: &str) -> ColumnValue<'_> {
|
||||
ColumnValue::Identity(s.into())
|
||||
}
|
||||
|
||||
fn prefix<'a, T>(s: T) -> ColumnValue<'a>
|
||||
where
|
||||
T: Into<Cow<'a, str>>,
|
||||
{
|
||||
ColumnValue::Prefix(s.into())
|
||||
}
|
||||
|
||||
/// Generate a test that asserts "partition_key" is reversible, yielding
|
||||
/// "want" assuming the partition "template" was used.
|
||||
macro_rules! test_build_column_values {
|
||||
|
@ -570,17 +735,13 @@ mod tests {
|
|||
let template = $template.into_iter().collect::<Vec<_>>();
|
||||
let template = test_table_partition_override(template);
|
||||
|
||||
// normalise the values into a (str, string) for the comparison
|
||||
// normalise the values into a (str, ColumnValue) for the comparison
|
||||
let want = $want
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
let v: &str = v;
|
||||
(k, v.to_string())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let got = build_column_values(&template, $partition_key)
|
||||
.map(|(k, v)| (k, v.to_string()))
|
||||
let input = String::from($partition_key);
|
||||
let got = build_column_values(&template, input.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(got, want);
|
||||
|
@ -597,7 +758,7 @@ mod tests {
|
|||
TemplatePart::TagValue("b"),
|
||||
],
|
||||
partition_key = "2023|bananas|plátanos",
|
||||
want = [("a", "bananas"), ("b", "plátanos")]
|
||||
want = [("a", identity("bananas")), ("b", identity("plátanos"))]
|
||||
);
|
||||
|
||||
test_build_column_values!(
|
||||
|
@ -608,7 +769,7 @@ mod tests {
|
|||
TemplatePart::TagValue("b"),
|
||||
],
|
||||
partition_key = "2023|!|plátanos",
|
||||
want = [("b", "plátanos")]
|
||||
want = [("b", identity("plátanos"))]
|
||||
);
|
||||
|
||||
test_build_column_values!(
|
||||
|
@ -630,7 +791,7 @@ mod tests {
|
|||
TemplatePart::TagValue("b"),
|
||||
],
|
||||
partition_key = "2023|cat%7Cdog|%21",
|
||||
want = [("a", "cat|dog"), ("b", "!")]
|
||||
want = [("a", identity("cat|dog")), ("b", identity("!"))]
|
||||
);
|
||||
|
||||
test_build_column_values!(
|
||||
|
@ -641,7 +802,40 @@ mod tests {
|
|||
TemplatePart::TagValue("b"),
|
||||
],
|
||||
partition_key = "2023|%2550|!",
|
||||
want = [("a", "%50")]
|
||||
want = [("a", identity("%50"))]
|
||||
);
|
||||
|
||||
test_build_column_values!(
|
||||
module_doc_example_7,
|
||||
template = [
|
||||
TemplatePart::TimeFormat("%Y"),
|
||||
TemplatePart::TagValue("a"),
|
||||
TemplatePart::TagValue("b"),
|
||||
],
|
||||
partition_key = "2023|BANANAS#|!",
|
||||
want = [("a", prefix("BANANAS"))]
|
||||
);
|
||||
|
||||
test_build_column_values!(
|
||||
unicode_code_point_prefix,
|
||||
template = [
|
||||
TemplatePart::TimeFormat("%Y"),
|
||||
TemplatePart::TagValue("a"),
|
||||
TemplatePart::TagValue("b"),
|
||||
],
|
||||
partition_key = "2023|%28%E3%83%8E%E0%B2%A0%E7%9B%8A%E0%B2%A0%29%E3%83%8E%E5%BD%A1%E2%94%BB%E2%94%81%E2%94%BB#|!",
|
||||
want = [("a", prefix("(ノಠ益ಠ)ノ彡┻━┻"))]
|
||||
);
|
||||
|
||||
test_build_column_values!(
|
||||
unicode_grapheme,
|
||||
template = [
|
||||
TemplatePart::TimeFormat("%Y"),
|
||||
TemplatePart::TagValue("a"),
|
||||
TemplatePart::TagValue("b"),
|
||||
],
|
||||
partition_key = "2023|%E0%AE%A8%E0%AE%BF#|!",
|
||||
want = [("a", prefix("நி"))]
|
||||
);
|
||||
|
||||
test_build_column_values!(
|
||||
|
@ -651,8 +845,8 @@ mod tests {
|
|||
TemplatePart::TagValue("a"),
|
||||
TemplatePart::TagValue("b"),
|
||||
],
|
||||
partition_key = "2023|is%7Cnot%21ambiguous%2510|!",
|
||||
want = [("a", "is|not!ambiguous%10")]
|
||||
partition_key = "2023|is%7Cnot%21ambiguous%2510%23|!",
|
||||
want = [("a", identity("is|not!ambiguous%10#"))]
|
||||
);
|
||||
|
||||
test_build_column_values!(
|
||||
|
@ -671,11 +865,30 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_partition_key_char_str_equality() {
|
||||
assert_eq!(
|
||||
PARTITION_KEY_VALUE_EMPTY.to_string(),
|
||||
PARTITION_KEY_VALUE_EMPTY_STR
|
||||
);
|
||||
fn test_column_value_partial_eq() {
|
||||
assert_eq!(identity("bananas"), "bananas");
|
||||
|
||||
assert_ne!(identity("bananas"), "bananas2");
|
||||
assert_ne!(identity("bananas2"), "bananas");
|
||||
|
||||
assert_ne!(prefix("bananas"), "bananas");
|
||||
assert_ne!(prefix("bananas"), "bananas2");
|
||||
assert_ne!(prefix("bananas2"), "bananas");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_column_value_is_prefix_match() {
|
||||
let b = "bananas".to_string();
|
||||
assert!(identity("bananas").is_prefix_match_of(b));
|
||||
|
||||
assert!(identity("bananas").is_prefix_match_of("bananas"));
|
||||
assert!(identity("bananas").is_prefix_match_of("bananas2"));
|
||||
|
||||
assert!(prefix("bananas").is_prefix_match_of("bananas"));
|
||||
assert!(prefix("bananas").is_prefix_match_of("bananas2"));
|
||||
|
||||
assert!(!identity("bananas2").is_prefix_match_of("bananas"));
|
||||
assert!(!prefix("bananas2").is_prefix_match_of("bananas"));
|
||||
}
|
||||
|
||||
/// This test asserts the default derived partitioning scheme with no
|
||||
|
@ -695,6 +908,14 @@ mod tests {
|
|||
assert_matches!(got.as_slice(), [TemplatePart::TimeFormat("%Y-%m-%d")]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn len_of_default_template_is_1() {
|
||||
let ns = NamespacePartitionTemplateOverride::default();
|
||||
let t = TablePartitionTemplateOverride::try_new(None, &ns).unwrap();
|
||||
|
||||
assert_eq!(t.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_custom_table_template_specified_gets_namespace_template() {
|
||||
let namespace_template =
|
||||
|
@ -707,6 +928,7 @@ mod tests {
|
|||
let table_template =
|
||||
TablePartitionTemplateOverride::try_new(None, &namespace_template).unwrap();
|
||||
|
||||
assert_eq!(table_template.len(), 1);
|
||||
assert_eq!(table_template.0, namespace_template.0);
|
||||
}
|
||||
|
||||
|
@ -730,6 +952,7 @@ mod tests {
|
|||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table_template.len(), 1);
|
||||
assert_eq!(table_template.0.unwrap().inner(), &custom_table_template);
|
||||
}
|
||||
|
||||
|
@ -779,5 +1002,6 @@ mod tests {
|
|||
);
|
||||
let table_json_str: String = buf.iter().map(extract_sqlite_argument_text).collect();
|
||||
assert_eq!(table_json_str, expected_json_str);
|
||||
assert_eq!(table.len(), 2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ itertools = "0.10"
|
|||
workspace-hack = { version = "0.1", path = "../workspace-hack" }
|
||||
percent-encoding = "2.2.0"
|
||||
thiserror = "1.0.40"
|
||||
unicode-segmentation = "1.10.1"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = "1.5.0"
|
||||
|
|
|
@ -11,11 +11,13 @@ use std::{borrow::Cow, ops::Range};
|
|||
|
||||
use data_types::partition_template::{
|
||||
TablePartitionTemplateOverride, TemplatePart, ENCODED_PARTITION_KEY_CHARS,
|
||||
PARTITION_KEY_DELIMITER, PARTITION_KEY_VALUE_EMPTY_STR, PARTITION_KEY_VALUE_NULL_STR,
|
||||
MAXIMUM_NUMBER_OF_TEMPLATE_PARTS, PARTITION_KEY_DELIMITER, PARTITION_KEY_MAX_PART_LEN,
|
||||
PARTITION_KEY_PART_TRUNCATED, PARTITION_KEY_VALUE_EMPTY_STR, PARTITION_KEY_VALUE_NULL_STR,
|
||||
};
|
||||
use percent_encoding::utf8_percent_encode;
|
||||
use schema::{InfluxColumnType, TIME_COLUMN_NAME};
|
||||
use thiserror::Error;
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
|
||||
use crate::{
|
||||
column::{Column, ColumnData},
|
||||
|
@ -49,6 +51,14 @@ pub fn partition_batch<'a>(
|
|||
batch: &'a MutableBatch,
|
||||
template: &'a TablePartitionTemplateOverride,
|
||||
) -> impl Iterator<Item = (Result<String, PartitionKeyError>, Range<usize>)> + 'a {
|
||||
let parts = template.len();
|
||||
if parts > MAXIMUM_NUMBER_OF_TEMPLATE_PARTS {
|
||||
panic!(
|
||||
"partition template contains {} parts, which exceeds the maximum of {} parts",
|
||||
parts, MAXIMUM_NUMBER_OF_TEMPLATE_PARTS
|
||||
);
|
||||
}
|
||||
|
||||
range_encode(partition_keys(batch, template.parts()))
|
||||
}
|
||||
|
||||
|
@ -88,13 +98,9 @@ impl<'a> Template<'a> {
|
|||
// potentially different key.
|
||||
*last_key = Some(this_key);
|
||||
|
||||
out.write_str(never_empty(
|
||||
Cow::from(utf8_percent_encode(
|
||||
dictionary.lookup_id(this_key).unwrap(),
|
||||
&ENCODED_PARTITION_KEY_CHARS,
|
||||
))
|
||||
.as_ref(),
|
||||
))?
|
||||
out.write_str(
|
||||
encode_key_part(dictionary.lookup_id(this_key).unwrap()).as_ref(),
|
||||
)?
|
||||
}
|
||||
_ => return Err(PartitionKeyError::TagValueNotTag(col.influx_type())),
|
||||
},
|
||||
|
@ -147,13 +153,49 @@ impl<'a> Template<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Return `s` if it is non-empty, else [`PARTITION_KEY_VALUE_EMPTY_STR`].
|
||||
#[inline(always)]
|
||||
fn never_empty(s: &str) -> &str {
|
||||
if s.is_empty() {
|
||||
return PARTITION_KEY_VALUE_EMPTY_STR;
|
||||
fn encode_key_part(s: &str) -> Cow<'_, str> {
|
||||
// Encode reserved characters and non-ascii characters.
|
||||
let as_str: Cow<'_, str> = utf8_percent_encode(s, &ENCODED_PARTITION_KEY_CHARS).into();
|
||||
|
||||
match as_str.len() {
|
||||
0 => Cow::Borrowed(PARTITION_KEY_VALUE_EMPTY_STR),
|
||||
1..=PARTITION_KEY_MAX_PART_LEN => as_str,
|
||||
_ => {
|
||||
// This string exceeds the maximum byte length limit and must be
|
||||
// truncated.
|
||||
//
|
||||
// Truncation of unicode strings can be tricky - this implementation
|
||||
// avoids splitting unicode code-points nor graphemes. See the
|
||||
// partition_template module docs in data_types before altering
|
||||
// this.
|
||||
|
||||
// Preallocate the string to hold the long partition key part.
|
||||
let mut buf = String::with_capacity(PARTITION_KEY_MAX_PART_LEN);
|
||||
|
||||
// This is a slow path, re-encoding the original input string -
|
||||
// fortunately this is an uncommon path.
|
||||
//
|
||||
// Walk the string, encoding each grapheme (which includes spaces)
|
||||
// individually, tracking the total length of the encoded string.
|
||||
// Once it hits 199 bytes, stop and append a #.
|
||||
|
||||
let mut bytes = 0;
|
||||
s.graphemes(true)
|
||||
.map(|v| Cow::from(utf8_percent_encode(v, &ENCODED_PARTITION_KEY_CHARS)))
|
||||
.take_while(|v| {
|
||||
bytes += v.len(); // Byte length of encoded grapheme
|
||||
bytes < PARTITION_KEY_MAX_PART_LEN
|
||||
})
|
||||
.for_each(|v| buf.push_str(v.as_ref()));
|
||||
|
||||
// Append the truncation marker.
|
||||
buf.push(PARTITION_KEY_PART_TRUNCATED);
|
||||
|
||||
assert!(buf.len() <= PARTITION_KEY_MAX_PART_LEN);
|
||||
|
||||
Cow::Owned(buf)
|
||||
}
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
/// Returns an iterator of partition keys for the given table batch.
|
||||
|
@ -293,7 +335,9 @@ mod tests {
|
|||
use crate::writer::Writer;
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use data_types::partition_template::{build_column_values, test_table_partition_override};
|
||||
use data_types::partition_template::{
|
||||
build_column_values, test_table_partition_override, ColumnValue,
|
||||
};
|
||||
use proptest::{prelude::*, prop_compose, proptest, strategy::Strategy};
|
||||
use rand::prelude::*;
|
||||
|
||||
|
@ -529,6 +573,20 @@ mod tests {
|
|||
assert_matches::assert_matches!(got, Err(PartitionKeyError::TagValueNotTag(_)));
|
||||
}
|
||||
|
||||
fn identity<'a, T>(s: T) -> ColumnValue<'a>
|
||||
where
|
||||
T: Into<Cow<'a, str>>,
|
||||
{
|
||||
ColumnValue::Identity(s.into())
|
||||
}
|
||||
|
||||
fn prefix<'a, T>(s: T) -> ColumnValue<'a>
|
||||
where
|
||||
T: Into<Cow<'a, str>>,
|
||||
{
|
||||
ColumnValue::Prefix(s.into())
|
||||
}
|
||||
|
||||
// Generate a test that asserts the derived partition key matches
|
||||
// "want_key", when using the provided "template" parts and set of "tags".
|
||||
//
|
||||
|
@ -557,8 +615,9 @@ mod tests {
|
|||
.unwrap();
|
||||
|
||||
for (col, value) in $tags {
|
||||
let v = String::from(value);
|
||||
writer
|
||||
.write_tag(col, Some(&[0b00000001]), vec![value].into_iter())
|
||||
.write_tag(col, Some(&[0b00000001]), vec![v.as_str()].into_iter())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
|
@ -569,24 +628,19 @@ mod tests {
|
|||
// normalise the values.
|
||||
let keys = generate_denormalised_keys(&batch, template.parts())
|
||||
.unwrap();
|
||||
assert_eq!(keys, vec![$want_key.to_string()]);
|
||||
assert_eq!(keys, vec![$want_key.to_string()], "generated key differs");
|
||||
|
||||
// Reverse the encoding.
|
||||
let reversed = build_column_values(&template, &keys[0]);
|
||||
|
||||
// normalise the tags into a (str, string) for the comparison
|
||||
let want = $want_reversed_tags
|
||||
// Expect the tags to be (str, ColumnValue) for the
|
||||
// comparison
|
||||
let want: Vec<(&str, ColumnValue<'_>)> = $want_reversed_tags
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
let v: &str = v;
|
||||
(k, v.to_string())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
.collect();
|
||||
|
||||
let got = reversed
|
||||
.map(|(k, v)| (k, v.to_string()))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(got, want);
|
||||
let got = reversed.collect::<Vec<_>>();
|
||||
assert_eq!(got, want, "reversed key differs");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -601,7 +655,7 @@ mod tests {
|
|||
],
|
||||
tags = [("a", "bananas"), ("b", "are_good")],
|
||||
want_key = "2023|bananas|are_good",
|
||||
want_reversed_tags = [("a", "bananas"), ("b", "are_good")]
|
||||
want_reversed_tags = [("a", identity("bananas")), ("b", identity("are_good"))]
|
||||
);
|
||||
|
||||
test_partition_key!(
|
||||
|
@ -613,7 +667,7 @@ mod tests {
|
|||
],
|
||||
tags = [("a", "bananas"), ("b", "plátanos")],
|
||||
want_key = "2023|bananas|pl%C3%A1tanos",
|
||||
want_reversed_tags = [("a", "bananas"), ("b", "plátanos")]
|
||||
want_reversed_tags = [("a", identity("bananas")), ("b", identity("plátanos"))]
|
||||
);
|
||||
|
||||
test_partition_key!(
|
||||
|
@ -629,7 +683,7 @@ mod tests {
|
|||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", "")],
|
||||
want_key = "^",
|
||||
want_reversed_tags = [("a", "")]
|
||||
want_reversed_tags = [("a", identity(""))]
|
||||
);
|
||||
|
||||
test_partition_key!(
|
||||
|
@ -637,7 +691,7 @@ mod tests {
|
|||
template = [TemplatePart::TagValue("a"), TemplatePart::TagValue("b")],
|
||||
tags = [("a", "bananas")],
|
||||
want_key = "bananas|!",
|
||||
want_reversed_tags = [("a", "bananas")]
|
||||
want_reversed_tags = [("a", identity("bananas"))]
|
||||
);
|
||||
|
||||
test_partition_key!(
|
||||
|
@ -652,7 +706,256 @@ mod tests {
|
|||
],
|
||||
tags = [("a", "|"), ("b", "!"), ("d", "%7C%21%257C"), ("e", "^")],
|
||||
want_key = "2023|%7C|%21|!|%257C%2521%25257C|%5E",
|
||||
want_reversed_tags = [("a", "|"), ("b", "!"), ("d", "%7C%21%257C"), ("e", "^")]
|
||||
want_reversed_tags = [
|
||||
("a", identity("|")),
|
||||
("b", identity("!")),
|
||||
("d", identity("%7C%21%257C")),
|
||||
("e", identity("^"))
|
||||
]
|
||||
);
|
||||
|
||||
test_partition_key!(
|
||||
truncated_char_reserved,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", "#")],
|
||||
want_key = "%23",
|
||||
want_reversed_tags = [("a", identity("#"))]
|
||||
);
|
||||
|
||||
// Keys < 200 bytes long should not be truncated.
|
||||
test_partition_key!(
|
||||
truncate_length_199,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", "A".repeat(199))],
|
||||
want_key = "A".repeat(199),
|
||||
want_reversed_tags = [("a", identity("A".repeat(199)))]
|
||||
);
|
||||
|
||||
// Keys of exactly 200 bytes long should not be truncated.
|
||||
test_partition_key!(
|
||||
truncate_length_200,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", "A".repeat(200))],
|
||||
want_key = "A".repeat(200),
|
||||
want_reversed_tags = [("a", identity("A".repeat(200)))]
|
||||
);
|
||||
|
||||
// Keys > 200 bytes long should be truncated to exactly 200 bytes,
|
||||
// terminated by a # character.
|
||||
test_partition_key!(
|
||||
truncate_length_201,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", "A".repeat(201))],
|
||||
want_key = format!("{}#", "A".repeat(199)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(199)))]
|
||||
);
|
||||
|
||||
// A key ending in an encoded sequence that does not cross the cut-off point
|
||||
// is preserved.
|
||||
//
|
||||
// This subtest generates a key of:
|
||||
//
|
||||
// `A..<repeats>%`
|
||||
// ^ cutoff
|
||||
//
|
||||
// Which when encoded, becomes:
|
||||
//
|
||||
// `A..<repeats>%25`
|
||||
// ^ cutoff
|
||||
//
|
||||
// So the entire encoded sequence should be preserved.
|
||||
test_partition_key!(
|
||||
truncate_encoding_sequence_ok,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}%", "A".repeat(197)))],
|
||||
want_key = format!("{}%25", "A".repeat(197)), // Not truncated
|
||||
want_reversed_tags = [("a", identity(format!("{}%", "A".repeat(197))))]
|
||||
);
|
||||
|
||||
// A key ending in an encoded sequence should not be split.
|
||||
//
|
||||
// This subtest generates a key of:
|
||||
//
|
||||
// `A..<repeats>%`
|
||||
// ^ cutoff
|
||||
//
|
||||
// Which when encoded, becomes:
|
||||
//
|
||||
// `A..<repeats>% 25` (space added for clarity)
|
||||
// ^ cutoff
|
||||
//
|
||||
// Where naive slicing would result in truncating an encoding sequence and
|
||||
// therefore the whole encoded sequence should be truncated.
|
||||
test_partition_key!(
|
||||
truncate_encoding_sequence_truncated_1,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}%", "A".repeat(198)))],
|
||||
want_key = format!("{}#", "A".repeat(198)), // Truncated
|
||||
want_reversed_tags = [("a", prefix("A".repeat(198)))]
|
||||
);
|
||||
|
||||
// A key ending in an encoded sequence should not be split.
|
||||
//
|
||||
// This subtest generates a key of:
|
||||
//
|
||||
// `A..<repeats>%`
|
||||
// ^ cutoff
|
||||
//
|
||||
// Which when encoded, becomes:
|
||||
//
|
||||
// `A..<repeats>%2 5` (space added for clarity)
|
||||
// ^ cutoff
|
||||
//
|
||||
// Where naive slicing would result in truncating an encoding sequence and
|
||||
// therefore the whole encoded sequence should be truncated.
|
||||
test_partition_key!(
|
||||
truncate_encoding_sequence_truncated_2,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}%", "A".repeat(199)))],
|
||||
want_key = format!("{}#", "A".repeat(199)), // Truncated
|
||||
want_reversed_tags = [("a", prefix("A".repeat(199)))]
|
||||
);
|
||||
|
||||
// A key ending in a unicode code-point should never be split.
|
||||
//
|
||||
// This subtest generates a key of:
|
||||
//
|
||||
// `A..<repeats>🍌`
|
||||
// ^ cutoff
|
||||
//
|
||||
// Which when encoded, becomes:
|
||||
//
|
||||
// `A..<repeats>%F0%9F%8D%8C`
|
||||
// ^ cutoff
|
||||
//
|
||||
// Therefore the entire code-point should be removed from the truncated
|
||||
// output.
|
||||
//
|
||||
// This test MUST NOT fail, or an invalid UTF-8 string is being generated
|
||||
// which is unusable in languages (like Rust).
|
||||
//
|
||||
// Advances the cut-off to ensure the position within the code-point doesn't
|
||||
// affect the output.
|
||||
test_partition_key!(
|
||||
truncate_within_code_point_1,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}🍌", "A".repeat(194)))],
|
||||
want_key = format!("{}#", "A".repeat(194)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(194)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_code_point_2,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}🍌", "A".repeat(195)))],
|
||||
want_key = format!("{}#", "A".repeat(195)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(195)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_code_point_3,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}🍌", "A".repeat(196)))],
|
||||
want_key = format!("{}#", "A".repeat(196)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(196)))]
|
||||
);
|
||||
|
||||
// A key ending in a unicode grapheme should never be split.
|
||||
//
|
||||
// This subtest generates a key of:
|
||||
//
|
||||
// `A..<repeats>நிbananas`
|
||||
// ^ cutoff
|
||||
//
|
||||
// Which when encoded, becomes:
|
||||
//
|
||||
// `A..<repeats>நிbananas` (within a grapheme)
|
||||
// ^ cutoff
|
||||
//
|
||||
// Therefore the entire grapheme (நி) should be removed from the truncated
|
||||
// output.
|
||||
//
|
||||
// This is a conservative implementation, and may be relaxed in the future.
|
||||
//
|
||||
// This first test asserts that a grapheme can be included, and then
|
||||
// subsequent tests increment the cut-off point by 1 byte each time.
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_0,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(181)))],
|
||||
want_key = format!("{}%E0%AE%A8%E0%AE%BF#", "A".repeat(181)),
|
||||
want_reversed_tags = [("a", prefix(format!("{}நி", "A".repeat(181))))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_1,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(182)))],
|
||||
want_key = format!("{}#", "A".repeat(182)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(182)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_2,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(183)))],
|
||||
want_key = format!("{}#", "A".repeat(183)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(183)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_3,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(184)))],
|
||||
want_key = format!("{}#", "A".repeat(184)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(184)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_4,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(185)))],
|
||||
want_key = format!("{}#", "A".repeat(185)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(185)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_5,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(186)))],
|
||||
want_key = format!("{}#", "A".repeat(186)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(186)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_6,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(187)))],
|
||||
want_key = format!("{}#", "A".repeat(187)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(187)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_7,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(188)))],
|
||||
want_key = format!("{}#", "A".repeat(188)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(188)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_8,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(189)))],
|
||||
want_key = format!("{}#", "A".repeat(189)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(189)))]
|
||||
);
|
||||
test_partition_key!(
|
||||
truncate_within_grapheme_9,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நிbananas", "A".repeat(190)))],
|
||||
want_key = format!("{}#", "A".repeat(190)),
|
||||
want_reversed_tags = [("a", prefix("A".repeat(190)))]
|
||||
);
|
||||
|
||||
// As above, but the grapheme is the last portion of the generated string
|
||||
// (no trailing bananas).
|
||||
test_partition_key!(
|
||||
truncate_grapheme_identity,
|
||||
template = [TemplatePart::TagValue("a")],
|
||||
tags = [("a", format!("{}நி", "A".repeat(182)))],
|
||||
want_key = format!("{}%E0%AE%A8%E0%AE%BF", "A".repeat(182)),
|
||||
want_reversed_tags = [("a", identity(format!("{}நி", "A".repeat(182))))]
|
||||
);
|
||||
|
||||
/// A test using an invalid strftime format string.
|
||||
|
@ -679,6 +982,20 @@ mod tests {
|
|||
assert_matches!(ret, Err(PartitionKeyError::InvalidStrftime));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(
|
||||
expected = "partition template contains 9 parts, which exceeds the maximum of 8 parts"
|
||||
)]
|
||||
fn test_too_many_parts() {
|
||||
let template = test_table_partition_override(
|
||||
std::iter::repeat(TemplatePart::TagValue("bananas"))
|
||||
.take(9)
|
||||
.collect(),
|
||||
);
|
||||
|
||||
let _ = partition_batch(&MutableBatch::new(), &template);
|
||||
}
|
||||
|
||||
// These values are arbitrarily chosen when building an input to the
|
||||
// partitioner.
|
||||
|
||||
|
@ -704,11 +1021,11 @@ mod tests {
|
|||
];
|
||||
|
||||
prop_compose! {
|
||||
/// Yields a vector of up to 12 unique template parts, chosen from
|
||||
/// [`TEST_TEMPLATE_PARTS`].
|
||||
/// Yields a vector of up to [`MAXIMUM_NUMBER_OF_TEMPLATE_PARTS`] unique
|
||||
/// template parts, chosen from [`TEST_TEMPLATE_PARTS`].
|
||||
fn arbitrary_template_parts()(set in proptest::collection::vec(
|
||||
proptest::sample::select(TEST_TEMPLATE_PARTS),
|
||||
(1, 12) // Set size range
|
||||
(1, MAXIMUM_NUMBER_OF_TEMPLATE_PARTS) // Set size range
|
||||
)) -> Vec<TemplatePart<'static>> {
|
||||
let mut set = set;
|
||||
set.dedup_by(|a, b| format!("{a:?}") == format!("{b:?}"));
|
||||
|
@ -762,20 +1079,44 @@ mod tests {
|
|||
assert_eq!(keys.len(), 1);
|
||||
|
||||
// Reverse the encoding.
|
||||
let reversed = build_column_values(&template, &keys[0]).map(|(k, v)| (k, v.to_string())).collect::<Vec<_>>();
|
||||
let reversed: Vec<(&str, ColumnValue<'_>)> = build_column_values(&template, &keys[0]).collect();
|
||||
|
||||
// Build the expected set of reversed tags by filtering out any
|
||||
// NULL tags (preserving empty string values).
|
||||
let want_reversed = template.parts().filter_map(|v| match v {
|
||||
let want_reversed: Vec<(&str, String)> = template.parts().filter_map(|v| match v {
|
||||
TemplatePart::TagValue(col_name) if tag_values.contains_key(col_name) => {
|
||||
// This tag had a (potentially empty) value wrote and should
|
||||
// appear in the reversed output.
|
||||
Some((col_name, tag_values.get(col_name).unwrap().to_string()))
|
||||
}
|
||||
_ => None,
|
||||
}).collect::<Vec<_>>();
|
||||
}).collect();
|
||||
|
||||
assert_eq!(reversed, want_reversed);
|
||||
assert_eq!(want_reversed.len(), reversed.len());
|
||||
|
||||
for (want, got) in want_reversed.iter().zip(reversed.iter()) {
|
||||
assert_eq!(got.0, want.0, "column names differ");
|
||||
|
||||
match got.1 {
|
||||
ColumnValue::Identity(_) => {
|
||||
// An identity is both equal to, and a prefix of, the
|
||||
// original value.
|
||||
assert_eq!(got.1, want.1, "identity values differ");
|
||||
assert!(
|
||||
got.1.is_prefix_match_of(&want.1),
|
||||
"prefix mismatch; {:?} is not a prefix of {:?}",
|
||||
got.1,
|
||||
want.1
|
||||
);
|
||||
},
|
||||
ColumnValue::Prefix(_) => assert!(
|
||||
got.1.is_prefix_match_of(&want.1),
|
||||
"prefix mismatch; {:?} is not a prefix of {:?}",
|
||||
got.1,
|
||||
want.1
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// A property test that asserts the partitioner tolerates (does not
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
use std::{borrow::Cow, fmt::Write};
|
||||
use std::fmt::Write;
|
||||
|
||||
use chrono::{format::StrftimeItems, TimeZone, Utc};
|
||||
use data_types::partition_template::ENCODED_PARTITION_KEY_CHARS;
|
||||
use percent_encoding::utf8_percent_encode;
|
||||
|
||||
use crate::PartitionKeyError;
|
||||
|
||||
use super::never_empty;
|
||||
use super::encode_key_part;
|
||||
|
||||
/// The number of nanoseconds in 1 day, definitely recited from memory.
|
||||
const DAY_NANOSECONDS: i64 = 86_400_000_000_000;
|
||||
|
@ -210,10 +208,7 @@ impl<'a> StrftimeFormatter<'a> {
|
|||
.map_err(|_| PartitionKeyError::InvalidStrftime)?;
|
||||
|
||||
// Encode any reserved characters in this new string.
|
||||
buf.1 = never_empty(
|
||||
Cow::from(utf8_percent_encode(&buf.1, &ENCODED_PARTITION_KEY_CHARS)).as_ref(),
|
||||
)
|
||||
.to_string();
|
||||
buf.1 = encode_key_part(&buf.1).to_string();
|
||||
|
||||
// Render this new value to the caller's buffer
|
||||
out.write_str(&buf.1)?;
|
||||
|
@ -344,7 +339,7 @@ mod tests {
|
|||
/// formatter, therefore this test asserts the following property:
|
||||
///
|
||||
/// For any timestamp and formatter, the output of this type must
|
||||
/// match the output of chrono's formatter, after URL encoding.
|
||||
/// match the output of chrono's formatter, after key encoding.
|
||||
///
|
||||
/// Validating this asserts correctness of the wrapper itself, assuming
|
||||
/// chrono's formatter produces correct output. Note the encoding is
|
||||
|
@ -366,10 +361,7 @@ mod tests {
|
|||
Utc.timestamp_nanos(ts)
|
||||
.format_with_items(items.clone())
|
||||
);
|
||||
let control = never_empty(
|
||||
Cow::from(utf8_percent_encode(&control, &ENCODED_PARTITION_KEY_CHARS)).as_ref(),
|
||||
)
|
||||
.to_string();
|
||||
let control = encode_key_part(&control);
|
||||
|
||||
// Generate the test string.
|
||||
let mut test = String::new();
|
||||
|
|
|
@ -97,17 +97,16 @@
|
|||
//! [selector functions]: https://docs.influxdata.com/influxdb/v1.8/query_language/functions/#selectors
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use arrow::{array::ArrayRef, datatypes::DataType};
|
||||
use arrow::datatypes::DataType;
|
||||
use datafusion::{
|
||||
error::{DataFusionError, Result as DataFusionResult},
|
||||
error::Result as DataFusionResult,
|
||||
logical_expr::{AccumulatorFunctionImplementation, Signature, Volatility},
|
||||
physical_plan::{udaf::AggregateUDF, Accumulator},
|
||||
prelude::SessionContext,
|
||||
scalar::ScalarValue,
|
||||
};
|
||||
|
||||
mod internal;
|
||||
use internal::{FirstSelector, LastSelector, MaxSelector, MinSelector, Selector};
|
||||
use internal::{Comparison, Selector, Target};
|
||||
|
||||
mod type_handling;
|
||||
use type_handling::AggType;
|
||||
|
@ -228,34 +227,30 @@ impl FactoryBuilder {
|
|||
let other_types = agg_type.other_types;
|
||||
|
||||
let accumulator: Box<dyn Accumulator> = match selector_type {
|
||||
SelectorType::First => Box::new(SelectorAccumulator::new(FirstSelector::new(
|
||||
SelectorType::First => Box::new(Selector::new(
|
||||
Comparison::Min,
|
||||
Target::Time,
|
||||
value_type,
|
||||
other_types.iter().cloned(),
|
||||
)?)),
|
||||
SelectorType::Last => {
|
||||
if !other_types.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector last w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
Box::new(SelectorAccumulator::new(LastSelector::new(value_type)?))
|
||||
}
|
||||
SelectorType::Min => {
|
||||
if !other_types.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector min w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
Box::new(SelectorAccumulator::new(MinSelector::new(value_type)?))
|
||||
}
|
||||
SelectorType::Max => {
|
||||
if !other_types.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector max w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
Box::new(SelectorAccumulator::new(MaxSelector::new(value_type)?))
|
||||
}
|
||||
)?),
|
||||
SelectorType::Last => Box::new(Selector::new(
|
||||
Comparison::Max,
|
||||
Target::Time,
|
||||
value_type,
|
||||
other_types.iter().cloned(),
|
||||
)?),
|
||||
SelectorType::Min => Box::new(Selector::new(
|
||||
Comparison::Min,
|
||||
Target::Value,
|
||||
value_type,
|
||||
other_types.iter().cloned(),
|
||||
)?),
|
||||
SelectorType::Max => Box::new(Selector::new(
|
||||
Comparison::Max,
|
||||
Target::Value,
|
||||
value_type,
|
||||
other_types.iter().cloned(),
|
||||
)?),
|
||||
};
|
||||
Ok(accumulator)
|
||||
})
|
||||
|
@ -293,79 +288,6 @@ fn make_uda(name: &str, factory_builder: FactoryBuilder) -> AggregateUDF {
|
|||
)
|
||||
}
|
||||
|
||||
/// Structure that implements the Accumulator trait for DataFusion
|
||||
/// and processes (value, timestamp) pair and computes values
|
||||
#[derive(Debug)]
|
||||
struct SelectorAccumulator<SELECTOR>
|
||||
where
|
||||
SELECTOR: Selector,
|
||||
{
|
||||
// The underlying implementation for the selector
|
||||
selector: SELECTOR,
|
||||
}
|
||||
|
||||
impl<SELECTOR> SelectorAccumulator<SELECTOR>
|
||||
where
|
||||
SELECTOR: Selector,
|
||||
{
|
||||
pub fn new(selector: SELECTOR) -> Self {
|
||||
Self { selector }
|
||||
}
|
||||
}
|
||||
|
||||
impl<SELECTOR> Accumulator for SelectorAccumulator<SELECTOR>
|
||||
where
|
||||
SELECTOR: Selector + 'static,
|
||||
{
|
||||
// this function serializes our state to a vector of
|
||||
// `ScalarValue`s, which DataFusion uses to pass this state
|
||||
// between execution stages.
|
||||
fn state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
self.selector.datafusion_state()
|
||||
}
|
||||
|
||||
/// Allocated size required for this accumulator, in bytes,
|
||||
/// including `Self`. Allocated means that for internal
|
||||
/// containers such as `Vec`, the `capacity` should be used not
|
||||
/// the `len`
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.selector) + self.selector.size()
|
||||
}
|
||||
|
||||
// Return the final value of this aggregator.
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
self.selector.evaluate()
|
||||
}
|
||||
|
||||
// This function receives one entry per argument of this
|
||||
// accumulator and updates the selector state function appropriately
|
||||
fn update_batch(&mut self, values: &[ArrayRef]) -> DataFusionResult<()> {
|
||||
if values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if values.len() < 2 {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"Internal error: Expected at least 2 arguments passed to selector function but got {}",
|
||||
values.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// invoke the actual worker function.
|
||||
self.selector
|
||||
.update_batch(&values[0], &values[1], &values[2..])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The input values and accumulator state are the same types for
|
||||
// selectors, and thus we can merge intermediate states with the
|
||||
// same function as inputs
|
||||
fn merge_batch(&mut self, states: &[ArrayRef]) -> DataFusionResult<()> {
|
||||
// merge is the same operation as update for these selectors
|
||||
self.update_batch(states)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use arrow::{
|
||||
|
@ -380,7 +302,7 @@ mod test {
|
|||
use datafusion::{datasource::MemTable, prelude::*};
|
||||
|
||||
use super::*;
|
||||
use utils::{run_case, run_case_err};
|
||||
use utils::{run_case, run_cases_err};
|
||||
|
||||
mod first {
|
||||
use super::*;
|
||||
|
@ -547,37 +469,24 @@ mod test {
|
|||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_tie_breaker() {
|
||||
run_case(
|
||||
selector_first().call(vec![col("f64_value"), col("time_dup")]),
|
||||
vec![
|
||||
"+------------------------------------------------+",
|
||||
"| selector_first(t.f64_value,t.time_dup) |",
|
||||
"+------------------------------------------------+",
|
||||
"| {value: 2.0, time: 1970-01-01T00:00:00.000001} |",
|
||||
"+------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_err() {
|
||||
run_case_err(
|
||||
selector_first().call(vec![]),
|
||||
"Error during planning: selector_first requires at least 2 arguments, got 0",
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector_first().call(vec![col("f64_value")]),
|
||||
"Error during planning: selector_first requires at least 2 arguments, got 1",
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector_first().call(vec![col("time"), col("f64_value")]),
|
||||
"Error during planning: selector_first second argument must be a timestamp, but got Float64",
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector_first().call(vec![col("time"), col("f64_value"), col("bool_value")]),
|
||||
"Error during planning: selector_first second argument must be a timestamp, but got Float64",
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector_first().call(vec![col("f64_value"), col("bool_value"), col("time")]),
|
||||
"Error during planning: selector_first second argument must be a timestamp, but got Boolean",
|
||||
)
|
||||
.await;
|
||||
run_cases_err(selector_first(), "selector_first").await;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -718,6 +627,53 @@ mod test {
|
|||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_other() {
|
||||
run_case(
|
||||
selector_last().call(vec![col("f64_value"), col("time"), col("bool_value"), col("f64_not_normal_3_value"), col("i64_2_value")]),
|
||||
vec![
|
||||
"+-------------------------------------------------------------------------------------------+",
|
||||
"| selector_last(t.f64_value,t.time,t.bool_value,t.f64_not_normal_3_value,t.i64_2_value) |",
|
||||
"+-------------------------------------------------------------------------------------------+",
|
||||
"| {value: 3.0, time: 1970-01-01T00:00:00.000006, other_1: false, other_2: NaN, other_3: 30} |",
|
||||
"+-------------------------------------------------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case(
|
||||
selector_last().call(vec![col("u64_2_value"), col("time"), col("bool_value"), col("f64_not_normal_4_value"), col("i64_2_value")]),
|
||||
vec![
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
"| selector_last(t.u64_2_value,t.time,t.bool_value,t.f64_not_normal_4_value,t.i64_2_value) |",
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
"| {value: 50, time: 1970-01-01T00:00:00.000005, other_1: false, other_2: inf, other_3: 50} |",
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_tie_breaker() {
|
||||
run_case(
|
||||
selector_last().call(vec![col("f64_value"), col("time_dup")]),
|
||||
vec![
|
||||
"+------------------------------------------------+",
|
||||
"| selector_last(t.f64_value,t.time_dup) |",
|
||||
"+------------------------------------------------+",
|
||||
"| {value: 5.0, time: 1970-01-01T00:00:00.000003} |",
|
||||
"+------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_err() {
|
||||
run_cases_err(selector_last(), "selector_last").await;
|
||||
}
|
||||
}
|
||||
|
||||
mod min {
|
||||
|
@ -845,6 +801,53 @@ mod test {
|
|||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_other() {
|
||||
run_case(
|
||||
selector_min().call(vec![col("u64_value"), col("time"), col("bool_value"), col("f64_not_normal_1_value"), col("i64_2_value")]),
|
||||
vec![
|
||||
"+---------------------------------------------------------------------------------------+",
|
||||
"| selector_min(t.u64_value,t.time,t.bool_value,t.f64_not_normal_1_value,t.i64_2_value) |",
|
||||
"+---------------------------------------------------------------------------------------+",
|
||||
"| {value: 10, time: 1970-01-01T00:00:00.000004, other_1: true, other_2: NaN, other_3: } |",
|
||||
"+---------------------------------------------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_tie_breaker() {
|
||||
run_case(
|
||||
selector_min().call(vec![col("f64_not_normal_2_value"), col("time_dup")]),
|
||||
vec![
|
||||
"+---------------------------------------------------+",
|
||||
"| selector_min(t.f64_not_normal_2_value,t.time_dup) |",
|
||||
"+---------------------------------------------------+",
|
||||
"| {value: -inf, time: 1970-01-01T00:00:00.000001} |",
|
||||
"+---------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case(
|
||||
selector_min().call(vec![col("bool_const"), col("time_dup")]),
|
||||
vec![
|
||||
"+-------------------------------------------------+",
|
||||
"| selector_min(t.bool_const,t.time_dup) |",
|
||||
"+-------------------------------------------------+",
|
||||
"| {value: true, time: 1970-01-01T00:00:00.000001} |",
|
||||
"+-------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_err() {
|
||||
run_cases_err(selector_min(), "selector_min").await;
|
||||
}
|
||||
}
|
||||
|
||||
mod max {
|
||||
|
@ -972,6 +975,53 @@ mod test {
|
|||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_other() {
|
||||
run_case(
|
||||
selector_max().call(vec![col("u64_value"), col("time"), col("bool_value"), col("f64_not_normal_1_value"), col("i64_2_value")]),
|
||||
vec![
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
"| selector_max(t.u64_value,t.time,t.bool_value,t.f64_not_normal_1_value,t.i64_2_value) |",
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
"| {value: 50, time: 1970-01-01T00:00:00.000005, other_1: false, other_2: inf, other_3: 50} |",
|
||||
"+------------------------------------------------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_tie_breaker() {
|
||||
run_case(
|
||||
selector_max().call(vec![col("f64_not_normal_2_value"), col("time_dup")]),
|
||||
vec![
|
||||
"+---------------------------------------------------+",
|
||||
"| selector_max(t.f64_not_normal_2_value,t.time_dup) |",
|
||||
"+---------------------------------------------------+",
|
||||
"| {value: inf, time: 1970-01-01T00:00:00.000002} |",
|
||||
"+---------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case(
|
||||
selector_max().call(vec![col("bool_const"), col("time_dup")]),
|
||||
vec![
|
||||
"+-------------------------------------------------+",
|
||||
"| selector_max(t.bool_const,t.time_dup) |",
|
||||
"+-------------------------------------------------+",
|
||||
"| {value: true, time: 1970-01-01T00:00:00.000001} |",
|
||||
"+-------------------------------------------------+",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_err() {
|
||||
run_cases_err(selector_max(), "selector_max").await;
|
||||
}
|
||||
}
|
||||
|
||||
mod utils {
|
||||
|
@ -991,7 +1041,7 @@ mod test {
|
|||
);
|
||||
}
|
||||
|
||||
pub async fn run_case_err(expr: Expr, expected: &'static str) {
|
||||
pub async fn run_case_err(expr: Expr, expected: &str) {
|
||||
println!("Running error case for {expr}");
|
||||
|
||||
let (schema, input) = input();
|
||||
|
@ -1006,6 +1056,38 @@ mod test {
|
|||
);
|
||||
}
|
||||
|
||||
pub async fn run_cases_err(selector: AggregateUDF, name: &str) {
|
||||
run_case_err(
|
||||
selector.call(vec![]),
|
||||
&format!("Error during planning: {name} requires at least 2 arguments, got 0"),
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector.call(vec![col("f64_value")]),
|
||||
&format!("Error during planning: {name} requires at least 2 arguments, got 1"),
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector.call(vec![col("time"), col("f64_value")]),
|
||||
&format!("Error during planning: {name} second argument must be a timestamp, but got Float64"),
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector.call(vec![col("time"), col("f64_value"), col("bool_value")]),
|
||||
&format!("Error during planning: {name} second argument must be a timestamp, but got Float64"),
|
||||
)
|
||||
.await;
|
||||
|
||||
run_case_err(
|
||||
selector.call(vec![col("f64_value"), col("bool_value"), col("time")]),
|
||||
&format!("Error during planning: {name} second argument must be a timestamp, but got Boolean"),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn input() -> (SchemaRef, Vec<RecordBatch>) {
|
||||
// define a schema for input
|
||||
// (value) and timestamp
|
||||
|
@ -1019,9 +1101,12 @@ mod test {
|
|||
Field::new("i64_value", DataType::Int64, true),
|
||||
Field::new("i64_2_value", DataType::Int64, true),
|
||||
Field::new("u64_value", DataType::UInt64, true),
|
||||
Field::new("u64_2_value", DataType::UInt64, true),
|
||||
Field::new("string_value", DataType::Utf8, true),
|
||||
Field::new("bool_value", DataType::Boolean, true),
|
||||
Field::new("bool_const", DataType::Boolean, true),
|
||||
Field::new("time", TIME_DATA_TYPE(), true),
|
||||
Field::new("time_dup", TIME_DATA_TYPE(), true),
|
||||
]));
|
||||
|
||||
// define data in two partitions
|
||||
|
@ -1057,9 +1142,12 @@ mod test {
|
|||
Arc::new(Int64Array::from(vec![Some(20), Some(40), None])),
|
||||
Arc::new(Int64Array::from(vec![None, None, None])),
|
||||
Arc::new(UInt64Array::from(vec![Some(20), Some(40), None])),
|
||||
Arc::new(UInt64Array::from(vec![Some(20), Some(40), None])),
|
||||
Arc::new(StringArray::from(vec![Some("two"), Some("four"), None])),
|
||||
Arc::new(BooleanArray::from(vec![Some(true), Some(false), None])),
|
||||
Arc::new(BooleanArray::from(vec![Some(true), Some(true), Some(true)])),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![1000, 2000, 3000])),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![1000, 1000, 2000])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
@ -1077,8 +1165,11 @@ mod test {
|
|||
Arc::new(Int64Array::from(vec![] as Vec<Option<i64>>)),
|
||||
Arc::new(Int64Array::from(vec![] as Vec<Option<i64>>)),
|
||||
Arc::new(UInt64Array::from(vec![] as Vec<Option<u64>>)),
|
||||
Arc::new(UInt64Array::from(vec![] as Vec<Option<u64>>)),
|
||||
Arc::new(StringArray::from(vec![] as Vec<Option<&str>>)),
|
||||
Arc::new(BooleanArray::from(vec![] as Vec<Option<bool>>)),
|
||||
Arc::new(BooleanArray::from(vec![] as Vec<Option<bool>>)),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![] as Vec<i64>)),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![] as Vec<i64>)),
|
||||
],
|
||||
) {
|
||||
|
@ -1118,6 +1209,7 @@ mod test {
|
|||
Arc::new(Int64Array::from(vec![Some(10), Some(50), Some(30)])),
|
||||
Arc::new(Int64Array::from(vec![None, Some(50), Some(30)])),
|
||||
Arc::new(UInt64Array::from(vec![Some(10), Some(50), Some(30)])),
|
||||
Arc::new(UInt64Array::from(vec![Some(10), Some(50), None])),
|
||||
Arc::new(StringArray::from(vec![
|
||||
Some("a_one"),
|
||||
Some("z_five"),
|
||||
|
@ -1128,7 +1220,9 @@ mod test {
|
|||
Some(false),
|
||||
Some(false),
|
||||
])),
|
||||
Arc::new(BooleanArray::from(vec![Some(true), Some(true), Some(true)])),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![4000, 5000, 6000])),
|
||||
Arc::new(TimestampNanosecondArray::from(vec![2000, 3000, 3000])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
|
|
@ -24,234 +24,30 @@ use datafusion::{
|
|||
|
||||
use super::type_handling::make_struct_scalar;
|
||||
|
||||
/// Implements the logic of the specific selector function (this is a
|
||||
/// cutdown version of the Accumulator DataFusion trait, to allow
|
||||
/// sharing between implementations)
|
||||
pub trait Selector: Debug + Send + Sync {
|
||||
/// return state in a form that DataFusion can store during execution
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>>;
|
||||
|
||||
/// produces the final value of this selector for the specified output type
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue>;
|
||||
|
||||
/// Update this selector's state based on values in value_arr and time_arr
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()>;
|
||||
|
||||
/// Allocated size required for this selector, in bytes,
|
||||
/// including `Self`. Allocated means that for internal
|
||||
/// containers such as `Vec`, the `capacity` should be used not
|
||||
/// the `len`
|
||||
fn size(&self) -> usize;
|
||||
/// How to compare values/time.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Comparison {
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FirstSelector {
|
||||
value: ScalarValue,
|
||||
time: Option<i64>,
|
||||
other: Box<[ScalarValue]>,
|
||||
}
|
||||
|
||||
impl FirstSelector {
|
||||
pub fn new<'a>(
|
||||
data_type: &'a DataType,
|
||||
other_types: impl IntoIterator<Item = &'a DataType>,
|
||||
) -> DataFusionResult<Self> {
|
||||
Ok(Self {
|
||||
value: ScalarValue::try_from(data_type)?,
|
||||
time: None,
|
||||
other: other_types
|
||||
.into_iter()
|
||||
.map(ScalarValue::try_from)
|
||||
.collect::<DataFusionResult<_>>()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for FirstSelector {
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok([
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
]
|
||||
.into_iter()
|
||||
.chain(self.other.iter().cloned())
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
self.other.iter(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
// Only look for times where the array also has a non
|
||||
// null value (the time array should have no nulls itself)
|
||||
//
|
||||
// For example, for the following input, the correct
|
||||
// current min time is 200 (not 100)
|
||||
//
|
||||
// value | time
|
||||
// --------------
|
||||
// NULL | 100
|
||||
// A | 200
|
||||
// B | 300
|
||||
//
|
||||
let time_arr = arrow::compute::nullif(time_arr, &arrow::compute::is_null(&value_arr)?)?;
|
||||
|
||||
let time_arr = time_arr
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampNanosecondArray>()
|
||||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
let cur_min_time = array_min(time_arr);
|
||||
|
||||
let need_update = match (&self.time, &cur_min_time) {
|
||||
(Some(time), Some(cur_min_time)) => cur_min_time < time,
|
||||
// No existing minimum, so update needed
|
||||
(None, Some(_)) => true,
|
||||
// No actual minimum time found, so no update needed
|
||||
(_, None) => false,
|
||||
};
|
||||
|
||||
if need_update {
|
||||
let index = time_arr
|
||||
.iter()
|
||||
// arrow doesn't tell us what index had the
|
||||
// minimum, so need to find it ourselves see also
|
||||
// https://github.com/apache/arrow-datafusion/issues/600
|
||||
.enumerate()
|
||||
.find(|(_, time)| cur_min_time == *time)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap(); // value always exists
|
||||
|
||||
// update all or nothing in case of an error
|
||||
let value_new = ScalarValue::try_from_array(&value_arr, index)?;
|
||||
let other_new = other_arrs
|
||||
.iter()
|
||||
.map(|arr| ScalarValue::try_from_array(arr, index))
|
||||
.collect::<DataFusionResult<_>>()?;
|
||||
|
||||
self.time = cur_min_time;
|
||||
self.value = value_new;
|
||||
self.other = other_new;
|
||||
impl Comparison {
|
||||
fn is_update<T>(&self, old: &T, new: &T) -> bool
|
||||
where
|
||||
T: PartialOrd,
|
||||
{
|
||||
match self {
|
||||
Self::Min => new < old,
|
||||
Self::Max => old < new,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value) + self.value.size()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LastSelector {
|
||||
value: ScalarValue,
|
||||
time: Option<i64>,
|
||||
}
|
||||
|
||||
impl LastSelector {
|
||||
pub fn new(data_type: &DataType) -> DataFusionResult<Self> {
|
||||
Ok(Self {
|
||||
value: ScalarValue::try_from(data_type)?,
|
||||
time: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for LastSelector {
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok(vec![
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
])
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
[],
|
||||
))
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
if !other_arrs.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector last w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Only look for times where the array also has a non
|
||||
// null value (the time array should have no nulls itself)
|
||||
//
|
||||
// For example, for the following input, the correct
|
||||
// current max time is 200 (not 300)
|
||||
//
|
||||
// value | time
|
||||
// --------------
|
||||
// A | 100
|
||||
// B | 200
|
||||
// NULL | 300
|
||||
//
|
||||
let time_arr = arrow::compute::nullif(time_arr, &arrow::compute::is_null(&value_arr)?)?;
|
||||
|
||||
let time_arr = time_arr
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampNanosecondArray>()
|
||||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
let cur_max_time = array_max(time_arr);
|
||||
|
||||
let need_update = match (&self.time, &cur_max_time) {
|
||||
(Some(time), Some(cur_max_time)) => time < cur_max_time,
|
||||
// No existing maximum, so update needed
|
||||
(None, Some(_)) => true,
|
||||
// No actual maximum value found, so no update needed
|
||||
(_, None) => false,
|
||||
};
|
||||
|
||||
if need_update {
|
||||
let index = time_arr
|
||||
.iter()
|
||||
// arrow doesn't tell us what index had the
|
||||
// maximum, so need to find it ourselves
|
||||
.enumerate()
|
||||
.find(|(_, time)| cur_max_time == *time)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap(); // value always exists
|
||||
|
||||
// update all or nothing in case of an error
|
||||
let value_new = ScalarValue::try_from_array(&value_arr, index)?;
|
||||
|
||||
self.time = cur_max_time;
|
||||
self.value = value_new;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value) + self.value.size()
|
||||
}
|
||||
/// What to compare?
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Target {
|
||||
Time,
|
||||
Value,
|
||||
}
|
||||
|
||||
/// Did we find a new min/max
|
||||
|
@ -270,6 +66,7 @@ impl ActionNeeded {
|
|||
Self::Nothing => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn update_time(&self) -> bool {
|
||||
match self {
|
||||
Self::UpdateValueAndTime => true,
|
||||
|
@ -279,184 +76,129 @@ impl ActionNeeded {
|
|||
}
|
||||
}
|
||||
|
||||
/// Common state implementation for different selectors.
|
||||
#[derive(Debug)]
|
||||
pub struct MinSelector {
|
||||
pub struct Selector {
|
||||
comp: Comparison,
|
||||
target: Target,
|
||||
value: ScalarValue,
|
||||
time: Option<i64>,
|
||||
other: Box<[ScalarValue]>,
|
||||
}
|
||||
|
||||
impl MinSelector {
|
||||
pub fn new(data_type: &DataType) -> DataFusionResult<Self> {
|
||||
impl Selector {
|
||||
pub fn new<'a>(
|
||||
comp: Comparison,
|
||||
target: Target,
|
||||
data_type: &'a DataType,
|
||||
other_types: impl IntoIterator<Item = &'a DataType>,
|
||||
) -> DataFusionResult<Self> {
|
||||
Ok(Self {
|
||||
comp,
|
||||
target,
|
||||
value: ScalarValue::try_from(data_type)?,
|
||||
time: None,
|
||||
other: other_types
|
||||
.into_iter()
|
||||
.map(ScalarValue::try_from)
|
||||
.collect::<DataFusionResult<_>>()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for MinSelector {
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok(vec![
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
])
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
[],
|
||||
))
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
fn update_time_based(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
use ActionNeeded::*;
|
||||
|
||||
if !other_arrs.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector min w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut min_accu = MinAccumulator::try_new(value_arr.data_type())?;
|
||||
min_accu.update_batch(&[Arc::clone(value_arr)])?;
|
||||
let cur_min_value = min_accu.evaluate()?;
|
||||
|
||||
let action_needed = match (self.value.is_null(), cur_min_value.is_null()) {
|
||||
(false, false) => {
|
||||
if cur_min_value < self.value {
|
||||
// new minimim found
|
||||
UpdateValueAndTime
|
||||
} else if cur_min_value == self.value {
|
||||
// same minimum found, time might need update
|
||||
UpdateTime
|
||||
} else {
|
||||
Nothing
|
||||
}
|
||||
}
|
||||
// No existing minimum time, so update needed
|
||||
(true, false) => UpdateValueAndTime,
|
||||
// No actual minimum time found, so no update needed
|
||||
(_, true) => Nothing,
|
||||
};
|
||||
|
||||
if action_needed.update_value() {
|
||||
self.value = cur_min_value;
|
||||
self.time = None; // ignore time associated with old value
|
||||
}
|
||||
|
||||
if action_needed.update_time() {
|
||||
// only keep values where we've found our current value.
|
||||
// Note: We MUST also mask-out NULLs in `value_arr`, otherwise we may easily select that!
|
||||
let time_arr = arrow::compute::nullif(
|
||||
time_arr,
|
||||
&arrow::compute::neq_dyn(&self.value.to_array_of_size(time_arr.len()), &value_arr)?,
|
||||
)?;
|
||||
let time_arr =
|
||||
arrow::compute::nullif(&time_arr, &arrow::compute::is_null(&value_arr)?)?;
|
||||
|
||||
let time_arr = time_arr
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampNanosecondArray>()
|
||||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
self.time = match (array_min(time_arr), self.time) {
|
||||
(Some(x), Some(y)) if x < y => Some(x),
|
||||
(Some(_), Some(x)) => Some(x),
|
||||
(None, Some(x)) => Some(x),
|
||||
(Some(x), None) => Some(x),
|
||||
(None, None) => None,
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value) + self.value.size()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MaxSelector {
|
||||
value: ScalarValue,
|
||||
time: Option<i64>,
|
||||
}
|
||||
|
||||
impl MaxSelector {
|
||||
pub fn new(data_type: &DataType) -> DataFusionResult<Self> {
|
||||
Ok(Self {
|
||||
value: ScalarValue::try_from(data_type)?,
|
||||
time: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for MaxSelector {
|
||||
fn datafusion_state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok(vec![
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
])
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
[],
|
||||
))
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
use ActionNeeded::*;
|
||||
|
||||
if !other_arrs.is_empty() {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"selector max w/ additional args".to_string(),
|
||||
));
|
||||
}
|
||||
let time_arr = arrow::compute::nullif(time_arr, &arrow::compute::is_null(&value_arr)?)?;
|
||||
|
||||
let time_arr = time_arr
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampNanosecondArray>()
|
||||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
let cur_time = match self.comp {
|
||||
Comparison::Min => array_min(time_arr),
|
||||
Comparison::Max => array_max(time_arr),
|
||||
};
|
||||
|
||||
let mut max_accu = MaxAccumulator::try_new(value_arr.data_type())?;
|
||||
max_accu.update_batch(&[Arc::clone(value_arr)])?;
|
||||
let cur_max_value = max_accu.evaluate()?;
|
||||
let need_update = match (&self.time, &cur_time) {
|
||||
(Some(time), Some(cur_time)) => self.comp.is_update(time, cur_time),
|
||||
// No existing min/max, so update needed
|
||||
(None, Some(_)) => true,
|
||||
// No actual min/max time found, so no update needed
|
||||
(_, None) => false,
|
||||
};
|
||||
|
||||
let action_needed = match (&self.value.is_null(), &cur_max_value.is_null()) {
|
||||
if need_update {
|
||||
let index = time_arr
|
||||
.iter()
|
||||
// arrow doesn't tell us what index had the
|
||||
// min/max, so need to find it ourselves
|
||||
.enumerate()
|
||||
.filter(|(_, time)| cur_time == *time)
|
||||
.map(|(idx, _)| idx)
|
||||
// break tie: favor first value
|
||||
.next()
|
||||
.unwrap(); // value always exists
|
||||
|
||||
// update all or nothing in case of an error
|
||||
let value_new = ScalarValue::try_from_array(&value_arr, index)?;
|
||||
let other_new = other_arrs
|
||||
.iter()
|
||||
.map(|arr| ScalarValue::try_from_array(arr, index))
|
||||
.collect::<DataFusionResult<_>>()?;
|
||||
|
||||
self.time = cur_time;
|
||||
self.value = value_new;
|
||||
self.other = other_new;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn update_value_based(
|
||||
&mut self,
|
||||
value_arr: &ArrayRef,
|
||||
time_arr: &ArrayRef,
|
||||
other_arrs: &[ArrayRef],
|
||||
) -> DataFusionResult<()> {
|
||||
use ActionNeeded::*;
|
||||
|
||||
let cur_value = match self.comp {
|
||||
Comparison::Min => {
|
||||
let mut min_accu = MinAccumulator::try_new(value_arr.data_type())?;
|
||||
min_accu.update_batch(&[Arc::clone(value_arr)])?;
|
||||
min_accu.evaluate()?
|
||||
}
|
||||
Comparison::Max => {
|
||||
let mut max_accu = MaxAccumulator::try_new(value_arr.data_type())?;
|
||||
max_accu.update_batch(&[Arc::clone(value_arr)])?;
|
||||
max_accu.evaluate()?
|
||||
}
|
||||
};
|
||||
|
||||
let action_needed = match (&self.value.is_null(), &cur_value.is_null()) {
|
||||
(false, false) => {
|
||||
if self.value < cur_max_value {
|
||||
// new maximum found
|
||||
if self.comp.is_update(&self.value, &cur_value) {
|
||||
// new min/max found
|
||||
UpdateValueAndTime
|
||||
} else if cur_max_value == self.value {
|
||||
} else if cur_value == self.value {
|
||||
// same maximum found, time might need update
|
||||
UpdateTime
|
||||
} else {
|
||||
Nothing
|
||||
}
|
||||
}
|
||||
// No existing maxmimum value, so update needed
|
||||
// No existing min/max value, so update needed
|
||||
(true, false) => UpdateValueAndTime,
|
||||
// No actual maximum value found, so no update needed
|
||||
// No actual min/max value found, so no update needed
|
||||
(_, true) => Nothing,
|
||||
};
|
||||
|
||||
if action_needed.update_value() {
|
||||
self.value = cur_max_value;
|
||||
self.value = cur_value;
|
||||
self.time = None; // ignore time associated with old value
|
||||
}
|
||||
|
||||
|
@ -479,7 +221,7 @@ impl Selector for MaxSelector {
|
|||
// the input type arguments should be ensured by datafusion
|
||||
.expect("Second argument was time");
|
||||
|
||||
// Note: we still use the MINIMUM timestamp here even though this is the max VALUE aggregator.
|
||||
// Note: we still use the MINIMUM timestamp here even if this is the max VALUE aggregator.
|
||||
self.time = match (array_min(time_arr), self.time) {
|
||||
(Some(x), Some(y)) if x < y => Some(x),
|
||||
(Some(_), Some(x)) => Some(x),
|
||||
|
@ -487,11 +229,82 @@ impl Selector for MaxSelector {
|
|||
(Some(x), None) => Some(x),
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
// update other if required
|
||||
if !self.other.is_empty() {
|
||||
let index = time_arr
|
||||
.iter()
|
||||
// arrow doesn't tell us what index had the
|
||||
// minimum, so need to find it ourselves
|
||||
.enumerate()
|
||||
.filter(|(_, time)| self.time == *time)
|
||||
.map(|(idx, _)| idx)
|
||||
// break tie: favor first value
|
||||
.next()
|
||||
.unwrap(); // value always exists
|
||||
|
||||
self.other = other_arrs
|
||||
.iter()
|
||||
.map(|arr| ScalarValue::try_from_array(arr, index))
|
||||
.collect::<DataFusionResult<_>>()?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Accumulator for Selector {
|
||||
fn state(&self) -> DataFusionResult<Vec<ScalarValue>> {
|
||||
Ok([
|
||||
self.value.clone(),
|
||||
ScalarValue::TimestampNanosecond(self.time, None),
|
||||
]
|
||||
.into_iter()
|
||||
.chain(self.other.iter().cloned())
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn update_batch(&mut self, values: &[ArrayRef]) -> DataFusionResult<()> {
|
||||
if values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if values.len() < 2 {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"Internal error: Expected at least 2 arguments passed to selector function but got {}",
|
||||
values.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let value_arr = &values[0];
|
||||
let time_arr = &values[1];
|
||||
let other_arrs = &values[2..];
|
||||
|
||||
match self.target {
|
||||
Target::Time => self.update_time_based(value_arr, time_arr, other_arrs)?,
|
||||
Target::Value => self.update_value_based(value_arr, time_arr, other_arrs)?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn merge_batch(&mut self, states: &[ArrayRef]) -> DataFusionResult<()> {
|
||||
// merge is the same operation as update for these selectors
|
||||
self.update_batch(states)
|
||||
}
|
||||
|
||||
fn evaluate(&self) -> DataFusionResult<ScalarValue> {
|
||||
Ok(make_struct_scalar(
|
||||
&self.value,
|
||||
&ScalarValue::TimestampNanosecond(self.time, None),
|
||||
self.other.iter(),
|
||||
))
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value) + self.value.size()
|
||||
std::mem::size_of_val(self) - std::mem::size_of_val(&self.value)
|
||||
+ self.value.size()
|
||||
+ self.other.iter().map(|s| s.size()).sum::<usize>()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue