Merge branch 'iox-repo'

pull/24376/head
Paul Dix 2023-09-21 09:22:15 -04:00
commit aa458ed166
1481 changed files with 384798 additions and 1177 deletions

25
.cargo/config Normal file
View File

@ -0,0 +1,25 @@
[build]
# enable tokio-console and some other goodies
rustflags = [
"--cfg", "tokio_unstable",
]
# sparse protocol opt-in
# See https://blog.rust-lang.org/2023/03/09/Rust-1.68.0.html#cargos-sparse-protocol
[registries.crates-io]
protocol = "sparse"
[target.x86_64-unknown-linux-gnu]
rustflags = [
# see above
"--cfg", "tokio_unstable",
# Faster linker.
"-C", "link-arg=-fuse-ld=lld",
# Fix `perf` as suggested by https://github.com/flamegraph-rs/flamegraph/blob/2d19a162df4066f37d58d5471634f0bd9f0f4a62/README.md?plain=1#L18
# Also see https://bugs.chromium.org/p/chromium/issues/detail?id=919499#c16
"-C", "link-arg=-Wl,--no-rosegment",
# Enable all features supported by CPUs more recent than haswell (2013)
"-C", "target-cpu=haswell",
# Enable framepointers because profiling and debugging is a nightmare w/o it and it is generally not considered a performance advantage on modern x86_64 CPUs.
"-C", "force-frame-pointers=yes",
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,31 @@
#!/usr/bin/env bash
set -euo pipefail
readonly PACKAGE="$1"
readonly FEATURES="$2"
readonly TAG="$3"
RUST_VERSION="$(sed -E -ne 's/channel = "(.*)"/\1/p' rust-toolchain.toml)"
COMMIT_SHA="$(git rev-parse HEAD)"
COMMIT_TS="$(env TZ=UTC0 git show --quiet --date='format-local:%Y-%m-%dT%H:%M:%SZ' --format="%cd" HEAD)"
NOW="$(date --utc --iso-8601=seconds)"
REPO_URL="https://github.com/influxdata/influxdb_iox"
exec docker buildx build \
--build-arg CARGO_INCREMENTAL="no" \
--build-arg CARGO_NET_GIT_FETCH_WITH_CLI="true" \
--build-arg FEATURES="$FEATURES" \
--build-arg RUST_VERSION="$RUST_VERSION" \
--build-arg PACKAGE="$PACKAGE" \
--label org.opencontainers.image.created="$NOW" \
--label org.opencontainers.image.url="$REPO_URL" \
--label org.opencontainers.image.revision="$COMMIT_SHA" \
--label org.opencontainers.image.vendor="InfluxData Inc." \
--label org.opencontainers.image.title="InfluxDB IOx, '$PACKAGE'" \
--label org.opencontainers.image.description="InfluxDB IOx production image for package '$PACKAGE'" \
--label com.influxdata.image.commit-date="$COMMIT_TS" \
--label com.influxdata.image.package="$PACKAGE" \
--progress plain \
--tag "$TAG" \
.

3
.circleci/yamllint.yml Normal file
View File

@ -0,0 +1,3 @@
rules:
truthy:
check-keys: false

61
.config/hakari.toml Normal file
View File

@ -0,0 +1,61 @@
# This file contains settings for `cargo hakari`.
# See https://docs.rs/cargo-hakari/latest/cargo_hakari/config for a full list of options.
hakari-package = "workspace-hack"
# Format version for hakari's output. Version 4 requires cargo-hakari 0.9.22 or above.
dep-format-version = "4"
# Setting workspace.resolver = "2" in the root Cargo.toml is HIGHLY recommended.
# Hakari works much better with the new feature resolver.
# For more about the new feature resolver, see:
# https://blog.rust-lang.org/2021/03/25/Rust-1.51.0.html#cargos-new-feature-resolver
resolver = "2"
# Add triples corresponding to platforms commonly used by developers here.
# https://doc.rust-lang.org/rustc/platform-support.html
platforms = [
"x86_64-unknown-linux-gnu",
"x86_64-apple-darwin",
"aarch64-apple-darwin",
"x86_64-pc-windows-msvc",
]
# Write out exact versions rather than a semver range. (Defaults to false.)
# exact-versions = true
# Don't search in these crates for dependencies, and don't have these crates depend on the
# workspace-hack crate.
#
# Lists most bench- or test-only crates. Also lists optional object_store dependencies as those are
# usually off in development, and influxdb_line_protocol which is published to crates.io separately.
[traversal-excludes]
workspace-members = [
"influxdb_iox_client",
# influxdb_line_protocol is published as a standalone crate, so don't
# depend on workspace
"influxdb-line-protocol",
"influxdb2_client",
"iox_data_generator",
"mutable_batch_tests",
]
third-party = [
{ name = "criterion" },
{ name = "pprof" },
{ name = "tikv-jemalloc-sys" },
]
#
# Packages specified in final-excludes will be removed from the output at the very end.
# This means that any transitive dependencies of theirs will still be included.
#
# Workspace crates excluded from the final output will not depend on the workspace-hack crate,
# and cargo hakari manage-deps will remove dependency edges rather than adding them.
#
[final-excludes]
workspace-members = [
# We don't want trogging to depend on workspace-hack so it can be used directly from the git repository
# but we want its tracing-subscriber feature flags to propagate into the workspace-hack
"trogging"
]

View File

@ -1,11 +1,13 @@
# EditorConfig helps us maintain consistent formatting on non-source files.
# Visit https://editorconfig.org/ for details on how to configure your editor to respect these settings.
# This is the terminal .editorconfig in this repository.
root = true
# You can't change this to * without the `checkfmt` make target failing due to many
# files that don't adhere to this.
[*.yml]
[*]
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
indent_style = space
[{Dockerfile*,*.proto}]
indent_size = 2
[{*.rs,*.toml,*.sh,*.bash}]
indent_size = 4

4
.gitattributes vendored Normal file
View File

@ -0,0 +1,4 @@
generated_types/protos/google/ linguist-generated=true
generated_types/protos/grpc/ linguist-generated=true
generated_types/src/wal_generated.rs linguist-generated=true
trace_exporters/src/thrift/ linguist-generated=true

View File

@ -5,36 +5,39 @@ about: Create a report to help us improve
<!--
Thank you for reporting a bug in InfluxDB.
Thank you for reporting a bug in InfluxDB IOx.
* Please ask usage questions on the Influx Community site.
* https://community.influxdata.com/
* Please add a :+1: or comment on a similar existing bug report instead of opening a new one.
* https://github.com/influxdata/influxdb/issues?utf8=%E2%9C%93&q=is%3Aissue+is%3Aopen+is%3Aclosed+sort%3Aupdated-desc+label%3Akind%2Fbug+
* Please check whether the bug can be reproduced with the latest release.
Have you read the contributing section of the README? Please do if you haven't.
https://github.com/influxdata/influxdb_iox/blob/main/README.md
* Please ask usage questions in the Influx Slack (there is an #influxdb-iox channel).
* https://influxdata.com/slack
* Please don't open duplicate issues; use the search. If there is an existing issue please don't add "+1" or "me too" comments; only add comments with new information.
* Please check whether the bug can be reproduced with tip of main.
* The fastest way to fix a bug is to open a Pull Request.
* https://github.com/influxdata/influxdb/pulls
-->
__Steps to reproduce:__
List the minimal actions needed to reproduce the behavior.
List the minimal actions needed to reproduce the behaviour.
1. ...
2. ...
3. ...
__Expected behavior:__
__Expected behaviour:__
Describe what you expected to happen.
__Actual behavior:__
__Actual behaviour:__
Describe What actually happened.
__Environment info:__
* System info: Run `uname -srm` and copy the output here
* InfluxDB version: Run `influxd version` and copy the output here
* Other relevant environment details: Container runtime, disk info, etc
* Please provide the command you used to build the project, including any `RUSTFLAGS`.
* System info: Run `uname -srm` or similar and copy the output here (we want to know your OS, architecture etc).
* If you're running IOx in a containerised environment then details about that would be helpful.
* Other relevant environment details: disk info, hardware setup etc.
__Config:__
Copy any non-default config values here or attach the full config as a gist or file.
@ -42,15 +45,6 @@ Copy any non-default config values here or attach the full config as a gist or f
<!-- The following sections are only required if relevant. -->
__Logs:__
Include snippet of errors in log.
__Performance:__
Generate profiles with the following commands for bugs related to performance, locking, out of memory (OOM), etc.
```sh
# Commands should be run when the bug is actively happening.
# Note: This command will run for ~30 seconds.
curl -o profiles.tar.gz "http://localhost:8086/debug/pprof/all?cpu=30s"
iostat -xd 1 30 > iostat.txt
# Attach the `profiles.tar.gz` and `iostat.txt` output files.
```
Include snippet of errors in logs or stack traces here.
Sometimes you can get useful information by running the program with the `RUST_BACKTRACE=full` environment variable.
Finally, the IOx server has a `-vv` for verbose logging.

View File

@ -0,0 +1,99 @@
---
name: Developer experience problem
about: Tell us about slow builds, tests, and code editing
---
<!--
Thank you for sharing your development experience concerns with the InfluxDB IOx team.
Have your read the contributing guide?
https://github.com/influxdata/influxdb_iox/blob/main/CONTRIBUTING.md
We welcome your thoughts and feedback on the IOx developer experience.
In particular:
- Is your code editor sluggish?
- How long does your edit-build or edit-test cycle take?
- Is your cardboard sword wearing out faster than usual?
We especially welcome your experience improving your development environment.
- Which hardware tradeoffs work best?
- Which code editor and plugins have been helpful?
- Suggest build configuration or code organization tweaks?
Thank you!
-->
## Describe Development Experience Issue:
<!--
Be descriptive, write a few sentences describing the difficult (or positive) developer experience.
-->
### Steps to reproduce:
<!--
For example:
1. git checkout deadbeef
2. cargo clean
3. cargo build --offline
4. ...
-->
1. ...
2. ...
3. ...
4. ...
### Desired result:
<!--
For example:
Build time less than 5 minutes
-->
### Actual result:
<!--
For example:
Build time 15m32s
-->
## Hardware Environment:
<!--
Describe the hardware you are developing on.
-->
- Package: ... <!-- laptop, workstation, cloud VM, etc -->
- CPU: ...
- Memory: ...
- Block Device: ...
## Operating System:
<!--
If Unix-like, then the command `uname -a` is ideal.
-->
## Code Editing Tool:
<!--
Describe your IDE or code editor, with version.
Some examples:
- vim 8.2
- IntelliJ CLion 2021.2.3
- VS Code 1.61.0
-->
## Build Environment:
[ ] I'm using sccache
<!--
Describe your build environment.
Some examples:
- IDE build trigger
- IDE debug trigger
- bash: cargo build
- zsh: cargo build --release
-->

View File

@ -5,26 +5,27 @@ about: Opening a feature request kicks off a discussion
<!--
Thank you for suggesting an idea to improve InfluxDB.
Thank you for suggesting an idea to improve InfluxDB IOx.
* Please ask usage questions on the Influx Community site.
* https://community.influxdata.com/
* Please add a :+1: or comment on a similar existing feature request instead of opening a new one.
* https://github.com/influxdata/influxdb/issues?utf8=%E2%9C%93&q=is%3Aissue+is%3Aopen+is%3Aclosed+sort%3Aupdated-desc+label%3A%22kind%2Ffeature+request%22+
Have you read the contributing section of the README? Please do if you haven't.
https://github.com/influxdata/influxdb_iox/blob/main/README.md
* Please ask usage questions in the Influx Slack (there is an #influxdb-iox channel).
* https://influxdata.com/slack
* If the feature you're interested in already has an open *or* closed ticket associated with it (please search) please don't add "+1" or "me too" comments. The Github reaction emojis are a way to indicate your support for something and will contribute to prioritisation.
-->
__Use case:__
Why is this important (helps with prioritizing requests)?
__Proposal:__
Short summary of the feature.
__Current behavior:__
__Current behaviour:__
Describe what currently happens.
__Desired behavior:__
__Desired behaviour:__
Describe what you want.
__Alternatives considered:__
Describe other solutions or features you considered.
__Use case:__
Why is this important (helps with prioritizing requests)?

View File

@ -1,30 +1,6 @@
- Closes #
### Required checklist
- [ ] Sample config files updated (both `/etc` folder and `NewDemoConfig` methods) (influxdb and plutonium)
- [ ] openapi swagger.yml updated (if modified API) - link openapi PR
- [ ] Signed [CLA](https://influxdata.com/community/cla/) (if not already signed)
Closes #
### Description
1-3 sentences describing the PR (or link to well written issue)
Describe your proposed changes here.
### Context
Why was this added? What value does it add? What are risks/best practices?
### Affected areas (delete section if not relevant):
List of user-visible changes. As a user, what would I need to see in docs?
Examples:
CLI commands, subcommands, and flags
API changes
Configuration (sample config blocks)
### Severity (delete section if not relevant)
i.e., ("recommend to upgrade immediately", "upgrade at your leisure", etc.)
### Note for reviewers:
Check the semantic commit type:
- Feat: a feature with user-visible changes
- Fix: a bug fix that we might tell a user “upgrade to get this fix for your issue”
- Chore: version bumps, internal doc (e.g. README) changes, code comment updates, code formatting fixes… must not be user facing (except dependency version changes)
- Build: build script changes, CI config changes, build tool updates
- Refactor: non-user-visible refactoring
- Check the PR title: we should be able to put this as a one-liner in the release notes
- [ ] I've read the contributing section of the project [README](https://github.com/influxdata/influxdb_iox/blob/main/README.md).
- [ ] Signed [CLA](https://influxdata.com/community/cla/) (if not already signed).

22
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,22 @@
---
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
schedule:
interval: daily
open-pull-requests-limit: 10
ignore:
# Thrift version needs to match the version of the thrift-compiler used to generate code,
# and therefore this dependency requires a more manual upgrade
#
# Additionally the thrift-compiler version available in standard repos tends to lag
# the latest release significantly, and so updating to the latest version adds friction
- dependency-name: "thrift"
# We want to update arrow and datafusion manually
- dependency-name: "arrow"
- dependency-name: "arrow-*"
- dependency-name: "parquet"
- dependency-name: "datafusion"
- dependency-name: "datafusion-*"

160
.gitignore vendored
View File

@ -1,145 +1,15 @@
# Keep editor-specific, non-project specific ignore rules in global .gitignore:
# https://help.github.com/articles/ignoring-files/#create-a-global-gitignore
vendor
.netrc
.vscode
.vs
.tern-project
.DS_Store
.idea
.cgo_ldflags
# binary databases
influxd.bolt
*.db
*.sqlite
# Files generated in CI
rustup-init.sh
private.key
# TLS keys generated for testing
test.crt
test.key
# Project distribution
/dist
# Project binaries.
/influx
/influxd
/fluxd
/transpilerd
/bin
/internal/cmd/kvmigrate/kvmigrate
# Project tools that you might install with go build.
/editorconfig-checker
/staticcheck
# Generated static assets
/static/data
/static/static_gen.go
/changelog_artifacts
# The below files are generated with make generate
# These are used with the assests go build tag.
chronograf/canned/bin_gen.go
chronograf/dist/dist_gen.go
chronograf/server/swagger_gen.go
# Ignore TSM/TSI testdata binary files
tsdb/tsi1/testdata
tsdb/testdata
# The rest of the file is the .gitignore from the original influxdb repository,
# copied here to prevent mistakenly checking in any binary files
# that may be present but previously ignored if you cloned/developed before v2.
*~
config.json
/bin/
/query/a.out*
# ignore generated files.
cmd/influxd/version.go
# executables
*.test
**/influx_tsm
!**/influx_tsm/
**/influx_stress
!**/influx_stress/
**/influxd
!**/influxd/
**/influx
!**/influx/
**/influxdb
!**/influxdb/
**/influx_inspect
!**/influx_inspect/
/benchmark-tool
/main
/benchmark-storage
godef
gosym
gocode
inspect-raft
# dependencies
out_rpm/
packages/
# autconf
autom4te.cache/
config.log
config.status
# log file
influxdb.log
benchmark.log
# config file
config.toml
# test data files
integration/migration_data/
test-logs/
# man outputs
man/*.xml
man/*.1
man/*.1.gz
# test outputs
/test-results.xml
junit-results
# profile data
/prof
# vendored files
/vendor
# DShell Ignores
.ash_history
.bash_history
.cache/
.cargo/
.dockerignore
.influxdbv2/
.profile
.rustup/
go/
goreleaser-install
**/target
**/*.rs.bk
.idea/
.env
.gdb_history
*.tsm
**/.DS_Store
**/.vscode
heaptrack.*
massif.out.*
perf.data*
perf.svg
perf.txt
valgrind-out.txt
*.pending-snap

2
.kodiak.toml Normal file
View File

@ -0,0 +1,2 @@
version = 1
merge.method = "squash"

8
.yamllint.yml Normal file
View File

@ -0,0 +1,8 @@
---
extends: default
ignore: |
target/
rules:
line-length: disable

200
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,200 @@
# Contributing
Thank you for thinking of contributing! We very much welcome contributions from the community.
To make the process easier and more valuable for everyone involved we have a few rules and guidelines to follow.
Anyone with a Github account is free to file issues on the project.
However, if you want to contribute documentation or code then you will need to sign InfluxData's Individual Contributor License Agreement (CLA), which can be found with more information [on our website].
[on our website]: https://www.influxdata.com/legal/cla/
## Submitting Issues and Feature Requests
Before you file an [issue], please search existing issues in case the same or similar issues have already been filed.
If you find an existing open ticket covering your issue then please avoid adding "👍" or "me too" comments; Github notifications can cause a lot of noise for the project maintainers who triage the back-log.
However, if you have a new piece of information for an existing ticket and you think it may help the investigation or resolution, then please do add it as a comment!
You can signal to the team that you're experiencing an existing issue with one of Github's emoji reactions (these are a good way to add "weight" to an issue from a prioritisation perspective).
### Submitting an Issue
The [New Issue] page has templates for both bug reports and feature requests.
Please fill one of them out!
The issue templates provide details on what information we will find useful to help us fix an issue.
In short though, the more information you can provide us about your environment and what behaviour you're seeing, the easier we can fix the issue.
If you can push a PR with test cases that trigger a defect or bug, even better!
P.S, if you have never written a bug report before, or if you want to brush up on your bug reporting skills, we recommend reading Simon Tatham's essay [How to Report Bugs Effectively].
As well as bug reports we also welcome feature requests (there is a dedicated issue template for these).
Typically, the maintainers will periodically review community feature requests and make decisions about if we want to add them.
For features we don't plan to support we will close the feature request ticket (so, again, please check closed tickets for feature requests before submitting them).
[issue]: https://github.com/influxdata/influxdb_iox/issues/new
[New Issue]: https://github.com/influxdata/influxdb_iox/issues/new
[How to Report Bugs Effectively]: https://www.chiark.greenend.org.uk/~sgtatham/bugs.html
## Contributing Changes
InfluxDB IOx is written mostly in idiomatic Rust—please see the [Style Guide] for more details.
All code must adhere to the `rustfmt` format, and pass all of the `clippy` checks we run in CI (there are more details further down this README).
[Style Guide]: docs/style_guide.md
### Finding Issues To Work On
The [good first issue](https://github.com/influxdata/influxdb_iox/labels/good%20first%20issue) and the [help wanted](https://github.com/influxdata/influxdb_iox/labels/help%20wanted) labels are used to identify issues where we encourage community contributions.
They both indicate issues for which we would welcome independent community contributions, but the former indicates a sub-set of these that are especially good for first-time contributors.
If you want some clarifications or guidance for working on one of these issues, or you simply want to let others know that you're working on one, please leave a comment on the ticket.
[good first issue]: https://github.com/influxdata/influxdb_iox/labels/good%20first%20issue
[help wanted]: https://github.com/influxdata/influxdb_iox/labels/help%20wanted
### Bigger Changes
If you're planning to submit significant changes, even if it relates to existing tickets **please** talk to the project maintainers first!
The easiest way to do this is to open up a new ticket, describing the changes you plan to make and *why* you plan to make them. Changes that may seem obviously good to you, are not always obvious to everyone else.
Example of changes where we would encourage up-front communication:
* new IOx features;
* significant refactors that move code between modules/crates etc;
* performance improvements involving new concurrency patterns or the use of `unsafe` code;
* API-breaking changes, or changes that require a data migration;
* any changes that risk the durability or correctness of data.
We are always excited to have community involvement but we can't accept everything.
To avoid having your hard work rejected the best approach to start a discussion first.
Further, please don't expect us to accept significant changes without new test coverage, and/or in the case of performance changes benchmarks that show the improvements.
### Making a PR
To open a PR you will need to have a Github account.
Fork the `influxdb_iox` repo and work on a branch on your fork.
When you have completed your changes, or you want some incremental feedback make a Pull Request to InfluxDB IOx [here].
If you want to discuss some work in progress then please prefix `[WIP]` to the
PR title.
For PRs that you consider ready for review, verify the following locally before you submit it:
* you have a coherent set of logical commits, with messages conforming to the [Conventional Commits] specification;
* all the tests and/or benchmarks pass, including documentation tests;
* the code is correctly formatted and all `clippy` checks pass; and
* you haven't left any "code cruft" (commented out code blocks etc).
There are some tips on verifying the above in the [next section](#running-tests).
**After** submitting a PR, you should:
* verify that all CI status checks pass and the PR is 💚;
* ask for help on the PR if any of the status checks are 🔴, and you don't know why;
* wait patiently for one of the team to review your PR, which could take a few days.
[here]: https://github.com/influxdata/influxdb_iox/compare
[Conventional Commits]: https://www.conventionalcommits.org/en/v1.0.0/
## Running Tests
The `cargo` build tool runs tests as well. Run:
```shell
cargo test --workspace
```
### Enabling logging in tests
To enable logging to stderr during a run of `cargo test` set the Rust
`RUST_LOG` environment variable. For example, to see all INFO messages:
```shell
RUST_LOG=info cargo test --workspace
```
Since this feature uses
[`EnvFilter`](https://docs.rs/tracing-subscriber/0.2.15/tracing_subscriber/filter/struct.EnvFilter.html) internally, you
can use all the features of that crate. For example, to disable the
(somewhat noisy) logs in some h2 modules, you can use a value of
`RUST_LOG` such as:
```shell
RUST_LOG=debug,hyper::proto::h1=info,h2=info cargo test --workspace
```
See [logging.md](docs/logging.md) for more information on logging.
### End-to-End Tests
There are end-to-end tests that spin up a server and make requests via the client library and API. They can be found in `tests/end_to_end_cases`
These require additional setup as described in [testing.md](docs/testing.md).
### Visually showing explain plans
Some query plans are output in the log in [graphviz](https://graphviz.org/) format. To display them you can use the `tools/iplan` helper.
For example, if you want to display this plan:
```
// Begin DataFusion GraphViz Plan (see https://graphviz.org)
digraph {
subgraph cluster_1
{
graph[label="LogicalPlan"]
2[shape=box label="SchemaPivot"]
3[shape=box label="Projection: "]
2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]
4[shape=box label="Filter: Int64(0) LtEq #time And #time Lt Int64(10000) And #host Eq Utf8(_server01_)"]
3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]
5[shape=box label="TableScan: attributes projection=None"]
4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]
}
subgraph cluster_6
{
graph[label="Detailed LogicalPlan"]
7[shape=box label="SchemaPivot\nSchema: [non_null_column:Utf8]"]
8[shape=box label="Projection: \nSchema: []"]
7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]
9[shape=box label="Filter: Int64(0) LtEq #time And #time Lt Int64(10000) And #host Eq Utf8(_server01_)\nSchema: [color:Utf8;N, time:Int64]"]
8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]
10[shape=box label="TableScan: attributes projection=None\nSchema: [color:Utf8;N, time:Int64]"]
9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]
}
}
// End DataFusion GraphViz Plan
```
You can pipe it to `iplan` and render as a .pdf
## Running `rustfmt` and `clippy`
CI will check the code formatting with [`rustfmt`] and Rust best practices with [`clippy`].
To automatically format your code according to `rustfmt` style, first make sure `rustfmt` is installed using `rustup`:
```shell
rustup component add rustfmt
```
Then, whenever you make a change and want to reformat, run:
```shell
cargo fmt --all
```
Similarly with `clippy`, install with:
```shell
rustup component add clippy
```
And run with:
```shell
cargo clippy --all-targets --workspace -- -D warnings
```
[`rustfmt`]: https://github.com/rust-lang/rustfmt
[`clippy`]: https://github.com/rust-lang/rust-clippy
## Distributed Tracing
See [tracing.md](docs/tracing.md) for more information on the distributed tracing functionality within IOx

7273
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

179
Cargo.toml Normal file
View File

@ -0,0 +1,179 @@
[workspace]
# In alphabetical order
members = [
"arrow_util",
"authz",
"backoff",
"cache_system",
"clap_blocks",
"client_util",
"compactor_test_utils",
"compactor",
"compactor_scheduler",
"data_types",
"datafusion_util",
"dml",
"executor",
"flightsql",
"garbage_collector",
"generated_types",
"gossip",
"gossip_compaction",
"gossip_parquet_file",
"gossip_schema",
"grpc-binary-logger-proto",
"grpc-binary-logger-test-proto",
"grpc-binary-logger",
"import_export",
"influxdb_influxql_parser",
"influxdb_iox_client",
"influxdb_iox",
"influxdb_line_protocol",
"influxdb_storage_client",
"influxdb_tsm",
"influxdb2_client",
"influxrpc_parser",
"ingest_structure",
"ingester_query_grpc",
"ingester_query_client",
"ingester_test_ctx",
"ingester",
"iox_catalog",
"iox_data_generator",
"iox_query_influxql",
"iox_query_influxrpc",
"iox_query",
"iox_tests",
"iox_time",
"ioxd_common",
"ioxd_compactor",
"ioxd_garbage_collector",
"ioxd_ingester",
"ioxd_querier",
"ioxd_router",
"ioxd_test",
"logfmt",
"metric_exporters",
"metric",
"mutable_batch_lp",
"mutable_batch_pb",
"mutable_batch_tests",
"mutable_batch",
"object_store_metrics",
"observability_deps",
"panic_logging",
"parquet_file",
"parquet_to_line_protocol",
"predicate",
"querier",
"query_functions",
"router",
"schema",
"service_common",
"service_grpc_catalog",
"service_grpc_flight",
"service_grpc_influxrpc",
"service_grpc_namespace",
"service_grpc_object_store",
"service_grpc_schema",
"service_grpc_table",
"service_grpc_testing",
"sharder",
"sqlx-hotswap-pool",
"test_helpers_end_to_end",
"test_helpers",
"tokio_metrics_bridge",
"trace_exporters",
"trace_http",
"trace",
"tracker",
"trogging",
"wal_inspect",
"wal",
"workspace-hack",
]
default-members = ["influxdb_iox"]
resolver = "2"
exclude = [
"*.md",
"*.txt",
".circleci/",
".editorconfig",
".git*",
".github/",
".kodiak.toml",
"Dockerfile*",
"LICENSE*",
"buf.yaml",
"docker/",
"docs/",
"massif.out.*",
"perf/",
"scripts/",
"test_bench/",
"test_fixtures/",
"tools/",
]
[workspace.package]
version = "0.1.0"
authors = ["IOx Project Developers"]
edition = "2021"
license = "MIT OR Apache-2.0"
[workspace.dependencies]
arrow = { version = "46.0.0" }
arrow-flight = { version = "46.0.0" }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "81f33b0e27f5694348cd953a937203d835b57178", default-features = false }
datafusion-proto = { git = "https://github.com/apache/arrow-datafusion.git", rev = "81f33b0e27f5694348cd953a937203d835b57178" }
hashbrown = { version = "0.14.0" }
object_store = { version = "0.7.0" }
parquet = { version = "46.0.0" }
tonic = { version = "0.9.2", features = ["tls", "tls-webpki-roots"] }
tonic-build = { version = "0.9.2" }
tonic-health = { version = "0.9.2" }
tonic-reflection = { version = "0.9.2" }
# This profile optimizes for runtime performance and small binary size at the expense of longer
# build times. It's most suitable for final release builds.
[profile.release]
codegen-units = 16
debug = true
lto = "thin"
[profile.bench]
debug = true
# This profile optimizes for short build times at the expense of larger binary size and slower
# runtime performance. It's most suitable for development iterations.
[profile.quick-release]
inherits = "release"
codegen-units = 16
lto = false
incremental = true
# Per insta docs: https://insta.rs/docs/quickstart/#optional-faster-runs
[profile.dev.package.insta]
opt-level = 3
[profile.dev.package.similar]
opt-level = 3
[patch.crates-io]
# Can remove after arrow 47 is released
# Pin to https://github.com/apache/arrow-rs/pull/4790
# To get fixes for
# - https://github.com/apache/arrow-rs/issues/4788,
# - https://github.com/apache/arrow-rs/pull/4799
arrow = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }
arrow-array = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }
arrow-buffer = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }
arrow-schema = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }
arrow-select = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }
arrow-string = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }
arrow-ord = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }
arrow-flight = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }
parquet = { git = "https://github.com/alamb/arrow-rs.git", rev = "7c236c06bfb78c0c877055c1617d9373971511a5" }

61
Dockerfile Normal file
View File

@ -0,0 +1,61 @@
#syntax=docker/dockerfile:1.2
ARG RUST_VERSION=1.57
FROM rust:${RUST_VERSION}-slim-bookworm as build
# cache mounts below may already exist and owned by root
USER root
RUN apt update \
&& apt install --yes binutils build-essential pkg-config libssl-dev clang lld git protobuf-compiler \
&& rm -rf /var/lib/{apt,dpkg,cache,log}
# Build influxdb_iox
COPY . /influxdb_iox
WORKDIR /influxdb_iox
ARG CARGO_INCREMENTAL=yes
ARG CARGO_NET_GIT_FETCH_WITH_CLI=false
ARG PROFILE=release
ARG FEATURES=aws,gcp,azure,jemalloc_replacing_malloc
ARG PACKAGE=influxdb_iox
ENV CARGO_INCREMENTAL=$CARGO_INCREMENTAL \
CARGO_NET_GIT_FETCH_WITH_CLI=$CARGO_NET_GIT_FETCH_WITH_CLI \
PROFILE=$PROFILE \
FEATURES=$FEATURES \
PACKAGE=$PACKAGE
RUN \
--mount=type=cache,id=influxdb_iox_rustup,sharing=locked,target=/usr/local/rustup \
--mount=type=cache,id=influxdb_iox_registry,sharing=locked,target=/usr/local/cargo/registry \
--mount=type=cache,id=influxdb_iox_git,sharing=locked,target=/usr/local/cargo/git \
--mount=type=cache,id=influxdb_iox_target,sharing=locked,target=/influxdb_iox/target \
du -cshx /usr/local/rustup /usr/local/cargo/registry /usr/local/cargo/git /influxdb_iox/target && \
cargo build --target-dir /influxdb_iox/target --package="$PACKAGE" --profile="$PROFILE" --no-default-features --features="$FEATURES" && \
objcopy --compress-debug-sections "target/$PROFILE/$PACKAGE" && \
cp "/influxdb_iox/target/$PROFILE/$PACKAGE" /root/$PACKAGE && \
du -cshx /usr/local/rustup /usr/local/cargo/registry /usr/local/cargo/git /influxdb_iox/target
FROM debian:bookworm-slim
RUN apt update \
&& apt install --yes ca-certificates gettext-base libssl3 --no-install-recommends \
&& rm -rf /var/lib/{apt,dpkg,cache,log} \
&& groupadd --gid 1500 iox \
&& useradd --uid 1500 --gid iox --shell /bin/bash --create-home iox
USER iox
RUN mkdir ~/.influxdb_iox
ARG PACKAGE=influxdb_iox
ENV PACKAGE=$PACKAGE
COPY --from=build "/root/$PACKAGE" "/usr/bin/$PACKAGE"
COPY docker/entrypoint.sh /usr/bin/entrypoint.sh
EXPOSE 8080 8082
ENTRYPOINT ["/usr/bin/entrypoint.sh"]
CMD ["run"]

7
Dockerfile.dockerignore Normal file
View File

@ -0,0 +1,7 @@
.*/
target/
tests/
docker/
!.cargo/
!.git/
!docker/entrypoint.sh

201
LICENSE-APACHE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

25
LICENSE-MIT Normal file
View File

@ -0,0 +1,25 @@
Copyright (c) 2020 InfluxData
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

371
README.md Normal file
View File

@ -0,0 +1,371 @@
# InfluxDB IOx
InfluxDB IOx (short for Iron Oxide, pronounced InfluxDB "eye-ox") is the core of InfluxDB, an open source time series database.
The name is in homage to Rust, the language this project is written in.
It is built using [Apache Arrow](https://arrow.apache.org/) and [DataFusion](https://arrow.apache.org/datafusion/) among other technologies.
InfluxDB IOx aims to be:
* The core of InfluxDB; providing industry standard SQL, InfluxQL, and Flux
* An in-memory columnar store using object storage for persistence
* A fast analytic database for structured and semi-structured events (like logs and tracing data)
* A system for defining replication (synchronous, asynchronous, push and pull) and partitioning rules for InfluxDB time series data and tabular analytics data
* A system supporting real-time subscriptions
* A processor that can transform and do arbitrary computation on time series and event data as it arrives
* An analytic database built for data science, supporting Apache Arrow Flight for fast data transfer
Persistence is through Parquet files in object storage.
It is a design goal to support integration with other big data systems through object storage and Parquet specifically.
For more details on the motivation behind the project and some of our goals, read through the [InfluxDB IOx announcement blog post](https://www.influxdata.com/blog/announcing-influxdb-iox/).
If you prefer a video that covers a little bit of InfluxDB history and high level goals for [InfluxDB IOx you can watch Paul Dix's announcement talk from InfluxDays NA 2020](https://www.youtube.com/watch?v=pnwkAAyMp18).
For more details on the motivation behind the selection of [Apache Arrow, Flight and Parquet, read this](https://www.influxdata.com/blog/apache-arrow-parquet-flight-and-their-ecosystem-are-a-game-changer-for-olap/).
## Platforms
Our current goal is that the following platforms will be able to run InfluxDB IOx.
* Linux x86 (`x86_64-unknown-linux-gnu`)
* Darwin x86 (`x86_64-apple-darwin`)
* Darwin arm (`aarch64-apple-darwin`)
## Project Status
This project is in active development, which is why we're not producing builds yet.
If you would like contact the InfluxDB IOx developers,
join the [InfluxData Community Slack](https://influxdata.com/slack) and look for the #influxdb_iox channel.
We're also hosting monthly tech talks and community office hours on the project on the 2nd Wednesday of the month at 8:30 AM Pacific Time.
* [Signup for upcoming IOx tech talks](https://www.influxdata.com/community-showcase/influxdb-tech-talks)
* [Watch past IOx tech talks](https://www.youtube.com/playlist?list=PLYt2jfZorkDp-PKBS05kf2Yx2NrRyPAAz)
## Get started
1. [Install dependencies](#install-dependencies)
1. [Clone the repository](#clone-the-repository)
1. [Configure the server](#configure-the-server)
1. [Compiling and Running](#compiling-and-running)
(You can also [build a Docker image](#build-a-docker-image-optional) to run InfluxDB IOx.)
1. [Write and read data](#write-and-read-data)
1. [Use the CLI](#use-the-cli)
1. [Use InfluxDB 2.0 API compatibility](#use-influxdb-20-api-compatibility)
1. [Run health checks](#run-health-checks)
1. [Manually call the gRPC API](#manually-call-the-grpc-api)
### Install dependencies
To compile and run InfluxDB IOx from source, you'll need the following:
* [Rust](#rust)
* [Clang](#clang)
* [lld (on Linux)](#lld)
* [protoc (on Apple Silicon)](#protoc)
* [Postgres](#postgres)
#### Rust
The easiest way to install Rust is to use [`rustup`](https://rustup.rs/), a Rust version manager.
Follow the instructions for your operating system on the `rustup` site.
`rustup` will check the [`rust-toolchain`](./rust-toolchain.toml) file and automatically install and use the correct Rust version for you.
#### C/C++ Compiler
You need some C/C++ compiler for some non-Rust dependencies like [`zstd`](https://crates.io/crates/zstd).
#### lld
If you are building InfluxDB IOx on Linux then you will need to ensure you have installed the `lld` LLVM linker.
Check if you have already installed it by running `lld -version`.
```shell
lld -version
lld is a generic driver.
Invoke ld.lld (Unix), ld64.lld (macOS), lld-link (Windows), wasm-ld (WebAssembly) instead
```
If `lld` is not already present, it can typically be installed with the system package manager.
#### protoc
Prost no longer bundles a `protoc` binary.
For instructions on how to install `protoc`, refer to the [official gRPC documentation](https://grpc.io/docs/protoc-installation/).
IOx should then build correctly.
#### Postgres
The catalog is stored in Postgres (unless you're running in ephemeral mode). Postgres can be installed via Homebrew:
```shell
brew install postgresql
```
then follow the instructions for starting Postgres either at system startup or on-demand.
### Clone the repository
Clone this repository using `git`.
If you use the `git` command line, this looks like:
```shell
git clone git@github.com:influxdata/influxdb_iox.git
```
Then change into the directory containing the code:
```shell
cd influxdb_iox
```
The rest of these instructions assume you are in this directory.
### Configure the server
InfluxDB IOx can be configured using either environment variables or a configuration file,
making it suitable for deployment in containerized environments.
For a list of configuration options, run `influxdb_iox --help`, after installing IOx.
For configuration options for specific subcommands, run `influxdb_iox <subcommand> --help`.
To use a configuration file, use a `.env` file in the working directory.
See the provided [example configuration file](docs/env.example).
To use the example configuration file, run:
```shell
cp docs/env.example .env
```
### Compiling and Running
InfluxDB IOx is built using Cargo, Rust's package manager and build tool.
To compile for development, run:
```shell
cargo build
```
To compile for release and install the `influxdb_iox` binary in your path (so you can run `influxdb_iox` directly) do:
```shell
# from within the main `influxdb_iox` checkout
cargo install --path influxdb_iox
```
This creates a binary at `target/debug/influxdb_iox`.
### Build a Docker image (optional)
Building the Docker image requires:
* Docker 18.09+
* BuildKit
To [enable BuildKit] by default, set `{ "features": { "buildkit": true } }` in the Docker engine configuration,
or run `docker build` with`DOCKER_BUILDKIT=1`
To build the Docker image:
```shell
DOCKER_BUILDKIT=1 docker build .
```
[Enable BuildKit]: https://docs.docker.com/develop/develop-images/build_enhancements/#to-enable-buildkit-builds
#### Local filesystem testing mode
InfluxDB IOx supports testing backed by the local filesystem.
> **Note**
>
> This mode should NOT be used for production systems: it will have poor performance and limited tuning knobs are available.
To run IOx in local testing mode, use:
```shell
./target/debug/influxdb_iox
# shorthand for
./target/debug/influxdb_iox run all-in-one
```
This will start an "all-in-one" IOx server with the following configuration:
1. File backed catalog (sqlite), object store, and write ahead log (wal) stored under `<HOMEDIR>/.influxdb_iox`
2. HTTP `v2` api server on port `8080`, querier gRPC server on port `8082` and several ports for other internal services.
You can also change the configuration in limited ways, such as choosing a different data directory:
```shell
./target/debug/influxdb_iox run all-in-one --data-dir=/tmp/iox_data
```
#### Compile and run
Rather than building and running the binary in `target`, you can also compile and run with one
command:
```shell
cargo run -- run all-in-one
```
#### Release mode for performance testing
To compile for performance testing, build in release mode then use the binary in `target/release`:
```shell
cargo build --release
./target/release/influxdb_iox run all-in-one
```
You can also compile and run in release mode with one step:
```shell
cargo run --release -- run all-in-one
```
#### Running tests
You can run tests using:
```shell
cargo test --all
```
See [docs/testing.md] for more information
### Write and read data
Data can be written to InfluxDB IOx by sending [line protocol] format to the `/api/v2/write` endpoint or using the CLI.
For example, assuming you are running in local mode, this command will send data in the `test_fixtures/lineproto/metrics.lp` file to the `company_sensors` namespace.
```shell
./target/debug/influxdb_iox -vv write company_sensors test_fixtures/lineproto/metrics.lp --host http://localhost:8080
```
Note that `--host http://localhost:8080` is required as the `/v2/api` endpoint is hosted on port `8080` while the default is the querier gRPC port `8082`.
To query the data stored in the `company_sensors` namespace:
```shell
./target/debug/influxdb_iox query company_sensors "SELECT * FROM cpu LIMIT 10"
```
### Use the CLI
InfluxDB IOx is packaged as a binary with commands to start the IOx server,
as well as a CLI interface for interacting with and configuring such servers.
The CLI itself is documented via built-in help which you can access by running `influxdb_iox --help`
### Use InfluxDB 2.0 API compatibility
InfluxDB IOx allows seamless interoperability with InfluxDB 2.0.
Where InfluxDB 2.0 stores data in organizations and buckets,
InfluxDB IOx stores data in _namespaces_.
IOx maps `organization` and `bucket` pairs to namespaces with the two parts separated by an underscore (`_`):
`organization_bucket`.
Here's an example using [`curl`] to send data into the `company_sensors` namespace using the InfluxDB 2.0 `/api/v2/write` API:
```shell
curl -v "http://127.0.0.1:8080/api/v2/write?org=company&bucket=sensors" --data-binary @test_fixtures/lineproto/metrics.lp
```
[line protocol]: https://docs.influxdata.com/influxdb/v2.6/reference/syntax/line-protocol/
[`curl`]: https://curl.se/
### Run health checks
The HTTP API exposes a healthcheck endpoint at `/health`
```console
$ curl http://127.0.0.1:8080/health
OK
```
The gRPC API implements the [gRPC Health Checking Protocol](https://github.com/grpc/grpc/blob/master/doc/health-checking.md).
This can be tested with [`grpc-health-probe`](https://github.com/grpc-ecosystem/grpc-health-probe):
```console
$ grpc_health_probe -addr 127.0.0.1:8082 -service influxdata.platform.storage.Storage
status: SERVING
```
### Manually call the gRPC API
To manually invoke one of the gRPC APIs, use a gRPC CLI client such as [grpcurl](https://github.com/fullstorydev/grpcurl).
Because the gRPC server library in IOx doesn't provide service reflection, you need to pass the IOx `.proto` files to your client
when making requests.
After you install **grpcurl**, you can use the `./scripts/grpcurl` wrapper script to make requests that use the `.proto` files for you--for example:
Use the `list` command to list gRPC API services:
```console
./scripts/grpcurl -plaintext 127.0.0.1:8082 list
```
```console
google.longrunning.Operations
grpc.health.v1.Health
influxdata.iox.authz.v1.IoxAuthorizerService
influxdata.iox.catalog.v1.CatalogService
influxdata.iox.compactor.v1.CompactionService
influxdata.iox.delete.v1.DeleteService
influxdata.iox.ingester.v1.PartitionBufferService
influxdata.iox.ingester.v1.PersistService
influxdata.iox.ingester.v1.ReplicationService
influxdata.iox.ingester.v1.WriteInfoService
influxdata.iox.ingester.v1.WriteService
influxdata.iox.namespace.v1.NamespaceService
influxdata.iox.object_store.v1.ObjectStoreService
influxdata.iox.schema.v1.SchemaService
influxdata.platform.storage.IOxTesting
influxdata.platform.storage.Storage
```
Use the `describe` command to view methods for a service:
```console
./scripts/grpcurl -plaintext 127.0.0.1:8082 describe influxdata.iox.namespace.v1.NamespaceService
```
```console
service NamespaceService {
...
rpc GetNamespaces ( .influxdata.iox.namespace.v1.GetNamespacesRequest ) returns ( .influxdata.iox.namespace.v1.GetNamespacesResponse );
...
}
```
Invoke a method:
```console
./scripts/grpcurl -plaintext 127.0.0.1:8082 influxdata.iox.namespace.v1.NamespaceService.GetNamespaces
```
```console
{
"namespaces": [
{
"id": "1",
"name": "company_sensors"
}
]
}
```
## Contributing
We welcome community contributions from anyone!
Read our [Contributing Guide](CONTRIBUTING.md) for instructions on how to run tests and how to make your first contribution.
## Architecture and Technical Documentation
There are a variety of technical documents describing various parts of IOx in the [docs](docs) directory.

27
arrow_util/Cargo.toml Normal file
View File

@ -0,0 +1,27 @@
[package]
name = "arrow_util"
description = "Apache Arrow utilities"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
# need dyn_cmp_dict feature for comparing dictionary arrays
arrow = { workspace = true, features = ["prettyprint", "dyn_cmp_dict"] }
# used by arrow anyway (needed for printing workaround)
chrono = { version = "0.4", default-features = false }
comfy-table = { version = "7.0", default-features = false }
hashbrown = { workspace = true }
num-traits = "0.2"
once_cell = { version = "1.18", features = ["parking_lot"] }
regex = "1.9.5"
snafu = "0.7"
uuid = "1"
workspace-hack = { version = "0.1", path = "../workspace-hack" }
[dev-dependencies]
datafusion = { workspace = true }
rand = "0.8.3"

588
arrow_util/src/bitset.rs Normal file
View File

@ -0,0 +1,588 @@
use arrow::buffer::{BooleanBuffer, Buffer};
use std::ops::Range;
/// An arrow-compatible mutable bitset implementation
///
/// Note: This currently operates on individual bytes at a time
/// it could be optimised to instead operate on usize blocks
#[derive(Debug, Default, Clone)]
pub struct BitSet {
/// The underlying data
///
/// Data is stored in the least significant bit of a byte first
buffer: Vec<u8>,
/// The length of this mask in bits
len: usize,
}
impl BitSet {
/// Creates a new BitSet
pub fn new() -> Self {
Self::default()
}
/// Creates a new BitSet with `count` unset bits.
pub fn with_size(count: usize) -> Self {
let mut bitset = Self::default();
bitset.append_unset(count);
bitset
}
/// Reserve space for `count` further bits
pub fn reserve(&mut self, count: usize) {
let new_buf_len = (self.len + count + 7) >> 3;
self.buffer.reserve(new_buf_len);
}
/// Appends `count` unset bits
pub fn append_unset(&mut self, count: usize) {
self.len += count;
let new_buf_len = (self.len + 7) >> 3;
self.buffer.resize(new_buf_len, 0);
}
/// Appends `count` set bits
pub fn append_set(&mut self, count: usize) {
let new_len = self.len + count;
let new_buf_len = (new_len + 7) >> 3;
let skew = self.len & 7;
if skew != 0 {
*self.buffer.last_mut().unwrap() |= 0xFF << skew;
}
self.buffer.resize(new_buf_len, 0xFF);
let rem = new_len & 7;
if rem != 0 {
*self.buffer.last_mut().unwrap() &= (1 << rem) - 1;
}
self.len = new_len;
}
/// Truncates the bitset to the provided length
pub fn truncate(&mut self, len: usize) {
let new_buf_len = (len + 7) >> 3;
self.buffer.truncate(new_buf_len);
let overrun = len & 7;
if overrun > 0 {
*self.buffer.last_mut().unwrap() &= (1 << overrun) - 1;
}
self.len = len;
}
/// Extends this [`BitSet`] by the context of `other`
pub fn extend_from(&mut self, other: &BitSet) {
self.append_bits(other.len, &other.buffer)
}
/// Extends this [`BitSet`] by `range` elements in `other`
pub fn extend_from_range(&mut self, other: &BitSet, range: Range<usize>) {
let count = range.end - range.start;
if count == 0 {
return;
}
let start_byte = range.start >> 3;
let end_byte = (range.end + 7) >> 3;
let skew = range.start & 7;
// `append_bits` requires the provided `to_set` to be byte aligned, therefore
// if the range being copied is not byte aligned we must first append
// the leading bits to reach a byte boundary
if skew == 0 {
// No skew can simply append bytes directly
self.append_bits(count, &other.buffer[start_byte..end_byte])
} else if start_byte + 1 == end_byte {
// Append bits from single byte
self.append_bits(count, &[other.buffer[start_byte] >> skew])
} else {
// Append trailing bits from first byte to reach byte boundary, then append
// bits from the remaining byte-aligned mask
let offset = 8 - skew;
self.append_bits(offset, &[other.buffer[start_byte] >> skew]);
self.append_bits(count - offset, &other.buffer[(start_byte + 1)..end_byte]);
}
}
/// Appends `count` boolean values from the slice of packed bits
pub fn append_bits(&mut self, count: usize, to_set: &[u8]) {
assert_eq!((count + 7) >> 3, to_set.len());
let new_len = self.len + count;
let new_buf_len = (new_len + 7) >> 3;
self.buffer.reserve(new_buf_len - self.buffer.len());
let whole_bytes = count >> 3;
let overrun = count & 7;
let skew = self.len & 7;
if skew == 0 {
self.buffer.extend_from_slice(&to_set[..whole_bytes]);
if overrun > 0 {
let masked = to_set[whole_bytes] & ((1 << overrun) - 1);
self.buffer.push(masked)
}
self.len = new_len;
debug_assert_eq!(self.buffer.len(), new_buf_len);
return;
}
for to_set_byte in &to_set[..whole_bytes] {
let low = *to_set_byte << skew;
let high = *to_set_byte >> (8 - skew);
*self.buffer.last_mut().unwrap() |= low;
self.buffer.push(high);
}
if overrun > 0 {
let masked = to_set[whole_bytes] & ((1 << overrun) - 1);
let low = masked << skew;
*self.buffer.last_mut().unwrap() |= low;
if overrun > 8 - skew {
let high = masked >> (8 - skew);
self.buffer.push(high)
}
}
self.len = new_len;
debug_assert_eq!(self.buffer.len(), new_buf_len);
}
/// Sets a given bit
pub fn set(&mut self, idx: usize) {
assert!(idx <= self.len);
let byte_idx = idx >> 3;
let bit_idx = idx & 7;
self.buffer[byte_idx] |= 1 << bit_idx;
}
/// Returns if the given index is set
pub fn get(&self, idx: usize) -> bool {
assert!(idx <= self.len);
let byte_idx = idx >> 3;
let bit_idx = idx & 7;
(self.buffer[byte_idx] >> bit_idx) & 1 != 0
}
/// Converts this BitSet to a buffer compatible with arrows boolean encoding
pub fn to_arrow(&self) -> BooleanBuffer {
let offset = 0;
BooleanBuffer::new(Buffer::from(&self.buffer), offset, self.len)
}
/// Returns the number of values stored in the bitset
pub fn len(&self) -> usize {
self.len
}
/// Returns if this bitset is empty
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Returns the number of bytes used by this bitset
pub fn byte_len(&self) -> usize {
self.buffer.len()
}
/// Return the raw packed bytes used by this bitset
pub fn bytes(&self) -> &[u8] {
&self.buffer
}
/// Return `true` if all bits in the [`BitSet`] are currently set.
pub fn is_all_set(&self) -> bool {
// An empty bitmap has no set bits.
if self.len == 0 {
return false;
}
// Check all the bytes in the bitmap that have all their bits considered
// part of the bit set.
let full_blocks = (self.len / 8).saturating_sub(1);
if !self.buffer.iter().take(full_blocks).all(|&v| v == u8::MAX) {
return false;
}
// Check the last byte of the bitmap that may only be partially part of
// the bit set, and therefore need masking to check only the relevant
// bits.
let mask = match self.len % 8 {
1..=8 => !(0xFF << (self.len % 8)), // LSB mask
0 => 0xFF,
_ => unreachable!(),
};
*self.buffer.last().unwrap() == mask
}
/// Return `true` if all bits in the [`BitSet`] are currently unset.
pub fn is_all_unset(&self) -> bool {
self.buffer.iter().all(|&v| v == 0)
}
}
/// Returns an iterator over set bit positions in increasing order
pub fn iter_set_positions(bytes: &[u8]) -> impl Iterator<Item = usize> + '_ {
iter_set_positions_with_offset(bytes, 0)
}
/// Returns an iterator over set bit positions in increasing order starting
/// at the provided bit offset
pub fn iter_set_positions_with_offset(
bytes: &[u8],
offset: usize,
) -> impl Iterator<Item = usize> + '_ {
let mut byte_idx = offset >> 3;
let mut in_progress = bytes.get(byte_idx).cloned().unwrap_or(0);
let skew = offset & 7;
in_progress &= 0xFF << skew;
std::iter::from_fn(move || loop {
if in_progress != 0 {
let bit_pos = in_progress.trailing_zeros();
in_progress ^= 1 << bit_pos;
return Some((byte_idx << 3) + (bit_pos as usize));
}
byte_idx += 1;
in_progress = *bytes.get(byte_idx)?;
})
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::BooleanBufferBuilder;
use rand::prelude::*;
use rand::rngs::OsRng;
/// Computes a compacted representation of a given bool array
fn compact_bools(bools: &[bool]) -> Vec<u8> {
bools
.chunks(8)
.map(|x| {
let mut collect = 0_u8;
for (idx, set) in x.iter().enumerate() {
if *set {
collect |= 1 << idx
}
}
collect
})
.collect()
}
fn iter_set_bools(bools: &[bool]) -> impl Iterator<Item = usize> + '_ {
bools.iter().enumerate().filter_map(|(x, y)| y.then(|| x))
}
#[test]
fn test_compact_bools() {
let bools = &[
false, false, true, true, false, false, true, false, true, false,
];
let collected = compact_bools(bools);
let indexes: Vec<_> = iter_set_bools(bools).collect();
assert_eq!(collected.as_slice(), &[0b01001100, 0b00000001]);
assert_eq!(indexes.as_slice(), &[2, 3, 6, 8])
}
#[test]
fn test_bit_mask() {
let mut mask = BitSet::new();
mask.append_bits(8, &[0b11111111]);
let d1 = mask.buffer.clone();
mask.append_bits(3, &[0b01010010]);
let d2 = mask.buffer.clone();
mask.append_bits(5, &[0b00010100]);
let d3 = mask.buffer.clone();
mask.append_bits(2, &[0b11110010]);
let d4 = mask.buffer.clone();
mask.append_bits(15, &[0b11011010, 0b01010101]);
let d5 = mask.buffer.clone();
assert_eq!(d1.as_slice(), &[0b11111111]);
assert_eq!(d2.as_slice(), &[0b11111111, 0b00000010]);
assert_eq!(d3.as_slice(), &[0b11111111, 0b10100010]);
assert_eq!(d4.as_slice(), &[0b11111111, 0b10100010, 0b00000010]);
assert_eq!(
d5.as_slice(),
&[0b11111111, 0b10100010, 0b01101010, 0b01010111, 0b00000001]
);
assert!(mask.get(0));
assert!(!mask.get(8));
assert!(mask.get(9));
assert!(mask.get(19));
}
fn make_rng() -> StdRng {
let seed = OsRng.next_u64();
println!("Seed: {seed}");
StdRng::seed_from_u64(seed)
}
#[test]
fn test_bit_mask_all_set() {
let mut mask = BitSet::new();
let mut all_bools = vec![];
let mut rng = make_rng();
for _ in 0..100 {
let mask_length = (rng.next_u32() % 50) as usize;
let bools: Vec<_> = std::iter::repeat(true).take(mask_length).collect();
let collected = compact_bools(&bools);
mask.append_bits(mask_length, &collected);
all_bools.extend_from_slice(&bools);
}
let collected = compact_bools(&all_bools);
assert_eq!(mask.buffer, collected);
let expected_indexes: Vec<_> = iter_set_bools(&all_bools).collect();
let actual_indexes: Vec<_> = iter_set_positions(&mask.buffer).collect();
assert_eq!(expected_indexes, actual_indexes);
}
#[test]
fn test_bit_mask_fuzz() {
let mut mask = BitSet::new();
let mut all_bools = vec![];
let mut rng = make_rng();
for _ in 0..100 {
let mask_length = (rng.next_u32() % 50) as usize;
let bools: Vec<_> = std::iter::from_fn(|| Some(rng.next_u32() & 1 == 0))
.take(mask_length)
.collect();
let collected = compact_bools(&bools);
mask.append_bits(mask_length, &collected);
all_bools.extend_from_slice(&bools);
}
let collected = compact_bools(&all_bools);
assert_eq!(mask.buffer, collected);
let expected_indexes: Vec<_> = iter_set_bools(&all_bools).collect();
let actual_indexes: Vec<_> = iter_set_positions(&mask.buffer).collect();
assert_eq!(expected_indexes, actual_indexes);
if !all_bools.is_empty() {
for _ in 0..10 {
let offset = rng.next_u32() as usize % all_bools.len();
let expected_indexes: Vec<_> = iter_set_bools(&all_bools[offset..])
.map(|x| x + offset)
.collect();
let actual_indexes: Vec<_> =
iter_set_positions_with_offset(&mask.buffer, offset).collect();
assert_eq!(expected_indexes, actual_indexes);
}
}
for index in actual_indexes {
assert!(mask.get(index));
}
}
#[test]
fn test_append_fuzz() {
let mut mask = BitSet::new();
let mut all_bools = vec![];
let mut rng = make_rng();
for _ in 0..100 {
let len = (rng.next_u32() % 32) as usize;
let set = rng.next_u32() & 1 == 0;
match set {
true => mask.append_set(len),
false => mask.append_unset(len),
}
all_bools.extend(std::iter::repeat(set).take(len));
let collected = compact_bools(&all_bools);
assert_eq!(mask.buffer, collected);
}
}
#[test]
fn test_truncate_fuzz() {
let mut mask = BitSet::new();
let mut all_bools = vec![];
let mut rng = make_rng();
for _ in 0..100 {
let mask_length = (rng.next_u32() % 32) as usize;
let bools: Vec<_> = std::iter::from_fn(|| Some(rng.next_u32() & 1 == 0))
.take(mask_length)
.collect();
let collected = compact_bools(&bools);
mask.append_bits(mask_length, &collected);
all_bools.extend_from_slice(&bools);
if !all_bools.is_empty() {
let truncate = rng.next_u32() as usize % all_bools.len();
mask.truncate(truncate);
all_bools.truncate(truncate);
}
let collected = compact_bools(&all_bools);
assert_eq!(mask.buffer, collected);
}
}
#[test]
fn test_extend_range_fuzz() {
let mut rng = make_rng();
let src_len = 32;
let src_bools: Vec<_> = std::iter::from_fn(|| Some(rng.next_u32() & 1 == 0))
.take(src_len)
.collect();
let mut src_mask = BitSet::new();
src_mask.append_bits(src_len, &compact_bools(&src_bools));
let mut dst_bools = Vec::new();
let mut dst_mask = BitSet::new();
for _ in 0..100 {
let a = rng.next_u32() as usize % src_len;
let b = rng.next_u32() as usize % src_len;
let start = a.min(b);
let end = a.max(b);
dst_bools.extend_from_slice(&src_bools[start..end]);
dst_mask.extend_from_range(&src_mask, start..end);
let collected = compact_bools(&dst_bools);
assert_eq!(dst_mask.buffer, collected);
}
}
#[test]
fn test_arrow_compat() {
let bools = &[
false, false, true, true, false, false, true, false, true, false, false, true,
];
let mut builder = BooleanBufferBuilder::new(bools.len());
builder.append_slice(bools);
let buffer = builder.finish();
let collected = compact_bools(bools);
let mut mask = BitSet::new();
mask.append_bits(bools.len(), &collected);
let mask_buffer = mask.to_arrow();
assert_eq!(collected.as_slice(), buffer.values());
assert_eq!(buffer.values(), mask_buffer.into_inner().as_slice());
}
#[test]
#[should_panic = "idx <= self.len"]
fn test_bitset_set_get_out_of_bounds() {
let mut v = BitSet::with_size(4);
// The bitset is of length 4, which is backed by a single byte with 8
// bits of storage capacity.
//
// Accessing bits past the 4 the bitset "contains" should not succeed.
v.get(5);
v.set(5);
}
#[test]
fn test_all_set_unset() {
for i in 1..100 {
let mut v = BitSet::new();
v.append_set(i);
assert!(v.is_all_set());
assert!(!v.is_all_unset());
}
}
#[test]
fn test_all_set_unset_multi_byte() {
let mut v = BitSet::new();
// Bitmap is composed of entirely set bits.
v.append_set(100);
assert!(v.is_all_set());
assert!(!v.is_all_unset());
// Now the bitmap is neither composed of entirely set, nor entirely
// unset bits.
v.append_unset(1);
assert!(!v.is_all_set());
assert!(!v.is_all_unset());
let mut v = BitSet::new();
// Bitmap is composed of entirely unset bits.
v.append_unset(100);
assert!(!v.is_all_set());
assert!(v.is_all_unset());
// And once again, it is neither all set, nor all unset.
v.append_set(1);
assert!(!v.is_all_set());
assert!(!v.is_all_unset());
}
#[test]
fn test_all_set_unset_single_byte() {
let mut v = BitSet::new();
// Bitmap is composed of entirely set bits.
v.append_set(2);
assert!(v.is_all_set());
assert!(!v.is_all_unset());
// Now the bitmap is neither composed of entirely set, nor entirely
// unset bits.
v.append_unset(1);
assert!(!v.is_all_set());
assert!(!v.is_all_unset());
let mut v = BitSet::new();
// Bitmap is composed of entirely unset bits.
v.append_unset(2);
assert!(!v.is_all_set());
assert!(v.is_all_unset());
// And once again, it is neither all set, nor all unset.
v.append_set(1);
assert!(!v.is_all_set());
assert!(!v.is_all_unset());
}
#[test]
fn test_all_set_unset_empty() {
let v = BitSet::new();
assert!(!v.is_all_set());
assert!(v.is_all_unset());
}
}

View File

@ -0,0 +1,302 @@
//! Contains a structure to map from strings to integer symbols based on
//! string interning.
use std::convert::TryFrom;
use arrow::array::{Array, ArrayDataBuilder, DictionaryArray};
use arrow::buffer::NullBuffer;
use arrow::datatypes::{DataType, Int32Type};
use hashbrown::HashMap;
use num_traits::{AsPrimitive, FromPrimitive, Zero};
use snafu::Snafu;
use crate::string::PackedStringArray;
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("duplicate key found {}", key))]
DuplicateKeyFound { key: String },
}
/// A String dictionary that builds on top of `PackedStringArray` adding O(1)
/// index lookups for a given string
///
/// Heavily inspired by the string-interner crate
#[derive(Debug, Clone)]
pub struct StringDictionary<K> {
hash: ahash::RandomState,
/// Used to provide a lookup from string value to key type
///
/// Note: K's hash implementation is not used, instead the raw entry
/// API is used to store keys w.r.t the hash of the strings themselves
///
dedup: HashMap<K, (), ()>,
/// Used to store strings
storage: PackedStringArray<K>,
}
impl<K: AsPrimitive<usize> + FromPrimitive + Zero> Default for StringDictionary<K> {
fn default() -> Self {
Self {
hash: ahash::RandomState::new(),
dedup: Default::default(),
storage: PackedStringArray::new(),
}
}
}
impl<K: AsPrimitive<usize> + FromPrimitive + Zero> StringDictionary<K> {
pub fn new() -> Self {
Default::default()
}
pub fn with_capacity(keys: usize, values: usize) -> StringDictionary<K> {
Self {
hash: Default::default(),
dedup: HashMap::with_capacity_and_hasher(keys, ()),
storage: PackedStringArray::with_capacity(keys, values),
}
}
/// Returns the id corresponding to value, adding an entry for the
/// id if it is not yet present in the dictionary.
pub fn lookup_value_or_insert(&mut self, value: &str) -> K {
use hashbrown::hash_map::RawEntryMut;
let hasher = &self.hash;
let storage = &mut self.storage;
let hash = hash_str(hasher, value);
let entry = self
.dedup
.raw_entry_mut()
.from_hash(hash, |key| value == storage.get(key.as_()).unwrap());
match entry {
RawEntryMut::Occupied(entry) => *entry.into_key(),
RawEntryMut::Vacant(entry) => {
let index = storage.append(value);
let key =
K::from_usize(index).expect("failed to fit string index into dictionary key");
*entry
.insert_with_hasher(hash, key, (), |key| {
let string = storage.get(key.as_()).unwrap();
hash_str(hasher, string)
})
.0
}
}
}
/// Returns the ID in self.dictionary that corresponds to `value`, if any.
pub fn lookup_value(&self, value: &str) -> Option<K> {
let hash = hash_str(&self.hash, value);
self.dedup
.raw_entry()
.from_hash(hash, |key| value == self.storage.get(key.as_()).unwrap())
.map(|(&symbol, &())| symbol)
}
/// Returns the str in self.dictionary that corresponds to `id`
pub fn lookup_id(&self, id: K) -> Option<&str> {
self.storage.get(id.as_())
}
pub fn size(&self) -> usize {
self.storage.size() + self.dedup.len() * std::mem::size_of::<K>()
}
pub fn values(&self) -> &PackedStringArray<K> {
&self.storage
}
pub fn into_inner(self) -> PackedStringArray<K> {
self.storage
}
/// Truncates this dictionary removing all keys larger than `id`
pub fn truncate(&mut self, id: K) {
let id = id.as_();
self.dedup.retain(|k, _| k.as_() <= id);
self.storage.truncate(id + 1)
}
/// Clears this dictionary removing all elements
pub fn clear(&mut self) {
self.storage.clear();
self.dedup.clear()
}
}
fn hash_str(hasher: &ahash::RandomState, value: &str) -> u64 {
use std::hash::{BuildHasher, Hash, Hasher};
let mut state = hasher.build_hasher();
value.hash(&mut state);
state.finish()
}
impl StringDictionary<i32> {
/// Convert to an arrow representation with the provided set of
/// keys and an optional null bitmask
pub fn to_arrow<I>(&self, keys: I, nulls: Option<NullBuffer>) -> DictionaryArray<Int32Type>
where
I: IntoIterator<Item = i32>,
I::IntoIter: ExactSizeIterator,
{
// the nulls are recorded in the keys array, the dictionary itself
// is entirely non null
let dictionary_nulls = None;
let keys = keys.into_iter();
let array_data = ArrayDataBuilder::new(DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
))
.len(keys.len())
.add_buffer(keys.collect())
.add_child_data(self.storage.to_arrow(dictionary_nulls).into_data())
.nulls(nulls)
// TODO consider skipping the validation checks by using
// `build_unchecked()`
.build()
.expect("Valid array data");
DictionaryArray::<Int32Type>::from(array_data)
}
}
impl<K> TryFrom<PackedStringArray<K>> for StringDictionary<K>
where
K: AsPrimitive<usize> + FromPrimitive + Zero,
{
type Error = Error;
fn try_from(storage: PackedStringArray<K>) -> Result<Self, Error> {
use hashbrown::hash_map::RawEntryMut;
let hasher = ahash::RandomState::new();
let mut dedup: HashMap<K, (), ()> = HashMap::with_capacity_and_hasher(storage.len(), ());
for (idx, value) in storage.iter().enumerate() {
let hash = hash_str(&hasher, value);
let entry = dedup
.raw_entry_mut()
.from_hash(hash, |key| value == storage.get(key.as_()).unwrap());
match entry {
RawEntryMut::Occupied(_) => {
return Err(Error::DuplicateKeyFound {
key: value.to_string(),
})
}
RawEntryMut::Vacant(entry) => {
let key =
K::from_usize(idx).expect("failed to fit string index into dictionary key");
entry.insert_with_hasher(hash, key, (), |key| {
let string = storage.get(key.as_()).unwrap();
hash_str(&hasher, string)
});
}
}
}
Ok(Self {
hash: hasher,
dedup,
storage,
})
}
}
#[cfg(test)]
mod test {
use std::convert::TryInto;
use super::*;
#[test]
fn test_dictionary() {
let mut dictionary = StringDictionary::<i32>::new();
let id1 = dictionary.lookup_value_or_insert("cupcake");
let id2 = dictionary.lookup_value_or_insert("cupcake");
let id3 = dictionary.lookup_value_or_insert("womble");
let id4 = dictionary.lookup_value("cupcake").unwrap();
let id5 = dictionary.lookup_value("womble").unwrap();
let cupcake = dictionary.lookup_id(id4).unwrap();
let womble = dictionary.lookup_id(id5).unwrap();
let arrow_expected = arrow::array::StringArray::from(vec!["cupcake", "womble"]);
let arrow_actual = dictionary.values().to_arrow(None);
assert_eq!(id1, id2);
assert_eq!(id1, id4);
assert_ne!(id1, id3);
assert_eq!(id3, id5);
assert_eq!(cupcake, "cupcake");
assert_eq!(womble, "womble");
assert!(dictionary.lookup_value("foo").is_none());
assert!(dictionary.lookup_id(-1).is_none());
assert_eq!(arrow_expected, arrow_actual);
}
#[test]
fn from_string_array() {
let mut data = PackedStringArray::<u64>::new();
data.append("cupcakes");
data.append("foo");
data.append("bingo");
let dictionary: StringDictionary<_> = data.try_into().unwrap();
assert_eq!(dictionary.lookup_value("cupcakes"), Some(0));
assert_eq!(dictionary.lookup_value("foo"), Some(1));
assert_eq!(dictionary.lookup_value("bingo"), Some(2));
assert_eq!(dictionary.lookup_id(0), Some("cupcakes"));
assert_eq!(dictionary.lookup_id(1), Some("foo"));
assert_eq!(dictionary.lookup_id(2), Some("bingo"));
}
#[test]
fn from_string_array_duplicates() {
let mut data = PackedStringArray::<u64>::new();
data.append("cupcakes");
data.append("foo");
data.append("bingo");
data.append("cupcakes");
let err = TryInto::<StringDictionary<_>>::try_into(data).expect_err("expected failure");
assert!(matches!(err, Error::DuplicateKeyFound { key } if &key == "cupcakes"))
}
#[test]
fn test_truncate() {
let mut dictionary = StringDictionary::<i32>::new();
dictionary.lookup_value_or_insert("cupcake");
dictionary.lookup_value_or_insert("cupcake");
dictionary.lookup_value_or_insert("bingo");
let bingo = dictionary.lookup_value_or_insert("bingo");
let bongo = dictionary.lookup_value_or_insert("bongo");
dictionary.lookup_value_or_insert("bingo");
dictionary.lookup_value_or_insert("cupcake");
dictionary.truncate(bingo);
assert_eq!(dictionary.values().len(), 2);
assert_eq!(dictionary.dedup.len(), 2);
assert_eq!(dictionary.lookup_value("cupcake"), Some(0));
assert_eq!(dictionary.lookup_value("bingo"), Some(1));
assert!(dictionary.lookup_value("bongo").is_none());
assert!(dictionary.lookup_id(bongo).is_none());
dictionary.lookup_value_or_insert("bongo");
assert_eq!(dictionary.lookup_value("bongo"), Some(2));
}
}

206
arrow_util/src/display.rs Normal file
View File

@ -0,0 +1,206 @@
use arrow::array::{ArrayRef, DurationNanosecondArray, TimestampNanosecondArray};
use arrow::datatypes::{DataType, TimeUnit};
use arrow::error::{ArrowError, Result};
use arrow::record_batch::RecordBatch;
use comfy_table::{Cell, Table};
use chrono::prelude::*;
/// custom version of
/// [pretty_format_batches](arrow::util::pretty::pretty_format_batches)
/// that displays timestamps using RFC3339 format (e.g. `2021-07-20T23:28:50Z`)
///
/// Should be removed if/when the capability is added upstream to arrow:
/// <https://github.com/apache/arrow-rs/issues/599>
pub fn pretty_format_batches(results: &[RecordBatch]) -> Result<String> {
Ok(create_table(results)?.to_string())
}
/// Convert the value at `column[row]` to a String
///
/// Special cases printing Timestamps in RFC3339 for IOx, otherwise
/// falls back to Arrow's implementation
///
fn array_value_to_string(column: &ArrayRef, row: usize) -> Result<String> {
match column.data_type() {
DataType::Timestamp(TimeUnit::Nanosecond, None) if column.is_valid(row) => {
let ts_column = column
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.unwrap();
let ts_value = ts_column.value(row);
const NANOS_IN_SEC: i64 = 1_000_000_000;
let secs = ts_value / NANOS_IN_SEC;
let nanos = (ts_value - (secs * NANOS_IN_SEC)) as u32;
let ts = NaiveDateTime::from_timestamp_opt(secs, nanos).ok_or_else(|| {
ArrowError::ExternalError(
format!("Cannot process timestamp (secs={secs}, nanos={nanos})").into(),
)
})?;
// treat as UTC
let ts = DateTime::<Utc>::from_naive_utc_and_offset(ts, Utc);
// convert to string in preferred influx format
let use_z = true;
Ok(ts.to_rfc3339_opts(SecondsFormat::AutoSi, use_z))
}
// TODO(edd): see https://github.com/apache/arrow-rs/issues/1168
DataType::Duration(TimeUnit::Nanosecond) if column.is_valid(row) => {
let dur_column = column
.as_any()
.downcast_ref::<DurationNanosecondArray>()
.unwrap();
let duration = std::time::Duration::from_nanos(
dur_column
.value(row)
.try_into()
.map_err(|e| ArrowError::InvalidArgumentError(format!("{e:?}")))?,
);
Ok(format!("{duration:?}"))
}
_ => {
// fallback to arrow's default printing for other types
arrow::util::display::array_value_to_string(column, row)
}
}
}
/// Convert a series of record batches into a table
///
/// NB: COPIED FROM ARROW
fn create_table(results: &[RecordBatch]) -> Result<Table> {
let mut table = Table::new();
table.load_preset("||--+-++| ++++++");
if results.is_empty() {
return Ok(table);
}
let schema = results[0].schema();
let mut header = Vec::new();
for field in schema.fields() {
header.push(Cell::new(field.name()));
}
table.set_header(header);
for (i, batch) in results.iter().enumerate() {
if batch.schema() != schema {
return Err(ArrowError::SchemaError(format!(
"Batches have different schemas:\n\nFirst:\n{}\n\nBatch {}:\n{}",
schema,
i + 1,
batch.schema()
)));
}
for row in 0..batch.num_rows() {
let mut cells = Vec::new();
for col in 0..batch.num_columns() {
let column = batch.column(col);
cells.push(Cell::new(array_value_to_string(column, row)?));
}
table.add_row(cells);
}
}
Ok(table)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow::{
array::{
ArrayRef, BooleanArray, DictionaryArray, Float64Array, Int64Array, StringArray,
UInt64Array,
},
datatypes::Int32Type,
};
use datafusion::common::assert_contains;
#[test]
fn test_formatting() {
// tests formatting all of the Arrow array types used in IOx
// tags use string dictionary
let dict_array: ArrayRef = Arc::new(
vec![Some("a"), None, Some("b")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
);
// field types
let int64_array: ArrayRef =
Arc::new([Some(-1), None, Some(2)].iter().collect::<Int64Array>());
let uint64_array: ArrayRef =
Arc::new([Some(1), None, Some(2)].iter().collect::<UInt64Array>());
let float64_array: ArrayRef = Arc::new(
[Some(1.0), None, Some(2.0)]
.iter()
.collect::<Float64Array>(),
);
let bool_array: ArrayRef = Arc::new(
[Some(true), None, Some(false)]
.iter()
.collect::<BooleanArray>(),
);
let string_array: ArrayRef = Arc::new(
vec![Some("foo"), None, Some("bar")]
.into_iter()
.collect::<StringArray>(),
);
// timestamp type
let ts_array: ArrayRef = Arc::new(
[None, Some(100), Some(1626823730000000000)]
.iter()
.collect::<TimestampNanosecondArray>(),
);
let batch = RecordBatch::try_from_iter(vec![
("dict", dict_array),
("int64", int64_array),
("uint64", uint64_array),
("float64", float64_array),
("bool", bool_array),
("string", string_array),
("time", ts_array),
])
.unwrap();
let table = pretty_format_batches(&[batch]).unwrap();
let expected = vec![
"+------+-------+--------+---------+-------+--------+--------------------------------+",
"| dict | int64 | uint64 | float64 | bool | string | time |",
"+------+-------+--------+---------+-------+--------+--------------------------------+",
"| a | -1 | 1 | 1.0 | true | foo | |",
"| | | | | | | 1970-01-01T00:00:00.000000100Z |",
"| b | 2 | 2 | 2.0 | false | bar | 2021-07-20T23:28:50Z |",
"+------+-------+--------+---------+-------+--------+--------------------------------+",
];
let actual: Vec<&str> = table.lines().collect();
assert_eq!(
expected, actual,
"Expected:\n\n{expected:#?}\nActual:\n\n{actual:#?}\n"
);
}
#[test]
fn test_pretty_format_batches_checks_schemas() {
let int64_array: ArrayRef = Arc::new([Some(2)].iter().collect::<Int64Array>());
let uint64_array: ArrayRef = Arc::new([Some(2)].iter().collect::<UInt64Array>());
let batch1 = RecordBatch::try_from_iter(vec![("col", int64_array)]).unwrap();
let batch2 = RecordBatch::try_from_iter(vec![("col", uint64_array)]).unwrap();
let err = pretty_format_batches(&[batch1, batch2]).unwrap_err();
assert_contains!(err.to_string(), "Batches have different schemas:");
}
}

26
arrow_util/src/flight.rs Normal file
View File

@ -0,0 +1,26 @@
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
/// Prepare an arrow Schema for transport over the Arrow Flight protocol
///
/// Converts dictionary types to underlying types due to <https://github.com/apache/arrow-rs/issues/3389>
pub fn prepare_schema_for_flight(schema: SchemaRef) -> SchemaRef {
let fields: Fields = schema
.fields()
.iter()
.map(|field| match field.data_type() {
DataType::Dictionary(_, value_type) => Arc::new(
Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
),
_ => Arc::clone(field),
})
.collect();
Arc::new(Schema::new(fields).with_metadata(schema.metadata().clone()))
}

27
arrow_util/src/lib.rs Normal file
View File

@ -0,0 +1,27 @@
#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
#![allow(clippy::clone_on_ref_ptr)]
#![warn(
missing_copy_implementations,
missing_debug_implementations,
clippy::explicit_iter_loop,
// See https://github.com/influxdata/influxdb_iox/pull/1671
clippy::future_not_send,
clippy::clone_on_ref_ptr,
clippy::todo,
clippy::dbg_macro,
unused_crate_dependencies
)]
// Workaround for "unused crate" lint false positives.
use workspace_hack as _;
pub mod bitset;
pub mod dictionary;
pub mod display;
pub mod flight;
pub mod optimize;
pub mod string;
pub mod util;
/// This has a collection of testing helper functions
pub mod test_util;

299
arrow_util/src/optimize.rs Normal file
View File

@ -0,0 +1,299 @@
use std::collections::BTreeSet;
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, DictionaryArray, StringArray};
use arrow::datatypes::{DataType, Int32Type};
use arrow::error::{ArrowError, Result};
use arrow::record_batch::RecordBatch;
use hashbrown::HashMap;
use crate::dictionary::StringDictionary;
/// Takes a record batch and returns a new record batch with dictionaries
/// optimized to contain no duplicate or unreferenced values
///
/// Where the input dictionaries are sorted, the output dictionaries
/// will also be
pub fn optimize_dictionaries(batch: &RecordBatch) -> Result<RecordBatch> {
let schema = batch.schema();
let new_columns = batch
.columns()
.iter()
.zip(schema.fields())
.map(|(col, field)| match field.data_type() {
DataType::Dictionary(key, value) => optimize_dict_col(col, key, value),
_ => Ok(Arc::clone(col)),
})
.collect::<Result<Vec<_>>>()?;
RecordBatch::try_new(schema, new_columns)
}
/// Optimizes the dictionaries for a column
fn optimize_dict_col(
col: &ArrayRef,
key_type: &DataType,
value_type: &DataType,
) -> Result<ArrayRef> {
if key_type != &DataType::Int32 {
return Err(ArrowError::NotYetImplemented(format!(
"truncating non-Int32 dictionaries not supported: {key_type}"
)));
}
if value_type != &DataType::Utf8 {
return Err(ArrowError::NotYetImplemented(format!(
"truncating non-string dictionaries not supported: {value_type}"
)));
}
let col = col
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.expect("unexpected datatype");
let keys = col.keys();
let values = col.values();
let values = values
.as_any()
.downcast_ref::<StringArray>()
.expect("unexpected datatype");
// The total length of the resulting values array
let mut values_len = 0_usize;
// Keys that appear in the values array
// Use a BTreeSet to preserve the order of the dictionary
let mut used_keys = BTreeSet::new();
for key in keys.iter().flatten() {
if used_keys.insert(key) {
values_len += values.value_length(key as usize) as usize;
}
}
// Then perform deduplication
let mut new_dictionary = StringDictionary::with_capacity(used_keys.len(), values_len);
let mut old_to_new_idx: HashMap<i32, i32> = HashMap::with_capacity(used_keys.len());
for key in used_keys {
let new_key = new_dictionary.lookup_value_or_insert(values.value(key as usize));
old_to_new_idx.insert(key, new_key);
}
let new_keys = keys.iter().map(|x| match x {
Some(x) => *old_to_new_idx.get(&x).expect("no mapping found"),
None => -1,
});
let nulls = keys.nulls().cloned();
Ok(Arc::new(new_dictionary.to_arrow(new_keys, nulls)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate as arrow_util;
use crate::assert_batches_eq;
use arrow::array::{ArrayDataBuilder, DictionaryArray, Float64Array, Int32Array, StringArray};
use arrow::compute::concat;
use std::iter::FromIterator;
#[test]
fn test_optimize_dictionaries() {
let values = StringArray::from(vec![
"duplicate",
"duplicate",
"foo",
"boo",
"unused",
"duplicate",
]);
let keys = Int32Array::from(vec![
Some(0),
Some(1),
None,
Some(1),
Some(2),
Some(5),
Some(3),
]);
let batch = RecordBatch::try_from_iter(vec![(
"foo",
Arc::new(build_dict(keys, values)) as ArrayRef,
)])
.unwrap();
let optimized = optimize_dictionaries(&batch).unwrap();
let col = optimized
.column(0)
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap();
let values = col.values();
let values = values.as_any().downcast_ref::<StringArray>().unwrap();
let values = values.iter().flatten().collect::<Vec<_>>();
assert_eq!(values, vec!["duplicate", "foo", "boo"]);
assert_batches_eq!(
vec![
"+-----------+",
"| foo |",
"+-----------+",
"| duplicate |",
"| duplicate |",
"| |",
"| duplicate |",
"| foo |",
"| duplicate |",
"| boo |",
"+-----------+",
],
&[optimized]
);
}
#[test]
fn test_optimize_dictionaries_concat() {
let f1_1 = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);
let t2_1 = DictionaryArray::<Int32Type>::from_iter(vec![
Some("a"),
Some("g"),
Some("a"),
Some("b"),
]);
let t1_1 = DictionaryArray::<Int32Type>::from_iter(vec![
Some("a"),
Some("a"),
Some("b"),
Some("b"),
]);
let f1_2 = Float64Array::from(vec![Some(1.0), Some(5.0), Some(2.0), Some(46.0)]);
let t2_2 = DictionaryArray::<Int32Type>::from_iter(vec![
Some("a"),
Some("b"),
Some("a"),
Some("a"),
]);
let t1_2 = DictionaryArray::<Int32Type>::from_iter(vec![
Some("a"),
Some("d"),
Some("a"),
Some("b"),
]);
let concat = RecordBatch::try_from_iter(vec![
("f1", concat(&[&f1_1, &f1_2]).unwrap()),
("t2", concat(&[&t2_1, &t2_2]).unwrap()),
("t1", concat(&[&t1_1, &t1_2]).unwrap()),
])
.unwrap();
let optimized = optimize_dictionaries(&concat).unwrap();
let col = optimized
.column(optimized.schema().column_with_name("t2").unwrap().0)
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap();
let values = col.values();
let values = values.as_any().downcast_ref::<StringArray>().unwrap();
let values = values.iter().flatten().collect::<Vec<_>>();
assert_eq!(values, vec!["a", "g", "b"]);
let col = optimized
.column(optimized.schema().column_with_name("t1").unwrap().0)
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap();
let values = col.values();
let values = values.as_any().downcast_ref::<StringArray>().unwrap();
let values = values.iter().flatten().collect::<Vec<_>>();
assert_eq!(values, vec!["a", "b", "d"]);
assert_batches_eq!(
vec![
"+------+----+----+",
"| f1 | t2 | t1 |",
"+------+----+----+",
"| 1.0 | a | a |",
"| 2.0 | g | a |",
"| 3.0 | a | b |",
"| 4.0 | b | b |",
"| 1.0 | a | a |",
"| 5.0 | b | d |",
"| 2.0 | a | a |",
"| 46.0 | a | b |",
"+------+----+----+",
],
&[optimized]
);
}
#[test]
fn test_optimize_dictionaries_null() {
let values = StringArray::from(vec!["bananas"]);
let keys = Int32Array::from(vec![None, None, Some(0)]);
let col = Arc::new(build_dict(keys, values)) as ArrayRef;
let col = optimize_dict_col(&col, &DataType::Int32, &DataType::Utf8).unwrap();
let batch = RecordBatch::try_from_iter(vec![("t", col)]).unwrap();
assert_batches_eq!(
vec![
"+---------+",
"| t |",
"+---------+",
"| |",
"| |",
"| bananas |",
"+---------+",
],
&[batch]
);
}
#[test]
fn test_optimize_dictionaries_slice() {
let values = StringArray::from(vec!["bananas"]);
let keys = Int32Array::from(vec![None, Some(0), None]);
let col = Arc::new(build_dict(keys, values)) as ArrayRef;
let col = col.slice(1, 2);
let col = optimize_dict_col(&col, &DataType::Int32, &DataType::Utf8).unwrap();
let batch = RecordBatch::try_from_iter(vec![("t", col)]).unwrap();
assert_batches_eq!(
vec![
"+---------+",
"| t |",
"+---------+",
"| bananas |",
"| |",
"+---------+",
],
&[batch]
);
}
fn build_dict(keys: Int32Array, values: StringArray) -> DictionaryArray<Int32Type> {
let data = ArrayDataBuilder::new(DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
))
.len(keys.len())
.add_buffer(keys.to_data().buffers()[0].clone())
.nulls(keys.nulls().cloned())
.add_child_data(values.into_data())
.build()
.unwrap();
DictionaryArray::from(data)
}
}

319
arrow_util/src/string.rs Normal file
View File

@ -0,0 +1,319 @@
use arrow::array::ArrayDataBuilder;
use arrow::array::StringArray;
use arrow::buffer::Buffer;
use arrow::buffer::NullBuffer;
use num_traits::{AsPrimitive, FromPrimitive, Zero};
use std::fmt::Debug;
use std::ops::Range;
/// A packed string array that stores start and end indexes into
/// a contiguous string slice.
///
/// The type parameter K alters the type used to store the offsets
#[derive(Debug, Clone)]
pub struct PackedStringArray<K> {
/// The start and end offsets of strings stored in storage
offsets: Vec<K>,
/// A contiguous array of string data
storage: String,
}
impl<K: Zero> Default for PackedStringArray<K> {
fn default() -> Self {
Self {
offsets: vec![K::zero()],
storage: String::new(),
}
}
}
impl<K: AsPrimitive<usize> + FromPrimitive + Zero> PackedStringArray<K> {
pub fn new() -> Self {
Self::default()
}
pub fn new_empty(len: usize) -> Self {
Self {
offsets: vec![K::zero(); len + 1],
storage: String::new(),
}
}
pub fn with_capacity(keys: usize, values: usize) -> Self {
let mut offsets = Vec::with_capacity(keys + 1);
offsets.push(K::zero());
Self {
offsets,
storage: String::with_capacity(values),
}
}
/// Append a value
///
/// Returns the index of the appended data
pub fn append(&mut self, data: &str) -> usize {
let id = self.offsets.len() - 1;
let offset = self.storage.len() + data.len();
let offset = K::from_usize(offset).expect("failed to fit into offset type");
self.offsets.push(offset);
self.storage.push_str(data);
id
}
/// Extends this [`PackedStringArray`] by the contents of `other`
pub fn extend_from(&mut self, other: &PackedStringArray<K>) {
let offset = self.storage.len();
self.storage.push_str(other.storage.as_str());
// Copy offsets skipping the first element as this string start delimiter is already
// provided by the end delimiter of the current offsets array
self.offsets.extend(
other
.offsets
.iter()
.skip(1)
.map(|x| K::from_usize(x.as_() + offset).expect("failed to fit into offset type")),
)
}
/// Extends this [`PackedStringArray`] by `range` elements from `other`
pub fn extend_from_range(&mut self, other: &PackedStringArray<K>, range: Range<usize>) {
let first_offset: usize = other.offsets[range.start].as_();
let end_offset: usize = other.offsets[range.end].as_();
let insert_offset = self.storage.len();
self.storage
.push_str(&other.storage[first_offset..end_offset]);
self.offsets.extend(
other.offsets[(range.start + 1)..(range.end + 1)]
.iter()
.map(|x| {
K::from_usize(x.as_() - first_offset + insert_offset)
.expect("failed to fit into offset type")
}),
)
}
/// Get the value at a given index
pub fn get(&self, index: usize) -> Option<&str> {
let start_offset = self.offsets.get(index)?.as_();
let end_offset = self.offsets.get(index + 1)?.as_();
Some(&self.storage[start_offset..end_offset])
}
/// Pads with empty strings to reach length
pub fn extend(&mut self, len: usize) {
let offset = K::from_usize(self.storage.len()).expect("failed to fit into offset type");
self.offsets.resize(self.offsets.len() + len, offset);
}
/// Truncates the array to the given length
pub fn truncate(&mut self, len: usize) {
self.offsets.truncate(len + 1);
let last_idx = self.offsets.last().expect("offsets empty");
self.storage.truncate(last_idx.as_());
}
/// Removes all elements from this array
pub fn clear(&mut self) {
self.offsets.truncate(1);
self.storage.clear();
}
pub fn iter(&self) -> PackedStringIterator<'_, K> {
PackedStringIterator {
array: self,
index: 0,
}
}
/// The number of strings in this array
pub fn len(&self) -> usize {
self.offsets.len() - 1
}
pub fn is_empty(&self) -> bool {
self.offsets.len() == 1
}
/// Return the amount of memory in bytes taken up by this array
pub fn size(&self) -> usize {
self.storage.capacity() + self.offsets.capacity() * std::mem::size_of::<K>()
}
pub fn inner(&self) -> (&[K], &str) {
(&self.offsets, &self.storage)
}
pub fn into_inner(self) -> (Vec<K>, String) {
(self.offsets, self.storage)
}
}
impl PackedStringArray<i32> {
/// Convert to an arrow with an optional null bitmask
pub fn to_arrow(&self, nulls: Option<NullBuffer>) -> StringArray {
let len = self.offsets.len() - 1;
let offsets = Buffer::from_slice_ref(&self.offsets);
let values = Buffer::from(self.storage.as_bytes());
let data = ArrayDataBuilder::new(arrow::datatypes::DataType::Utf8)
.len(len)
.add_buffer(offsets)
.add_buffer(values)
.nulls(nulls)
.build()
// TODO consider skipping the validation checks by using
// `new_unchecked`
.expect("Valid array data");
StringArray::from(data)
}
}
#[derive(Debug)]
pub struct PackedStringIterator<'a, K> {
array: &'a PackedStringArray<K>,
index: usize,
}
impl<'a, K: AsPrimitive<usize> + FromPrimitive + Zero> Iterator for PackedStringIterator<'a, K> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
let item = self.array.get(self.index)?;
self.index += 1;
Some(item)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.array.len() - self.index;
(len, Some(len))
}
}
#[cfg(test)]
mod tests {
use crate::string::PackedStringArray;
#[test]
fn test_storage() {
let mut array = PackedStringArray::<i32>::new();
array.append("hello");
array.append("world");
array.append("cupcake");
assert_eq!(array.get(0).unwrap(), "hello");
assert_eq!(array.get(1).unwrap(), "world");
assert_eq!(array.get(2).unwrap(), "cupcake");
assert!(array.get(-1_i32 as usize).is_none());
assert!(array.get(3).is_none());
array.extend(2);
assert_eq!(array.get(3).unwrap(), "");
assert_eq!(array.get(4).unwrap(), "");
assert!(array.get(5).is_none());
}
#[test]
fn test_empty() {
let array = PackedStringArray::<u8>::new_empty(20);
assert_eq!(array.get(12).unwrap(), "");
assert_eq!(array.get(9).unwrap(), "");
assert_eq!(array.get(3).unwrap(), "");
}
#[test]
fn test_truncate() {
let mut array = PackedStringArray::<i32>::new();
array.append("hello");
array.append("world");
array.append("cupcake");
array.truncate(1);
assert_eq!(array.len(), 1);
assert_eq!(array.get(0).unwrap(), "hello");
array.append("world");
assert_eq!(array.len(), 2);
assert_eq!(array.get(0).unwrap(), "hello");
assert_eq!(array.get(1).unwrap(), "world");
}
#[test]
fn test_extend_from() {
let mut a = PackedStringArray::<i32>::new();
a.append("hello");
a.append("world");
a.append("cupcake");
a.append("");
let mut b = PackedStringArray::<i32>::new();
b.append("foo");
b.append("bar");
a.extend_from(&b);
let a_content: Vec<_> = a.iter().collect();
assert_eq!(
a_content,
vec!["hello", "world", "cupcake", "", "foo", "bar"]
);
}
#[test]
fn test_extend_from_range() {
let mut a = PackedStringArray::<i32>::new();
a.append("hello");
a.append("world");
a.append("cupcake");
a.append("");
let mut b = PackedStringArray::<i32>::new();
b.append("foo");
b.append("bar");
b.append("");
b.append("fiz");
a.extend_from_range(&b, 1..3);
assert_eq!(a.len(), 6);
let a_content: Vec<_> = a.iter().collect();
assert_eq!(a_content, vec!["hello", "world", "cupcake", "", "bar", ""]);
// Should be a no-op
a.extend_from_range(&b, 0..0);
let a_content: Vec<_> = a.iter().collect();
assert_eq!(a_content, vec!["hello", "world", "cupcake", "", "bar", ""]);
a.extend_from_range(&b, 0..1);
let a_content: Vec<_> = a.iter().collect();
assert_eq!(
a_content,
vec!["hello", "world", "cupcake", "", "bar", "", "foo"]
);
a.extend_from_range(&b, 1..4);
let a_content: Vec<_> = a.iter().collect();
assert_eq!(
a_content,
vec!["hello", "world", "cupcake", "", "bar", "", "foo", "bar", "", "fiz"]
);
}
}

418
arrow_util/src/test_util.rs Normal file
View File

@ -0,0 +1,418 @@
//! A collection of testing functions for arrow based code
use std::sync::Arc;
use crate::display::pretty_format_batches;
use arrow::{
array::{new_null_array, ArrayRef, StringArray},
compute::kernels::sort::{lexsort, SortColumn, SortOptions},
datatypes::Schema,
error::ArrowError,
record_batch::RecordBatch,
};
use once_cell::sync::Lazy;
use regex::{Captures, Regex};
use std::{borrow::Cow, collections::HashMap};
use uuid::Uuid;
/// Compares the formatted output with the pretty formatted results of
/// record batches. This is a macro so errors appear on the correct line
///
/// Designed so that failure output can be directly copy/pasted
/// into the test code as expected results.
///
/// Expects to be called about like this:
/// assert_batches_eq(expected_lines: &[&str], chunks: &[RecordBatch])
#[macro_export]
macro_rules! assert_batches_eq {
($EXPECTED_LINES: expr, $CHUNKS: expr) => {
let expected_lines: Vec<String> =
$EXPECTED_LINES.into_iter().map(|s| s.to_string()).collect();
let actual_lines = arrow_util::test_util::batches_to_lines($CHUNKS);
assert_eq!(
expected_lines, actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
};
}
/// Compares formatted output of a record batch with an expected
/// vector of strings in a way that order does not matter.
/// This is a macro so errors appear on the correct line
///
/// Designed so that failure output can be directly copy/pasted
/// into the test code as expected results.
///
/// Expects to be called about like this:
///
/// `assert_batch_sorted_eq!(expected_lines: &[&str], batches: &[RecordBatch])`
#[macro_export]
macro_rules! assert_batches_sorted_eq {
($EXPECTED_LINES: expr, $CHUNKS: expr) => {
let expected_lines: Vec<String> = $EXPECTED_LINES.iter().map(|&s| s.into()).collect();
let expected_lines = arrow_util::test_util::sort_lines(expected_lines);
let actual_lines = arrow_util::test_util::batches_to_sorted_lines($CHUNKS);
assert_eq!(
expected_lines, actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
};
}
/// Converts the [`RecordBatch`]es into a pretty printed output suitable for
/// comparing in tests
///
/// Example:
///
/// ```text
/// "+-----+------+------+--------------------------------+",
/// "| foo | host | load | time |",
/// "+-----+------+------+--------------------------------+",
/// "| | a | 1.0 | 1970-01-01T00:00:00.000000011Z |",
/// "| | a | 14.0 | 1970-01-01T00:00:00.000010001Z |",
/// "| | a | 3.0 | 1970-01-01T00:00:00.000000033Z |",
/// "| | b | 5.0 | 1970-01-01T00:00:00.000000011Z |",
/// "| | z | 0.0 | 1970-01-01T00:00:00Z |",
/// "+-----+------+------+--------------------------------+",
/// ```
pub fn batches_to_lines(batches: &[RecordBatch]) -> Vec<String> {
crate::display::pretty_format_batches(batches)
.unwrap()
.trim()
.lines()
.map(|s| s.to_string())
.collect()
}
/// Converts the [`RecordBatch`]es into a pretty printed output suitable for
/// comparing in tests where sorting does not matter.
pub fn batches_to_sorted_lines(batches: &[RecordBatch]) -> Vec<String> {
sort_lines(batches_to_lines(batches))
}
/// Sorts the lines (assumed to be the output of `batches_to_lines` for stable comparison)
pub fn sort_lines(mut lines: Vec<String>) -> Vec<String> {
// sort except for header + footer
let num_lines = lines.len();
if num_lines > 3 {
lines.as_mut_slice()[2..num_lines - 1].sort_unstable()
}
lines
}
// sort a record batch by all columns (to provide a stable output order for test
// comparison)
pub fn sort_record_batch(batch: RecordBatch) -> RecordBatch {
let sort_input: Vec<SortColumn> = batch
.columns()
.iter()
.map(|col| SortColumn {
values: Arc::clone(col),
options: Some(SortOptions {
descending: false,
nulls_first: false,
}),
})
.collect();
let sort_output = lexsort(&sort_input, None).expect("Sorting to complete");
RecordBatch::try_new(batch.schema(), sort_output).unwrap()
}
/// Return a new `StringArray` where each element had a normalization
/// function `norm` applied.
pub fn normalize_string_array<N>(arr: &StringArray, norm: N) -> ArrayRef
where
N: Fn(&str) -> String,
{
let normalized: StringArray = arr.iter().map(|s| s.map(&norm)).collect();
Arc::new(normalized)
}
/// Return a new set of `RecordBatch`es where the function `norm` has
/// applied to all `StringArray` rows.
pub fn normalize_batches<N>(batches: Vec<RecordBatch>, norm: N) -> Vec<RecordBatch>
where
N: Fn(&str) -> String,
{
// The idea here is is to get a function that normalizes strings
// and apply it to each StringArray element by element
batches
.into_iter()
.map(|batch| {
let new_columns: Vec<_> = batch
.columns()
.iter()
.map(|array| {
if let Some(array) = array.as_any().downcast_ref::<StringArray>() {
normalize_string_array(array, &norm)
} else {
Arc::clone(array)
}
})
.collect();
RecordBatch::try_new(batch.schema(), new_columns)
.expect("error occurred during normalization")
})
.collect()
}
/// Equalize batch schemas by creating NULL columns.
pub fn equalize_batch_schemas(batches: Vec<RecordBatch>) -> Result<Vec<RecordBatch>, ArrowError> {
let common_schema = Arc::new(Schema::try_merge(
batches.iter().map(|batch| batch.schema().as_ref().clone()),
)?);
Ok(batches
.into_iter()
.map(|batch| {
let batch_schema = batch.schema();
let columns = common_schema
.fields()
.iter()
.map(|field| match batch_schema.index_of(field.name()) {
Ok(idx) => Arc::clone(batch.column(idx)),
Err(_) => new_null_array(field.data_type(), batch.num_rows()),
})
.collect();
RecordBatch::try_new(Arc::clone(&common_schema), columns).unwrap()
})
.collect())
}
/// Match the parquet UUID
///
/// For example, given
/// `32/51/216/13452/1d325760-2b20-48de-ab48-2267b034133d.parquet`
///
/// matches `1d325760-2b20-48de-ab48-2267b034133d`
pub static REGEX_UUID: Lazy<Regex> = Lazy::new(|| {
Regex::new("[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}").expect("UUID regex")
});
/// Match the parquet directory names
/// For example, given
/// `51/216/1a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f4502/1d325760-2b20-48de-ab48-2267b034133d.parquet`
///
/// matches `51/216/1a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f4502`
static REGEX_DIRS: Lazy<Regex> =
Lazy::new(|| Regex::new(r#"[0-9]+/[0-9]+/[0-9a-f]+/"#).expect("directory regex"));
/// Replace table row separators of flexible width with fixed with. This is required
/// because the original timing values may differ in "printed width", so the table
/// cells have different widths and hence the separators / borders. E.g.:
///
/// `+--+--+` -> `----------`
/// `+--+------+` -> `----------`
///
/// Note that we're kinda inexact with our regex here, but it gets the job done.
static REGEX_LINESEP: Lazy<Regex> = Lazy::new(|| Regex::new(r#"[+-]{6,}"#).expect("linesep regex"));
/// Similar to the row separator issue above, the table columns are right-padded
/// with spaces. Due to the different "printed width" of the timing values, we need
/// to normalize this padding as well. E.g.:
///
/// ` |` -> ` |`
/// ` |` -> ` |`
static REGEX_COL: Lazy<Regex> = Lazy::new(|| Regex::new(r"\s+\|").expect("col regex"));
/// Matches line like `metrics=[foo=1, bar=2]`
static REGEX_METRICS: Lazy<Regex> =
Lazy::new(|| Regex::new(r"metrics=\[([^\]]*)\]").expect("metrics regex"));
/// Matches things like `1s`, `1.2ms` and `10.2μs`
static REGEX_TIMING: Lazy<Regex> =
Lazy::new(|| Regex::new(r"[0-9]+(\.[0-9]+)?.s").expect("timing regex"));
/// Matches things like `FilterExec: .*` and `ParquetExec: .*`
///
/// Should be used in combination w/ [`REGEX_TIME_OP`].
static REGEX_FILTER: Lazy<Regex> = Lazy::new(|| {
Regex::new("(?P<prefix>(FilterExec)|(ParquetExec): )(?P<expr>.*)").expect("filter regex")
});
/// Matches things like `time@3 < -9223372036854775808` and `time_min@2 > 1641031200399937022`
static REGEX_TIME_OP: Lazy<Regex> = Lazy::new(|| {
Regex::new("(?P<prefix>time((_min)|(_max))?@[0-9]+ [<>=]=? )(?P<value>-?[0-9]+)")
.expect("time opt regex")
});
fn normalize_for_variable_width(s: Cow<'_, str>) -> String {
let s = REGEX_LINESEP.replace_all(&s, "----------");
REGEX_COL.replace_all(&s, " |").to_string()
}
pub fn strip_table_lines(s: Cow<'_, str>) -> String {
let s = REGEX_LINESEP.replace_all(&s, "----------");
REGEX_COL.replace_all(&s, "").to_string()
}
fn normalize_time_ops(s: &str) -> String {
REGEX_TIME_OP
.replace_all(s, |c: &Captures<'_>| {
let prefix = c.name("prefix").expect("always captures").as_str();
format!("{prefix}<REDACTED>")
})
.to_string()
}
/// A query to run with optional annotations
#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
pub struct Normalizer {
/// If true, results are sorted first
pub sorted_compare: bool,
/// If true, replace UUIDs with static placeholders.
pub normalized_uuids: bool,
/// If true, normalize timings in queries by replacing them with
/// static placeholders, for example:
///
/// `1s` -> `1.234ms`
pub normalized_metrics: bool,
/// if true, normalize filter predicates for explain plans
/// `FilterExec: <REDACTED>`
pub normalized_filters: bool,
/// if `true`, render tables without borders.
pub no_table_borders: bool,
}
impl Normalizer {
pub fn new() -> Self {
Default::default()
}
/// Take the output of running the query and apply the specified normalizations to them
pub fn normalize_results(&self, mut results: Vec<RecordBatch>) -> Vec<String> {
// compare against sorted results, if requested
if self.sorted_compare && !results.is_empty() {
let schema = results[0].schema();
let batch =
arrow::compute::concat_batches(&schema, &results).expect("concatenating batches");
results = vec![sort_record_batch(batch)];
}
let mut current_results = pretty_format_batches(&results)
.unwrap()
.trim()
.lines()
.map(|s| s.to_string())
.collect::<Vec<_>>();
// normalize UUIDs, if requested
if self.normalized_uuids {
let mut seen: HashMap<String, u128> = HashMap::new();
current_results = current_results
.into_iter()
.map(|s| {
// Rewrite Parquet directory names like
// `51/216/1a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f45021a3f4502/1d325760-2b20-48de-ab48-2267b034133d.parquet`
//
// to:
// 1/1/1/00000000-0000-0000-0000-000000000000.parquet
let s = REGEX_UUID.replace_all(&s, |s: &Captures<'_>| {
let next = seen.len() as u128;
Uuid::from_u128(
*seen
.entry(s.get(0).unwrap().as_str().to_owned())
.or_insert(next),
)
.to_string()
});
let s = normalize_for_variable_width(s);
REGEX_DIRS.replace_all(&s, "1/1/1/").to_string()
})
.collect();
}
// normalize metrics, if requested
if self.normalized_metrics {
current_results = current_results
.into_iter()
.map(|s| {
// Replace timings with fixed value, e.g.:
//
// `1s` -> `1.234ms`
// `1.2ms` -> `1.234ms`
// `10.2μs` -> `1.234ms`
let s = REGEX_TIMING.replace_all(&s, "1.234ms");
let s = normalize_for_variable_width(s);
// Metrics are currently ordered by value (not by key), so different timings may
// reorder them. We "parse" the list and normalize the sorting. E.g.:
//
// `metrics=[]` => `metrics=[]`
// `metrics=[foo=1, bar=2]` => `metrics=[bar=2, foo=1]`
// `metrics=[foo=2, bar=1]` => `metrics=[bar=1, foo=2]`
REGEX_METRICS
.replace_all(&s, |c: &Captures<'_>| {
let mut metrics: Vec<_> = c[1].split(", ").collect();
metrics.sort();
format!("metrics=[{}]", metrics.join(", "))
})
.to_string()
})
.collect();
}
// normalize Filters, if requested
//
// Converts:
// FilterExec: time@2 < -9223372036854775808 OR time@2 > 1640995204240217000
// ParquetExec: limit=None, partitions={...}, predicate=time@2 > 1640995204240217000, pruning_predicate=time@2 > 1640995204240217000, output_ordering=[...], projection=[...]
//
// to
// FilterExec: time@2 < <REDACTED> OR time@2 > <REDACTED>
// ParquetExec: limit=None, partitions={...}, predicate=time@2 > <REDACTED>, pruning_predicate=time@2 > <REDACTED>, output_ordering=[...], projection=[...]
if self.normalized_filters {
current_results = current_results
.into_iter()
.map(|s| {
REGEX_FILTER
.replace_all(&s, |c: &Captures<'_>| {
let prefix = c.name("prefix").expect("always captrues").as_str();
let expr = c.name("expr").expect("always captures").as_str();
let expr = normalize_time_ops(expr);
format!("{prefix}{expr}")
})
.to_string()
})
.collect();
}
current_results
}
/// Adds information on what normalizations were applied to the input
pub fn add_description(&self, output: &mut Vec<String>) {
if self.sorted_compare {
output.push("-- Results After Sorting".into())
}
if self.normalized_uuids {
output.push("-- Results After Normalizing UUIDs".into())
}
if self.normalized_metrics {
output.push("-- Results After Normalizing Metrics".into())
}
if self.normalized_filters {
output.push("-- Results After Normalizing Filters".into())
}
if self.no_table_borders {
output.push("-- Results After No Table Borders".into())
}
}
}

57
arrow_util/src/util.rs Normal file
View File

@ -0,0 +1,57 @@
//! Utility functions for working with arrow
use std::iter::FromIterator;
use std::sync::Arc;
use arrow::{
array::{new_null_array, ArrayRef, StringArray},
datatypes::SchemaRef,
error::ArrowError,
record_batch::RecordBatch,
};
/// Returns a single column record batch of type Utf8 from the
/// contents of something that can be turned into an iterator over
/// `Option<&str>`
pub fn str_iter_to_batch<Ptr, I>(field_name: &str, iter: I) -> Result<RecordBatch, ArrowError>
where
I: IntoIterator<Item = Option<Ptr>>,
Ptr: AsRef<str>,
{
let array = StringArray::from_iter(iter);
RecordBatch::try_from_iter(vec![(field_name, Arc::new(array) as ArrayRef)])
}
/// Ensures the record batch has the specified schema
pub fn ensure_schema(
output_schema: &SchemaRef,
batch: &RecordBatch,
) -> Result<RecordBatch, ArrowError> {
let batch_schema = batch.schema();
// Go over all columns of output_schema
let batch_output_columns = output_schema
.fields()
.iter()
.map(|output_field| {
// See if the output_field available in the batch
let batch_field_index = batch_schema
.fields()
.iter()
.enumerate()
.find(|(_, batch_field)| output_field.name() == batch_field.name())
.map(|(idx, _)| idx);
if let Some(batch_field_index) = batch_field_index {
// The column available, use it
Arc::clone(batch.column(batch_field_index))
} else {
// the column not available, add it with all null values
new_null_array(output_field.data_type(), batch.num_rows())
}
})
.collect::<Vec<_>>();
RecordBatch::try_new(Arc::clone(output_schema), batch_output_columns)
}

34
authz/Cargo.toml Normal file
View File

@ -0,0 +1,34 @@
[package]
name = "authz"
description = "Interface to authorization checking services"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
backoff = { path = "../backoff" }
http = {version = "0.2.9", optional = true }
iox_time = { version = "0.1.0", path = "../iox_time" }
generated_types = { path = "../generated_types" }
metric = { version = "0.1.0", path = "../metric" }
observability_deps = { path = "../observability_deps" }
workspace-hack = { version = "0.1", path = "../workspace-hack" }
# crates.io dependencies in alphabetical order.
async-trait = "0.1"
base64 = "0.21.4"
snafu = "0.7"
tonic = { workspace = true }
[dev-dependencies]
assert_matches = "1.5.0"
parking_lot = "0.12.1"
paste = "1.0.14"
test_helpers_end_to_end = { path = "../test_helpers_end_to_end" }
tokio = "1.32.0"
[features]
http = ["dep:http"]

88
authz/src/authorizer.rs Normal file
View File

@ -0,0 +1,88 @@
use std::ops::ControlFlow;
use async_trait::async_trait;
use backoff::{Backoff, BackoffConfig};
use super::{Error, Permission};
/// An authorizer is used to validate a request
/// (+ associated permissions needed to fulfill the request)
/// with an authorization token that has been extracted from the request.
#[async_trait]
pub trait Authorizer: std::fmt::Debug + Send + Sync {
/// Determine the permissions associated with a request token.
///
/// The returned list of permissions is the intersection of the permissions
/// requested and the permissions associated with the token.
///
/// Implementations of this trait should return the specified errors under
/// the following conditions:
///
/// * [`Error::InvalidToken`]: the token is invalid / in an incorrect
/// format / otherwise corrupt and a permission check cannot be
/// performed
///
/// * [`Error::NoToken`]: the token was not provided
///
/// * [`Error::Forbidden`]: the token was well formed, but lacks
/// authorisation to perform the requested action
///
/// * [`Error::Verification`]: the token permissions were not possible
/// to validate - an internal error has occurred
async fn permissions(
&self,
token: Option<Vec<u8>>,
perms: &[Permission],
) -> Result<Vec<Permission>, Error>;
/// Make a test request that determines if end-to-end communication
/// with the service is working.
///
/// Test is performed during deployment, with ordering of availability not being guaranteed.
async fn probe(&self) -> Result<(), Error> {
Backoff::new(&BackoffConfig::default())
.retry_with_backoff("probe iox-authz service", move || {
async {
match self.permissions(Some(b"".to_vec()), &[]).await {
// got response from authorizer server
Ok(_)
| Err(Error::Forbidden)
| Err(Error::InvalidToken)
| Err(Error::NoToken) => ControlFlow::Break(Ok(())),
// communication error == Error::Verification
Err(e) => ControlFlow::<_, Error>::Continue(e),
}
}
})
.await
.expect("retry forever")
}
}
/// Wrapped `Option<dyn Authorizer>`
/// Provides response to inner `IoxAuthorizer::permissions()`
#[async_trait]
impl<T: Authorizer> Authorizer for Option<T> {
async fn permissions(
&self,
token: Option<Vec<u8>>,
perms: &[Permission],
) -> Result<Vec<Permission>, Error> {
match self {
Some(authz) => authz.permissions(token, perms).await,
// no authz rpc service => return same perms requested. Used for testing.
None => Ok(perms.to_vec()),
}
}
}
#[async_trait]
impl<T: AsRef<dyn Authorizer> + std::fmt::Debug + Send + Sync> Authorizer for T {
async fn permissions(
&self,
token: Option<Vec<u8>>,
perms: &[Permission],
) -> Result<Vec<Permission>, Error> {
self.as_ref().permissions(token, perms).await
}
}

29
authz/src/http.rs Normal file
View File

@ -0,0 +1,29 @@
//! HTTP authorisation helpers.
use http::HeaderValue;
/// We strip off the "authorization" header from the request, to prevent it from being accidentally logged
/// and we put it in an extension of the request. Extensions are typed and this is the typed wrapper that
/// holds an (optional) authorization header value.
pub struct AuthorizationHeaderExtension(Option<HeaderValue>);
impl AuthorizationHeaderExtension {
/// Construct new extension wrapper for a possible header value
pub fn new(header: Option<HeaderValue>) -> Self {
Self(header)
}
}
impl std::fmt::Debug for AuthorizationHeaderExtension {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("AuthorizationHeaderExtension(...)")
}
}
impl std::ops::Deref for AuthorizationHeaderExtension {
type Target = Option<HeaderValue>;
fn deref(&self) -> &Self::Target {
&self.0
}
}

View File

@ -0,0 +1,248 @@
use async_trait::async_trait;
use iox_time::{SystemProvider, TimeProvider};
use metric::{DurationHistogram, Metric, Registry};
use super::{Authorizer, Error, Permission};
const AUTHZ_DURATION_METRIC: &str = "authz_permission_check_duration";
/// An instrumentation decorator over a [`Authorizer`] implementation.
///
/// This wrapper captures the latency distribution of the decorated
/// [`Authorizer::permissions()`] call, faceted by success/error result.
#[derive(Debug)]
pub struct AuthorizerInstrumentation<T, P = SystemProvider> {
inner: T,
time_provider: P,
/// Permissions-check duration distribution for successesful rpc, but not authorized.
ioxauth_rpc_duration_success_unauth: DurationHistogram,
/// Permissions-check duration distribution for successesful rpc + authorized.
ioxauth_rpc_duration_success_auth: DurationHistogram,
/// Permissions-check duration distribution for errors.
ioxauth_rpc_duration_error: DurationHistogram,
}
impl<T> AuthorizerInstrumentation<T> {
/// Record permissions-check duration metrics, broken down by result.
pub fn new(registry: &Registry, inner: T) -> Self {
let metric: Metric<DurationHistogram> =
registry.register_metric(AUTHZ_DURATION_METRIC, "duration of authz permissions check");
let ioxauth_rpc_duration_success_unauth =
metric.recorder(&[("result", "success"), ("auth_state", "unauthorised")]);
let ioxauth_rpc_duration_success_auth =
metric.recorder(&[("result", "success"), ("auth_state", "authorised")]);
let ioxauth_rpc_duration_error =
metric.recorder(&[("result", "error"), ("auth_state", "unauthorised")]);
Self {
inner,
time_provider: Default::default(),
ioxauth_rpc_duration_success_unauth,
ioxauth_rpc_duration_success_auth,
ioxauth_rpc_duration_error,
}
}
}
#[async_trait]
impl<T> Authorizer for AuthorizerInstrumentation<T>
where
T: Authorizer,
{
async fn permissions(
&self,
token: Option<Vec<u8>>,
perms: &[Permission],
) -> Result<Vec<Permission>, Error> {
let t = self.time_provider.now();
let res = self.inner.permissions(token, perms).await;
if let Some(delta) = self.time_provider.now().checked_duration_since(t) {
match &res {
Ok(_) => self.ioxauth_rpc_duration_success_auth.record(delta),
Err(Error::Forbidden) | Err(Error::InvalidToken) => {
self.ioxauth_rpc_duration_success_unauth.record(delta)
}
Err(Error::Verification { .. }) => self.ioxauth_rpc_duration_error.record(delta),
Err(Error::NoToken) => {} // rpc was not made
};
}
res
}
}
#[cfg(test)]
mod test {
use std::collections::VecDeque;
use metric::{assert_histogram, Attributes, Registry};
use parking_lot::Mutex;
use super::*;
use crate::{Action, Resource};
#[derive(Debug, Default)]
struct MockAuthorizerState {
ret: VecDeque<Result<Vec<Permission>, Error>>,
}
#[derive(Debug, Default)]
struct MockAuthorizer {
state: Mutex<MockAuthorizerState>,
}
impl MockAuthorizer {
pub(crate) fn with_permissions_return(
self,
ret: impl Into<VecDeque<Result<Vec<Permission>, Error>>>,
) -> Self {
self.state.lock().ret = ret.into();
self
}
}
#[async_trait]
impl Authorizer for MockAuthorizer {
async fn permissions(
&self,
_token: Option<Vec<u8>>,
_perms: &[Permission],
) -> Result<Vec<Permission>, Error> {
self.state
.lock()
.ret
.pop_front()
.expect("no mock sink value to return")
}
}
macro_rules! assert_metric_counts {
(
$metrics:ident,
expected_success = $expected_success:expr,
expected_rpc_success_unauth = $expected_rpc_success_unauth:expr,
expected_rpc_failures = $expected_rpc_failures:expr,
) => {
let histogram = &$metrics
.get_instrument::<Metric<DurationHistogram>>(AUTHZ_DURATION_METRIC)
.expect("failed to read metric");
let success_labels =
Attributes::from(&[("result", "success"), ("auth_state", "authorised")]);
let histogram_success = &histogram
.get_observer(&success_labels)
.expect("failed to find metric with provided attributes")
.fetch();
assert_histogram!(
$metrics,
DurationHistogram,
AUTHZ_DURATION_METRIC,
labels = success_labels,
samples = $expected_success,
sum = histogram_success.total,
);
let success_unauth_labels =
Attributes::from(&[("result", "success"), ("auth_state", "unauthorised")]);
let histogram_success_unauth = &histogram
.get_observer(&success_unauth_labels)
.expect("failed to find metric with provided attributes")
.fetch();
assert_histogram!(
$metrics,
DurationHistogram,
AUTHZ_DURATION_METRIC,
labels = success_unauth_labels,
samples = $expected_rpc_success_unauth,
sum = histogram_success_unauth.total,
);
let rpc_error_labels =
Attributes::from(&[("result", "error"), ("auth_state", "unauthorised")]);
let histogram_rpc_error = &histogram
.get_observer(&rpc_error_labels)
.expect("failed to find metric with provided attributes")
.fetch();
assert_histogram!(
$metrics,
DurationHistogram,
AUTHZ_DURATION_METRIC,
labels = rpc_error_labels,
samples = $expected_rpc_failures,
sum = histogram_rpc_error.total,
);
};
}
macro_rules! test_authorizer_metric {
(
$name:ident,
rpc_response = $rpc_response:expr,
will_pass_auth = $will_pass_auth:expr,
expected_success_cnt = $expected_success_cnt:expr,
expected_success_unauth_cnt = $expected_success_unauth_cnt:expr,
expected_error_cnt = $expected_error_cnt:expr,
) => {
paste::paste! {
#[tokio::test]
async fn [<test_authorizer_metric_ $name>]() {
let metrics = Registry::default();
let decorated_authz = AuthorizerInstrumentation::new(
&metrics,
MockAuthorizer::default().with_permissions_return([$rpc_response])
);
let token = "any".as_bytes().to_vec();
let got = decorated_authz
.permissions(Some(token), &[])
.await;
assert_eq!(got.is_ok(), $will_pass_auth);
assert_metric_counts!(
metrics,
expected_success = $expected_success_cnt,
expected_rpc_success_unauth = $expected_success_unauth_cnt,
expected_rpc_failures = $expected_error_cnt,
);
}
}
};
}
test_authorizer_metric!(
ok,
rpc_response = Ok(vec![Permission::ResourceAction(
Resource::Database("foo".to_string()),
Action::Write,
)]),
will_pass_auth = true,
expected_success_cnt = 1,
expected_success_unauth_cnt = 0,
expected_error_cnt = 0,
);
test_authorizer_metric!(
will_record_failure_if_rpc_fails,
rpc_response = Err(Error::verification("test", "test error")),
will_pass_auth = false,
expected_success_cnt = 0,
expected_success_unauth_cnt = 0,
expected_error_cnt = 1,
);
test_authorizer_metric!(
will_record_success_if_rpc_pass_but_auth_fails,
rpc_response = Err(Error::InvalidToken),
will_pass_auth = false,
expected_success_cnt = 0,
expected_success_unauth_cnt = 1,
expected_error_cnt = 0,
);
}

309
authz/src/iox_authorizer.rs Normal file
View File

@ -0,0 +1,309 @@
use async_trait::async_trait;
use generated_types::influxdata::iox::authz::v1::{self as proto, AuthorizeResponse};
use observability_deps::tracing::warn;
use snafu::Snafu;
use tonic::Response;
use super::{Authorizer, Permission};
/// Authorizer implementation using influxdata.iox.authz.v1 protocol.
#[derive(Clone, Debug)]
pub struct IoxAuthorizer {
client:
proto::iox_authorizer_service_client::IoxAuthorizerServiceClient<tonic::transport::Channel>,
}
impl IoxAuthorizer {
/// Attempt to create a new client by connecting to a given endpoint.
pub fn connect_lazy<D>(dst: D) -> Result<Self, Box<dyn std::error::Error>>
where
D: TryInto<tonic::transport::Endpoint> + Send,
D::Error: Into<tonic::codegen::StdError>,
{
let ep = tonic::transport::Endpoint::new(dst)?;
let client = proto::iox_authorizer_service_client::IoxAuthorizerServiceClient::new(
ep.connect_lazy(),
);
Ok(Self { client })
}
async fn request(
&self,
token: Vec<u8>,
requested_perms: &[Permission],
) -> Result<Response<AuthorizeResponse>, tonic::Status> {
let req = proto::AuthorizeRequest {
token,
permissions: requested_perms
.iter()
.filter_map(|p| p.clone().try_into().ok())
.collect(),
};
let mut client = self.client.clone();
client.authorize(req).await
}
}
#[async_trait]
impl Authorizer for IoxAuthorizer {
async fn permissions(
&self,
token: Option<Vec<u8>>,
requested_perms: &[Permission],
) -> Result<Vec<Permission>, Error> {
let authz_rpc_result = self
.request(token.ok_or(Error::NoToken)?, requested_perms)
.await
.map_err(|status| Error::Verification {
msg: status.message().to_string(),
source: Box::new(status),
})?
.into_inner();
if !authz_rpc_result.valid {
return Err(Error::InvalidToken);
}
let intersected_perms: Vec<Permission> = authz_rpc_result
.permissions
.into_iter()
.filter_map(|p| match p.try_into() {
Ok(p) => Some(p),
Err(e) => {
warn!(error=%e, "authz service returned incompatible permission");
None
}
})
.collect();
if intersected_perms.is_empty() {
return Err(Error::Forbidden);
}
Ok(intersected_perms)
}
}
/// Authorization related error.
#[derive(Debug, Snafu)]
pub enum Error {
/// Communication error when verifying a token.
#[snafu(display("token verification not possible: {msg}"))]
Verification {
/// Message describing the error.
msg: String,
/// Source of the error.
source: Box<dyn std::error::Error + Send + Sync + 'static>,
},
/// Token is invalid.
#[snafu(display("invalid token"))]
InvalidToken,
/// The token's permissions do not allow the operation.
#[snafu(display("forbidden"))]
Forbidden,
/// No token has been supplied, but is required.
#[snafu(display("no token"))]
NoToken,
}
impl Error {
/// Create new Error::Verification.
pub fn verification(
msg: impl Into<String>,
source: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
) -> Self {
Self::Verification {
msg: msg.into(),
source: source.into(),
}
}
}
impl From<tonic::Status> for Error {
fn from(value: tonic::Status) -> Self {
Self::verification(value.message(), value.clone())
}
}
#[cfg(test)]
mod test {
use std::{
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use assert_matches::assert_matches;
use test_helpers_end_to_end::Authorizer as AuthorizerServer;
use tokio::{
net::TcpListener,
task::{spawn, JoinHandle},
};
use tonic::transport::{server::TcpIncoming, Server};
use super::*;
use crate::{Action, Authorizer, Permission, Resource};
const NAMESPACE: &str = "bananas";
macro_rules! test_iox_authorizer {
(
$name:ident,
token_permissions = $token_permissions:expr,
permissions_required = $permissions_required:expr,
want = $want:pat
) => {
paste::paste! {
#[tokio::test]
async fn [<test_iox_authorizer_ $name>]() {
let mut authz_server = AuthorizerServer::create().await;
let authz = IoxAuthorizer::connect_lazy(authz_server.addr())
.expect("Failed to create IoxAuthorizer client.");
let token = authz_server.create_token_for(NAMESPACE, $token_permissions);
let got = authz.permissions(
Some(token.as_bytes().to_vec()),
$permissions_required
).await;
assert_matches!(got, $want);
}
}
};
}
test_iox_authorizer!(
ok,
token_permissions = &["ACTION_WRITE"],
permissions_required = &[Permission::ResourceAction(
Resource::Database(NAMESPACE.to_string()),
Action::Write,
)],
want = Ok(_)
);
test_iox_authorizer!(
insufficient_perms,
token_permissions = &["ACTION_READ"],
permissions_required = &[Permission::ResourceAction(
Resource::Database(NAMESPACE.to_string()),
Action::Write,
)],
want = Err(Error::Forbidden)
);
test_iox_authorizer!(
any_of_required_perms,
token_permissions = &["ACTION_WRITE"],
permissions_required = &[
Permission::ResourceAction(Resource::Database(NAMESPACE.to_string()), Action::Write,),
Permission::ResourceAction(Resource::Database(NAMESPACE.to_string()), Action::Create,)
],
want = Ok(_)
);
#[tokio::test]
async fn test_invalid_token() {
let authz_server = AuthorizerServer::create().await;
let authz = IoxAuthorizer::connect_lazy(authz_server.addr())
.expect("Failed to create IoxAuthorizer client.");
let invalid_token = b"UGLY";
let got = authz
.permissions(
Some(invalid_token.to_vec()),
&[Permission::ResourceAction(
Resource::Database(NAMESPACE.to_string()),
Action::Read,
)],
)
.await;
assert_matches!(got, Err(Error::InvalidToken));
}
#[tokio::test]
async fn test_delayed_probe_response() {
#[derive(Default, Debug)]
struct DelayedAuthorizer(Arc<AtomicBool>);
impl DelayedAuthorizer {
fn start_countdown(&self) {
let started = Arc::clone(&self.0);
spawn(async move {
tokio::time::sleep(Duration::from_secs(2)).await;
started.store(true, Ordering::Relaxed);
});
}
}
#[async_trait]
impl proto::iox_authorizer_service_server::IoxAuthorizerService for DelayedAuthorizer {
async fn authorize(
&self,
_request: tonic::Request<proto::AuthorizeRequest>,
) -> Result<tonic::Response<AuthorizeResponse>, tonic::Status> {
let startup_done = self.0.load(Ordering::Relaxed);
if !startup_done {
return Err(tonic::Status::deadline_exceeded("startup not done"));
}
Ok(tonic::Response::new(AuthorizeResponse {
valid: true,
subject: None,
permissions: vec![],
}))
}
}
#[derive(Debug)]
struct DelayedServer {
addr: SocketAddr,
handle: JoinHandle<Result<(), tonic::transport::Error>>,
}
impl DelayedServer {
async fn create() -> Self {
let listener = TcpListener::bind("localhost:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, false, None).unwrap();
// start countdown mocking startup delay of sidecar
let authz = DelayedAuthorizer::default();
authz.start_countdown();
let router = Server::builder().add_service(
proto::iox_authorizer_service_server::IoxAuthorizerServiceServer::new(authz),
);
let handle = spawn(router.serve_with_incoming(incoming));
Self { addr, handle }
}
fn addr(&self) -> String {
format!("http://{}", self.addr)
}
fn close(self) {
self.handle.abort();
}
}
let authz_server = DelayedServer::create().await;
let authz_client = IoxAuthorizer::connect_lazy(authz_server.addr())
.expect("Failed to create IoxAuthorizer client.");
assert_matches!(
authz_client.probe().await,
Ok(()),
"authz probe should work even with delay"
);
authz_server.close();
}
}

100
authz/src/lib.rs Normal file
View File

@ -0,0 +1,100 @@
//! IOx authorization client.
//!
//! Authorization client interface to be used by IOx components to
//! restrict access to authorized requests where required.
#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)]
#![warn(
missing_copy_implementations,
missing_docs,
clippy::explicit_iter_loop,
// See https://github.com/influxdata/influxdb_iox/pull/1671
clippy::future_not_send,
clippy::use_self,
clippy::clone_on_ref_ptr,
clippy::todo,
clippy::dbg_macro,
unused_crate_dependencies
)]
#![allow(rustdoc::private_intra_doc_links)]
// Workaround for "unused crate" lint false positives.
use workspace_hack as _;
use base64::{prelude::BASE64_STANDARD, Engine};
use generated_types::influxdata::iox::authz::v1::{self as proto};
use observability_deps::tracing::warn;
mod authorizer;
pub use authorizer::Authorizer;
mod iox_authorizer;
pub use iox_authorizer::{Error, IoxAuthorizer};
mod instrumentation;
pub use instrumentation::AuthorizerInstrumentation;
mod permission;
pub use permission::{Action, Permission, Resource};
#[cfg(feature = "http")]
pub mod http;
/// Extract a token from an HTTP header or gRPC metadata value.
pub fn extract_token<T: AsRef<[u8]> + ?Sized>(value: Option<&T>) -> Option<Vec<u8>> {
let mut parts = value?.as_ref().splitn(2, |&v| v == b' ');
let token = match parts.next()? {
b"Token" | b"Bearer" => parts.next()?.to_vec(),
b"Basic" => parts
.next()
.and_then(|v| BASE64_STANDARD.decode(v).ok())?
.splitn(2, |&v| v == b':')
.nth(1)?
.to_vec(),
_ => return None,
};
if token.is_empty() {
None
} else {
Some(token)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verify_error_from_tonic_status() {
let s = tonic::Status::resource_exhausted("test error");
let e = Error::from(s);
assert_eq!(
"token verification not possible: test error",
format!("{e}")
)
}
#[test]
fn test_extract_token() {
assert_eq!(None, extract_token::<&str>(None));
assert_eq!(None, extract_token(Some("")));
assert_eq!(None, extract_token(Some("Basic")));
assert_eq!(None, extract_token(Some("Basic Og=="))); // ":"
assert_eq!(None, extract_token(Some("Basic dXNlcm5hbWU6"))); // "username:"
assert_eq!(None, extract_token(Some("Basic Og=="))); // ":"
assert_eq!(
Some(b"password".to_vec()),
extract_token(Some("Basic OnBhc3N3b3Jk"))
); // ":password"
assert_eq!(
Some(b"password2".to_vec()),
extract_token(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQy"))
); // "username:password2"
assert_eq!(None, extract_token(Some("Bearer")));
assert_eq!(None, extract_token(Some("Bearer ")));
assert_eq!(Some(b"token".to_vec()), extract_token(Some("Bearer token")));
assert_eq!(None, extract_token(Some("Token")));
assert_eq!(None, extract_token(Some("Token ")));
assert_eq!(
Some(b"token2".to_vec()),
extract_token(Some("Token token2"))
);
}
}

310
authz/src/permission.rs Normal file
View File

@ -0,0 +1,310 @@
use super::proto;
use snafu::Snafu;
/// Action is the type of operation being attempted on a resource.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Action {
/// The create action is used when a new instance of the resource will
/// be created.
Create,
/// The delete action is used when a resource will be deleted.
Delete,
/// The read action is used when the data contained by a resource will
/// be read.
Read,
/// The read-schema action is used when only metadata about a resource
/// will be read.
ReadSchema,
/// The write action is used when data is being written to the resource.
Write,
}
impl TryFrom<proto::resource_action_permission::Action> for Action {
type Error = IncompatiblePermissionError;
fn try_from(value: proto::resource_action_permission::Action) -> Result<Self, Self::Error> {
match value {
proto::resource_action_permission::Action::ReadSchema => Ok(Self::ReadSchema),
proto::resource_action_permission::Action::Read => Ok(Self::Read),
proto::resource_action_permission::Action::Write => Ok(Self::Write),
proto::resource_action_permission::Action::Create => Ok(Self::Create),
proto::resource_action_permission::Action::Delete => Ok(Self::Delete),
_ => Err(IncompatiblePermissionError {}),
}
}
}
impl From<Action> for proto::resource_action_permission::Action {
fn from(value: Action) -> Self {
match value {
Action::Create => Self::Create,
Action::Delete => Self::Delete,
Action::Read => Self::Read,
Action::ReadSchema => Self::ReadSchema,
Action::Write => Self::Write,
}
}
}
/// An incompatible-permission-error is the error that is returned if
/// there is an attempt to convert a permssion into a form that is
/// unsupported. For the most part this should not cause an error to
/// be returned to the user, but more as a signal that the conversion
/// can never succeed and therefore the permisison can never be granted.
/// This error will normally be silently dropped along with the source
/// permission that caused it.
#[derive(Clone, Copy, Debug, PartialEq, Snafu)]
#[snafu(display("incompatible permission"))]
pub struct IncompatiblePermissionError {}
/// A permission is an authorization that can be checked with an
/// authorizer. Not all authorizers neccessarily support all forms of
/// permission. If an authorizer doesn't support a permission then it
/// is not an error, the permission will always be denied.
#[derive(Clone, Debug, PartialEq)]
pub enum Permission {
/// ResourceAction is a permission in the form of a reasource and an
/// action.
ResourceAction(Resource, Action),
}
impl TryFrom<proto::Permission> for Permission {
type Error = IncompatiblePermissionError;
fn try_from(value: proto::Permission) -> Result<Self, Self::Error> {
match value.permission_one_of {
Some(proto::permission::PermissionOneOf::ResourceAction(ra)) => {
let r = Resource::try_from_proto(
proto::resource_action_permission::ResourceType::from_i32(ra.resource_type)
.ok_or(IncompatiblePermissionError {})?,
ra.resource_id,
)?;
let a = Action::try_from(
proto::resource_action_permission::Action::from_i32(ra.action)
.ok_or(IncompatiblePermissionError {})?,
)?;
Ok(Self::ResourceAction(r, a))
}
_ => Err(IncompatiblePermissionError {}),
}
}
}
impl TryFrom<Permission> for proto::Permission {
type Error = IncompatiblePermissionError;
fn try_from(value: Permission) -> Result<Self, Self::Error> {
match value {
Permission::ResourceAction(r, a) => {
let (rt, ri) = r.try_into_proto()?;
let a: proto::resource_action_permission::Action = a.into();
Ok(Self {
permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction(
proto::ResourceActionPermission {
resource_type: rt as i32,
resource_id: ri,
action: a as i32,
},
)),
})
}
}
}
}
/// A resource is the object that a request is trying to access.
#[derive(Clone, Debug, PartialEq)]
pub enum Resource {
/// A database is a named IOx database.
Database(String),
}
impl Resource {
fn try_from_proto(
rt: proto::resource_action_permission::ResourceType,
ri: Option<String>,
) -> Result<Self, IncompatiblePermissionError> {
match (rt, ri) {
(proto::resource_action_permission::ResourceType::Database, Some(s)) => {
Ok(Self::Database(s))
}
_ => Err(IncompatiblePermissionError {}),
}
}
fn try_into_proto(
self,
) -> Result<
(
proto::resource_action_permission::ResourceType,
Option<String>,
),
IncompatiblePermissionError,
> {
match self {
Self::Database(s) => Ok((
proto::resource_action_permission::ResourceType::Database,
Some(s),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn action_try_from_proto() {
assert_eq!(
Action::Create,
Action::try_from(proto::resource_action_permission::Action::Create).unwrap(),
);
assert_eq!(
Action::Delete,
Action::try_from(proto::resource_action_permission::Action::Delete).unwrap(),
);
assert_eq!(
Action::Read,
Action::try_from(proto::resource_action_permission::Action::Read).unwrap(),
);
assert_eq!(
Action::ReadSchema,
Action::try_from(proto::resource_action_permission::Action::ReadSchema).unwrap(),
);
assert_eq!(
Action::Write,
Action::try_from(proto::resource_action_permission::Action::Write).unwrap(),
);
assert_eq!(
IncompatiblePermissionError {},
Action::try_from(proto::resource_action_permission::Action::Unspecified).unwrap_err(),
);
}
#[test]
fn action_into_proto() {
assert_eq!(
proto::resource_action_permission::Action::Create,
proto::resource_action_permission::Action::from(Action::Create)
);
assert_eq!(
proto::resource_action_permission::Action::Delete,
proto::resource_action_permission::Action::from(Action::Delete)
);
assert_eq!(
proto::resource_action_permission::Action::Read,
proto::resource_action_permission::Action::from(Action::Read)
);
assert_eq!(
proto::resource_action_permission::Action::ReadSchema,
proto::resource_action_permission::Action::from(Action::ReadSchema)
);
assert_eq!(
proto::resource_action_permission::Action::Write,
proto::resource_action_permission::Action::from(Action::Write)
);
}
#[test]
fn resource_try_from_proto() {
assert_eq!(
Resource::Database("ns1".into()),
Resource::try_from_proto(
proto::resource_action_permission::ResourceType::Database,
Some("ns1".into())
)
.unwrap()
);
assert_eq!(
IncompatiblePermissionError {},
Resource::try_from_proto(
proto::resource_action_permission::ResourceType::Database,
None
)
.unwrap_err()
);
assert_eq!(
IncompatiblePermissionError {},
Resource::try_from_proto(
proto::resource_action_permission::ResourceType::Unspecified,
Some("ns1".into())
)
.unwrap_err()
);
}
#[test]
fn resource_try_into_proto() {
assert_eq!(
(
proto::resource_action_permission::ResourceType::Database,
Some("ns1".into())
),
Resource::Database("ns1".into()).try_into_proto().unwrap(),
);
}
#[test]
fn permission_try_from_proto() {
assert_eq!(
Permission::ResourceAction(Resource::Database("ns2".into()), Action::Create),
Permission::try_from(proto::Permission {
permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction(
proto::ResourceActionPermission {
resource_type: 1,
resource_id: Some("ns2".into()),
action: 4,
}
))
})
.unwrap()
);
assert_eq!(
IncompatiblePermissionError {},
Permission::try_from(proto::Permission {
permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction(
proto::ResourceActionPermission {
resource_type: 0,
resource_id: Some("ns2".into()),
action: 4,
}
))
})
.unwrap_err()
);
assert_eq!(
IncompatiblePermissionError {},
Permission::try_from(proto::Permission {
permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction(
proto::ResourceActionPermission {
resource_type: 1,
resource_id: Some("ns2".into()),
action: 0,
}
))
})
.unwrap_err()
);
}
#[test]
fn permission_try_into_proto() {
assert_eq!(
proto::Permission {
permission_one_of: Some(proto::permission::PermissionOneOf::ResourceAction(
proto::ResourceActionPermission {
resource_type: 1,
resource_id: Some("ns3".into()),
action: 4,
}
))
},
proto::Permission::try_from(Permission::ResourceAction(
Resource::Database("ns3".into()),
Action::Create
))
.unwrap()
);
}
}

13
backoff/Cargo.toml Normal file
View File

@ -0,0 +1,13 @@
[package]
name = "backoff"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
tokio = { version = "1.32", features = ["macros", "time"] }
observability_deps = { path = "../observability_deps" }
rand = "0.8"
snafu = "0.7"
workspace-hack = { version = "0.1", path = "../workspace-hack" }

395
backoff/src/lib.rs Normal file
View File

@ -0,0 +1,395 @@
//! Backoff functionality.
#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
#![warn(
missing_copy_implementations,
missing_debug_implementations,
missing_docs,
clippy::explicit_iter_loop,
// See https://github.com/influxdata/influxdb_iox/pull/1671
clippy::future_not_send,
clippy::use_self,
clippy::clone_on_ref_ptr,
clippy::todo,
clippy::dbg_macro,
unused_crate_dependencies
)]
// Workaround for "unused crate" lint false positives.
use workspace_hack as _;
use observability_deps::tracing::warn;
use rand::prelude::*;
use snafu::Snafu;
use std::ops::ControlFlow;
use std::time::Duration;
/// Exponential backoff with jitter
///
/// See <https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/>
#[derive(Debug, Clone, PartialEq)]
#[allow(missing_copy_implementations)]
pub struct BackoffConfig {
/// Initial backoff.
pub init_backoff: Duration,
/// Maximum backoff.
pub max_backoff: Duration,
/// Multiplier for each backoff round.
pub base: f64,
/// Timeout until we try to retry.
pub deadline: Option<Duration>,
}
impl Default for BackoffConfig {
fn default() -> Self {
Self {
init_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(500),
base: 3.,
deadline: None,
}
}
}
/// Error after giving up retrying.
#[derive(Debug, Snafu, PartialEq, Eq)]
#[allow(missing_copy_implementations, missing_docs)]
pub enum BackoffError<E>
where
E: std::error::Error + 'static,
{
#[snafu(display("Retry did not exceed within {deadline:?}: {source}"))]
DeadlineExceeded { deadline: Duration, source: E },
}
/// Backoff result.
pub type BackoffResult<T, E> = Result<T, BackoffError<E>>;
/// [`Backoff`] can be created from a [`BackoffConfig`]
///
/// Consecutive calls to [`Backoff::next`] will return the next backoff interval
///
pub struct Backoff {
init_backoff: f64,
next_backoff_secs: f64,
max_backoff_secs: f64,
base: f64,
total: f64,
deadline: Option<f64>,
rng: Option<Box<dyn RngCore + Sync + Send>>,
}
impl std::fmt::Debug for Backoff {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Backoff")
.field("init_backoff", &self.init_backoff)
.field("next_backoff_secs", &self.next_backoff_secs)
.field("max_backoff_secs", &self.max_backoff_secs)
.field("base", &self.base)
.field("total", &self.total)
.field("deadline", &self.deadline)
.finish()
}
}
impl Backoff {
/// Create a new [`Backoff`] from the provided [`BackoffConfig`].
///
/// # Pancis
/// Panics if [`BackoffConfig::base`] is not finite or < 1.0.
pub fn new(config: &BackoffConfig) -> Self {
Self::new_with_rng(config, None)
}
/// Creates a new `Backoff` with the optional `rng`.
///
/// Used [`rand::thread_rng()`] if no rng provided.
///
/// See [`new`](Self::new) for panic handling.
pub fn new_with_rng(
config: &BackoffConfig,
rng: Option<Box<dyn RngCore + Sync + Send>>,
) -> Self {
assert!(
config.base.is_finite(),
"Backoff base ({}) must be finite.",
config.base,
);
assert!(
config.base >= 1.0,
"Backoff base ({}) must be greater or equal than 1.",
config.base,
);
let max_backoff = config.max_backoff.as_secs_f64();
let init_backoff = config.init_backoff.as_secs_f64().min(max_backoff);
Self {
init_backoff,
next_backoff_secs: init_backoff,
max_backoff_secs: max_backoff,
base: config.base,
total: 0.0,
deadline: config.deadline.map(|d| d.as_secs_f64()),
rng,
}
}
/// Fade this backoff over to a different backoff config.
pub fn fade_to(&mut self, config: &BackoffConfig) {
// Note: `new` won't have the same RNG, but this doesn't matter
let new = Self::new(config);
*self = Self {
init_backoff: new.init_backoff,
next_backoff_secs: self.next_backoff_secs,
max_backoff_secs: new.max_backoff_secs,
base: new.base,
total: self.total,
deadline: new.deadline,
rng: self.rng.take(),
};
}
/// Perform an async operation that retries with a backoff
pub async fn retry_with_backoff<F, F1, B, E>(
&mut self,
task_name: &str,
mut do_stuff: F,
) -> BackoffResult<B, E>
where
F: (FnMut() -> F1) + Send,
F1: std::future::Future<Output = ControlFlow<B, E>> + Send,
E: std::error::Error + Send + 'static,
{
loop {
// first execute `F` and then use it, so we can avoid `F: Sync`.
let do_stuff = do_stuff();
let e = match do_stuff.await {
ControlFlow::Break(r) => break Ok(r),
ControlFlow::Continue(e) => e,
};
let backoff = match self.next() {
Some(backoff) => backoff,
None => {
return Err(BackoffError::DeadlineExceeded {
deadline: Duration::from_secs_f64(self.deadline.expect("deadline")),
source: e,
});
}
};
warn!(
error=%e,
task_name,
backoff_secs = backoff.as_secs(),
"request encountered non-fatal error - backing off",
);
tokio::time::sleep(backoff).await;
}
}
/// Retry all errors.
pub async fn retry_all_errors<F, F1, B, E>(
&mut self,
task_name: &str,
mut do_stuff: F,
) -> BackoffResult<B, E>
where
F: (FnMut() -> F1) + Send,
F1: std::future::Future<Output = Result<B, E>> + Send,
E: std::error::Error + Send + 'static,
{
self.retry_with_backoff(task_name, move || {
// first execute `F` and then use it, so we can avoid `F: Sync`.
let do_stuff = do_stuff();
async {
match do_stuff.await {
Ok(b) => ControlFlow::Break(b),
Err(e) => ControlFlow::Continue(e),
}
}
})
.await
}
}
impl Iterator for Backoff {
type Item = Duration;
/// Returns the next backoff duration to wait for, if any
fn next(&mut self) -> Option<Self::Item> {
let range = self.init_backoff..=(self.next_backoff_secs * self.base);
let rand_backoff = match self.rng.as_mut() {
Some(rng) => rng.gen_range(range),
None => thread_rng().gen_range(range),
};
let next_backoff = self.max_backoff_secs.min(rand_backoff);
self.total += next_backoff;
let res = std::mem::replace(&mut self.next_backoff_secs, next_backoff);
if let Some(deadline) = self.deadline {
if self.total >= deadline {
return None;
}
}
duration_try_from_secs_f64(res)
}
}
const MAX_F64_SECS: f64 = 1_000_000.0;
/// Try to get `Duration` from `f64` secs.
///
/// This is required till <https://github.com/rust-lang/rust/issues/83400> is resolved.
fn duration_try_from_secs_f64(secs: f64) -> Option<Duration> {
(secs.is_finite() && (0.0..=MAX_F64_SECS).contains(&secs))
.then(|| Duration::from_secs_f64(secs))
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::mock::StepRng;
#[test]
fn test_backoff() {
let init_backoff_secs = 1.;
let max_backoff_secs = 500.;
let base = 3.;
let config = BackoffConfig {
init_backoff: Duration::from_secs_f64(init_backoff_secs),
max_backoff: Duration::from_secs_f64(max_backoff_secs),
deadline: None,
base,
};
let assert_fuzzy_eq = |a: f64, b: f64| assert!((b - a).abs() < 0.0001, "{a} != {b}");
// Create a static rng that takes the minimum of the range
let rng = Box::new(StepRng::new(0, 0));
let mut backoff = Backoff::new_with_rng(&config, Some(rng));
for _ in 0..20 {
assert_eq!(backoff.next().unwrap().as_secs_f64(), init_backoff_secs);
}
// Create a static rng that takes the maximum of the range
let rng = Box::new(StepRng::new(u64::MAX, 0));
let mut backoff = Backoff::new_with_rng(&config, Some(rng));
for i in 0..20 {
let value = (base.powi(i) * init_backoff_secs).min(max_backoff_secs);
assert_fuzzy_eq(backoff.next().unwrap().as_secs_f64(), value);
}
// Create a static rng that takes the mid point of the range
let rng = Box::new(StepRng::new(u64::MAX / 2, 0));
let mut backoff = Backoff::new_with_rng(&config, Some(rng));
let mut value = init_backoff_secs;
for _ in 0..20 {
assert_fuzzy_eq(backoff.next().unwrap().as_secs_f64(), value);
value =
(init_backoff_secs + (value * base - init_backoff_secs) / 2.).min(max_backoff_secs);
}
// deadline
let rng = Box::new(StepRng::new(u64::MAX, 0));
let deadline = Duration::from_secs_f64(init_backoff_secs);
let mut backoff = Backoff::new_with_rng(
&BackoffConfig {
deadline: Some(deadline),
..config
},
Some(rng),
);
assert_eq!(backoff.next(), None);
}
#[test]
fn test_overflow() {
let rng = Box::new(StepRng::new(u64::MAX, 0));
let cfg = BackoffConfig {
init_backoff: Duration::MAX,
max_backoff: Duration::MAX,
..Default::default()
};
let mut backoff = Backoff::new_with_rng(&cfg, Some(rng));
assert_eq!(backoff.next(), None);
}
#[test]
fn test_duration_try_from_f64() {
for d in [-0.1, f64::INFINITY, f64::NAN, MAX_F64_SECS + 0.1] {
assert!(duration_try_from_secs_f64(d).is_none());
}
for d in [0.0, MAX_F64_SECS] {
assert!(duration_try_from_secs_f64(d).is_some());
}
}
#[test]
fn test_max_backoff_smaller_init() {
let rng = Box::new(StepRng::new(u64::MAX, 0));
let cfg = BackoffConfig {
init_backoff: Duration::from_secs(2),
max_backoff: Duration::from_secs(1),
..Default::default()
};
let mut backoff = Backoff::new_with_rng(&cfg, Some(rng));
assert_eq!(backoff.next(), Some(Duration::from_secs(1)));
assert_eq!(backoff.next(), Some(Duration::from_secs(1)));
}
#[test]
#[should_panic(expected = "Backoff base (inf) must be finite.")]
fn test_panic_inf_base() {
let cfg = BackoffConfig {
base: f64::INFINITY,
..Default::default()
};
Backoff::new(&cfg);
}
#[test]
#[should_panic(expected = "Backoff base (NaN) must be finite.")]
fn test_panic_nan_base() {
let cfg = BackoffConfig {
base: f64::NAN,
..Default::default()
};
Backoff::new(&cfg);
}
#[test]
#[should_panic(expected = "Backoff base (0) must be greater or equal than 1.")]
fn test_panic_zero_base() {
let cfg = BackoffConfig {
base: 0.0,
..Default::default()
};
Backoff::new(&cfg);
}
#[test]
fn test_constant_backoff() {
let rng = Box::new(StepRng::new(u64::MAX, 0));
let cfg = BackoffConfig {
init_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(1),
base: 1.0,
..Default::default()
};
let mut backoff = Backoff::new_with_rng(&cfg, Some(rng));
assert_eq!(backoff.next(), Some(Duration::from_secs(1)));
assert_eq!(backoff.next(), Some(Duration::from_secs(1)));
}
}

22
buf.yaml Normal file
View File

@ -0,0 +1,22 @@
---
version: v1beta1
build:
roots:
- generated_types/protos/
- ingester_query_grpc/protos/
lint:
allow_comment_ignores: true
ignore:
- google
- grpc
- com/github/influxdata/idpe/storage/read
- influxdata/platform
use:
- DEFAULT
- STYLE_DEFAULT
breaking:
use:
- WIRE
- WIRE_JSON

36
cache_system/Cargo.toml Normal file
View File

@ -0,0 +1,36 @@
[package]
name = "cache_system"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
async-trait = "0.1.73"
backoff = { path = "../backoff" }
futures = "0.3"
iox_time = { path = "../iox_time" }
metric = { path = "../metric" }
observability_deps = { path = "../observability_deps" }
ouroboros = "0.18"
parking_lot = { version = "0.12", features = ["arc_lock"] }
pdatastructs = { version = "0.7", default-features = false, features = ["fixedbitset"] }
rand = "0.8.3"
tokio = { version = "1.32", features = ["macros", "parking_lot", "rt-multi-thread", "sync", "time"] }
tokio-util = { version = "0.7.9" }
trace = { path = "../trace"}
workspace-hack = { version = "0.1", path = "../workspace-hack" }
[dev-dependencies]
criterion = { version = "0.5", default-features = false, features = ["rayon"]}
proptest = { version = "1", default_features = false, features = ["std"] }
test_helpers = { path = "../test_helpers" }
[lib]
# Allow --save-baseline to work
# https://github.com/bheisler/criterion.rs/issues/275
bench = false
[[bench]]
name = "addressable_heap"
harness = false

View File

@ -0,0 +1,420 @@
use std::mem::size_of;
use cache_system::addressable_heap::AddressableHeap;
use criterion::{
criterion_group, criterion_main, measurement::WallTime, AxisScale, BatchSize, BenchmarkGroup,
BenchmarkId, Criterion, PlotConfiguration, SamplingMode,
};
use rand::{prelude::SliceRandom, thread_rng, Rng};
/// Payload (`V`) for testing.
///
/// This is a 64bit-wide object which is enough to store a [`Box`] or a [`usize`].
#[derive(Debug, Clone, Default)]
struct Payload([u8; 8]);
const _: () = assert!(size_of::<Payload>() == 8);
const _: () = assert!(size_of::<Payload>() >= size_of::<Box<Vec<u32>>>());
const _: () = assert!(size_of::<Payload>() >= size_of::<usize>());
type TestHeap = AddressableHeap<u64, Payload, u64>;
const TEST_SIZES: &[usize] = &[0, 1, 10, 100, 1_000, 10_000];
#[derive(Debug, Clone)]
struct Entry {
k: u64,
o: u64,
}
impl Entry {
fn new_random<R>(rng: &mut R) -> Self
where
R: Rng,
{
Self {
// leave some room at the top and bottom
k: (rng.gen::<u64>() << 1) + (u64::MAX << 2),
// leave some room at the top and bottom
o: (rng.gen::<u64>() << 1) + (u64::MAX << 2),
}
}
fn new_random_n<R>(rng: &mut R, n: usize) -> Vec<Self>
where
R: Rng,
{
(0..n).map(|_| Self::new_random(rng)).collect()
}
}
fn create_filled_heap<R>(rng: &mut R, n: usize) -> (TestHeap, Vec<Entry>)
where
R: Rng,
{
let mut heap = TestHeap::default();
let mut entries = Vec::with_capacity(n);
for _ in 0..n {
let entry = Entry::new_random(rng);
heap.insert(entry.k, Payload::default(), entry.o);
entries.push(entry);
}
(heap, entries)
}
fn setup_group(g: &mut BenchmarkGroup<'_, WallTime>) {
g.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
g.sampling_mode(SamplingMode::Flat);
}
fn bench_insert_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("insert_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| (TestHeap::default(), Entry::new_random_n(&mut rng, *n)),
|(mut heap, entries)| {
for entry in &entries {
heap.insert(entry.k, Payload::default(), entry.o);
}
// let criterion handle the drop
(heap, entries)
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_peek_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("peek_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| create_filled_heap(&mut rng, *n).0,
|heap| {
heap.peek();
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_get_existing_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("get_existing_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
if *n == 0 {
continue;
}
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| {
let (heap, entries) = create_filled_heap(&mut rng, *n);
let entry = entries.choose(&mut rng).unwrap().clone();
(heap, entry)
},
|(heap, entry)| {
heap.get(&entry.k);
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_get_new_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("get_new_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| {
let (heap, _entries) = create_filled_heap(&mut rng, *n);
let entry = Entry::new_random(&mut rng);
(heap, entry)
},
|(heap, entry)| {
heap.get(&entry.k);
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_pop_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("pop_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| create_filled_heap(&mut rng, *n).0,
|mut heap| {
for _ in 0..*n {
heap.pop();
}
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_remove_existing_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("remove_existing_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
if *n == 0 {
continue;
}
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| {
let (heap, entries) = create_filled_heap(&mut rng, *n);
let entry = entries.choose(&mut rng).unwrap().clone();
(heap, entry)
},
|(mut heap, entry)| {
heap.remove(&entry.k);
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_remove_new_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("remove_new_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| {
let (heap, _entries) = create_filled_heap(&mut rng, *n);
let entry = Entry::new_random(&mut rng);
(heap, entry)
},
|(mut heap, entry)| {
heap.remove(&entry.k);
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_replace_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("replace_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
if *n == 0 {
continue;
}
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| {
let (heap, entries) = create_filled_heap(&mut rng, *n);
let entry = entries.choose(&mut rng).unwrap().clone();
let entry = Entry {
k: entry.k,
o: Entry::new_random(&mut rng).o,
};
(heap, entry)
},
|(mut heap, entry)| {
heap.insert(entry.k, Payload::default(), entry.o);
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_update_order_existing_to_random_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("update_order_existing_to_random_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
if *n == 0 {
continue;
}
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| {
let (heap, entries) = create_filled_heap(&mut rng, *n);
let entry = entries.choose(&mut rng).unwrap().clone();
let entry = Entry {
k: entry.k,
o: Entry::new_random(&mut rng).o,
};
(heap, entry)
},
|(mut heap, entry)| {
heap.update_order(&entry.k, entry.o);
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_update_order_existing_to_last_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("update_order_existing_to_first_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
if *n == 0 {
continue;
}
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| {
let (heap, entries) = create_filled_heap(&mut rng, *n);
let entry = entries.choose(&mut rng).unwrap().clone();
let entry = Entry {
k: entry.k,
o: u64::MAX - (u64::MAX << 2),
};
(heap, entry)
},
|(mut heap, entry)| {
heap.update_order(&entry.k, entry.o);
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
fn bench_update_order_new_after_n_elements(c: &mut Criterion) {
let mut g = c.benchmark_group("update_order_new_after_n_elements");
setup_group(&mut g);
let mut rng = thread_rng();
for n in TEST_SIZES {
g.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &_n| {
b.iter_batched(
|| {
let (heap, _entries) = create_filled_heap(&mut rng, *n);
let entry = Entry::new_random(&mut rng);
(heap, entry)
},
|(mut heap, entry)| {
heap.update_order(&entry.k, entry.o);
// let criterion handle the drop
heap
},
BatchSize::LargeInput,
);
});
}
g.finish();
}
criterion_group! {
name = benches;
config = Criterion::default();
targets =
bench_insert_n_elements,
bench_peek_after_n_elements,
bench_get_existing_after_n_elements,
bench_get_new_after_n_elements,
bench_pop_n_elements,
bench_remove_existing_after_n_elements,
bench_remove_new_after_n_elements,
bench_replace_after_n_elements,
bench_update_order_existing_to_random_after_n_elements,
bench_update_order_existing_to_last_after_n_elements,
bench_update_order_new_after_n_elements,
}
criterion_main!(benches);

View File

@ -0,0 +1,609 @@
//! Implementation of an [`AddressableHeap`].
use std::{
collections::{hash_map, BTreeSet, HashMap},
hash::Hash,
};
/// Addressable heap.
///
/// Stores a value `V` together with a key `K` and an order `O`. Elements are sorted by `O` and the smallest element can
/// be peeked/popped. At the same time elements can be addressed via `K`.
///
/// Note that `K` requires the inner data structure to implement [`Ord`] as a tie breaker.
#[derive(Debug, Clone)]
pub struct AddressableHeap<K, V, O>
where
K: Clone + Eq + Hash + Ord,
O: Clone + Ord,
{
/// Key to order and value.
///
/// The order is required to lookup data within the queue.
///
/// The value is stored here instead of the queue since HashMap entries are copied around less often than queue elements.
key_to_order_and_value: HashMap<K, (V, O)>,
/// Queue that handles the priorities.
///
/// The order goes first, the key goes second.
///
/// Note: This is not really a heap, but it fulfills the interface that we need.
queue: BTreeSet<(O, K)>,
}
impl<K, V, O> AddressableHeap<K, V, O>
where
K: Clone + Eq + Hash + Ord,
O: Clone + Ord,
{
/// Create new, empty heap.
pub fn new() -> Self {
Self {
key_to_order_and_value: HashMap::new(),
queue: BTreeSet::new(),
}
}
/// Check if the heap is empty.
pub fn is_empty(&self) -> bool {
let res1 = self.key_to_order_and_value.is_empty();
let res2 = self.queue.is_empty();
assert_eq!(res1, res2, "data structures out of sync");
res1
}
/// Insert element.
///
/// If the element (compared by `K`) already exists, it will be returned.
pub fn insert(&mut self, k: K, v: V, o: O) -> Option<(V, O)> {
let (result, k) = match self.key_to_order_and_value.entry(k.clone()) {
hash_map::Entry::Occupied(mut entry_o) => {
// `entry_o.replace_entry(...)` is not stabel yet, see https://github.com/rust-lang/rust/issues/44286
let mut tmp = (v, o.clone());
std::mem::swap(&mut tmp, entry_o.get_mut());
let (v_old, o_old) = tmp;
let query = (o_old, k);
let existed = self.queue.remove(&query);
assert!(existed, "key was in key_to_order");
let (o_old, k) = query;
(Some((v_old, o_old)), k)
}
hash_map::Entry::Vacant(entry_v) => {
entry_v.insert((v, o.clone()));
(None, k)
}
};
let inserted = self.queue.insert((o, k));
assert!(inserted, "entry should have been removed by now");
result
}
/// Peek first element (by smallest `O`).
pub fn peek(&self) -> Option<(&K, &V, &O)> {
self.iter().next()
}
/// Pop first element (by smallest `O`) from heap.
pub fn pop(&mut self) -> Option<(K, V, O)> {
if let Some((o, k)) = self.queue.pop_first() {
let (v, o2) = self
.key_to_order_and_value
.remove(&k)
.expect("value is in queue");
assert!(o == o2);
Some((k, v, o))
} else {
None
}
}
/// Iterate over elements in order of `O` (starting at smallest).
///
/// This is equivalent to calling [`pop`](Self::pop) multiple times, but does NOT modify the collection.
pub fn iter(&self) -> AddressableHeapIter<'_, K, V, O> {
AddressableHeapIter {
key_to_order_and_value: &self.key_to_order_and_value,
queue_iter: self.queue.iter(),
}
}
/// Get element by key.
pub fn get(&self, k: &K) -> Option<(&V, &O)> {
self.key_to_order_and_value.get(k).map(project_tuple)
}
/// Remove element by key.
///
/// If the element exists within the heap (addressed via `K`), the value and order will be returned.
pub fn remove(&mut self, k: &K) -> Option<(V, O)> {
if let Some((k, (v, o))) = self.key_to_order_and_value.remove_entry(k) {
let query = (o, k);
let existed = self.queue.remove(&query);
assert!(existed, "key was in key_to_order");
let (o, _k) = query;
Some((v, o))
} else {
None
}
}
/// Update order of a given key.
///
/// Returns existing order if the key existed.
pub fn update_order(&mut self, k: &K, o: O) -> Option<O> {
match self.key_to_order_and_value.get_mut(k) {
Some(entry) => {
let mut o_old = o.clone();
std::mem::swap(&mut entry.1, &mut o_old);
let query = (o_old, k.clone());
let existed = self.queue.remove(&query);
assert!(existed, "key was in key_to_order");
let (o_old, k) = query;
let inserted = self.queue.insert((o, k));
assert!(inserted, "entry should have been removed by now");
Some(o_old)
}
None => None,
}
}
}
impl<K, V, O> Default for AddressableHeap<K, V, O>
where
K: Clone + Eq + Hash + Ord,
O: Clone + Ord,
{
fn default() -> Self {
Self::new()
}
}
/// Project tuple references.
fn project_tuple<A, B>(t: &(A, B)) -> (&A, &B) {
(&t.0, &t.1)
}
/// Iterator of [`AddressableHeap::iter`].
pub struct AddressableHeapIter<'a, K, V, O>
where
K: Clone + Eq + Hash + Ord,
O: Clone + Ord,
{
key_to_order_and_value: &'a HashMap<K, (V, O)>,
queue_iter: std::collections::btree_set::Iter<'a, (O, K)>,
}
impl<'a, K, V, O> Iterator for AddressableHeapIter<'a, K, V, O>
where
K: Clone + Eq + Hash + Ord,
O: Clone + Ord,
{
type Item = (&'a K, &'a V, &'a O);
fn next(&mut self) -> Option<Self::Item> {
self.queue_iter.next().map(|(o, k)| {
let (v, o2) = self
.key_to_order_and_value
.get(k)
.expect("value is in queue");
assert!(o == o2);
(k, v, o)
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.queue_iter.size_hint()
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
#[test]
fn test_peek_empty() {
let heap = AddressableHeap::<i32, &str, i32>::new();
assert_eq!(heap.peek(), None);
}
#[test]
fn test_peek_some() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.insert(3, "c", 5);
assert_eq!(heap.peek(), Some((&2, &"b", &3)));
}
#[test]
fn test_peek_tie() {
let mut heap = AddressableHeap::new();
heap.insert(3, "a", 1);
heap.insert(1, "b", 1);
heap.insert(2, "c", 1);
assert_eq!(heap.peek(), Some((&1, &"b", &1)));
}
#[test]
fn test_peek_after_remove() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.insert(3, "c", 5);
assert_eq!(heap.peek(), Some((&2, &"b", &3)));
heap.remove(&3);
assert_eq!(heap.peek(), Some((&2, &"b", &3)));
heap.remove(&2);
assert_eq!(heap.peek(), Some((&1, &"a", &4)));
heap.remove(&1);
assert_eq!(heap.peek(), None);
}
#[test]
fn test_peek_after_override() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.insert(1, "c", 2);
assert_eq!(heap.peek(), Some((&1, &"c", &2)));
}
#[test]
fn test_pop_empty() {
let mut heap = AddressableHeap::<i32, &str, i32>::new();
assert_eq!(heap.pop(), None);
}
#[test]
fn test_pop_all() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.insert(3, "c", 5);
assert_eq!(heap.pop(), Some((2, "b", 3)));
assert_eq!(heap.pop(), Some((1, "a", 4)));
assert_eq!(heap.pop(), Some((3, "c", 5)));
assert_eq!(heap.pop(), None);
}
#[test]
fn test_pop_tie() {
let mut heap = AddressableHeap::new();
heap.insert(3, "a", 1);
heap.insert(1, "b", 1);
heap.insert(2, "c", 1);
assert_eq!(heap.pop(), Some((1, "b", 1)));
assert_eq!(heap.pop(), Some((2, "c", 1)));
assert_eq!(heap.pop(), Some((3, "a", 1)));
assert_eq!(heap.pop(), None);
}
#[test]
fn test_pop_after_insert() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.insert(3, "c", 5);
assert_eq!(heap.pop(), Some((2, "b", 3)));
heap.insert(4, "d", 2);
assert_eq!(heap.pop(), Some((4, "d", 2)));
assert_eq!(heap.pop(), Some((1, "a", 4)));
}
#[test]
fn test_pop_after_remove() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.insert(3, "c", 5);
heap.remove(&2);
assert_eq!(heap.pop(), Some((1, "a", 4)));
}
#[test]
fn test_pop_after_override() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.insert(1, "c", 2);
assert_eq!(heap.pop(), Some((1, "c", 2)));
assert_eq!(heap.pop(), Some((2, "b", 3)));
assert_eq!(heap.pop(), None);
}
#[test]
fn test_get_empty() {
let heap = AddressableHeap::<i32, &str, i32>::new();
assert_eq!(heap.get(&1), None);
}
#[test]
fn test_get_multiple() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
assert_eq!(heap.get(&1), Some((&"a", &4)));
assert_eq!(heap.get(&2), Some((&"b", &3)));
}
#[test]
fn test_get_after_remove() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.remove(&1);
assert_eq!(heap.get(&1), None);
assert_eq!(heap.get(&2), Some((&"b", &3)));
}
#[test]
fn test_get_after_pop() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.pop();
assert_eq!(heap.get(&1), Some((&"a", &4)));
assert_eq!(heap.get(&2), None);
}
#[test]
fn test_get_after_override() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(1, "b", 3);
assert_eq!(heap.get(&1), Some((&"b", &3)));
}
#[test]
fn test_remove_empty() {
let mut heap = AddressableHeap::<i32, &str, i32>::new();
assert_eq!(heap.remove(&1), None);
}
#[test]
fn test_remove_some() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
assert_eq!(heap.remove(&1), Some(("a", 4)));
assert_eq!(heap.remove(&2), Some(("b", 3)));
}
#[test]
fn test_remove_twice() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
assert_eq!(heap.remove(&1), Some(("a", 4)));
assert_eq!(heap.remove(&1), None);
}
#[test]
fn test_remove_after_pop() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(2, "b", 3);
heap.pop();
assert_eq!(heap.remove(&1), Some(("a", 4)));
assert_eq!(heap.remove(&2), None);
}
#[test]
fn test_remove_after_override() {
let mut heap = AddressableHeap::new();
heap.insert(1, "a", 4);
heap.insert(1, "b", 3);
assert_eq!(heap.remove(&1), Some(("b", 3)));
assert_eq!(heap.remove(&1), None);
}
#[test]
fn test_override() {
let mut heap = AddressableHeap::new();
assert_eq!(heap.insert(1, "a", 4), None);
assert_eq!(heap.insert(2, "b", 3), None);
assert_eq!(heap.insert(1, "c", 5), Some(("a", 4)));
}
/// Simple version of [`AddressableHeap`] for testing.
struct SimpleAddressableHeap {
inner: Vec<(u8, String, i8)>,
}
impl SimpleAddressableHeap {
fn new() -> Self {
Self { inner: Vec::new() }
}
fn is_empty(&self) -> bool {
self.inner.is_empty()
}
fn insert(&mut self, k: u8, v: String, o: i8) -> Option<(String, i8)> {
let res = self.remove(&k);
self.inner.push((k, v, o));
res
}
fn peek(&self) -> Option<(&u8, &String, &i8)> {
self.inner
.iter()
.min_by_key(|(k, _v, o)| (o, k))
.map(|(k, v, o)| (k, v, o))
}
fn dump_ordered(&self) -> Vec<(u8, String, i8)> {
let mut inner = self.inner.clone();
inner.sort_by_key(|(k, _v, o)| (*o, *k));
inner
}
fn pop(&mut self) -> Option<(u8, String, i8)> {
self.inner
.iter()
.enumerate()
.min_by_key(|(_idx, (k, _v, o))| (o, k))
.map(|(idx, _)| idx)
.map(|idx| self.inner.remove(idx))
}
fn get(&self, k: &u8) -> Option<(&String, &i8)> {
self.inner
.iter()
.find(|(k2, _v, _o)| k2 == k)
.map(|(_k, v, o)| (v, o))
}
fn remove(&mut self, k: &u8) -> Option<(String, i8)> {
self.inner
.iter()
.enumerate()
.find(|(_idx, (k2, _v, _o))| k2 == k)
.map(|(idx, _)| idx)
.map(|idx| {
let (_k, v, o) = self.inner.remove(idx);
(v, o)
})
}
fn update_order(&mut self, k: &u8, o: i8) -> Option<i8> {
if let Some((v, o_old)) = self.remove(k) {
self.insert(*k, v, o);
Some(o_old)
} else {
None
}
}
}
#[derive(Debug, Clone)]
enum Action {
IsEmpty,
Insert { k: u8, v: String, o: i8 },
Peek,
Iter,
Pop,
Get { k: u8 },
Remove { k: u8 },
UpdateOrder { k: u8, o: i8 },
}
// Use a hand-rolled strategy instead of `proptest-derive`, because the latter one is quite a heavy dependency.
fn action() -> impl Strategy<Value = Action> {
prop_oneof![
Just(Action::IsEmpty),
(any::<u8>(), ".*", any::<i8>()).prop_map(|(k, v, o)| Action::Insert { k, v, o }),
Just(Action::Peek),
Just(Action::Iter),
Just(Action::Pop),
any::<u8>().prop_map(|k| Action::Get { k }),
any::<u8>().prop_map(|k| Action::Remove { k }),
(any::<u8>(), any::<i8>()).prop_map(|(k, o)| Action::UpdateOrder { k, o }),
]
}
proptest! {
#[test]
fn test_proptest(actions in prop::collection::vec(action(), 0..100)) {
let mut heap = AddressableHeap::new();
let mut sim = SimpleAddressableHeap::new();
for action in actions {
match action {
Action::IsEmpty => {
let res1 = heap.is_empty();
let res2 = sim.is_empty();
assert_eq!(res1, res2);
}
Action::Insert{k, v, o} => {
let res1 = heap.insert(k, v.clone(), o);
let res2 = sim.insert(k, v, o);
assert_eq!(res1, res2);
}
Action::Peek => {
let res1 = heap.peek();
let res2 = sim.peek();
assert_eq!(res1, res2);
}
Action::Iter => {
let res1 = heap.iter().map(|(k, v, o)| (*k, v.clone(), *o)).collect::<Vec<_>>();
let res2 = sim.dump_ordered();
assert_eq!(res1, res2);
}
Action::Pop => {
let res1 = heap.pop();
let res2 = sim.pop();
assert_eq!(res1, res2);
}
Action::Get{k} => {
let res1 = heap.get(&k);
let res2 = sim.get(&k);
assert_eq!(res1, res2);
}
Action::Remove{k} => {
let res1 = heap.remove(&k);
let res2 = sim.remove(&k);
assert_eq!(res1, res2);
}
Action::UpdateOrder{k, o} => {
let res1 = heap.update_order(&k, o);
let res2 = sim.update_order(&k, o);
assert_eq!(res1, res2);
}
}
}
}
}
}

View File

@ -0,0 +1,51 @@
//! Implements [`CacheBackend`] for [`HashMap`].
use std::{
any::Any,
collections::HashMap,
fmt::Debug,
hash::{BuildHasher, Hash},
};
use super::CacheBackend;
impl<K, V, S> CacheBackend for HashMap<K, V, S>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
S: BuildHasher + Send + 'static,
{
type K = K;
type V = V;
fn get(&mut self, k: &Self::K) -> Option<Self::V> {
Self::get(self, k).cloned()
}
fn set(&mut self, k: Self::K, v: Self::V) {
self.insert(k, v);
}
fn remove(&mut self, k: &Self::K) {
self.remove(k);
}
fn is_empty(&self) -> bool {
self.is_empty()
}
fn as_any(&self) -> &dyn Any {
self as &dyn Any
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generic() {
use crate::backend::test_util::test_generic;
test_generic(HashMap::new);
}
}

View File

@ -0,0 +1,67 @@
//! Storage backends to keep and manage cached entries.
use std::{any::Any, fmt::Debug, hash::Hash};
pub mod hash_map;
pub mod policy;
#[cfg(test)]
mod test_util;
/// Backend to keep and manage stored entries.
///
/// A backend might remove entries at any point, e.g. due to memory pressure or expiration.
pub trait CacheBackend: Debug + Send + 'static {
/// Cache key.
type K: Clone + Eq + Hash + Ord + Debug + Send + 'static;
/// Cached value.
type V: Clone + Debug + Send + 'static;
/// Get value for given key if it exists.
fn get(&mut self, k: &Self::K) -> Option<Self::V>;
/// Set value for given key.
///
/// It is OK to set and override a key that already exists.
fn set(&mut self, k: Self::K, v: Self::V);
/// Remove value for given key.
///
/// It is OK to remove a key even when it does not exist.
fn remove(&mut self, k: &Self::K);
/// Check if backend is empty.
fn is_empty(&self) -> bool;
/// Return backend as [`Any`] which can be used to downcast to a specific implementation.
fn as_any(&self) -> &dyn Any;
}
impl<K, V> CacheBackend for Box<dyn CacheBackend<K = K, V = V>>
where
K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
V: Clone + Debug + Send + 'static,
{
type K = K;
type V = V;
fn get(&mut self, k: &Self::K) -> Option<Self::V> {
self.as_mut().get(k)
}
fn set(&mut self, k: Self::K, v: Self::V) {
self.as_mut().set(k, v)
}
fn remove(&mut self, k: &Self::K) {
self.as_mut().remove(k)
}
fn is_empty(&self) -> bool {
self.as_ref().is_empty()
}
fn as_any(&self) -> &dyn Any {
self as &dyn Any
}
}

View File

@ -0,0 +1,599 @@
//! Test integration between different policies.
use std::{collections::HashMap, sync::Arc, time::Duration};
use iox_time::{MockProvider, Time};
use parking_lot::Mutex;
use rand::rngs::mock::StepRng;
use test_helpers::maybe_start_logging;
use tokio::{runtime::Handle, sync::Notify};
use crate::{
backend::{
policy::refresh::test_util::{backoff_cfg, NotifyExt},
CacheBackend,
},
loader::test_util::TestLoader,
resource_consumption::{test_util::TestSize, ResourceEstimator},
};
use super::{
lru::{LruPolicy, ResourcePool},
refresh::{test_util::TestRefreshDurationProvider, RefreshPolicy},
remove_if::{RemoveIfHandle, RemoveIfPolicy},
ttl::{test_util::TestTtlProvider, TtlPolicy},
PolicyBackend,
};
#[tokio::test]
async fn test_refresh_can_prevent_expiration() {
let TestStateTtlAndRefresh {
mut backend,
refresh_duration_provider,
ttl_provider,
time_provider,
loader,
notify_idle,
..
} = TestStateTtlAndRefresh::new();
loader.mock_next(1, String::from("foo"));
refresh_duration_provider.set_refresh_in(
1,
String::from("a"),
Some(backoff_cfg(Duration::from_secs(1))),
);
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(2)));
refresh_duration_provider.set_refresh_in(1, String::from("foo"), None);
ttl_provider.set_expires_in(1, String::from("foo"), Some(Duration::from_secs(2)));
backend.set(1, String::from("a"));
// perform refresh
time_provider.inc(Duration::from_secs(1));
notify_idle.notified_with_timeout().await;
// no expired because refresh resets the timer
time_provider.inc(Duration::from_secs(1));
assert_eq!(backend.get(&1), Some(String::from("foo")));
// we don't request a 2nd refresh (refresh duration is None), so this finally expires
time_provider.inc(Duration::from_secs(1));
assert_eq!(backend.get(&1), None);
}
#[tokio::test]
async fn test_refresh_sets_new_expiration_after_it_finishes() {
let TestStateTtlAndRefresh {
mut backend,
refresh_duration_provider,
ttl_provider,
time_provider,
loader,
notify_idle,
..
} = TestStateTtlAndRefresh::new();
let barrier = loader.block_next(1, String::from("foo"));
refresh_duration_provider.set_refresh_in(
1,
String::from("a"),
Some(backoff_cfg(Duration::from_secs(1))),
);
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(3)));
refresh_duration_provider.set_refresh_in(1, String::from("foo"), None);
ttl_provider.set_expires_in(1, String::from("foo"), Some(Duration::from_secs(3)));
backend.set(1, String::from("a"));
// perform refresh
time_provider.inc(Duration::from_secs(1));
notify_idle.notified_with_timeout().await;
time_provider.inc(Duration::from_secs(1));
barrier.wait().await;
notify_idle.notified_with_timeout().await;
assert_eq!(backend.get(&1), Some(String::from("foo")));
// no expired because refresh resets the timer after it was ready (now), not when it started (1s ago)
time_provider.inc(Duration::from_secs(2));
assert_eq!(backend.get(&1), Some(String::from("foo")));
// we don't request a 2nd refresh (refresh duration is None), so this finally expires
time_provider.inc(Duration::from_secs(1));
assert_eq!(backend.get(&1), None);
}
#[tokio::test]
async fn test_refresh_does_not_update_lru_time() {
let TestStateLruAndRefresh {
mut backend,
refresh_duration_provider,
size_estimator,
time_provider,
loader,
notify_idle,
pool,
..
} = TestStateLruAndRefresh::new();
size_estimator.mock_size(1, String::from("a"), TestSize(4));
size_estimator.mock_size(1, String::from("foo"), TestSize(4));
size_estimator.mock_size(2, String::from("b"), TestSize(4));
size_estimator.mock_size(3, String::from("c"), TestSize(4));
refresh_duration_provider.set_refresh_in(
1,
String::from("a"),
Some(backoff_cfg(Duration::from_secs(1))),
);
refresh_duration_provider.set_refresh_in(1, String::from("foo"), None);
refresh_duration_provider.set_refresh_in(2, String::from("b"), None);
refresh_duration_provider.set_refresh_in(3, String::from("c"), None);
let barrier = loader.block_next(1, String::from("foo"));
backend.set(1, String::from("a"));
pool.wait_converged().await;
// trigger refresh
time_provider.inc(Duration::from_secs(1));
time_provider.inc(Duration::from_secs(1));
backend.set(2, String::from("b"));
pool.wait_converged().await;
time_provider.inc(Duration::from_secs(1));
notify_idle.notified_with_timeout().await;
barrier.wait().await;
notify_idle.notified_with_timeout().await;
// add a third item to the cache, forcing LRU to evict one of the items
backend.set(3, String::from("c"));
pool.wait_converged().await;
// Should evict `1` even though it was refreshed after `2` was added
assert_eq!(backend.get(&1), None);
}
#[tokio::test]
async fn test_if_refresh_to_slow_then_expire() {
let TestStateTtlAndRefresh {
mut backend,
refresh_duration_provider,
ttl_provider,
time_provider,
loader,
notify_idle,
..
} = TestStateTtlAndRefresh::new();
let barrier = loader.block_next(1, String::from("foo"));
refresh_duration_provider.set_refresh_in(
1,
String::from("a"),
Some(backoff_cfg(Duration::from_secs(1))),
);
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(2)));
backend.set(1, String::from("a"));
// perform refresh
time_provider.inc(Duration::from_secs(1));
notify_idle.notified_with_timeout().await;
time_provider.inc(Duration::from_secs(1));
notify_idle.not_notified().await;
assert_eq!(backend.get(&1), None);
// late loader finish will NOT bring the entry back
barrier.wait().await;
notify_idle.notified_with_timeout().await;
assert_eq!(backend.get(&1), None);
}
#[tokio::test]
async fn test_refresh_can_trigger_lru_eviction() {
maybe_start_logging();
let TestStateLRUAndRefresh {
mut backend,
refresh_duration_provider,
loader,
size_estimator,
time_provider,
notify_idle,
pool,
..
} = TestStateLRUAndRefresh::new();
assert_eq!(pool.limit(), TestSize(10));
loader.mock_next(1, String::from("b"));
refresh_duration_provider.set_refresh_in(
1,
String::from("a"),
Some(backoff_cfg(Duration::from_secs(1))),
);
refresh_duration_provider.set_refresh_in(1, String::from("b"), None);
refresh_duration_provider.set_refresh_in(2, String::from("c"), None);
refresh_duration_provider.set_refresh_in(3, String::from("d"), None);
size_estimator.mock_size(1, String::from("a"), TestSize(1));
size_estimator.mock_size(1, String::from("b"), TestSize(9));
size_estimator.mock_size(2, String::from("c"), TestSize(1));
size_estimator.mock_size(3, String::from("d"), TestSize(1));
backend.set(1, String::from("a"));
backend.set(2, String::from("c"));
backend.set(3, String::from("d"));
pool.wait_converged().await;
assert_eq!(backend.get(&2), Some(String::from("c")));
assert_eq!(backend.get(&3), Some(String::from("d")));
time_provider.inc(Duration::from_millis(1));
assert_eq!(backend.get(&1), Some(String::from("a")));
// refresh
time_provider.inc(Duration::from_secs(10));
notify_idle.notified_with_timeout().await;
pool.wait_converged().await;
// needed to evict 2->"c"
assert_eq!(backend.get(&1), Some(String::from("b")));
assert_eq!(backend.get(&2), None);
assert_eq!(backend.get(&3), Some(String::from("d")));
}
#[tokio::test]
async fn test_lru_learns_about_ttl_evictions() {
let TestStateTtlAndLRU {
mut backend,
ttl_provider,
size_estimator,
time_provider,
pool,
..
} = TestStateTtlAndLRU::new().await;
assert_eq!(pool.limit(), TestSize(10));
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
ttl_provider.set_expires_in(2, String::from("b"), None);
ttl_provider.set_expires_in(3, String::from("c"), None);
size_estimator.mock_size(1, String::from("a"), TestSize(4));
size_estimator.mock_size(2, String::from("b"), TestSize(4));
size_estimator.mock_size(3, String::from("c"), TestSize(4));
backend.set(1, String::from("a"));
backend.set(2, String::from("b"));
assert_eq!(pool.current(), TestSize(8));
// evict
time_provider.inc(Duration::from_secs(1));
assert_eq!(backend.get(&1), None);
// now there's space for 3->"c"
assert_eq!(pool.current(), TestSize(4));
backend.set(3, String::from("c"));
assert_eq!(pool.current(), TestSize(8));
assert_eq!(backend.get(&1), None);
assert_eq!(backend.get(&2), Some(String::from("b")));
assert_eq!(backend.get(&3), Some(String::from("c")));
}
#[tokio::test]
async fn test_remove_if_check_does_not_extend_lifetime() {
let TestStateLruAndRemoveIf {
mut backend,
size_estimator,
time_provider,
remove_if_handle,
pool,
..
} = TestStateLruAndRemoveIf::new().await;
size_estimator.mock_size(1, String::from("a"), TestSize(4));
size_estimator.mock_size(2, String::from("b"), TestSize(4));
size_estimator.mock_size(3, String::from("c"), TestSize(4));
backend.set(1, String::from("a"));
pool.wait_converged().await;
time_provider.inc(Duration::from_secs(1));
backend.set(2, String::from("b"));
pool.wait_converged().await;
time_provider.inc(Duration::from_secs(1));
// Checking remove_if should not count as a "use" of 1
// for the "least recently used" calculation
remove_if_handle.remove_if(&1, |_| false);
backend.set(3, String::from("c"));
pool.wait_converged().await;
// adding "c" totals 12 size, but backend has room for only 10
// so "least recently used" (in this case 1, not 2) should be removed
assert_eq!(backend.get(&1), None);
assert!(backend.get(&2).is_some());
}
/// Test setup that integrates the TTL policy with a refresh.
struct TestStateTtlAndRefresh {
backend: PolicyBackend<u8, String>,
ttl_provider: Arc<TestTtlProvider>,
refresh_duration_provider: Arc<TestRefreshDurationProvider>,
time_provider: Arc<MockProvider>,
loader: Arc<TestLoader<u8, (), String>>,
notify_idle: Arc<Notify>,
}
impl TestStateTtlAndRefresh {
fn new() -> Self {
let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new());
let ttl_provider = Arc::new(TestTtlProvider::new());
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let metric_registry = metric::Registry::new();
let loader = Arc::new(TestLoader::default());
let notify_idle = Arc::new(Notify::new());
// set up "RNG" that always generates the maximum, so we can test things easier
let rng_overwrite = StepRng::new(u64::MAX, 0);
let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _);
backend.add_policy(RefreshPolicy::new_inner(
Arc::clone(&time_provider) as _,
Arc::clone(&refresh_duration_provider) as _,
Arc::clone(&loader) as _,
"my_cache",
&metric_registry,
Arc::clone(&notify_idle),
&Handle::current(),
Some(rng_overwrite),
));
backend.add_policy(TtlPolicy::new(
Arc::clone(&ttl_provider) as _,
"my_cache",
&metric_registry,
));
Self {
backend,
refresh_duration_provider,
ttl_provider,
time_provider,
loader,
notify_idle,
}
}
}
/// Test setup that integrates the LRU policy with a refresh.
struct TestStateLRUAndRefresh {
backend: PolicyBackend<u8, String>,
size_estimator: Arc<TestSizeEstimator>,
refresh_duration_provider: Arc<TestRefreshDurationProvider>,
time_provider: Arc<MockProvider>,
loader: Arc<TestLoader<u8, (), String>>,
pool: Arc<ResourcePool<TestSize>>,
notify_idle: Arc<Notify>,
}
impl TestStateLRUAndRefresh {
fn new() -> Self {
let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new());
let size_estimator = Arc::new(TestSizeEstimator::default());
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let metric_registry = Arc::new(metric::Registry::new());
let loader = Arc::new(TestLoader::default());
let notify_idle = Arc::new(Notify::new());
// set up "RNG" that always generates the maximum, so we can test things easier
let rng_overwrite = StepRng::new(u64::MAX, 0);
let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _);
backend.add_policy(RefreshPolicy::new_inner(
Arc::clone(&time_provider) as _,
Arc::clone(&refresh_duration_provider) as _,
Arc::clone(&loader) as _,
"my_cache",
&metric_registry,
Arc::clone(&notify_idle),
&Handle::current(),
Some(rng_overwrite),
));
let pool = Arc::new(ResourcePool::new(
"my_pool",
TestSize(10),
Arc::clone(&metric_registry),
&Handle::current(),
));
backend.add_policy(LruPolicy::new(
Arc::clone(&pool),
"my_cache",
Arc::clone(&size_estimator) as _,
));
Self {
backend,
refresh_duration_provider,
size_estimator,
time_provider,
loader,
pool,
notify_idle,
}
}
}
/// Test setup that integrates the TTL policy with LRU.
struct TestStateTtlAndLRU {
backend: PolicyBackend<u8, String>,
ttl_provider: Arc<TestTtlProvider>,
time_provider: Arc<MockProvider>,
size_estimator: Arc<TestSizeEstimator>,
pool: Arc<ResourcePool<TestSize>>,
}
impl TestStateTtlAndLRU {
async fn new() -> Self {
let ttl_provider = Arc::new(TestTtlProvider::new());
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let metric_registry = Arc::new(metric::Registry::new());
let size_estimator = Arc::new(TestSizeEstimator::default());
let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _);
backend.add_policy(TtlPolicy::new(
Arc::clone(&ttl_provider) as _,
"my_cache",
&metric_registry,
));
let pool = Arc::new(ResourcePool::new(
"my_pool",
TestSize(10),
Arc::clone(&metric_registry),
&Handle::current(),
));
backend.add_policy(LruPolicy::new(
Arc::clone(&pool),
"my_cache",
Arc::clone(&size_estimator) as _,
));
Self {
backend,
ttl_provider,
time_provider,
size_estimator,
pool,
}
}
}
/// Test setup that integrates the LRU policy with RemoveIf and max size of 10
struct TestStateLruAndRemoveIf {
backend: PolicyBackend<u8, String>,
time_provider: Arc<MockProvider>,
size_estimator: Arc<TestSizeEstimator>,
remove_if_handle: RemoveIfHandle<u8, String>,
pool: Arc<ResourcePool<TestSize>>,
}
impl TestStateLruAndRemoveIf {
async fn new() -> Self {
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let metric_registry = Arc::new(metric::Registry::new());
let size_estimator = Arc::new(TestSizeEstimator::default());
let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _);
let pool = Arc::new(ResourcePool::new(
"my_pool",
TestSize(10),
Arc::clone(&metric_registry),
&Handle::current(),
));
backend.add_policy(LruPolicy::new(
Arc::clone(&pool),
"my_cache",
Arc::clone(&size_estimator) as _,
));
let (constructor, remove_if_handle) =
RemoveIfPolicy::create_constructor_and_handle("my_cache", &metric_registry);
backend.add_policy(constructor);
Self {
backend,
time_provider,
size_estimator,
remove_if_handle,
pool,
}
}
}
/// Test setup that integrates the LRU policy with a refresh.
struct TestStateLruAndRefresh {
backend: PolicyBackend<u8, String>,
size_estimator: Arc<TestSizeEstimator>,
refresh_duration_provider: Arc<TestRefreshDurationProvider>,
time_provider: Arc<MockProvider>,
loader: Arc<TestLoader<u8, (), String>>,
notify_idle: Arc<Notify>,
pool: Arc<ResourcePool<TestSize>>,
}
impl TestStateLruAndRefresh {
fn new() -> Self {
let refresh_duration_provider = Arc::new(TestRefreshDurationProvider::new());
let size_estimator = Arc::new(TestSizeEstimator::default());
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let metric_registry = Arc::new(metric::Registry::new());
let loader = Arc::new(TestLoader::default());
let notify_idle = Arc::new(Notify::new());
// set up "RNG" that always generates the maximum, so we can test things easier
let rng_overwrite = StepRng::new(u64::MAX, 0);
let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _);
backend.add_policy(RefreshPolicy::new_inner(
Arc::clone(&time_provider) as _,
Arc::clone(&refresh_duration_provider) as _,
Arc::clone(&loader) as _,
"my_cache",
&metric_registry,
Arc::clone(&notify_idle),
&Handle::current(),
Some(rng_overwrite),
));
let pool = Arc::new(ResourcePool::new(
"my_pool",
TestSize(10),
Arc::clone(&metric_registry),
&Handle::current(),
));
backend.add_policy(LruPolicy::new(
Arc::clone(&pool),
"my_cache",
Arc::clone(&size_estimator) as _,
));
Self {
backend,
refresh_duration_provider,
size_estimator,
time_provider,
loader,
notify_idle,
pool,
}
}
}
#[derive(Debug, Default)]
struct TestSizeEstimator {
sizes: Mutex<HashMap<(u8, String), TestSize>>,
}
impl TestSizeEstimator {
fn mock_size(&self, k: u8, v: String, s: TestSize) {
self.sizes.lock().insert((k, v), s);
}
}
impl ResourceEstimator for TestSizeEstimator {
type K = u8;
type V = String;
type S = TestSize;
fn consumption(&self, k: &Self::K, v: &Self::V) -> Self::S {
*self.sizes.lock().get(&(*k, v.clone())).unwrap()
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,288 @@
//! Backend that supports custom removal / expiry of keys
use metric::U64Counter;
use parking_lot::Mutex;
use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc};
use crate::{
backend::policy::{CacheBackend, CallbackHandle, ChangeRequest, Subscriber},
cache::{Cache, CacheGetStatus},
};
/// Allows explicitly removing entries from the cache.
#[derive(Debug, Clone)]
pub struct RemoveIfPolicy<K, V>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
{
// the policy itself doesn't do anything, the handles will do all the work
_phantom: PhantomData<(K, V)>,
}
impl<K, V> RemoveIfPolicy<K, V>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
{
/// Create new policy.
///
/// This returns the policy constructor which shall be pass to
/// [`PolicyBackend::add_policy`] and handle that can be used to remove entries.
///
/// Note that as long as the policy constructor is NOT passed to [`PolicyBackend::add_policy`], the operations on
/// the handle are essentially no-ops (i.e. they will not remove anything).
///
/// [`PolicyBackend::add_policy`]: super::PolicyBackend::add_policy
pub fn create_constructor_and_handle(
name: &'static str,
metric_registry: &metric::Registry,
) -> (
impl FnOnce(CallbackHandle<K, V>) -> Self,
RemoveIfHandle<K, V>,
) {
let metric_removed_by_predicate = metric_registry
.register_metric::<U64Counter>(
"cache_removed_by_custom_condition",
"Number of entries removed from a cache via a custom condition",
)
.recorder(&[("name", name)]);
let handle = RemoveIfHandle {
callback_handle: Arc::new(Mutex::new(None)),
metric_removed_by_predicate,
};
let handle_captured = handle.clone();
let policy_constructor = move |callback_handle| {
*handle_captured.callback_handle.lock() = Some(callback_handle);
Self {
_phantom: PhantomData,
}
};
(policy_constructor, handle)
}
}
impl<K, V> Subscriber for RemoveIfPolicy<K, V>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
{
type K = K;
type V = V;
}
/// Handle created by [`RemoveIfPolicy`] that can be used to evict data from caches.
///
/// The handle can be cloned freely. All clones will refer to the same underlying backend.
#[derive(Debug, Clone)]
pub struct RemoveIfHandle<K, V>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
{
callback_handle: Arc<Mutex<Option<CallbackHandle<K, V>>>>,
metric_removed_by_predicate: U64Counter,
}
impl<K, V> RemoveIfHandle<K, V>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
{
/// "remove" a key (aka remove it from the shared backend) if the
/// specified predicate is true. If the key is removed return
/// true, otherwise return false
///
/// Note that the predicate function is called while the lock is
/// held (and thus the inner backend can't be concurrently accessed
pub fn remove_if<P>(&self, k: &K, predicate: P) -> bool
where
P: FnOnce(V) -> bool,
{
let mut guard = self.callback_handle.lock();
let handle = match guard.as_mut() {
Some(handle) => handle,
None => return false,
};
let metric_removed_by_predicate = self.metric_removed_by_predicate.clone();
let mut removed = false;
let removed_captured = &mut removed;
let k = k.clone();
handle.execute_requests(vec![ChangeRequest::from_fn(move |backend| {
if let Some(v) = backend.get_untracked(&k) {
if predicate(v) {
metric_removed_by_predicate.inc(1);
backend.remove(&k);
*removed_captured = true;
}
}
})]);
removed
}
/// Performs [`remove_if`](Self::remove_if) and [`GET`](Cache::get) in one go.
///
/// Ensures that these two actions interact correctly.
///
/// # Forward process
/// This function only works if cache values evolve in one direction. This is that the predicate can only flip from
/// `true` to `false` over time (i.e. it detects an outdated value and then an up-to-date value), NOT the other way
/// around (i.e. data cannot get outdated under the same predicate).
pub async fn remove_if_and_get_with_status<P, C, GetExtra>(
&self,
cache: &C,
k: K,
predicate: P,
extra: GetExtra,
) -> (V, CacheGetStatus)
where
P: Fn(V) -> bool + Send,
C: Cache<K = K, V = V, GetExtra = GetExtra>,
GetExtra: Clone + Send,
{
let mut removed = self.remove_if(&k, &predicate);
loop {
// avoid some `Sync` bounds
let k_for_get = k.clone();
let extra_for_get = extra.clone();
let (v, status) = cache.get_with_status(k_for_get, extra_for_get).await;
match status {
CacheGetStatus::Hit => {
// key existed and no other process loaded it => safe to use
return (v, status);
}
CacheGetStatus::Miss => {
// key didn't exist and we loaded it => safe to use
return (v, status);
}
CacheGetStatus::MissAlreadyLoading => {
if removed {
// key was outdated but there was some loading in process, this may have overlapped with our check
// so our check might have been incomplete => need to re-check
removed = self.remove_if(&k, &predicate);
if removed {
// removed again, so cannot use our result
continue;
} else {
// didn't remove => safe to use
return (v, status);
}
} else {
// there was a load action in process but the key was already up-to-date, so it's OK to use the new
// data as well (forward process)
return (v, status);
}
}
}
}
}
/// Same as [`remove_if_and_get_with_status`](Self::remove_if_and_get_with_status) but without the status.
pub async fn remove_if_and_get<P, C, GetExtra>(
&self,
cache: &C,
k: K,
predicate: P,
extra: GetExtra,
) -> V
where
P: Fn(V) -> bool + Send,
C: Cache<K = K, V = V, GetExtra = GetExtra>,
GetExtra: Clone + Send,
{
self.remove_if_and_get_with_status(cache, k, predicate, extra)
.await
.0
}
}
#[cfg(test)]
mod tests {
use iox_time::{MockProvider, Time};
use metric::{Observation, RawReporter};
use crate::backend::{policy::PolicyBackend, CacheBackend};
use super::*;
#[test]
fn test_generic_backend() {
use crate::backend::test_util::test_generic;
test_generic(|| {
let metric_registry = metric::Registry::new();
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let mut backend = PolicyBackend::hashmap_backed(time_provider);
let (policy_constructor, _handle) =
RemoveIfPolicy::create_constructor_and_handle("my_cache", &metric_registry);
backend.add_policy(policy_constructor);
backend
});
}
#[test]
fn test_remove_if() {
let metric_registry = metric::Registry::new();
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let mut backend: PolicyBackend<u8, String> = PolicyBackend::hashmap_backed(time_provider);
let (policy_constructor, handle) =
RemoveIfPolicy::create_constructor_and_handle("my_cache", &metric_registry);
backend.add_policy(policy_constructor);
backend.set(1, "foo".into());
backend.set(2, "bar".into());
assert_eq!(get_removed_metric(&metric_registry), 0);
assert!(!handle.remove_if(&1, |v| v == "zzz"));
assert_eq!(backend.get(&1), Some("foo".into()));
assert_eq!(backend.get(&2), Some("bar".into()));
assert_eq!(get_removed_metric(&metric_registry), 0);
assert!(handle.remove_if(&1, |v| v == "foo"));
assert_eq!(backend.get(&1), None);
assert_eq!(backend.get(&2), Some("bar".into()));
assert_eq!(get_removed_metric(&metric_registry), 1);
assert!(!handle.remove_if(&1, |v| v == "bar"));
assert_eq!(backend.get(&1), None);
assert_eq!(backend.get(&2), Some("bar".into()));
assert_eq!(get_removed_metric(&metric_registry), 1);
}
#[test]
fn test_not_linked() {
let metric_registry = metric::Registry::new();
let (_policy_constructor, handle) =
RemoveIfPolicy::<u8, String>::create_constructor_and_handle(
"my_cache",
&metric_registry,
);
assert_eq!(get_removed_metric(&metric_registry), 0);
assert!(!handle.remove_if(&1, |v| v == "zzz"));
assert_eq!(get_removed_metric(&metric_registry), 0);
}
fn get_removed_metric(metric_registry: &metric::Registry) -> u64 {
let mut reporter = RawReporter::default();
metric_registry.report(&mut reporter);
let observation = reporter
.metric("cache_removed_by_custom_condition")
.unwrap()
.observation(&[("name", "my_cache")])
.unwrap();
if let Observation::U64Counter(c) = observation {
*c
} else {
panic!("Wrong observation type")
}
}
}

View File

@ -0,0 +1,755 @@
//! Time-to-live handling.
use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc, time::Duration};
use iox_time::Time;
use metric::U64Counter;
use crate::addressable_heap::AddressableHeap;
use super::{CallbackHandle, ChangeRequest, Subscriber};
/// Interface to provide TTL (time to live) data for a key-value pair.
pub trait TtlProvider: std::fmt::Debug + Send + Sync + 'static {
/// Cache key.
type K;
/// Cached value.
type V;
/// When should the given key-value pair expire?
///
/// Return `None` for "never".
///
/// The function is only called once for a newly cached key-value pair. This means:
/// - There is no need in remembering the time of a given pair (e.g. you can safely always return a constant).
/// - You cannot change the TTL after the data was cached.
///
/// Expiration is set to take place AT OR AFTER the provided duration.
fn expires_in(&self, k: &Self::K, v: &Self::V) -> Option<Duration>;
}
/// [`TtlProvider`] that never expires.
#[derive(Default)]
pub struct NeverTtlProvider<K, V>
where
K: 'static,
V: 'static,
{
// phantom data that is Send and Sync, see https://stackoverflow.com/a/50201389
_k: PhantomData<fn() -> K>,
_v: PhantomData<fn() -> V>,
}
impl<K, V> std::fmt::Debug for NeverTtlProvider<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NeverTtlProvider").finish_non_exhaustive()
}
}
impl<K, V> TtlProvider for NeverTtlProvider<K, V> {
type K = K;
type V = V;
fn expires_in(&self, _k: &Self::K, _v: &Self::V) -> Option<Duration> {
None
}
}
/// [`TtlProvider`] that returns a constant value.
pub struct ConstantValueTtlProvider<K, V>
where
K: 'static,
V: 'static,
{
// phantom data that is Send and Sync, see https://stackoverflow.com/a/50201389
_k: PhantomData<fn() -> K>,
_v: PhantomData<fn() -> V>,
ttl: Option<Duration>,
}
impl<K, V> ConstantValueTtlProvider<K, V>
where
K: 'static,
V: 'static,
{
/// Create new provider with the given TTL value.
pub fn new(ttl: Option<Duration>) -> Self {
Self {
_k: PhantomData,
_v: PhantomData,
ttl,
}
}
}
impl<K, V> std::fmt::Debug for ConstantValueTtlProvider<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConstantValueTtlProvider")
.field("ttl", &self.ttl)
.finish_non_exhaustive()
}
}
impl<K, V> TtlProvider for ConstantValueTtlProvider<K, V> {
type K = K;
type V = V;
fn expires_in(&self, _k: &Self::K, _v: &Self::V) -> Option<Duration> {
self.ttl
}
}
/// [`TtlProvider`] that returns different values for `None`/`Some(...)` values.
pub struct OptionalValueTtlProvider<K, V>
where
K: 'static,
V: 'static,
{
// phantom data that is Send and Sync, see https://stackoverflow.com/a/50201389
_k: PhantomData<fn() -> K>,
_v: PhantomData<fn() -> V>,
ttl_none: Option<Duration>,
ttl_some: Option<Duration>,
}
impl<K, V> OptionalValueTtlProvider<K, V>
where
K: 'static,
V: 'static,
{
/// Create new provider with the given TTL values for `None` and `Some(...)`.
pub fn new(ttl_none: Option<Duration>, ttl_some: Option<Duration>) -> Self {
Self {
_k: PhantomData,
_v: PhantomData,
ttl_none,
ttl_some,
}
}
}
impl<K, V> std::fmt::Debug for OptionalValueTtlProvider<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OptionalValueTtlProvider")
.field("ttl_none", &self.ttl_none)
.field("ttl_some", &self.ttl_some)
.finish_non_exhaustive()
}
}
impl<K, V> TtlProvider for OptionalValueTtlProvider<K, V> {
type K = K;
type V = Option<V>;
fn expires_in(&self, _k: &Self::K, v: &Self::V) -> Option<Duration> {
match v {
None => self.ttl_none,
Some(_) => self.ttl_some,
}
}
}
/// Cache policy that implements Time To Life.
///
/// # Cache Eviction
/// Every method ([`get`](Subscriber::get), [`set`](Subscriber::set), [`remove`](Subscriber::remove)) causes the
/// cache to check for expired keys. This may lead to certain delays, esp. when dropping the contained values takes a
/// long time.
#[derive(Debug)]
pub struct TtlPolicy<K, V>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
{
ttl_provider: Arc<dyn TtlProvider<K = K, V = V>>,
expiration: AddressableHeap<K, (), Time>,
metric_expired: U64Counter,
}
impl<K, V> TtlPolicy<K, V>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
{
/// Create new TTL policy.
pub fn new(
ttl_provider: Arc<dyn TtlProvider<K = K, V = V>>,
name: &'static str,
metric_registry: &metric::Registry,
) -> impl FnOnce(CallbackHandle<K, V>) -> Self {
let metric_expired = metric_registry
.register_metric::<U64Counter>(
"cache_ttl_expired",
"Number of entries that expired via TTL.",
)
.recorder(&[("name", name)]);
|mut callback_handle| {
callback_handle.execute_requests(vec![ChangeRequest::ensure_empty()]);
Self {
ttl_provider,
expiration: Default::default(),
metric_expired,
}
}
}
fn evict_expired(&mut self, now: Time) -> Vec<ChangeRequest<'static, K, V>> {
let mut requests = vec![];
while self
.expiration
.peek()
.map(|(_k, _, t)| *t <= now)
.unwrap_or_default()
{
let (k, _, _t) = self.expiration.pop().unwrap();
self.metric_expired.inc(1);
requests.push(ChangeRequest::remove(k));
}
requests
}
}
impl<K, V> Subscriber for TtlPolicy<K, V>
where
K: Clone + Eq + Debug + Hash + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
{
type K = K;
type V = V;
fn get(&mut self, _k: &Self::K, now: Time) -> Vec<ChangeRequest<'static, Self::K, Self::V>> {
self.evict_expired(now)
}
fn set(
&mut self,
k: &Self::K,
v: &Self::V,
now: Time,
) -> Vec<ChangeRequest<'static, Self::K, Self::V>> {
let mut requests = self.evict_expired(now);
if let Some(ttl) = self.ttl_provider.expires_in(k, v) {
if ttl.is_zero() {
requests.push(ChangeRequest::remove(k.clone()));
}
match now.checked_add(ttl) {
Some(t) => {
self.expiration.insert(k.clone(), (), t);
}
None => {
// Still need to ensure that any current expiration is disabled
self.expiration.remove(k);
}
}
} else {
// Still need to ensure that any current expiration is disabled
self.expiration.remove(k);
};
requests
}
fn remove(&mut self, k: &Self::K, now: Time) -> Vec<ChangeRequest<'static, Self::K, Self::V>> {
self.expiration.remove(k);
self.evict_expired(now)
}
}
pub mod test_util {
//! Test utils for TTL policy.
use std::collections::HashMap;
use parking_lot::Mutex;
use super::*;
/// [`TtlProvider`] for testing.
#[derive(Debug, Default)]
pub struct TestTtlProvider {
expires_in: Mutex<HashMap<(u8, String), Option<Duration>>>,
}
impl TestTtlProvider {
/// Create new, empty provider.
pub fn new() -> Self {
Self::default()
}
/// Set TTL time for given key-value pair.
pub fn set_expires_in(&self, k: u8, v: String, d: Option<Duration>) {
self.expires_in.lock().insert((k, v), d);
}
}
impl TtlProvider for TestTtlProvider {
type K = u8;
type V = String;
fn expires_in(&self, k: &Self::K, v: &Self::V) -> Option<Duration> {
*self
.expires_in
.lock()
.get(&(*k, v.clone()))
.expect("expires_in value not mocked")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic(expected = "expires_in value not mocked")]
fn test_panic_value_not_mocked() {
TestTtlProvider::new().expires_in(&1, &String::from("foo"));
}
#[test]
fn test_mocking() {
let provider = TestTtlProvider::default();
provider.set_expires_in(1, String::from("a"), None);
provider.set_expires_in(1, String::from("b"), Some(Duration::from_secs(1)));
provider.set_expires_in(2, String::from("a"), Some(Duration::from_secs(2)));
assert_eq!(provider.expires_in(&1, &String::from("a")), None,);
assert_eq!(
provider.expires_in(&1, &String::from("b")),
Some(Duration::from_secs(1)),
);
assert_eq!(
provider.expires_in(&2, &String::from("a")),
Some(Duration::from_secs(2)),
);
// replace
provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(3)));
assert_eq!(
provider.expires_in(&1, &String::from("a")),
Some(Duration::from_secs(3)),
);
}
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, time::Duration};
use iox_time::MockProvider;
use metric::{Observation, RawReporter};
use crate::backend::{policy::PolicyBackend, CacheBackend};
use super::{test_util::TestTtlProvider, *};
#[test]
fn test_never_ttl_provider() {
let provider = NeverTtlProvider::<u8, i8>::default();
assert_eq!(provider.expires_in(&1, &2), None);
}
#[test]
fn test_constant_value_ttl_provider() {
let ttl = Some(Duration::from_secs(1));
let provider = ConstantValueTtlProvider::<u8, i8>::new(ttl);
assert_eq!(provider.expires_in(&1, &2), ttl);
}
#[test]
fn test_optional_value_ttl_provider() {
let ttl_none = Some(Duration::from_secs(1));
let ttl_some = Some(Duration::from_secs(2));
let provider = OptionalValueTtlProvider::<u8, i8>::new(ttl_none, ttl_some);
assert_eq!(provider.expires_in(&1, &None), ttl_none);
assert_eq!(provider.expires_in(&1, &Some(2)), ttl_some);
}
#[test]
#[should_panic(expected = "inner backend is not empty")]
fn test_panic_inner_not_empty() {
let ttl_provider = Arc::new(TestTtlProvider::new());
let metric_registry = metric::Registry::new();
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let mut backend: PolicyBackend<u8, String> = PolicyBackend::hashmap_backed(time_provider);
let policy_constructor =
TtlPolicy::new(Arc::clone(&ttl_provider) as _, "my_cache", &metric_registry);
backend.add_policy(|mut handle| {
handle.execute_requests(vec![ChangeRequest::set(1, String::from("foo"))]);
policy_constructor(handle)
});
}
#[test]
fn test_expires_single() {
let TestState {
mut backend,
metric_registry,
ttl_provider,
time_provider,
} = TestState::new();
assert_eq!(get_expired_metric(&metric_registry), 0);
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
assert_eq!(get_expired_metric(&metric_registry), 0);
time_provider.inc(Duration::from_secs(1));
assert_eq!(backend.get(&1), None);
assert_eq!(get_expired_metric(&metric_registry), 1);
}
#[test]
fn test_overflow_expire() {
let ttl_provider = Arc::new(TestTtlProvider::new());
let metric_registry = metric::Registry::new();
// init time provider at MAX!
let time_provider = Arc::new(MockProvider::new(Time::MAX));
let mut backend: PolicyBackend<u8, String> = PolicyBackend::hashmap_backed(time_provider);
backend.add_policy(TtlPolicy::new(
Arc::clone(&ttl_provider) as _,
"my_cache",
&metric_registry,
));
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::MAX));
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
}
#[test]
fn test_never_expire() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), None);
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
time_provider.inc(Duration::from_secs(1));
assert_eq!(backend.get(&1), Some(String::from("a")));
}
#[test]
fn test_expiration_uses_key_and_value() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
ttl_provider.set_expires_in(1, String::from("b"), Some(Duration::from_secs(4)));
ttl_provider.set_expires_in(2, String::from("a"), Some(Duration::from_secs(2)));
backend.set(1, String::from("b"));
time_provider.inc(Duration::from_secs(3));
assert_eq!(backend.get(&1), Some(String::from("b")));
}
#[test]
fn test_override_with_different_expiration() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(3)));
backend.set(1, String::from("a"));
time_provider.inc(Duration::from_secs(2));
assert_eq!(backend.get(&1), Some(String::from("a")));
}
#[test]
fn test_override_with_no_expiration() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
ttl_provider.set_expires_in(1, String::from("a"), None);
backend.set(1, String::from("a"));
time_provider.inc(Duration::from_secs(2));
assert_eq!(backend.get(&1), Some(String::from("a")));
}
#[test]
fn test_override_with_some_expiration() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), None);
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
backend.set(1, String::from("a"));
time_provider.inc(Duration::from_secs(2));
assert_eq!(backend.get(&1), None);
}
#[test]
fn test_override_with_overflow() {
let ttl_provider = Arc::new(TestTtlProvider::new());
let metric_registry = metric::Registry::new();
// init time provider at nearly MAX!
let time_provider = Arc::new(MockProvider::new(Time::MAX - Duration::from_secs(2)));
let mut backend: PolicyBackend<u8, String> =
PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _);
backend.add_policy(TtlPolicy::new(
Arc::clone(&ttl_provider) as _,
"my_cache",
&metric_registry,
));
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(u64::MAX)));
backend.set(1, String::from("a"));
time_provider.inc(Duration::from_secs(2));
assert_eq!(backend.get(&1), Some(String::from("a")));
}
#[test]
fn test_readd_with_different_expiration() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(3)));
backend.remove(&1);
backend.set(1, String::from("a"));
time_provider.inc(Duration::from_secs(2));
assert_eq!(backend.get(&1), Some(String::from("a")));
}
#[test]
fn test_readd_with_no_expiration() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
ttl_provider.set_expires_in(1, String::from("a"), None);
backend.remove(&1);
backend.set(1, String::from("a"));
time_provider.inc(Duration::from_secs(2));
assert_eq!(backend.get(&1), Some(String::from("a")));
}
#[test]
fn test_update_cleans_multiple_keys() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
ttl_provider.set_expires_in(2, String::from("b"), Some(Duration::from_secs(2)));
ttl_provider.set_expires_in(3, String::from("c"), Some(Duration::from_secs(2)));
ttl_provider.set_expires_in(4, String::from("d"), Some(Duration::from_secs(3)));
backend.set(1, String::from("a"));
backend.set(2, String::from("b"));
backend.set(3, String::from("c"));
backend.set(4, String::from("d"));
assert_eq!(backend.get(&1), Some(String::from("a")));
assert_eq!(backend.get(&2), Some(String::from("b")));
assert_eq!(backend.get(&3), Some(String::from("c")));
assert_eq!(backend.get(&4), Some(String::from("d")));
time_provider.inc(Duration::from_secs(2));
assert_eq!(backend.get(&1), None);
{
let inner_ref = backend.inner_ref();
let inner_backend = inner_ref
.as_any()
.downcast_ref::<HashMap<u8, String>>()
.unwrap();
assert!(!inner_backend.contains_key(&1));
assert!(!inner_backend.contains_key(&2));
assert!(!inner_backend.contains_key(&3));
assert!(inner_backend.contains_key(&4));
}
assert_eq!(backend.get(&2), None);
assert_eq!(backend.get(&3), None);
assert_eq!(backend.get(&4), Some(String::from("d")));
}
#[test]
fn test_remove_expired_key() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
time_provider.inc(Duration::from_secs(1));
backend.remove(&1);
assert_eq!(backend.get(&1), None);
}
#[test]
fn test_expire_removed_key() {
let TestState {
mut backend,
ttl_provider,
time_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(1)));
ttl_provider.set_expires_in(2, String::from("b"), Some(Duration::from_secs(2)));
backend.set(1, String::from("a"));
backend.remove(&1);
time_provider.inc(Duration::from_secs(1));
backend.set(2, String::from("b"));
assert_eq!(backend.get(&1), None);
assert_eq!(backend.get(&2), Some(String::from("b")));
}
#[test]
fn test_expire_immediately() {
let TestState {
mut backend,
ttl_provider,
..
} = TestState::new();
ttl_provider.set_expires_in(1, String::from("a"), Some(Duration::from_secs(0)));
backend.set(1, String::from("a"));
assert!(backend.is_empty());
assert_eq!(backend.get(&1), None);
}
#[test]
fn test_generic_backend() {
use crate::backend::test_util::test_generic;
test_generic(|| {
let ttl_provider = Arc::new(NeverTtlProvider::default());
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let metric_registry = metric::Registry::new();
let mut backend = PolicyBackend::hashmap_backed(time_provider);
backend.add_policy(TtlPolicy::new(
Arc::clone(&ttl_provider) as _,
"my_cache",
&metric_registry,
));
backend
});
}
struct TestState {
backend: PolicyBackend<u8, String>,
metric_registry: metric::Registry,
ttl_provider: Arc<TestTtlProvider>,
time_provider: Arc<MockProvider>,
}
impl TestState {
fn new() -> Self {
let ttl_provider = Arc::new(TestTtlProvider::new());
let time_provider = Arc::new(MockProvider::new(Time::MIN));
let metric_registry = metric::Registry::new();
let mut backend = PolicyBackend::hashmap_backed(Arc::clone(&time_provider) as _);
backend.add_policy(TtlPolicy::new(
Arc::clone(&ttl_provider) as _,
"my_cache",
&metric_registry,
));
Self {
backend,
metric_registry,
ttl_provider,
time_provider,
}
}
}
fn get_expired_metric(metric_registry: &metric::Registry) -> u64 {
let mut reporter = RawReporter::default();
metric_registry.report(&mut reporter);
let observation = reporter
.metric("cache_ttl_expired")
.unwrap()
.observation(&[("name", "my_cache")])
.unwrap();
if let Observation::U64Counter(c) = observation {
*c
} else {
panic!("Wrong observation type")
}
}
}

View File

@ -0,0 +1,112 @@
use super::CacheBackend;
/// Generic test set for [`Backend`].
///
/// The backend must NOT perform any pruning/deletions during the tests (even though backends are allowed to do that in
/// general).
pub fn test_generic<B, F>(constructor: F)
where
B: CacheBackend<K = u8, V = String>,
F: Fn() -> B,
{
test_get_empty(constructor());
test_get_set(constructor());
test_get_twice(constructor());
test_override(constructor());
test_set_remove_get(constructor());
test_remove_empty(constructor());
test_readd(constructor());
test_is_empty(constructor());
}
/// Test GET on empty backend.
fn test_get_empty<B>(mut backend: B)
where
B: CacheBackend<K = u8, V = String>,
{
assert_eq!(backend.get(&1), None);
}
/// Test GET and SET without any overrides.
fn test_get_set<B>(mut backend: B)
where
B: CacheBackend<K = u8, V = String>,
{
backend.set(1, String::from("a"));
backend.set(2, String::from("b"));
assert_eq!(backend.get(&1), Some(String::from("a")));
assert_eq!(backend.get(&2), Some(String::from("b")));
assert_eq!(backend.get(&3), None);
}
/// Test that a value can be retrieved multiple times.
fn test_get_twice<B>(mut backend: B)
where
B: CacheBackend<K = u8, V = String>,
{
backend.set(1, String::from("a"));
assert_eq!(backend.get(&1), Some(String::from("a")));
assert_eq!(backend.get(&1), Some(String::from("a")));
}
/// Test that setting a value twice w/o deletion overrides the existing value.
fn test_override<B>(mut backend: B)
where
B: CacheBackend<K = u8, V = String>,
{
backend.set(1, String::from("a"));
backend.set(1, String::from("b"));
assert_eq!(backend.get(&1), Some(String::from("b")));
}
/// Test removal of on empty backend.
fn test_remove_empty<B>(mut backend: B)
where
B: CacheBackend<K = u8, V = String>,
{
backend.remove(&1);
}
/// Test removal of existing values.
fn test_set_remove_get<B>(mut backend: B)
where
B: CacheBackend<K = u8, V = String>,
{
backend.set(1, String::from("a"));
backend.remove(&1);
assert_eq!(backend.get(&1), None);
}
/// Test setting a new value after removing it.
fn test_readd<B>(mut backend: B)
where
B: CacheBackend<K = u8, V = String>,
{
backend.set(1, String::from("a"));
backend.remove(&1);
backend.set(1, String::from("b"));
assert_eq!(backend.get(&1), Some(String::from("b")));
}
/// Test `is_empty` check.
fn test_is_empty<B>(mut backend: B)
where
B: CacheBackend<K = u8, V = String>,
{
assert!(backend.is_empty());
backend.set(1, String::from("a"));
backend.set(2, String::from("b"));
assert!(!backend.is_empty());
backend.remove(&1);
assert!(!backend.is_empty());
backend.remove(&2);
assert!(backend.is_empty());
}

442
cache_system/src/cache/driver.rs vendored Normal file
View File

@ -0,0 +1,442 @@
//! Main data structure, see [`CacheDriver`].
use crate::{
backend::CacheBackend,
cancellation_safe_future::{CancellationSafeFuture, CancellationSafeFutureReceiver},
loader::Loader,
};
use async_trait::async_trait;
use futures::{
channel::oneshot::{channel, Canceled, Sender},
future::{BoxFuture, Shared},
FutureExt, TryFutureExt,
};
use observability_deps::tracing::debug;
use parking_lot::Mutex;
use std::{collections::HashMap, fmt::Debug, future::Future, sync::Arc};
use super::{Cache, CacheGetStatus, CachePeekStatus};
/// Combine a [`CacheBackend`] and a [`Loader`] into a single [`Cache`]
#[derive(Debug)]
pub struct CacheDriver<B, L>
where
B: CacheBackend,
L: Loader<K = B::K, V = B::V>,
{
state: Arc<Mutex<CacheState<B>>>,
loader: Arc<L>,
}
impl<B, L> CacheDriver<B, L>
where
B: CacheBackend,
L: Loader<K = B::K, V = B::V>,
{
/// Create new, empty cache with given loader function.
pub fn new(loader: Arc<L>, backend: B) -> Self {
Self {
state: Arc::new(Mutex::new(CacheState {
cached_entries: backend,
running_queries: HashMap::new(),
tag_counter: 0,
})),
loader,
}
}
fn start_new_query(
state: &mut CacheState<B>,
state_captured: Arc<Mutex<CacheState<B>>>,
loader: Arc<L>,
k: B::K,
extra: L::Extra,
) -> (
CancellationSafeFuture<impl Future<Output = ()>>,
SharedReceiver<B::V>,
) {
let (tx_main, rx_main) = channel();
let receiver = rx_main
.map_ok(|v| Arc::new(Mutex::new(v)))
.map_err(Arc::new)
.boxed()
.shared();
let (tx_set, rx_set) = channel();
// generate unique tag
let tag = state.tag_counter;
state.tag_counter += 1;
// need to wrap the query into a `CancellationSafeFuture` so that it doesn't get cancelled when
// this very request is cancelled
let join_handle_receiver = CancellationSafeFutureReceiver::default();
let k_captured = k.clone();
let fut = async move {
let loader_fut = async move {
let submitter = ResultSubmitter::new(state_captured, k_captured.clone(), tag);
// execute the loader
// If we panic here then `tx` will be dropped and the receivers will be
// notified.
let v = loader.load(k_captured, extra).await;
// remove "running" state and store result
let was_running = submitter.submit(v.clone());
if !was_running {
// value was side-loaded, so we cannot populate `v`. Instead block this
// execution branch and wait for `rx_set` to deliver the side-loaded
// result.
loop {
tokio::task::yield_now().await;
}
}
v
};
// prefer the side-loader
let v = futures::select_biased! {
maybe_v = rx_set.fuse() => {
match maybe_v {
Ok(v) => {
// data get side-loaded via `Cache::set`. In this case, we do
// NOT modify the state because there would be a lock-gap. The
// `set` function will do that for us instead.
v
}
Err(_) => {
// sender side is gone, very likely the cache is shutting down
debug!(
"Sender for side-loading data into running query gone.",
);
return;
}
}
}
v = loader_fut.fuse() => v,
};
// broadcast result
// It's OK if the receiver side is gone. This might happen during shutdown
tx_main.send(v).ok();
};
let fut = CancellationSafeFuture::new(fut, join_handle_receiver.clone());
state.running_queries.insert(
k,
RunningQuery {
recv: receiver.clone(),
set: tx_set,
_join_handle: join_handle_receiver,
tag,
},
);
(fut, receiver)
}
}
#[async_trait]
impl<B, L> Cache for CacheDriver<B, L>
where
B: CacheBackend,
L: Loader<K = B::K, V = B::V>,
{
type K = B::K;
type V = B::V;
type GetExtra = L::Extra;
type PeekExtra = ();
async fn get_with_status(
&self,
k: Self::K,
extra: Self::GetExtra,
) -> (Self::V, CacheGetStatus) {
// place state locking into its own scope so it doesn't leak into the generator (async
// function)
let (fut, receiver, status) = {
let mut state = self.state.lock();
// check if the entry has already been cached
if let Some(v) = state.cached_entries.get(&k) {
return (v, CacheGetStatus::Hit);
}
// check if there is already a query for this key running
if let Some(running_query) = state.running_queries.get(&k) {
(
None,
running_query.recv.clone(),
CacheGetStatus::MissAlreadyLoading,
)
} else {
// requires new query
let (fut, receiver) = Self::start_new_query(
&mut state,
Arc::clone(&self.state),
Arc::clone(&self.loader),
k,
extra,
);
(Some(fut), receiver, CacheGetStatus::Miss)
}
};
// try to run the loader future in this very task context to avoid spawning tokio tasks (which adds latency and
// overhead)
if let Some(fut) = fut {
fut.await;
}
let v = retrieve_from_shared(receiver).await;
(v, status)
}
async fn peek_with_status(
&self,
k: Self::K,
_extra: Self::PeekExtra,
) -> Option<(Self::V, CachePeekStatus)> {
// place state locking into its own scope so it doesn't leak into the generator (async
// function)
let (receiver, status) = {
let mut state = self.state.lock();
// check if the entry has already been cached
if let Some(v) = state.cached_entries.get(&k) {
return Some((v, CachePeekStatus::Hit));
}
// check if there is already a query for this key running
if let Some(running_query) = state.running_queries.get(&k) {
(
running_query.recv.clone(),
CachePeekStatus::MissAlreadyLoading,
)
} else {
return None;
}
};
let v = retrieve_from_shared(receiver).await;
Some((v, status))
}
async fn set(&self, k: Self::K, v: Self::V) {
let maybe_join_handle = {
let mut state = self.state.lock();
let maybe_recv = if let Some(running_query) = state.running_queries.remove(&k) {
// it's OK when the receiver side is gone (likely panicked)
running_query.set.send(v.clone()).ok();
// When we side-load data into the running task, the task does NOT modify the
// backend, so we have to do that. The reason for not letting the task feed the
// side-loaded data back into `cached_entries` is that we would need to drop the
// state lock here before the task could acquire it, leading to a lock gap.
Some(running_query.recv)
} else {
None
};
state.cached_entries.set(k, v);
maybe_recv
};
// drive running query (if any) to completion
if let Some(recv) = maybe_join_handle {
// we do not care if the query died (e.g. due to a panic)
recv.await.ok();
}
}
}
impl<B, L> Drop for CacheDriver<B, L>
where
B: CacheBackend,
L: Loader<K = B::K, V = B::V>,
{
fn drop(&mut self) {
for _ in self.state.lock().running_queries.drain() {}
}
}
/// Helper to submit results of running queries.
///
/// Ensures that running query is removed when dropped (e.g. during panic).
struct ResultSubmitter<B>
where
B: CacheBackend,
{
state: Arc<Mutex<CacheState<B>>>,
tag: u64,
k: Option<B::K>,
v: Option<B::V>,
}
impl<B> ResultSubmitter<B>
where
B: CacheBackend,
{
fn new(state: Arc<Mutex<CacheState<B>>>, k: B::K, tag: u64) -> Self {
Self {
state,
tag,
k: Some(k),
v: None,
}
}
/// Submit value.
///
/// Returns `true` if this very query was running.
fn submit(mut self, v: B::V) -> bool {
assert!(self.v.is_none());
self.v = Some(v);
self.finalize()
}
/// Finalize request.
///
/// Returns `true` if this very query was running.
fn finalize(&mut self) -> bool {
let k = self.k.take().expect("finalized twice");
let mut state = self.state.lock();
match state.running_queries.get(&k) {
Some(running_query) if running_query.tag == self.tag => {
state.running_queries.remove(&k);
if let Some(v) = self.v.take() {
// this very query is in charge of the key, so store in in the
// underlying cache
state.cached_entries.set(k, v);
}
true
}
_ => {
// This query is actually not really running any longer but got
// shut down, e.g. due to side loading. Do NOT store the
// generated value in the underlying cache.
false
}
}
}
}
impl<B> Drop for ResultSubmitter<B>
where
B: CacheBackend,
{
fn drop(&mut self) {
if self.k.is_some() {
// not finalized yet
self.finalize();
}
}
}
/// A [`tokio::sync::oneshot::Receiver`] that can be cloned.
///
/// The types are:
///
/// - `Arc<Mutex<V>>`: Ensures that we can clone `V` without requiring `V: Sync`. At the same time
/// the reference to `V` (i.e. the `Arc`) must be cloneable for `Shared`
/// - `Arc<RecvError>`: Is required because `RecvError` is not `Clone` but `Shared` requires that.
/// - `BoxFuture`: The transformation from `Result<V, RecvError>` to `Result<Arc<Mutex<V>>,
/// Arc<RecvError>>` results in a kinda messy type and we wanna erase that.
/// - `Shared`: Allow the receiver to be cloned and be awaited from multiple places.
type SharedReceiver<V> = Shared<BoxFuture<'static, Result<Arc<Mutex<V>>, Arc<Canceled>>>>;
/// Retrieve data from shared receiver.
async fn retrieve_from_shared<V>(receiver: SharedReceiver<V>) -> V
where
V: Clone + Send,
{
receiver
.await
.expect("cache loader panicked, see logs")
.lock()
.clone()
}
/// State for coordinating the execution of a single running query.
#[derive(Debug)]
struct RunningQuery<V> {
/// A receiver that can await the result as well.
recv: SharedReceiver<V>,
/// A sender that enables setting entries while the query is running.
#[allow(dead_code)]
set: Sender<V>,
/// A handle for the task that is currently executing the query.
///
/// The handle can be used to abort the running query, e.g. when dropping the cache.
///
/// This is "dead code" because we only store it to keep the future alive. There's no direct interaction.
_join_handle: CancellationSafeFutureReceiver<()>,
/// Tag so that queries for the same key (e.g. when starting, side-loading, starting again) can
/// be told apart.
tag: u64,
}
/// Inner cache state that is usually guarded by a lock.
///
/// The state parts must be updated in a consistent manner, i.e. while using the same lock guard.
#[derive(Debug)]
struct CacheState<B>
where
B: CacheBackend,
{
/// Cached entires (i.e. queries completed).
cached_entries: B,
/// Currently running queries indexed by cache key.
running_queries: HashMap<B::K, RunningQuery<B::V>>,
/// Tag counter for running queries.
tag_counter: u64,
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
cache::test_util::{run_test_generic, TestAdapter},
loader::test_util::TestLoader,
};
use super::*;
#[tokio::test]
async fn test_generic() {
run_test_generic(MyTestAdapter).await;
}
struct MyTestAdapter;
impl TestAdapter for MyTestAdapter {
type GetExtra = bool;
type PeekExtra = ();
type Cache = CacheDriver<HashMap<u8, String>, TestLoader>;
fn construct(&self, loader: Arc<TestLoader>) -> Arc<Self::Cache> {
Arc::new(CacheDriver::new(Arc::clone(&loader) as _, HashMap::new()))
}
fn get_extra(&self, inner: bool) -> Self::GetExtra {
inner
}
fn peek_extra(&self) -> Self::PeekExtra {}
}
}

713
cache_system/src/cache/metrics.rs vendored Normal file
View File

@ -0,0 +1,713 @@
//! Metrics instrumentation for [`Cache`]s.
use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait;
use iox_time::{Time, TimeProvider};
use metric::{Attributes, DurationHistogram, U64Counter};
use observability_deps::tracing::warn;
use trace::span::{Span, SpanRecorder};
use super::{Cache, CacheGetStatus, CachePeekStatus};
/// Struct containing all the metrics
#[derive(Debug)]
struct Metrics {
time_provider: Arc<dyn TimeProvider>,
metric_get_hit: DurationHistogram,
metric_get_miss: DurationHistogram,
metric_get_miss_already_loading: DurationHistogram,
metric_get_cancelled: DurationHistogram,
metric_peek_hit: DurationHistogram,
metric_peek_miss: DurationHistogram,
metric_peek_miss_already_loading: DurationHistogram,
metric_peek_cancelled: DurationHistogram,
metric_set: U64Counter,
}
impl Metrics {
fn new(
name: &'static str,
time_provider: Arc<dyn TimeProvider>,
metric_registry: &metric::Registry,
) -> Self {
let attributes = Attributes::from(&[("name", name)]);
let mut attributes_get = attributes.clone();
let metric_get = metric_registry
.register_metric::<DurationHistogram>("iox_cache_get", "Cache GET requests");
attributes_get.insert("status", "hit");
let metric_get_hit = metric_get.recorder(attributes_get.clone());
attributes_get.insert("status", "miss");
let metric_get_miss = metric_get.recorder(attributes_get.clone());
attributes_get.insert("status", "miss_already_loading");
let metric_get_miss_already_loading = metric_get.recorder(attributes_get.clone());
attributes_get.insert("status", "cancelled");
let metric_get_cancelled = metric_get.recorder(attributes_get);
let mut attributes_peek = attributes.clone();
let metric_peek = metric_registry
.register_metric::<DurationHistogram>("iox_cache_peek", "Cache PEEK requests");
attributes_peek.insert("status", "hit");
let metric_peek_hit = metric_peek.recorder(attributes_peek.clone());
attributes_peek.insert("status", "miss");
let metric_peek_miss = metric_peek.recorder(attributes_peek.clone());
attributes_peek.insert("status", "miss_already_loading");
let metric_peek_miss_already_loading = metric_peek.recorder(attributes_peek.clone());
attributes_peek.insert("status", "cancelled");
let metric_peek_cancelled = metric_peek.recorder(attributes_peek);
let metric_set = metric_registry
.register_metric::<U64Counter>("iox_cache_set", "Cache SET requests.")
.recorder(attributes);
Self {
time_provider,
metric_get_hit,
metric_get_miss,
metric_get_miss_already_loading,
metric_get_cancelled,
metric_peek_hit,
metric_peek_miss,
metric_peek_miss_already_loading,
metric_peek_cancelled,
metric_set,
}
}
}
/// Wraps given cache with metrics.
#[derive(Debug)]
pub struct CacheWithMetrics<C>
where
C: Cache,
{
inner: C,
metrics: Metrics,
}
impl<C> CacheWithMetrics<C>
where
C: Cache,
{
/// Create new metrics wrapper around given cache.
pub fn new(
inner: C,
name: &'static str,
time_provider: Arc<dyn TimeProvider>,
metric_registry: &metric::Registry,
) -> Self {
Self {
inner,
metrics: Metrics::new(name, time_provider, metric_registry),
}
}
}
#[async_trait]
impl<C> Cache for CacheWithMetrics<C>
where
C: Cache,
{
type K = C::K;
type V = C::V;
type GetExtra = (C::GetExtra, Option<Span>);
type PeekExtra = (C::PeekExtra, Option<Span>);
async fn get_with_status(
&self,
k: Self::K,
extra: Self::GetExtra,
) -> (Self::V, CacheGetStatus) {
let (extra, span) = extra;
let mut set_on_drop = SetGetMetricOnDrop::new(&self.metrics, span);
let (v, status) = self.inner.get_with_status(k, extra).await;
set_on_drop.status = Some(status);
(v, status)
}
async fn peek_with_status(
&self,
k: Self::K,
extra: Self::PeekExtra,
) -> Option<(Self::V, CachePeekStatus)> {
let (extra, span) = extra;
let mut set_on_drop = SetPeekMetricOnDrop::new(&self.metrics, span);
let res = self.inner.peek_with_status(k, extra).await;
set_on_drop.status = Some(res.as_ref().map(|(_v, status)| *status));
res
}
async fn set(&self, k: Self::K, v: Self::V) {
self.inner.set(k, v).await;
self.metrics.metric_set.inc(1);
}
}
/// Helper that set's GET metrics on drop depending on the `status`.
///
/// A drop might happen due to completion (in which case the `status` should be set) or if the future is cancelled (in
/// which case the `status` is `None`).
struct SetGetMetricOnDrop<'a> {
metrics: &'a Metrics,
t_start: Time,
status: Option<CacheGetStatus>,
span_recorder: SpanRecorder,
}
impl<'a> SetGetMetricOnDrop<'a> {
fn new(metrics: &'a Metrics, span: Option<Span>) -> Self {
let t_start = metrics.time_provider.now();
Self {
metrics,
t_start,
status: None,
span_recorder: SpanRecorder::new(span),
}
}
}
impl<'a> Drop for SetGetMetricOnDrop<'a> {
fn drop(&mut self) {
let t_end = self.metrics.time_provider.now();
match t_end.checked_duration_since(self.t_start) {
Some(duration) => {
match self.status {
Some(CacheGetStatus::Hit) => &self.metrics.metric_get_hit,
Some(CacheGetStatus::Miss) => &self.metrics.metric_get_miss,
Some(CacheGetStatus::MissAlreadyLoading) => {
&self.metrics.metric_get_miss_already_loading
}
None => &self.metrics.metric_get_cancelled,
}
.record(duration);
}
None => {
warn!("Clock went backwards, not recording cache GET duration");
}
}
if let Some(status) = self.status {
self.span_recorder.ok(status.name());
}
}
}
/// Helper that set's PEEK metrics on drop depending on the `status`.
///
/// A drop might happen due to completion (in which case the `status` should be set) or if the future is cancelled (in
/// which case the `status` is `None`).
struct SetPeekMetricOnDrop<'a> {
metrics: &'a Metrics,
t_start: Time,
status: Option<Option<CachePeekStatus>>,
span_recorder: SpanRecorder,
}
impl<'a> SetPeekMetricOnDrop<'a> {
fn new(metrics: &'a Metrics, span: Option<Span>) -> Self {
let t_start = metrics.time_provider.now();
Self {
metrics,
t_start,
status: None,
span_recorder: SpanRecorder::new(span),
}
}
}
impl<'a> Drop for SetPeekMetricOnDrop<'a> {
fn drop(&mut self) {
let t_end = self.metrics.time_provider.now();
match t_end.checked_duration_since(self.t_start) {
Some(duration) => {
match self.status {
Some(Some(CachePeekStatus::Hit)) => &self.metrics.metric_peek_hit,
Some(Some(CachePeekStatus::MissAlreadyLoading)) => {
&self.metrics.metric_peek_miss_already_loading
}
Some(None) => &self.metrics.metric_peek_miss,
None => &self.metrics.metric_peek_cancelled,
}
.record(duration);
}
None => {
warn!("Clock went backwards, not recording cache PEEK duration");
}
}
if let Some(status) = self.status {
self.span_recorder
.ok(status.map(|status| status.name()).unwrap_or("miss"));
}
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, time::Duration};
use futures::{stream::FuturesUnordered, StreamExt};
use iox_time::{MockProvider, Time};
use metric::{HistogramObservation, Observation, RawReporter};
use tokio::sync::Barrier;
use trace::{span::SpanStatus, RingBufferTraceCollector};
use crate::{
cache::{
driver::CacheDriver,
test_util::{run_test_generic, TestAdapter},
},
loader::test_util::TestLoader,
test_util::{AbortAndWaitExt, EnsurePendingExt},
};
use super::*;
#[tokio::test]
async fn test_generic() {
run_test_generic(MyTestAdapter).await;
}
struct MyTestAdapter;
impl TestAdapter for MyTestAdapter {
type GetExtra = (bool, Option<Span>);
type PeekExtra = ((), Option<Span>);
type Cache = CacheWithMetrics<CacheDriver<HashMap<u8, String>, TestLoader>>;
fn construct(&self, loader: Arc<TestLoader>) -> Arc<Self::Cache> {
TestMetricsCache::new_with_loader(loader).cache
}
fn get_extra(&self, inner: bool) -> Self::GetExtra {
(inner, None)
}
fn peek_extra(&self) -> Self::PeekExtra {
((), None)
}
}
#[tokio::test]
async fn test_get() {
let test_cache = TestMetricsCache::new();
let traces = Arc::new(RingBufferTraceCollector::new(1_000));
let mut reporter = RawReporter::default();
test_cache.metric_registry.report(&mut reporter);
for status in ["hit", "miss", "miss_already_loading", "cancelled"] {
let hist = get_metric_cache_get(&reporter, status);
assert_eq!(hist.sample_count(), 0);
assert_eq!(hist.total, Duration::from_secs(0));
}
test_cache.loader.block_global();
let barrier_pending_1 = Arc::new(Barrier::new(2));
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
let traces_captured = Arc::clone(&traces);
let cache_captured = Arc::clone(&test_cache.cache);
let join_handle_1 = tokio::task::spawn(async move {
cache_captured
.get(
1,
(
true,
Some(Span::root("miss", Arc::clone(&traces_captured) as _)),
),
)
.ensure_pending(barrier_pending_1_captured)
.await
});
barrier_pending_1.wait().await;
let d1 = Duration::from_secs(1);
test_cache.time_provider.inc(d1);
let barrier_pending_2 = Arc::new(Barrier::new(2));
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
let traces_captured = Arc::clone(&traces);
let cache_captured = Arc::clone(&test_cache.cache);
let n_miss_already_loading = 10;
let join_handle_2 = tokio::task::spawn(async move {
(0..n_miss_already_loading)
.map(|_| {
cache_captured.get(
1,
(
true,
Some(Span::root(
"miss_already_loading",
Arc::clone(&traces_captured) as _,
)),
),
)
})
.collect::<FuturesUnordered<_>>()
.collect::<Vec<_>>()
.ensure_pending(barrier_pending_2_captured)
.await
});
barrier_pending_2.wait().await;
let d2 = Duration::from_secs(3);
test_cache.time_provider.inc(d2);
test_cache.loader.mock_next(1, "v".into());
test_cache.loader.unblock_global();
join_handle_1.await.unwrap();
join_handle_2.await.unwrap();
test_cache.loader.block_global();
test_cache.time_provider.inc(Duration::from_secs(10));
let n_hit = 100;
for _ in 0..n_hit {
test_cache
.cache
.get(1, (true, Some(Span::root("hit", Arc::clone(&traces) as _))))
.await;
}
let n_cancelled = 200;
let barrier_pending_3 = Arc::new(Barrier::new(2));
let barrier_pending_3_captured = Arc::clone(&barrier_pending_3);
let traces_captured = Arc::clone(&traces);
let cache_captured = Arc::clone(&test_cache.cache);
let join_handle_3 = tokio::task::spawn(async move {
(0..n_cancelled)
.map(|_| {
cache_captured.get(
2,
(
true,
Some(Span::root("cancelled", Arc::clone(&traces_captured) as _)),
),
)
})
.collect::<FuturesUnordered<_>>()
.collect::<Vec<_>>()
.ensure_pending(barrier_pending_3_captured)
.await
});
barrier_pending_3.wait().await;
let d3 = Duration::from_secs(20);
test_cache.time_provider.inc(d3);
join_handle_3.abort_and_wait().await;
let mut reporter = RawReporter::default();
test_cache.metric_registry.report(&mut reporter);
let hist = get_metric_cache_get(&reporter, "hit");
assert_eq!(hist.sample_count(), n_hit);
// "hit"s are instant because there's no lock contention
assert_eq!(hist.total, Duration::from_secs(0));
let hist = get_metric_cache_get(&reporter, "miss");
let n = 1;
assert_eq!(hist.sample_count(), n);
assert_eq!(hist.total, (n as u32) * (d1 + d2));
let hist = get_metric_cache_get(&reporter, "miss_already_loading");
assert_eq!(hist.sample_count(), n_miss_already_loading);
assert_eq!(hist.total, (n_miss_already_loading as u32) * d2);
let hist = get_metric_cache_get(&reporter, "cancelled");
assert_eq!(hist.sample_count(), n_cancelled);
assert_eq!(hist.total, (n_cancelled as u32) * d3);
// check spans
assert_n_spans(&traces, "hit", SpanStatus::Ok, n_hit as usize);
assert_n_spans(&traces, "miss", SpanStatus::Ok, 1);
assert_n_spans(
&traces,
"miss_already_loading",
SpanStatus::Ok,
n_miss_already_loading as usize,
);
assert_n_spans(
&traces,
"cancelled",
SpanStatus::Unknown,
n_cancelled as usize,
);
}
#[tokio::test]
async fn test_peek() {
let test_cache = TestMetricsCache::new();
let traces = Arc::new(RingBufferTraceCollector::new(1_000));
let mut reporter = RawReporter::default();
test_cache.metric_registry.report(&mut reporter);
for status in ["hit", "miss", "miss_already_loading", "cancelled"] {
let hist = get_metric_cache_peek(&reporter, status);
assert_eq!(hist.sample_count(), 0);
assert_eq!(hist.total, Duration::from_secs(0));
}
test_cache.loader.block_global();
test_cache
.cache
.peek(1, ((), Some(Span::root("miss", Arc::clone(&traces) as _))))
.await;
let barrier_pending_1 = Arc::new(Barrier::new(2));
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
let cache_captured = Arc::clone(&test_cache.cache);
let join_handle_1 = tokio::task::spawn(async move {
cache_captured
.get(1, (true, None))
.ensure_pending(barrier_pending_1_captured)
.await
});
barrier_pending_1.wait().await;
let d1 = Duration::from_secs(1);
test_cache.time_provider.inc(d1);
let barrier_pending_2 = Arc::new(Barrier::new(2));
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
let traces_captured = Arc::clone(&traces);
let cache_captured = Arc::clone(&test_cache.cache);
let n_miss_already_loading = 10;
let join_handle_2 = tokio::task::spawn(async move {
(0..n_miss_already_loading)
.map(|_| {
cache_captured.peek(
1,
(
(),
Some(Span::root(
"miss_already_loading",
Arc::clone(&traces_captured) as _,
)),
),
)
})
.collect::<FuturesUnordered<_>>()
.collect::<Vec<_>>()
.ensure_pending(barrier_pending_2_captured)
.await
});
barrier_pending_2.wait().await;
let d2 = Duration::from_secs(3);
test_cache.time_provider.inc(d2);
test_cache.loader.mock_next(1, "v".into());
test_cache.loader.unblock_global();
join_handle_1.await.unwrap();
join_handle_2.await.unwrap();
test_cache.loader.block_global();
test_cache.time_provider.inc(Duration::from_secs(10));
let n_hit = 100;
for _ in 0..n_hit {
test_cache
.cache
.peek(1, ((), Some(Span::root("hit", Arc::clone(&traces) as _))))
.await;
}
let n_cancelled = 200;
let barrier_pending_3 = Arc::new(Barrier::new(2));
let barrier_pending_3_captured = Arc::clone(&barrier_pending_3);
let cache_captured = Arc::clone(&test_cache.cache);
tokio::task::spawn(async move {
cache_captured
.get(2, (true, None))
.ensure_pending(barrier_pending_3_captured)
.await
});
barrier_pending_3.wait().await;
let barrier_pending_4 = Arc::new(Barrier::new(2));
let barrier_pending_4_captured = Arc::clone(&barrier_pending_4);
let traces_captured = Arc::clone(&traces);
let cache_captured = Arc::clone(&test_cache.cache);
let join_handle_3 = tokio::task::spawn(async move {
(0..n_cancelled)
.map(|_| {
cache_captured.peek(
2,
(
(),
Some(Span::root("cancelled", Arc::clone(&traces_captured) as _)),
),
)
})
.collect::<FuturesUnordered<_>>()
.collect::<Vec<_>>()
.ensure_pending(barrier_pending_4_captured)
.await
});
barrier_pending_4.wait().await;
let d3 = Duration::from_secs(20);
test_cache.time_provider.inc(d3);
join_handle_3.abort_and_wait().await;
let mut reporter = RawReporter::default();
test_cache.metric_registry.report(&mut reporter);
let hist = get_metric_cache_peek(&reporter, "hit");
assert_eq!(hist.sample_count(), n_hit);
// "hit"s are instant because there's no lock contention
assert_eq!(hist.total, Duration::from_secs(0));
let hist = get_metric_cache_peek(&reporter, "miss");
let n = 1;
assert_eq!(hist.sample_count(), n);
// "miss"es are instant
assert_eq!(hist.total, Duration::from_secs(0));
let hist = get_metric_cache_peek(&reporter, "miss_already_loading");
assert_eq!(hist.sample_count(), n_miss_already_loading);
assert_eq!(hist.total, (n_miss_already_loading as u32) * d2);
let hist = get_metric_cache_peek(&reporter, "cancelled");
assert_eq!(hist.sample_count(), n_cancelled);
assert_eq!(hist.total, (n_cancelled as u32) * d3);
// check spans
assert_n_spans(&traces, "hit", SpanStatus::Ok, n_hit as usize);
assert_n_spans(&traces, "miss", SpanStatus::Ok, 1);
assert_n_spans(
&traces,
"miss_already_loading",
SpanStatus::Ok,
n_miss_already_loading as usize,
);
assert_n_spans(
&traces,
"cancelled",
SpanStatus::Unknown,
n_cancelled as usize,
);
}
#[tokio::test]
async fn test_set() {
let test_cache = TestMetricsCache::new();
let mut reporter = RawReporter::default();
test_cache.metric_registry.report(&mut reporter);
assert_eq!(
reporter
.metric("iox_cache_set")
.unwrap()
.observation(&[("name", "test")])
.unwrap(),
&Observation::U64Counter(0)
);
test_cache.cache.set(1, String::from("foo")).await;
let mut reporter = RawReporter::default();
test_cache.metric_registry.report(&mut reporter);
assert_eq!(
reporter
.metric("iox_cache_set")
.unwrap()
.observation(&[("name", "test")])
.unwrap(),
&Observation::U64Counter(1)
);
}
struct TestMetricsCache {
loader: Arc<TestLoader>,
time_provider: Arc<MockProvider>,
metric_registry: metric::Registry,
cache: Arc<CacheWithMetrics<CacheDriver<HashMap<u8, String>, TestLoader>>>,
}
impl TestMetricsCache {
fn new() -> Self {
Self::new_with_loader(Arc::new(TestLoader::default()))
}
fn new_with_loader(loader: Arc<TestLoader>) -> Self {
let inner = CacheDriver::new(Arc::clone(&loader) as _, HashMap::new());
let time_provider =
Arc::new(MockProvider::new(Time::from_timestamp_millis(0).unwrap()));
let metric_registry = metric::Registry::new();
let cache = Arc::new(CacheWithMetrics::new(
inner,
"test",
Arc::clone(&time_provider) as _,
&metric_registry,
));
Self {
loader,
time_provider,
metric_registry,
cache,
}
}
}
fn get_metric_cache_get(
reporter: &RawReporter,
status: &'static str,
) -> HistogramObservation<Duration> {
if let Observation::DurationHistogram(hist) = reporter
.metric("iox_cache_get")
.unwrap()
.observation(&[("name", "test"), ("status", status)])
.unwrap()
{
hist.clone()
} else {
panic!("Wrong observation type");
}
}
fn get_metric_cache_peek(
reporter: &RawReporter,
status: &'static str,
) -> HistogramObservation<Duration> {
if let Observation::DurationHistogram(hist) = reporter
.metric("iox_cache_peek")
.unwrap()
.observation(&[("name", "test"), ("status", status)])
.unwrap()
{
hist.clone()
} else {
panic!("Wrong observation type");
}
}
fn assert_n_spans(
traces: &RingBufferTraceCollector,
name: &'static str,
status: SpanStatus,
expected: usize,
) {
let actual = traces
.spans()
.into_iter()
.filter(|span| (span.name == name) && (span.status == status))
.count();
assert_eq!(actual, expected);
}
}

167
cache_system/src/cache/mod.rs vendored Normal file
View File

@ -0,0 +1,167 @@
//! Top-level trait ([`Cache`]) that provides a fully functional cache.
//!
//! Caches usually combine a [backend](crate::backend) with a [loader](crate::loader). The easiest way to achieve that
//! is to use [`CacheDriver`](crate::cache::driver::CacheDriver). Caches might also wrap inner caches to provide certain
//! extra functionality like metrics.
use std::{fmt::Debug, hash::Hash};
use async_trait::async_trait;
pub mod driver;
pub mod metrics;
#[cfg(test)]
mod test_util;
/// Status of a [`Cache`] [GET](Cache::get_with_status) request.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheGetStatus {
/// The requested entry was present in the storage backend.
Hit,
/// The requested entry was NOT present in the storage backend and the loader had no previous query running.
Miss,
/// The requested entry was NOT present in the storage backend, but there was already a loader query running for
/// this particular key.
MissAlreadyLoading,
}
impl CacheGetStatus {
/// Get human and machine readable name.
pub fn name(&self) -> &'static str {
match self {
Self::Hit => "hit",
Self::Miss => "miss",
Self::MissAlreadyLoading => "miss_already_loading",
}
}
}
/// Status of a [`Cache`] [PEEK](Cache::peek_with_status) request.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CachePeekStatus {
/// The requested entry was present in the storage backend.
Hit,
/// The requested entry was NOT present in the storage backend, but there was already a loader query running for
/// this particular key.
MissAlreadyLoading,
}
impl CachePeekStatus {
/// Get human and machine redable name.
pub fn name(&self) -> &'static str {
match self {
Self::Hit => "hit",
Self::MissAlreadyLoading => "miss_already_loading",
}
}
}
/// High-level cache implementation.
///
/// # Concurrency
///
/// Multiple cache requests for different keys can run at the same time. When data is requested for
/// the same key the underlying loader will only be polled once, even when the requests are made
/// while the loader is still running.
///
/// # Cancellation
///
/// Canceling a [`get`](Self::get) request will NOT cancel the underlying loader. The data will
/// still be cached.
///
/// # Panic
///
/// If the underlying loader panics, all currently running [`get`](Self::get) requests will panic.
/// The data will NOT be cached.
#[async_trait]
pub trait Cache: Debug + Send + Sync + 'static {
/// Cache key.
type K: Clone + Eq + Hash + Debug + Ord + Send + 'static;
/// Cache value.
type V: Clone + Debug + Send + 'static;
/// Extra data that is provided during [`GET`](Self::get) but that is NOT part of the cache key.
type GetExtra: Debug + Send + 'static;
/// Extra data that is provided during [`PEEK`](Self::peek) but that is NOT part of the cache key.
type PeekExtra: Debug + Send + 'static;
/// Get value from cache.
///
/// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet.
async fn get(&self, k: Self::K, extra: Self::GetExtra) -> Self::V {
self.get_with_status(k, extra).await.0
}
/// Get value from cache and the [status](CacheGetStatus).
///
/// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet.
async fn get_with_status(&self, k: Self::K, extra: Self::GetExtra)
-> (Self::V, CacheGetStatus);
/// Peek value from cache.
///
/// In contrast to [`get`](Self::get) this will only return a value if there is a stored value or the value loading
/// is already in progress. This will NOT start a new loading task.
///
/// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet.
async fn peek(&self, k: Self::K, extra: Self::PeekExtra) -> Option<Self::V> {
self.peek_with_status(k, extra).await.map(|(v, _status)| v)
}
/// Peek value from cache and the [status](CachePeekStatus).
///
/// In contrast to [`get_with_status`](Self::get_with_status) this will only return a value if there is a stored
/// value or the value loading is already in progress. This will NOT start a new loading task.
///
/// Note that `extra` is only used if the key is missing from the storage backend and no loader query is running yet.
async fn peek_with_status(
&self,
k: Self::K,
extra: Self::PeekExtra,
) -> Option<(Self::V, CachePeekStatus)>;
/// Side-load an entry into the cache.
///
/// This will also complete a currently running request for this key.
async fn set(&self, k: Self::K, v: Self::V);
}
#[async_trait]
impl<K, V, GetExtra, PeekExtra> Cache
for Box<dyn Cache<K = K, V = V, GetExtra = GetExtra, PeekExtra = PeekExtra>>
where
K: Clone + Eq + Hash + Debug + Ord + Send + 'static,
V: Clone + Debug + Send + 'static,
GetExtra: Debug + Send + 'static,
PeekExtra: Debug + Send + 'static,
{
type K = K;
type V = V;
type GetExtra = GetExtra;
type PeekExtra = PeekExtra;
async fn get_with_status(
&self,
k: Self::K,
extra: Self::GetExtra,
) -> (Self::V, CacheGetStatus) {
self.as_ref().get_with_status(k, extra).await
}
async fn peek_with_status(
&self,
k: Self::K,
extra: Self::PeekExtra,
) -> Option<(Self::V, CachePeekStatus)> {
self.as_ref().peek_with_status(k, extra).await
}
async fn set(&self, k: Self::K, v: Self::V) {
self.as_ref().set(k, v).await
}
}

462
cache_system/src/cache/test_util.rs vendored Normal file
View File

@ -0,0 +1,462 @@
use std::{sync::Arc, time::Duration};
use tokio::sync::Barrier;
use crate::{
cache::{CacheGetStatus, CachePeekStatus},
loader::test_util::TestLoader,
test_util::{AbortAndWaitExt, EnsurePendingExt},
};
use super::Cache;
/// Interface between generic tests and a concrete cache type.
pub trait TestAdapter: Send + Sync + 'static {
/// Extra information for GET.
type GetExtra: Send;
/// Extra information for PEEK.
type PeekExtra: Send;
/// Cache type.
type Cache: Cache<K = u8, V = String, GetExtra = Self::GetExtra, PeekExtra = Self::PeekExtra>;
/// Create new cache with given loader.
fn construct(&self, loader: Arc<TestLoader>) -> Arc<Self::Cache>;
/// Build [`GetExtra`](Self::GetExtra).
///
/// Must contain a [`bool`] payload that is later included into the value string for testing purposes.
fn get_extra(&self, inner: bool) -> Self::GetExtra;
/// Build [`PeekExtra`](Self::PeekExtra).
fn peek_extra(&self) -> Self::PeekExtra;
}
/// Setup test.
fn setup<T>(adapter: &T) -> (Arc<T::Cache>, Arc<TestLoader>)
where
T: TestAdapter,
{
let loader = Arc::new(TestLoader::default());
let cache = adapter.construct(Arc::clone(&loader));
(cache, loader)
}
pub async fn run_test_generic<T>(adapter: T)
where
T: TestAdapter,
{
let adapter = Arc::new(adapter);
test_answers_are_correct(Arc::clone(&adapter)).await;
test_linear_memory(Arc::clone(&adapter)).await;
test_concurrent_query_loads_once(Arc::clone(&adapter)).await;
test_queries_are_parallelized(Arc::clone(&adapter)).await;
test_cancel_request(Arc::clone(&adapter)).await;
test_panic_request(Arc::clone(&adapter)).await;
test_drop_cancels_loader(Arc::clone(&adapter)).await;
test_set_before_request(Arc::clone(&adapter)).await;
test_set_during_request(Arc::clone(&adapter)).await;
}
async fn test_answers_are_correct<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.mock_next(1, "res_1".to_owned());
loader.mock_next(2, "res_2".to_owned());
assert_eq!(
cache.get(1, adapter.get_extra(true)).await,
String::from("res_1")
);
assert_eq!(
cache.peek(1, adapter.peek_extra()).await,
Some(String::from("res_1"))
);
assert_eq!(
cache.get(2, adapter.get_extra(false)).await,
String::from("res_2")
);
assert_eq!(
cache.peek(2, adapter.peek_extra()).await,
Some(String::from("res_2"))
);
}
async fn test_linear_memory<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.mock_next(1, "res_1".to_owned());
loader.mock_next(2, "res_2".to_owned());
assert_eq!(cache.peek_with_status(1, adapter.peek_extra()).await, None,);
assert_eq!(
cache.get_with_status(1, adapter.get_extra(true)).await,
(String::from("res_1"), CacheGetStatus::Miss),
);
assert_eq!(
cache.get_with_status(1, adapter.get_extra(false)).await,
(String::from("res_1"), CacheGetStatus::Hit),
);
assert_eq!(
cache.peek_with_status(1, adapter.peek_extra()).await,
Some((String::from("res_1"), CachePeekStatus::Hit)),
);
assert_eq!(
cache.get_with_status(2, adapter.get_extra(false)).await,
(String::from("res_2"), CacheGetStatus::Miss),
);
assert_eq!(
cache.get_with_status(2, adapter.get_extra(false)).await,
(String::from("res_2"), CacheGetStatus::Hit),
);
assert_eq!(
cache.get_with_status(1, adapter.get_extra(true)).await,
(String::from("res_1"), CacheGetStatus::Hit),
);
assert_eq!(
cache.peek_with_status(1, adapter.peek_extra()).await,
Some((String::from("res_1"), CachePeekStatus::Hit)),
);
assert_eq!(loader.loaded(), vec![(1, true), (2, false)]);
}
async fn test_concurrent_query_loads_once<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.block_global();
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let barrier_pending_1 = Arc::new(Barrier::new(2));
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
let handle_1 = tokio::spawn(async move {
cache_captured
.get_with_status(1, adapter_captured.get_extra(true))
.ensure_pending(barrier_pending_1_captured)
.await
});
barrier_pending_1.wait().await;
let barrier_pending_2 = Arc::new(Barrier::new(3));
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
let handle_2 = tokio::spawn(async move {
// use a different `extra` here to proof that the first one was used
cache_captured
.get_with_status(1, adapter_captured.get_extra(false))
.ensure_pending(barrier_pending_2_captured)
.await
});
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
let handle_3 = tokio::spawn(async move {
// use a different `extra` here to proof that the first one was used
cache
.peek_with_status(1, adapter.peek_extra())
.ensure_pending(barrier_pending_2_captured)
.await
});
barrier_pending_2.wait().await;
loader.mock_next(1, "res_1".to_owned());
// Shouldn't issue concurrent load requests for the same key
let n_blocked = loader.unblock_global();
assert_eq!(n_blocked, 1);
assert_eq!(
handle_1.await.unwrap(),
(String::from("res_1"), CacheGetStatus::Miss),
);
assert_eq!(
handle_2.await.unwrap(),
(String::from("res_1"), CacheGetStatus::MissAlreadyLoading),
);
assert_eq!(
handle_3.await.unwrap(),
Some((String::from("res_1"), CachePeekStatus::MissAlreadyLoading)),
);
assert_eq!(loader.loaded(), vec![(1, true)]);
}
async fn test_queries_are_parallelized<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.block_global();
let barrier = Arc::new(Barrier::new(4));
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let barrier_captured = Arc::clone(&barrier);
let handle_1 = tokio::spawn(async move {
cache_captured
.get(1, adapter_captured.get_extra(true))
.ensure_pending(barrier_captured)
.await
});
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let barrier_captured = Arc::clone(&barrier);
let handle_2 = tokio::spawn(async move {
cache_captured
.get(1, adapter_captured.get_extra(true))
.ensure_pending(barrier_captured)
.await
});
let barrier_captured = Arc::clone(&barrier);
let handle_3 = tokio::spawn(async move {
cache
.get(2, adapter.get_extra(false))
.ensure_pending(barrier_captured)
.await
});
barrier.wait().await;
loader.mock_next(1, "res_1".to_owned());
loader.mock_next(2, "res_2".to_owned());
let n_blocked = loader.unblock_global();
assert_eq!(n_blocked, 2);
assert_eq!(handle_1.await.unwrap(), String::from("res_1"));
assert_eq!(handle_2.await.unwrap(), String::from("res_1"));
assert_eq!(handle_3.await.unwrap(), String::from("res_2"));
assert_eq!(loader.loaded(), vec![(1, true), (2, false)]);
}
async fn test_cancel_request<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.block_global();
let barrier_pending_1 = Arc::new(Barrier::new(2));
let barrier_pending_1_captured = Arc::clone(&barrier_pending_1);
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let handle_1 = tokio::spawn(async move {
cache_captured
.get(1, adapter_captured.get_extra(true))
.ensure_pending(barrier_pending_1_captured)
.await
});
barrier_pending_1.wait().await;
let barrier_pending_2 = Arc::new(Barrier::new(2));
let barrier_pending_2_captured = Arc::clone(&barrier_pending_2);
let handle_2 = tokio::spawn(async move {
cache
.get(1, adapter.get_extra(false))
.ensure_pending(barrier_pending_2_captured)
.await
});
barrier_pending_2.wait().await;
// abort first handle
handle_1.abort_and_wait().await;
loader.mock_next(1, "res_1".to_owned());
let n_blocked = loader.unblock_global();
assert_eq!(n_blocked, 1);
assert_eq!(handle_2.await.unwrap(), String::from("res_1"));
assert_eq!(loader.loaded(), vec![(1, true)]);
}
async fn test_panic_request<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.block_global();
// set up initial panicking request
let barrier_pending_get_panic = Arc::new(Barrier::new(2));
let barrier_pending_get_panic_captured = Arc::clone(&barrier_pending_get_panic);
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let handle_get_panic = tokio::spawn(async move {
cache_captured
.get(1, adapter_captured.get_extra(true))
.ensure_pending(barrier_pending_get_panic_captured)
.await
});
barrier_pending_get_panic.wait().await;
// set up other requests
let barrier_pending_others = Arc::new(Barrier::new(4));
let barrier_pending_others_captured = Arc::clone(&barrier_pending_others);
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let handle_get_while_loading_panic = tokio::spawn(async move {
cache_captured
.get(1, adapter_captured.get_extra(false))
.ensure_pending(barrier_pending_others_captured)
.await
});
let barrier_pending_others_captured = Arc::clone(&barrier_pending_others);
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let handle_peek_while_loading_panic = tokio::spawn(async move {
cache_captured
.peek(1, adapter_captured.peek_extra())
.ensure_pending(barrier_pending_others_captured)
.await
});
let barrier_pending_others_captured = Arc::clone(&barrier_pending_others);
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let handle_get_other_key = tokio::spawn(async move {
cache_captured
.get(2, adapter_captured.get_extra(false))
.ensure_pending(barrier_pending_others_captured)
.await
});
barrier_pending_others.wait().await;
loader.panic_next(1);
loader.mock_next(1, "res_1".to_owned());
loader.mock_next(2, "res_2".to_owned());
let n_blocked = loader.unblock_global();
assert_eq!(n_blocked, 2);
// panic of initial request
handle_get_panic.await.unwrap_err();
// requests that use the same loading status also panic
handle_get_while_loading_panic.await.unwrap_err();
handle_peek_while_loading_panic.await.unwrap_err();
// unrelated request should succeed
assert_eq!(handle_get_other_key.await.unwrap(), String::from("res_2"));
// failing key was tried exactly once (and the other unrelated key as well)
assert_eq!(loader.loaded(), vec![(1, true), (2, false)]);
// loading after panic just works (no poisoning)
assert_eq!(
cache.get(1, adapter.get_extra(false)).await,
String::from("res_1")
);
assert_eq!(loader.loaded(), vec![(1, true), (2, false), (1, false)]);
}
async fn test_drop_cancels_loader<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.block_global();
let barrier_pending = Arc::new(Barrier::new(2));
let barrier_pending_captured = Arc::clone(&barrier_pending);
let handle = tokio::spawn(async move {
cache
.get(1, adapter.get_extra(true))
.ensure_pending(barrier_pending_captured)
.await
});
barrier_pending.wait().await;
handle.abort_and_wait().await;
assert_eq!(Arc::strong_count(&loader), 1);
}
async fn test_set_before_request<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.block_global();
cache.set(1, String::from("foo")).await;
// blocked loader is not used
let res = tokio::time::timeout(
Duration::from_millis(10),
cache.get(1, adapter.get_extra(false)),
)
.await
.unwrap();
assert_eq!(res, String::from("foo"));
assert_eq!(loader.loaded(), Vec::<(u8, bool)>::new());
}
async fn test_set_during_request<T>(adapter: Arc<T>)
where
T: TestAdapter,
{
let (cache, loader) = setup(adapter.as_ref());
loader.block_global();
let adapter_captured = Arc::clone(&adapter);
let cache_captured = Arc::clone(&cache);
let barrier_pending = Arc::new(Barrier::new(2));
let barrier_pending_captured = Arc::clone(&barrier_pending);
let handle = tokio::spawn(async move {
cache_captured
.get(1, adapter_captured.get_extra(true))
.ensure_pending(barrier_pending_captured)
.await
});
barrier_pending.wait().await;
cache.set(1, String::from("foo")).await;
// request succeeds even though the loader is blocked
let res = tokio::time::timeout(Duration::from_millis(10), handle)
.await
.unwrap()
.unwrap();
assert_eq!(res, String::from("foo"));
assert_eq!(loader.loaded(), vec![(1, true)]);
// still cached
let res = tokio::time::timeout(
Duration::from_millis(10),
cache.get(1, adapter.get_extra(false)),
)
.await
.unwrap();
assert_eq!(res, String::from("foo"));
assert_eq!(loader.loaded(), vec![(1, true)]);
}

View File

@ -0,0 +1,184 @@
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::future::BoxFuture;
use parking_lot::Mutex;
use tokio::task::JoinHandle;
/// Receiver for [`CancellationSafeFuture`] join handles if the future was rescued from cancellation.
///
/// `T` is the [output type](Future::Output) of the wrapped future.
#[derive(Debug, Default, Clone)]
pub struct CancellationSafeFutureReceiver<T> {
inner: Arc<ReceiverInner<T>>,
}
#[derive(Debug, Default)]
struct ReceiverInner<T> {
slot: Mutex<Option<JoinHandle<T>>>,
}
impl<T> Drop for ReceiverInner<T> {
fn drop(&mut self) {
let handle = self.slot.lock();
if let Some(handle) = handle.as_ref() {
handle.abort();
}
}
}
/// Wrapper around a future that cannot be cancelled.
///
/// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_ it.
pub struct CancellationSafeFuture<F>
where
F: Future + Send + 'static,
F::Output: Send,
{
/// Mark if the inner future finished. If not, we must spawn a helper task on drop.
done: bool,
/// Inner future.
///
/// Wrapped in an `Option` so we can extract it during drop. Inside that option however we also need a pinned
/// box because once this wrapper is polled, it will be pinned in memory -- even during drop. Now the inner
/// future does not necessarily implement `Unpin`, so we need a heap allocation to pin it in memory even when we
/// move it out of this option.
inner: Option<BoxFuture<'static, F::Output>>,
/// Where to store the join handle on drop.
receiver: CancellationSafeFutureReceiver<F::Output>,
}
impl<F> Drop for CancellationSafeFuture<F>
where
F: Future + Send + 'static,
F::Output: Send,
{
fn drop(&mut self) {
if !self.done {
// acquire lock BEFORE checking the Arc
let mut receiver = self.receiver.inner.slot.lock();
assert!(receiver.is_none());
// The Mutex is owned by the Arc and cannot be moved out of it. So after we acquired the lock we can safely
// check if any external party still has access to the receiver state. If not, we assume there is no
// interest in this future at all (e.g. during shutdown) and will NOT spawn it.
if Arc::strong_count(&self.receiver.inner) > 1 {
let inner = self.inner.take().expect("Double-drop?");
let handle = tokio::task::spawn(inner);
*receiver = Some(handle);
}
}
}
}
impl<F> CancellationSafeFuture<F>
where
F: Future + Send,
F::Output: Send,
{
/// Create new future that is protected from cancellation.
///
/// If [`CancellationSafeFuture`] is cancelled (i.e. dropped) and there is still some external receiver of the state
/// left, than we will drive the payload (`f`) to completion. Otherwise `f` will be cancelled.
pub fn new(fut: F, receiver: CancellationSafeFutureReceiver<F::Output>) -> Self {
Self {
done: false,
inner: Some(Box::pin(fut)),
receiver,
}
}
}
impl<F> Future for CancellationSafeFuture<F>
where
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(!self.done, "Polling future that already returned");
match self.inner.as_mut().expect("not dropped").as_mut().poll(cx) {
Poll::Ready(res) => {
self.done = true;
Poll::Ready(res)
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use tokio::sync::Barrier;
use super::*;
#[tokio::test]
async fn test_happy_path() {
let done = Arc::new(AtomicBool::new(false));
let done_captured = Arc::clone(&done);
let receiver = Default::default();
let fut = CancellationSafeFuture::new(
async move {
done_captured.store(true, Ordering::SeqCst);
},
receiver,
);
fut.await;
assert!(done.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_cancel_future() {
let done = Arc::new(Barrier::new(2));
let done_captured = Arc::clone(&done);
let receiver = CancellationSafeFutureReceiver::default();
let fut = CancellationSafeFuture::new(
async move {
done_captured.wait().await;
},
receiver.clone(),
);
drop(fut);
tokio::time::timeout(Duration::from_secs(5), done.wait())
.await
.unwrap();
}
#[tokio::test]
async fn test_receiver_gone() {
let done = Arc::new(Barrier::new(2));
let done_captured = Arc::clone(&done);
let receiver = Default::default();
let fut = CancellationSafeFuture::new(
async move {
done_captured.wait().await;
},
receiver,
);
drop(fut);
assert_eq!(Arc::strong_count(&done), 1);
}
}

28
cache_system/src/lib.rs Normal file
View File

@ -0,0 +1,28 @@
//! Flexible and modular cache system.
#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)]
#![warn(
missing_copy_implementations,
missing_docs,
clippy::explicit_iter_loop,
// See https://github.com/influxdata/influxdb_iox/pull/1671
clippy::future_not_send,
clippy::use_self,
clippy::clone_on_ref_ptr,
clippy::todo,
clippy::dbg_macro,
unused_crate_dependencies
)]
// Workaround for "unused crate" lint false positives.
#[cfg(test)]
use criterion as _;
use workspace_hack as _;
pub mod addressable_heap;
pub mod backend;
pub mod cache;
mod cancellation_safe_future;
pub mod loader;
pub mod resource_consumption;
#[cfg(test)]
mod test_util;

View File

@ -0,0 +1,496 @@
//! Batching of loader request.
use std::{
collections::HashMap,
fmt::Debug,
future::Future,
hash::Hash,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::Poll,
};
use async_trait::async_trait;
use futures::{
channel::oneshot::{channel, Sender},
FutureExt,
};
use observability_deps::tracing::trace;
use parking_lot::Mutex;
use crate::cancellation_safe_future::{CancellationSafeFuture, CancellationSafeFutureReceiver};
use super::Loader;
/// Batch [load](Loader::load) requests.
///
/// Requests against this loader will be [pending](std::task::Poll::Pending) until [flush](BatchLoaderFlusher::flush) is
/// called. To simplify the usage -- esp. in combination with [`Cache::get`] -- use [`BatchLoaderFlusherExt`].
///
///
/// [`Cache::get`]: crate::cache::Cache::get
#[derive(Debug)]
pub struct BatchLoader<K, Extra, V, L>
where
K: Debug + Hash + Send + 'static,
Extra: Debug + Send + 'static,
V: Debug + Send + 'static,
L: Loader<K = Vec<K>, Extra = Vec<Extra>, V = Vec<V>>,
{
inner: Arc<BatchLoaderInner<K, Extra, V, L>>,
}
impl<K, Extra, V, L> BatchLoader<K, Extra, V, L>
where
K: Debug + Hash + Send + 'static,
Extra: Debug + Send + 'static,
V: Debug + Send + 'static,
L: Loader<K = Vec<K>, Extra = Vec<Extra>, V = Vec<V>>,
{
/// Create new batch loader based on a non-batched, vector-based one.
pub fn new(inner: L) -> Self {
Self {
inner: Arc::new(BatchLoaderInner {
inner,
pending: Default::default(),
job_id_counter: Default::default(),
job_handles: Default::default(),
}),
}
}
}
/// State of [`BatchLoader`].
///
/// This is an extra struct so it can be wrapped into an [`Arc`] and shared with the futures that are spawned into
/// [`CancellationSafeFuture`]
#[derive(Debug)]
struct BatchLoaderInner<K, Extra, V, L>
where
K: Debug + Hash + Send + 'static,
Extra: Debug + Send + 'static,
V: Debug + Send + 'static,
L: Loader<K = Vec<K>, Extra = Vec<Extra>, V = Vec<V>>,
{
inner: L,
pending: Mutex<Vec<(K, Extra, Sender<V>)>>,
job_id_counter: AtomicU64,
job_handles: Mutex<HashMap<u64, CancellationSafeFutureReceiver<()>>>,
}
/// Flush interface for [`BatchLoader`].
///
/// This is a trait so you can [type-erase](https://en.wikipedia.org/wiki/Type_erasure) it by putting it into an
/// [`Arc`],
///
/// This trait is object-safe.
#[async_trait]
pub trait BatchLoaderFlusher: Debug + Send + Sync + 'static {
/// Flush all batched requests.
async fn flush(&self);
}
#[async_trait]
impl BatchLoaderFlusher for Arc<dyn BatchLoaderFlusher> {
async fn flush(&self) {
self.as_ref().flush().await;
}
}
#[async_trait]
impl<K, Extra, V, L> BatchLoaderFlusher for BatchLoader<K, Extra, V, L>
where
K: Debug + Hash + Send + 'static,
Extra: Debug + Send + 'static,
V: Debug + Send + 'static,
L: Loader<K = Vec<K>, Extra = Vec<Extra>, V = Vec<V>>,
{
async fn flush(&self) {
let pending: Vec<_> = {
let mut pending = self.inner.pending.lock();
std::mem::take(pending.as_mut())
};
if pending.is_empty() {
return;
}
trace!(n_pending = pending.len(), "flush batch loader",);
let job_id = self.inner.job_id_counter.fetch_add(1, Ordering::SeqCst);
let handle_recv = CancellationSafeFutureReceiver::default();
{
let mut job_handles = self.inner.job_handles.lock();
job_handles.insert(job_id, handle_recv.clone());
}
let inner = Arc::clone(&self.inner);
let fut = CancellationSafeFuture::new(
async move {
let mut keys = Vec::with_capacity(pending.len());
let mut extras = Vec::with_capacity(pending.len());
let mut senders = Vec::with_capacity(pending.len());
for (k, extra, sender) in pending {
keys.push(k);
extras.push(extra);
senders.push(sender);
}
let values = inner.inner.load(keys, extras).await;
assert_eq!(values.len(), senders.len());
for (value, sender) in values.into_iter().zip(senders) {
sender.send(value).unwrap();
}
let mut job_handles = inner.job_handles.lock();
job_handles.remove(&job_id);
},
handle_recv,
);
fut.await;
}
}
#[async_trait]
impl<K, Extra, V, L> Loader for BatchLoader<K, Extra, V, L>
where
K: Debug + Hash + Send + 'static,
Extra: Debug + Send + 'static,
V: Debug + Send + 'static,
L: Loader<K = Vec<K>, Extra = Vec<Extra>, V = Vec<V>>,
{
type K = K;
type Extra = Extra;
type V = V;
async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V {
let (tx, rx) = channel();
{
let mut pending = self.inner.pending.lock();
pending.push((k, extra, tx));
}
rx.await.unwrap()
}
}
/// Extension trait for [`BatchLoaderFlusher`] because the methods on this extension trait are not object safe.
#[async_trait]
pub trait BatchLoaderFlusherExt {
/// Try to poll all given futures and automatically [flush](BatchLoaderFlusher) if any of them end up in a pending state.
///
/// This guarantees that the order of the results is identical to the order of the futures.
async fn auto_flush<F>(&self, futures: Vec<F>) -> Vec<F::Output>
where
F: Future + Send,
F::Output: Send;
}
#[async_trait]
impl<B> BatchLoaderFlusherExt for B
where
B: BatchLoaderFlusher,
{
async fn auto_flush<F>(&self, futures: Vec<F>) -> Vec<F::Output>
where
F: Future + Send,
F::Output: Send,
{
let mut futures = futures
.into_iter()
.map(|f| f.boxed())
.enumerate()
.collect::<Vec<_>>();
let mut output: Vec<Option<F::Output>> = (0..futures.len()).map(|_| None).collect();
while !futures.is_empty() {
let mut pending = Vec::with_capacity(futures.len());
for (idx, mut f) in futures.into_iter() {
match futures::poll!(&mut f) {
Poll::Ready(res) => {
output[idx] = Some(res);
}
Poll::Pending => {
pending.push((idx, f));
}
}
}
if !pending.is_empty() {
self.flush().await;
// prevent hot-looping:
// It seems that in some cases the underlying loader is ready but the data is not available via the
// cache driver yet. This is likely due to the signalling system within the cache driver that prevents
// cancelation, but also allows side-loading and at the same time prevents that the same key is loaded
// multiple times. Tokio doesn't know that this method here is basically a wait loop. So we yield back
// to the tokio worker and to allow it to make some progress. Since flush+load take some time anyways,
// this yield here is not overall performance critical.
tokio::task::yield_now().await;
}
futures = pending;
}
output
.into_iter()
.map(|o| o.expect("all futures finished"))
.collect()
}
}
#[cfg(test)]
mod tests {
use tokio::sync::Barrier;
use crate::{
cache::{driver::CacheDriver, Cache},
loader::test_util::TestLoader,
test_util::EnsurePendingExt,
};
use super::*;
type TestLoaderT = Arc<TestLoader<Vec<u8>, Vec<bool>, Vec<String>>>;
#[tokio::test]
async fn test_flush_empty() {
let (inner, batch) = setup();
batch.flush().await;
assert_eq!(inner.loaded(), vec![],);
}
#[tokio::test]
async fn test_flush_manual() {
let (inner, batch) = setup();
let pending_barrier_1 = Arc::new(Barrier::new(2));
let pending_barrier_1_captured = Arc::clone(&pending_barrier_1);
let batch_captured = Arc::clone(&batch);
let handle_1 = tokio::spawn(async move {
batch_captured
.load(1, true)
.ensure_pending(pending_barrier_1_captured)
.await
});
pending_barrier_1.wait().await;
let pending_barrier_2 = Arc::new(Barrier::new(2));
let pending_barrier_2_captured = Arc::clone(&pending_barrier_2);
let batch_captured = Arc::clone(&batch);
let handle_2 = tokio::spawn(async move {
batch_captured
.load(2, false)
.ensure_pending(pending_barrier_2_captured)
.await
});
pending_barrier_2.wait().await;
inner.mock_next(vec![1, 2], vec![String::from("foo"), String::from("bar")]);
batch.flush().await;
assert_eq!(inner.loaded(), vec![(vec![1, 2], vec![true, false])],);
assert_eq!(handle_1.await.unwrap(), String::from("foo"));
assert_eq!(handle_2.await.unwrap(), String::from("bar"));
}
/// Simulate the following scenario:
///
/// 1. load `1`, flush it, inner load starts processing `[1]`
/// 2. load `2`, flush it, inner load starts processing `[2]`
/// 3. inner loader returns result for `[2]`, batch loader returns that result as well
/// 4. inner loader returns result for `[1]`, batch loader returns that result as well
#[tokio::test]
async fn test_concurrent_load() {
let (inner, batch) = setup();
let load_barrier_1 = inner.block_next(vec![1], vec![String::from("foo")]);
inner.mock_next(vec![2], vec![String::from("bar")]);
// set up first load
let pending_barrier_1 = Arc::new(Barrier::new(2));
let pending_barrier_1_captured = Arc::clone(&pending_barrier_1);
let batch_captured = Arc::clone(&batch);
let handle_1 = tokio::spawn(async move {
batch_captured
.load(1, true)
.ensure_pending(pending_barrier_1_captured)
.await
});
pending_barrier_1.wait().await;
// flush first load, this is blocked by the load barrier
let pending_barrier_2 = Arc::new(Barrier::new(2));
let pending_barrier_2_captured = Arc::clone(&pending_barrier_2);
let batch_captured = Arc::clone(&batch);
let handle_2 = tokio::spawn(async move {
batch_captured
.flush()
.ensure_pending(pending_barrier_2_captured)
.await;
});
pending_barrier_2.wait().await;
// set up second load
let pending_barrier_3 = Arc::new(Barrier::new(2));
let pending_barrier_3_captured = Arc::clone(&pending_barrier_3);
let batch_captured = Arc::clone(&batch);
let handle_3 = tokio::spawn(async move {
batch_captured
.load(2, false)
.ensure_pending(pending_barrier_3_captured)
.await
});
pending_barrier_3.wait().await;
// flush 2nd load and get result
batch.flush().await;
assert_eq!(handle_3.await.unwrap(), String::from("bar"));
// flush 1st load and get result
load_barrier_1.wait().await;
handle_2.await.unwrap();
assert_eq!(handle_1.await.unwrap(), String::from("foo"));
assert_eq!(
inner.loaded(),
vec![(vec![1], vec![true]), (vec![2], vec![false])],
);
}
#[tokio::test]
async fn test_cancel_flush() {
let (inner, batch) = setup();
let load_barrier_1 = inner.block_next(vec![1], vec![String::from("foo")]);
// set up load
let pending_barrier_1 = Arc::new(Barrier::new(2));
let pending_barrier_1_captured = Arc::clone(&pending_barrier_1);
let batch_captured = Arc::clone(&batch);
let handle_1 = tokio::spawn(async move {
batch_captured
.load(1, true)
.ensure_pending(pending_barrier_1_captured)
.await
});
pending_barrier_1.wait().await;
// flush load, this is blocked by the load barrier
let pending_barrier_2 = Arc::new(Barrier::new(2));
let pending_barrier_2_captured = Arc::clone(&pending_barrier_2);
let batch_captured = Arc::clone(&batch);
let handle_2 = tokio::spawn(async move {
batch_captured
.flush()
.ensure_pending(pending_barrier_2_captured)
.await;
});
pending_barrier_2.wait().await;
// abort flush
handle_2.abort();
// flush load and get result
load_barrier_1.wait().await;
assert_eq!(handle_1.await.unwrap(), String::from("foo"));
assert_eq!(inner.loaded(), vec![(vec![1], vec![true])],);
}
#[tokio::test]
async fn test_cancel_load_and_flush() {
let (inner, batch) = setup();
let load_barrier_1 = inner.block_next(vec![1], vec![String::from("foo")]);
// set up load
let pending_barrier_1 = Arc::new(Barrier::new(2));
let pending_barrier_1_captured = Arc::clone(&pending_barrier_1);
let batch_captured = Arc::clone(&batch);
let handle_1 = tokio::spawn(async move {
batch_captured
.load(1, true)
.ensure_pending(pending_barrier_1_captured)
.await
});
pending_barrier_1.wait().await;
// flush load, this is blocked by the load barrier
let pending_barrier_2 = Arc::new(Barrier::new(2));
let pending_barrier_2_captured = Arc::clone(&pending_barrier_2);
let batch_captured = Arc::clone(&batch);
let handle_2 = tokio::spawn(async move {
batch_captured
.flush()
.ensure_pending(pending_barrier_2_captured)
.await;
});
pending_barrier_2.wait().await;
// abort load and flush
handle_1.abort();
handle_2.abort();
// unblock
load_barrier_1.wait().await;
// load was still driven to completion
assert_eq!(inner.loaded(), vec![(vec![1], vec![true])],);
}
#[tokio::test]
async fn test_auto_flush_with_loader() {
let (inner, batch) = setup();
inner.mock_next(vec![1, 2], vec![String::from("foo"), String::from("bar")]);
assert_eq!(
batch
.auto_flush(vec![batch.load(1, true), batch.load(2, false)])
.await,
vec![String::from("foo"), String::from("bar")],
);
assert_eq!(inner.loaded(), vec![(vec![1, 2], vec![true, false])],);
}
#[tokio::test]
async fn test_auto_flush_integration_with_cache_driver() {
let (inner, batch) = setup();
let cache = CacheDriver::new(Arc::clone(&batch), HashMap::new());
inner.mock_next(vec![1, 2], vec![String::from("foo"), String::from("bar")]);
inner.mock_next(vec![3], vec![String::from("baz")]);
assert_eq!(
batch
.auto_flush(vec![cache.get(1, true), cache.get(2, false)])
.await,
vec![String::from("foo"), String::from("bar")],
);
assert_eq!(
batch
.auto_flush(vec![cache.get(2, true), cache.get(3, true)])
.await,
vec![String::from("bar"), String::from("baz")],
);
assert_eq!(
inner.loaded(),
vec![(vec![1, 2], vec![true, false]), (vec![3], vec![true])],
);
}
fn setup() -> (TestLoaderT, Arc<BatchLoader<u8, bool, String, TestLoaderT>>) {
let inner = TestLoaderT::default();
let batch = Arc::new(BatchLoader::new(Arc::clone(&inner)));
(inner, batch)
}
}

View File

@ -0,0 +1,247 @@
//! Metrics for [`Loader`].
use std::sync::Arc;
use async_trait::async_trait;
use iox_time::TimeProvider;
use metric::{DurationHistogram, U64Counter};
use observability_deps::tracing::warn;
use parking_lot::Mutex;
use pdatastructs::filters::{bloomfilter::BloomFilter, Filter};
use super::Loader;
/// Wraps a [`Loader`] and adds metrics.
pub struct MetricsLoader<L>
where
L: Loader,
{
inner: L,
time_provider: Arc<dyn TimeProvider>,
metric_calls_new: U64Counter,
metric_calls_probably_reloaded: U64Counter,
metric_duration: DurationHistogram,
seen: Mutex<BloomFilter<L::K>>,
}
impl<L> MetricsLoader<L>
where
L: Loader,
{
/// Create new wrapper.
///
/// # Testing
/// If `testing` is set, the "seen" metrics will NOT be processed correctly because the underlying data structure is
/// too expensive to create many times a second in an un-optimized debug build.
pub fn new(
inner: L,
name: &'static str,
time_provider: Arc<dyn TimeProvider>,
metric_registry: &metric::Registry,
testing: bool,
) -> Self {
let metric_calls = metric_registry.register_metric::<U64Counter>(
"cache_load_function_calls",
"Count how often a cache loader was called.",
);
let metric_calls_new = metric_calls.recorder(&[("name", name), ("status", "new")]);
let metric_calls_probably_reloaded =
metric_calls.recorder(&[("name", name), ("status", "probably_reloaded")]);
let metric_duration = metric_registry
.register_metric::<DurationHistogram>(
"cache_load_function_duration",
"Time taken by cache load function calls",
)
.recorder(&[("name", name)]);
let seen = if testing {
BloomFilter::with_params(1, 1)
} else {
// Set up bloom filter for "probably reloaded" test:
//
// - input size: we expect 10M elements
// - reliability: probability of false positives should be <= 1%
// - CPU efficiency: number of hash functions should be <= 10
// - RAM efficiency: size should be <= 15MB
//
//
// A bloom filter was chosen here because of the following properties:
//
// - memory bound: The storage size is bound even when the set of "probably reloaded" entries approaches
// infinite sizes.
// - memory efficiency: We do not need to store the actual keys.
// - infallible: Inserting new data into the filter never fails (in contrast to for example a CuckooFilter or
// QuotientFilter).
//
// The fact that a filter can produce false positives (i.e. it classifies an actual new entry as "probably
// reloaded") is considered to be OK since the metric is more of an estimate and a guide for cache tuning. We
// might want to use a more efficient (i.e. more modern) filter design at one point though.
let seen = BloomFilter::with_properties(10_000_000, 1.0 / 100.0);
const BOUND_HASH_FUNCTIONS: usize = 10;
assert!(
seen.k() <= BOUND_HASH_FUNCTIONS,
"number of hash functions for bloom filter should be <= {} but is {}",
BOUND_HASH_FUNCTIONS,
seen.k(),
);
const BOUND_SIZE_BYTES: usize = 15_000_000;
let size_bytes = (seen.m() + 7) / 8;
assert!(
size_bytes <= BOUND_SIZE_BYTES,
"size of bloom filter should be <= {BOUND_SIZE_BYTES} bytes but is {size_bytes} bytes",
);
seen
};
Self {
inner,
time_provider,
metric_calls_new,
metric_calls_probably_reloaded,
metric_duration,
seen: Mutex::new(seen),
}
}
}
impl<L> std::fmt::Debug for MetricsLoader<L>
where
L: Loader,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MetricsLoader").finish_non_exhaustive()
}
}
#[async_trait]
impl<L> Loader for MetricsLoader<L>
where
L: Loader,
{
type K = L::K;
type V = L::V;
type Extra = L::Extra;
async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V {
{
let mut seen_guard = self.seen.lock();
if seen_guard.insert(&k).expect("bloom filter cannot fail") {
&self.metric_calls_new
} else {
&self.metric_calls_probably_reloaded
}
.inc(1);
}
let t_start = self.time_provider.now();
let v = self.inner.load(k, extra).await;
let t_end = self.time_provider.now();
match t_end.checked_duration_since(t_start) {
Some(duration) => {
self.metric_duration.record(duration);
}
None => {
warn!("Clock went backwards, not recording loader duration");
}
}
v
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use iox_time::{MockProvider, Time};
use metric::{Observation, RawReporter};
use crate::loader::FunctionLoader;
use super::*;
#[tokio::test]
async fn test_metrics() {
let time_provider = Arc::new(MockProvider::new(Time::from_timestamp_millis(0).unwrap()));
let metric_registry = Arc::new(metric::Registry::new());
let time_provider_captured = Arc::clone(&time_provider);
let d = Duration::from_secs(10);
let inner_loader = FunctionLoader::new(move |x: u64, _extra: ()| {
let time_provider_captured = Arc::clone(&time_provider_captured);
async move {
time_provider_captured.inc(d);
x.to_string()
}
});
let loader = MetricsLoader::new(
inner_loader,
"my_loader",
time_provider,
&metric_registry,
false,
);
let mut reporter = RawReporter::default();
metric_registry.report(&mut reporter);
for status in ["new", "probably_reloaded"] {
assert_eq!(
reporter
.metric("cache_load_function_calls")
.unwrap()
.observation(&[("name", "my_loader"), ("status", status)])
.unwrap(),
&Observation::U64Counter(0)
);
}
if let Observation::DurationHistogram(hist) = reporter
.metric("cache_load_function_duration")
.unwrap()
.observation(&[("name", "my_loader")])
.unwrap()
{
assert_eq!(hist.sample_count(), 0);
assert_eq!(hist.total, Duration::from_secs(0));
} else {
panic!("Wrong observation type");
}
assert_eq!(loader.load(42, ()).await, String::from("42"));
assert_eq!(loader.load(42, ()).await, String::from("42"));
assert_eq!(loader.load(1337, ()).await, String::from("1337"));
let mut reporter = RawReporter::default();
metric_registry.report(&mut reporter);
assert_eq!(
reporter
.metric("cache_load_function_calls")
.unwrap()
.observation(&[("name", "my_loader"), ("status", "new")])
.unwrap(),
&Observation::U64Counter(2)
);
assert_eq!(
reporter
.metric("cache_load_function_calls")
.unwrap()
.observation(&[("name", "my_loader"), ("status", "probably_reloaded")])
.unwrap(),
&Observation::U64Counter(1)
);
if let Observation::DurationHistogram(hist) = reporter
.metric("cache_load_function_duration")
.unwrap()
.observation(&[("name", "my_loader")])
.unwrap()
{
assert_eq!(hist.sample_count(), 3);
assert_eq!(hist.total, 3 * d);
} else {
panic!("Wrong observation type");
}
}
}

View File

@ -0,0 +1,151 @@
//! How to load new cache entries.
use async_trait::async_trait;
use std::{fmt::Debug, future::Future, hash::Hash, marker::PhantomData, sync::Arc};
pub mod batch;
pub mod metrics;
#[cfg(test)]
pub(crate) mod test_util;
/// Loader for missing [`Cache`](crate::cache::Cache) entries.
#[async_trait]
pub trait Loader: std::fmt::Debug + Send + Sync + 'static {
/// Cache key.
type K: Debug + Hash + Send + 'static;
/// Extra data needed when loading a missing entry. Specify `()` if not needed.
type Extra: Debug + Send + 'static;
/// Cache value.
type V: Debug + Send + 'static;
/// Load value for given key, using the extra data if needed.
async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V;
}
#[async_trait]
impl<K, V, Extra> Loader for Box<dyn Loader<K = K, V = V, Extra = Extra>>
where
K: Debug + Hash + Send + 'static,
V: Debug + Send + 'static,
Extra: Debug + Send + 'static,
{
type K = K;
type V = V;
type Extra = Extra;
async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V {
self.as_ref().load(k, extra).await
}
}
#[async_trait]
impl<K, V, Extra, L> Loader for Arc<L>
where
K: Debug + Hash + Send + 'static,
V: Debug + Send + 'static,
Extra: Debug + Send + 'static,
L: Loader<K = K, V = V, Extra = Extra>,
{
type K = K;
type V = V;
type Extra = Extra;
async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V {
self.as_ref().load(k, extra).await
}
}
/// Simple-to-use wrapper for async functions to act as a [`Loader`].
///
/// # Typing
/// Semantically this wrapper has only one degree of freedom: `T`, which is the async loader function. However until
/// [`fn_traits`] are stable, there is no way to extract the parameters and return value from a function via associated
/// types. So we need to add additional type parametes for the special `Fn(...) -> ...` handling.
///
/// It is likely that `T` will be a closure, e.g.:
///
/// ```
/// use cache_system::loader::FunctionLoader;
///
/// let my_loader = FunctionLoader::new(|k: u8, _extra: ()| async move {
/// format!("{k}")
/// });
/// ```
///
/// There is no way to spell out the exact type of `my_loader` in the above example, because the closure has an
/// anonymous type. If you need the type signature of [`FunctionLoader`], you have to
/// [erase the type](https://en.wikipedia.org/wiki/Type_erasure) by putting the [`FunctionLoader`] it into a [`Box`],
/// e.g.:
///
/// ```
/// use cache_system::loader::{Loader, FunctionLoader};
///
/// let my_loader = FunctionLoader::new(|k: u8, _extra: ()| async move {
/// format!("{k}")
/// });
/// let m_loader: Box<dyn Loader<K = u8, V = String, Extra = ()>> = Box::new(my_loader);
/// ```
///
///
/// [`fn_traits`]: https://doc.rust-lang.org/beta/unstable-book/library-features/fn-traits.html
pub struct FunctionLoader<T, F, K, Extra>
where
T: Fn(K, Extra) -> F + Send + Sync + 'static,
F: Future + Send + 'static,
K: Debug + Send + 'static,
F::Output: Debug + Send + 'static,
Extra: Debug + Send + 'static,
{
loader: T,
_phantom: PhantomData<dyn Fn() -> (F, K, Extra) + Send + Sync + 'static>,
}
impl<T, F, K, Extra> FunctionLoader<T, F, K, Extra>
where
T: Fn(K, Extra) -> F + Send + Sync + 'static,
F: Future + Send + 'static,
K: Debug + Send + 'static,
F::Output: Debug + Send + 'static,
Extra: Debug + Send + 'static,
{
/// Create loader from function.
pub fn new(loader: T) -> Self {
Self {
loader,
_phantom: PhantomData,
}
}
}
impl<T, F, K, Extra> std::fmt::Debug for FunctionLoader<T, F, K, Extra>
where
T: Fn(K, Extra) -> F + Send + Sync + 'static,
F: Future + Send + 'static,
K: Debug + Send + 'static,
F::Output: Debug + Send + 'static,
Extra: Debug + Send + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FunctionLoader").finish_non_exhaustive()
}
}
#[async_trait]
impl<T, F, K, Extra> Loader for FunctionLoader<T, F, K, Extra>
where
T: Fn(K, Extra) -> F + Send + Sync + 'static,
F: Future + Send + 'static,
K: Debug + Hash + Send + 'static,
F::Output: Debug + Send + 'static,
Extra: Debug + Send + 'static,
{
type K = K;
type V = F::Output;
type Extra = Extra;
async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V {
(self.loader)(k, extra).await
}
}

View File

@ -0,0 +1,239 @@
use std::{collections::HashMap, fmt::Debug, hash::Hash, sync::Arc};
use async_trait::async_trait;
use parking_lot::Mutex;
use tokio::sync::{Barrier, Notify};
use super::Loader;
#[derive(Debug)]
enum TestLoaderResponse<V> {
Answer { v: V, block: Option<Arc<Barrier>> },
Panic,
}
/// An easy-to-mock [`Loader`].
#[derive(Debug, Default)]
pub struct TestLoader<K = u8, Extra = bool, V = String>
where
K: Clone + Debug + Eq + Hash + Send + 'static,
Extra: Clone + Debug + Send + 'static,
V: Clone + Debug + Send + 'static,
{
responses: Mutex<HashMap<K, Vec<TestLoaderResponse<V>>>>,
blocked: Mutex<Option<Arc<Notify>>>,
loaded: Mutex<Vec<(K, Extra)>>,
}
impl<K, V, Extra> TestLoader<K, Extra, V>
where
K: Clone + Debug + Eq + Hash + Send + 'static,
Extra: Clone + Debug + Send + 'static,
V: Clone + Debug + Send + 'static,
{
/// Mock next value for given key-value pair.
pub fn mock_next(&self, k: K, v: V) {
self.mock_inner(k, TestLoaderResponse::Answer { v, block: None });
}
/// Block on next load for given key-value pair.
///
/// Return a barrier that can be used to unblock the load.
#[must_use]
pub fn block_next(&self, k: K, v: V) -> Arc<Barrier> {
let block = Arc::new(Barrier::new(2));
self.mock_inner(
k,
TestLoaderResponse::Answer {
v,
block: Some(Arc::clone(&block)),
},
);
block
}
/// Panic when loading value for `k`.
///
/// If this is used together with [`block_global`](Self::block_global), the panic will occur AFTER
/// blocking.
pub fn panic_next(&self, k: K) {
self.mock_inner(k, TestLoaderResponse::Panic);
}
fn mock_inner(&self, k: K, response: TestLoaderResponse<V>) {
let mut responses = self.responses.lock();
responses.entry(k).or_default().push(response);
}
/// Block all [`load`](Self::load) requests until [`unblock`](Self::unblock) is called.
///
/// If this is used together with [`panic_once`](Self::panic_once), the panic will occur
/// AFTER blocking.
pub fn block_global(&self) {
let mut blocked = self.blocked.lock();
assert!(blocked.is_none());
*blocked = Some(Arc::new(Notify::new()));
}
/// Unblock all requests.
///
/// Returns number of requests that were blocked.
pub fn unblock_global(&self) -> usize {
let handle = self.blocked.lock().take().unwrap();
let blocked_count = Arc::strong_count(&handle) - 1;
handle.notify_waiters();
blocked_count
}
/// List all keys that were loaded.
///
/// Contains duplicates if keys were loaded multiple times.
pub fn loaded(&self) -> Vec<(K, Extra)> {
self.loaded.lock().clone()
}
}
impl<K, Extra, V> Drop for TestLoader<K, Extra, V>
where
K: Clone + Debug + Eq + Hash + Send + 'static,
Extra: Clone + Debug + Send + 'static,
V: Clone + Debug + Send + 'static,
{
fn drop(&mut self) {
// prevent double-panic (i.e. aborts)
if !std::thread::panicking() {
for entries in self.responses.lock().values() {
assert!(entries.is_empty(), "mocked response left");
}
}
}
}
#[async_trait]
impl<K, V, Extra> Loader for TestLoader<K, Extra, V>
where
K: Clone + Debug + Eq + Hash + Send + 'static,
Extra: Clone + Debug + Send + 'static,
V: Clone + Debug + Send + 'static,
{
type K = K;
type Extra = Extra;
type V = V;
async fn load(&self, k: Self::K, extra: Self::Extra) -> Self::V {
self.loaded.lock().push((k.clone(), extra));
// need to capture the cloned notify handle, otherwise the lock guard leaks into the
// generator
let maybe_block = self.blocked.lock().clone();
if let Some(block) = maybe_block {
block.notified().await;
}
let response = {
let mut guard = self.responses.lock();
let entries = guard.get_mut(&k).expect("entry not mocked");
assert!(!entries.is_empty(), "no mocked response left");
entries.remove(0)
};
match response {
TestLoaderResponse::Answer { v, block } => {
if let Some(block) = block {
block.wait().await;
}
v
}
TestLoaderResponse::Panic => {
panic!("test")
}
}
}
}
#[cfg(test)]
mod tests {
use futures::FutureExt;
use super::*;
#[tokio::test]
#[should_panic(expected = "entry not mocked")]
async fn test_loader_panic_entry_unknown() {
let loader = TestLoader::<u8, (), String>::default();
loader.load(1, ()).await;
}
#[tokio::test]
#[should_panic(expected = "no mocked response left")]
async fn test_loader_panic_no_mocked_reponse_left() {
let loader = TestLoader::default();
loader.mock_next(1, String::from("foo"));
loader.load(1, ()).await;
loader.load(1, ()).await;
}
#[test]
#[should_panic(expected = "mocked response left")]
fn test_loader_panic_requests_left() {
let loader = TestLoader::<u8, (), String>::default();
loader.mock_next(1, String::from("foo"));
}
#[test]
#[should_panic(expected = "panic-by-choice")]
fn test_loader_no_double_panic() {
let loader = TestLoader::<u8, (), String>::default();
loader.mock_next(1, String::from("foo"));
panic!("panic-by-choice");
}
#[tokio::test]
async fn test_loader_nonblocking_mock() {
let loader = TestLoader::default();
loader.mock_next(1, String::from("foo"));
loader.mock_next(1, String::from("bar"));
loader.mock_next(2, String::from("baz"));
assert_eq!(loader.load(1, ()).await, String::from("foo"));
assert_eq!(loader.load(2, ()).await, String::from("baz"));
assert_eq!(loader.load(1, ()).await, String::from("bar"));
}
#[tokio::test]
async fn test_loader_blocking_mock() {
let loader = Arc::new(TestLoader::default());
let loader_barrier = loader.block_next(1, String::from("foo"));
loader.mock_next(2, String::from("bar"));
let is_blocked_barrier = Arc::new(Barrier::new(2));
let loader_captured = Arc::clone(&loader);
let is_blocked_barrier_captured = Arc::clone(&is_blocked_barrier);
let handle = tokio::task::spawn(async move {
let mut fut_load = loader_captured.load(1, ()).fuse();
futures::select_biased! {
_ = fut_load => {
panic!("should not finish");
}
_ = is_blocked_barrier_captured.wait().fuse() => {}
}
fut_load.await
});
is_blocked_barrier.wait().await;
// can still load other entries
assert_eq!(loader.load(2, ()).await, String::from("bar"));
// unblock load
loader_barrier.wait().await;
assert_eq!(handle.await.unwrap(), String::from("foo"));
}
}

View File

@ -0,0 +1,195 @@
//! Reasoning about resource consumption of cached data.
use std::{
fmt::Debug,
marker::PhantomData,
ops::{Add, Sub},
};
/// Strongly-typed resource consumption.
///
/// Can be used to represent in-RAM memory as well as on-disc memory.
pub trait Resource:
Add<Output = Self>
+ Copy
+ Debug
+ Into<u64>
+ Ord
+ PartialOrd
+ Send
+ Sync
+ Sub<Output = Self>
+ 'static
{
/// Create resource consumption of zero.
fn zero() -> Self;
/// Unit name.
///
/// This must be a single lowercase word.
fn unit() -> &'static str;
}
/// An estimator of [`Resource`] consumption for a given key-value pair.
pub trait ResourceEstimator: Debug + Send + Sync + 'static {
/// Cache key.
type K;
/// Cached value.
type V;
/// Size that can be estimated.
type S: Resource;
/// Estimate size of given key-value pair.
fn consumption(&self, k: &Self::K, v: &Self::V) -> Self::S;
}
/// A simple function-based [`ResourceEstimator].
///
/// # Typing
/// Semantically this wrapper has only one degree of freedom: `F`, which is the estimator function. However until
/// [`fn_traits`] are stable, there is no way to extract the parameters and return value from a function via associated
/// types. So we need to add additional type parametes for the special `Fn(...) -> ...` handling.
///
/// It is likely that `F` will be a closure, e.g.:
///
/// ```
/// use cache_system::resource_consumption::{
/// FunctionEstimator,
/// test_util::TestSize,
/// };
///
/// let my_estimator = FunctionEstimator::new(|_k: &u8, v: &String| -> TestSize {
/// TestSize(std::mem::size_of::<(u8, String)>() + v.capacity())
/// });
/// ```
///
/// There is no way to spell out the exact type of `my_estimator` in the above example, because the closure has an
/// anonymous type. If you need the type signature of [`FunctionEstimator`], you have to
/// [erase the type](https://en.wikipedia.org/wiki/Type_erasure) by putting the [`FunctionEstimator`] it into a [`Box`],
/// e.g.:
///
/// ```
/// use cache_system::resource_consumption::{
/// FunctionEstimator,
/// ResourceEstimator,
/// test_util::TestSize,
/// };
///
/// let my_estimator = FunctionEstimator::new(|_k: &u8, v: &String| -> TestSize {
/// TestSize(std::mem::size_of::<(u8, String)>() + v.capacity())
/// });
/// let my_estimator: Box<dyn ResourceEstimator<K = u8, V = String, S = TestSize>> = Box::new(my_estimator);
/// ```
///
///
/// [`fn_traits`]: https://doc.rust-lang.org/beta/unstable-book/library-features/fn-traits.html
pub struct FunctionEstimator<F, K, V, S>
where
F: Fn(&K, &V) -> S + Send + Sync + 'static,
K: 'static,
V: 'static,
S: Resource,
{
estimator: F,
_phantom: PhantomData<dyn Fn() -> (K, V, S) + Send + Sync + 'static>,
}
impl<F, K, V, S> FunctionEstimator<F, K, V, S>
where
F: Fn(&K, &V) -> S + Send + Sync + 'static,
K: 'static,
V: 'static,
S: Resource,
{
/// Create new resource estimator from given function.
pub fn new(f: F) -> Self {
Self {
estimator: f,
_phantom: PhantomData,
}
}
}
impl<F, K, V, S> std::fmt::Debug for FunctionEstimator<F, K, V, S>
where
F: Fn(&K, &V) -> S + Send + Sync + 'static,
K: 'static,
V: 'static,
S: Resource,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FunctionEstimator").finish_non_exhaustive()
}
}
impl<F, K, V, S> ResourceEstimator for FunctionEstimator<F, K, V, S>
where
F: Fn(&K, &V) -> S + Send + Sync + 'static,
K: 'static,
V: 'static,
S: Resource,
{
type K = K;
type V = V;
type S = S;
fn consumption(&self, k: &Self::K, v: &Self::V) -> Self::S {
(self.estimator)(k, v)
}
}
pub mod test_util {
//! Helpers to test resource consumption-based algorithms.
use super::*;
/// Simple resource type for testing.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct TestSize(pub usize);
impl Resource for TestSize {
fn zero() -> Self {
Self(0)
}
fn unit() -> &'static str {
"bytes"
}
}
impl From<TestSize> for u64 {
fn from(s: TestSize) -> Self {
s.0 as Self
}
}
impl Add for TestSize {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self(self.0.checked_add(rhs.0).expect("overflow"))
}
}
impl Sub for TestSize {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self(self.0.checked_sub(rhs.0).expect("underflow"))
}
}
}
#[cfg(test)]
mod tests {
use crate::resource_consumption::test_util::TestSize;
use super::*;
#[test]
fn test_function_estimator() {
let estimator =
FunctionEstimator::new(|k: &u8, v: &u16| TestSize((*k as usize) * 10 + (*v as usize)));
assert_eq!(estimator.consumption(&3, &2), TestSize(32));
}
}

View File

@ -0,0 +1,62 @@
use std::{future::Future, sync::Arc, time::Duration};
use async_trait::async_trait;
use futures::FutureExt;
use tokio::{sync::Barrier, task::JoinHandle};
#[async_trait]
pub trait EnsurePendingExt {
type Out;
/// Ensure that the future is pending. In the pending case, try to pass the given barrier. Afterwards await the future again.
///
/// This is helpful to ensure a future is in a pending state before continuing with the test setup.
async fn ensure_pending(self, barrier: Arc<Barrier>) -> Self::Out;
}
#[async_trait]
impl<F> EnsurePendingExt for F
where
F: Future + Send + Unpin,
{
type Out = F::Output;
async fn ensure_pending(self, barrier: Arc<Barrier>) -> Self::Out {
let mut fut = self.fuse();
futures::select_biased! {
_ = fut => panic!("fut should be pending"),
_ = barrier.wait().fuse() => (),
}
fut.await
}
}
#[async_trait]
pub trait AbortAndWaitExt {
/// Abort handle and wait for completion.
///
/// Note that this is NOT just a "wait with timeout or panic". This extension is specific to [`JoinHandle`] and will:
///
/// 1. Call [`JoinHandle::abort`].
/// 2. Await the [`JoinHandle`] with a timeout (or panic if the timeout is reached).
/// 3. Check that the handle returned a [`JoinError`] that signals that the tracked task was indeed cancelled and
/// didn't exit otherwise (either by finishing or by panicking).
async fn abort_and_wait(self);
}
#[async_trait]
impl<T> AbortAndWaitExt for JoinHandle<T>
where
T: std::fmt::Debug + Send,
{
async fn abort_and_wait(mut self) {
self.abort();
let join_err = tokio::time::timeout(Duration::from_secs(1), self)
.await
.expect("no timeout")
.expect_err("handle was aborted and therefore MUST fail");
assert!(join_err.is_cancelled());
}
}

31
clap_blocks/Cargo.toml Normal file
View File

@ -0,0 +1,31 @@
[package]
name = "clap_blocks"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
clap = { version = "4", features = ["derive", "env"] }
futures = "0.3"
http = "0.2.9"
humantime = "2.1.0"
iox_catalog = { path = "../iox_catalog" }
metric = { path = "../metric" }
object_store = { workspace = true }
observability_deps = { path = "../observability_deps" }
snafu = "0.7"
sysinfo = "0.29.10"
trace_exporters = { path = "../trace_exporters" }
trogging = { path = "../trogging", default-features = false, features = ["clap"] }
uuid = { version = "1", features = ["v4"] }
workspace-hack = { version = "0.1", path = "../workspace-hack" }
[dev-dependencies]
tempfile = "3.8.0"
test_helpers = { path = "../test_helpers" }
[features]
azure = ["object_store/azure"] # Optional Azure Object store support
gcp = ["object_store/gcp"] # Optional GCP object store support
aws = ["object_store/aws"] # Optional AWS / S3 object store support

View File

@ -0,0 +1,162 @@
//! Catalog-DSN-related configs.
use iox_catalog::sqlite::{SqliteCatalog, SqliteConnectionOptions};
use iox_catalog::{
interface::Catalog,
mem::MemCatalog,
postgres::{PostgresCatalog, PostgresConnectionOptions},
};
use observability_deps::tracing::*;
use snafu::{ResultExt, Snafu};
use std::{sync::Arc, time::Duration};
#[derive(Debug, Snafu)]
#[allow(missing_docs)]
pub enum Error {
#[snafu(display("Unknown Catalog DSN {dsn}. Expected a string like 'postgresql://postgres@localhost:5432/postgres' or 'sqlite:///tmp/catalog.sqlite'"))]
UnknownCatalogDsn { dsn: String },
#[snafu(display("Catalog DSN not specified. Expected a string like 'postgresql://postgres@localhost:5432/postgres' or 'sqlite:///tmp/catalog.sqlite'"))]
DsnNotSpecified {},
#[snafu(display("A catalog error occurred: {}", source))]
Catalog {
source: iox_catalog::interface::Error,
},
}
fn default_max_connections() -> &'static str {
let s = PostgresConnectionOptions::DEFAULT_MAX_CONNS.to_string();
Box::leak(Box::new(s))
}
fn default_connect_timeout() -> &'static str {
let s =
humantime::format_duration(PostgresConnectionOptions::DEFAULT_CONNECT_TIMEOUT).to_string();
Box::leak(Box::new(s))
}
fn default_idle_timeout() -> &'static str {
let s = humantime::format_duration(PostgresConnectionOptions::DEFAULT_IDLE_TIMEOUT).to_string();
Box::leak(Box::new(s))
}
fn default_hotswap_poll_interval_timeout() -> &'static str {
let s = humantime::format_duration(PostgresConnectionOptions::DEFAULT_HOTSWAP_POLL_INTERVAL)
.to_string();
Box::leak(Box::new(s))
}
/// CLI config for catalog DSN.
#[derive(Debug, Clone, Default, clap::Parser)]
pub struct CatalogDsnConfig {
/// Catalog connection string.
///
/// The dsn determines the type of catalog used.
///
/// PostgreSQL: `postgresql://postgres@localhost:5432/postgres`
///
/// Sqlite (a local filename /tmp/foo.sqlite): `sqlite:///tmp/foo.sqlite`
///
/// Memory (ephemeral, only useful for testing): `memory`
///
#[clap(long = "catalog-dsn", env = "INFLUXDB_IOX_CATALOG_DSN", action)]
pub dsn: Option<String>,
/// Maximum number of connections allowed to the catalog at any one time.
#[clap(
long = "catalog-max-connections",
env = "INFLUXDB_IOX_CATALOG_MAX_CONNECTIONS",
default_value = default_max_connections(),
action,
)]
pub max_catalog_connections: u32,
/// Schema name for PostgreSQL-based catalogs.
#[clap(
long = "catalog-postgres-schema-name",
env = "INFLUXDB_IOX_CATALOG_POSTGRES_SCHEMA_NAME",
default_value = PostgresConnectionOptions::DEFAULT_SCHEMA_NAME,
action,
)]
pub postgres_schema_name: String,
/// Set the amount of time to attempt connecting to the database.
#[clap(
long = "catalog-connect-timeout",
env = "INFLUXDB_IOX_CATALOG_CONNECT_TIMEOUT",
default_value = default_connect_timeout(),
value_parser = humantime::parse_duration,
)]
pub connect_timeout: Duration,
/// Set a maximum idle duration for individual connections.
#[clap(
long = "catalog-idle-timeout",
env = "INFLUXDB_IOX_CATALOG_IDLE_TIMEOUT",
default_value = default_idle_timeout(),
value_parser = humantime::parse_duration,
)]
pub idle_timeout: Duration,
/// If the DSN points to a file (i.e. starts with `dsn-file://`), this sets the interval how often the the file
/// should be polled for updates.
///
/// If an update is encountered, the underlying connection pool will be hot-swapped.
#[clap(
long = "catalog-hotswap-poll-interval",
env = "INFLUXDB_IOX_CATALOG_HOTSWAP_POLL_INTERVAL",
default_value = default_hotswap_poll_interval_timeout(),
value_parser = humantime::parse_duration,
)]
pub hotswap_poll_interval: Duration,
}
impl CatalogDsnConfig {
/// Get config-dependent catalog.
pub async fn get_catalog(
&self,
app_name: &'static str,
metrics: Arc<metric::Registry>,
) -> Result<Arc<dyn Catalog>, Error> {
let Some(dsn) = self.dsn.as_ref() else {
return Err(Error::DsnNotSpecified {});
};
if dsn.starts_with("postgres") || dsn.starts_with("dsn-file://") {
// do not log entire postgres dsn as it may contain credentials
info!(postgres_schema_name=%self.postgres_schema_name, "Catalog: Postgres");
let options = PostgresConnectionOptions {
app_name: app_name.to_string(),
schema_name: self.postgres_schema_name.clone(),
dsn: dsn.clone(),
max_conns: self.max_catalog_connections,
connect_timeout: self.connect_timeout,
idle_timeout: self.idle_timeout,
hotswap_poll_interval: self.hotswap_poll_interval,
};
Ok(Arc::new(
PostgresCatalog::connect(options, metrics)
.await
.context(CatalogSnafu)?,
))
} else if dsn == "memory" {
info!("Catalog: In-memory");
let mem = MemCatalog::new(metrics);
Ok(Arc::new(mem))
} else if let Some(file_path) = dsn.strip_prefix("sqlite://") {
info!(file_path, "Catalog: Sqlite");
let options = SqliteConnectionOptions {
file_path: file_path.to_string(),
};
Ok(Arc::new(
SqliteCatalog::connect(options, metrics)
.await
.context(CatalogSnafu)?,
))
} else {
Err(Error::UnknownCatalogDsn {
dsn: dsn.to_string(),
})
}
}
}

View File

@ -0,0 +1,253 @@
//! CLI config for compactor-related commands
use std::num::NonZeroUsize;
use crate::{gossip::GossipConfig, memory_size::MemorySize};
use super::compactor_scheduler::CompactorSchedulerConfig;
/// CLI config for compactor
#[derive(Debug, Clone, clap::Parser)]
pub struct CompactorConfig {
/// Gossip config.
#[clap(flatten)]
pub gossip_config: GossipConfig,
/// Configuration for the compactor scheduler
#[clap(flatten)]
pub compactor_scheduler_config: CompactorSchedulerConfig,
/// Number of partitions that should be compacted in parallel.
///
/// This should usually be larger than the compaction job
/// concurrency since one partition can spawn multiple compaction
/// jobs.
#[clap(
long = "compaction-partition-concurrency",
env = "INFLUXDB_IOX_COMPACTION_PARTITION_CONCURRENCY",
default_value = "100",
action
)]
pub compaction_partition_concurrency: NonZeroUsize,
/// Number of concurrent compaction jobs scheduled to DataFusion.
///
/// This should usually be smaller than the partition concurrency
/// since one partition can spawn multiple DF compaction jobs.
#[clap(
long = "compaction-df-concurrency",
env = "INFLUXDB_IOX_COMPACTION_DF_CONCURRENCY",
default_value = "10",
action
)]
pub compaction_df_concurrency: NonZeroUsize,
/// Number of jobs PER PARTITION that move files in and out of the
/// scratchpad.
#[clap(
long = "compaction-partition-scratchpad-concurrency",
env = "INFLUXDB_IOX_COMPACTION_PARTITION_SCRATCHPAD_CONCURRENCY",
default_value = "10",
action
)]
pub compaction_partition_scratchpad_concurrency: NonZeroUsize,
/// Number of threads to use for the compactor query execution,
/// compaction and persistence.
/// If not specified, defaults to one less than the number of cores on the system
#[clap(
long = "query-exec-thread-count",
env = "INFLUXDB_IOX_QUERY_EXEC_THREAD_COUNT",
action
)]
pub query_exec_thread_count: Option<NonZeroUsize>,
/// Size of memory pool used during compaction plan execution, in
/// bytes.
///
/// If compaction plans attempt to allocate more than this many
/// bytes during execution, they will error with
/// "ResourcesExhausted".
///
/// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`).
#[clap(
long = "exec-mem-pool-bytes",
env = "INFLUXDB_IOX_EXEC_MEM_POOL_BYTES",
default_value = "8589934592", // 8GB
action
)]
pub exec_mem_pool_bytes: MemorySize,
/// Desired max size of compacted parquet files.
///
/// Note this is a target desired value, rather than a guarantee.
/// 1024 * 1024 * 100 = 104,857,600
#[clap(
long = "compaction-max-desired-size-bytes",
env = "INFLUXDB_IOX_COMPACTION_MAX_DESIRED_FILE_SIZE_BYTES",
default_value = "104857600",
action
)]
pub max_desired_file_size_bytes: u64,
/// Percentage of desired max file size for "leading edge split"
/// optimization.
///
/// This setting controls the estimated output file size at which
/// the compactor will apply the "leading edge" optimization.
///
/// When compacting files together, if the output size is
/// estimated to be greater than the following quantity, the
/// "leading edge split" optimization will be applied:
///
/// percentage_max_file_size * max_desired_file_size_bytes
///
/// This value must be between (0, 100)
///
/// Default is 20
#[clap(
long = "compaction-percentage-max-file_size",
env = "INFLUXDB_IOX_COMPACTION_PERCENTAGE_MAX_FILE_SIZE",
default_value = "20",
action
)]
pub percentage_max_file_size: u16,
/// Split file percentage for "leading edge split"
///
/// To reduce the likelihood of recompacting the same data too many
/// times, the compactor uses the "leading edge split"
/// optimization for the common case where the new data written
/// into a partition also has the most recent timestamps.
///
/// When compacting multiple files together, if the compactor
/// estimates the resulting file will be large enough (see
/// `percentage_max_file_size`) it creates two output files
/// rather than one, split by time, like this:
///
/// `|-------------- older_data -----------------||---- newer_data ----|`
///
/// In the common case, the file containing `older_data` is less
/// likely to overlap with new data written in.
///
/// This setting controls what percentage of data is placed into
/// the `older_data` portion.
///
/// Increasing this value increases the average size of compacted
/// files after the first round of compaction. However, doing so
/// also increase the likelihood that late arriving data will
/// overlap with larger existing files, necessitating additional
/// compaction rounds.
///
/// This value must be between (0, 100)
#[clap(
long = "compaction-split-percentage",
env = "INFLUXDB_IOX_COMPACTION_SPLIT_PERCENTAGE",
default_value = "80",
action
)]
pub split_percentage: u16,
/// Maximum duration of the per-partition compaction task in seconds.
#[clap(
long = "compaction-partition-timeout-secs",
env = "INFLUXDB_IOX_COMPACTION_PARTITION_TIMEOUT_SECS",
default_value = "1800",
action
)]
pub partition_timeout_secs: u64,
/// Shadow mode.
///
/// This will NOT write / commit any output to the object store or catalog.
///
/// This is mostly useful for debugging.
#[clap(
long = "compaction-shadow-mode",
env = "INFLUXDB_IOX_COMPACTION_SHADOW_MODE",
action
)]
pub shadow_mode: bool,
/// Enable scratchpad.
///
/// This allows disabling the scratchpad in production.
///
/// Disabling this is useful for testing performance and memory consequences of the scratchpad.
#[clap(
long = "compaction-enable-scratchpad",
env = "INFLUXDB_IOX_COMPACTION_ENABLE_SCRATCHPAD",
default_value = "true",
action
)]
pub enable_scratchpad: bool,
/// Maximum number of files that the compactor will try and
/// compact in a single plan.
///
/// The higher this setting is the fewer compactor plans are run
/// and thus fewer resources over time are consumed by the
/// compactor. Increasing this setting also increases the peak
/// memory used for each compaction plan, and thus if it is set
/// too high, the compactor plans may exceed available memory.
#[clap(
long = "compaction-max-num-files-per-plan",
env = "INFLUXDB_IOX_COMPACTION_MAX_NUM_FILES_PER_PLAN",
default_value = "20",
action
)]
pub max_num_files_per_plan: usize,
/// Minimum number of L1 files to compact to L2.
///
/// If there are more than this many L1 (by definition non
/// overlapping) files in a partition, the compactor will compact
/// them together into one or more larger L2 files.
///
/// Setting this value higher in general results in fewer overall
/// resources spent on compaction but more files per partition (and
/// thus less optimal compression and query performance).
#[clap(
long = "compaction-min-num-l1-files-to-compact",
env = "INFLUXDB_IOX_COMPACTION_MIN_NUM_L1_FILES_TO_COMPACT",
default_value = "10",
action
)]
pub min_num_l1_files_to_compact: usize,
/// Only process all discovered partitions once.
///
/// By default the compactor will continuously loop over all
/// partitions looking for work. Setting this option results in
/// exiting the loop after the one iteration.
#[clap(
long = "compaction-process-once",
env = "INFLUXDB_IOX_COMPACTION_PROCESS_ONCE",
action
)]
pub process_once: bool,
/// Maximum number of columns in a table of a partition that
/// will be able to considered to get compacted
///
/// If a table has more than this many columns, the compactor will
/// not compact it, to avoid large memory use.
#[clap(
long = "compaction-max-num-columns-per-table",
env = "INFLUXDB_IOX_COMPACTION_MAX_NUM_COLUMNS_PER_TABLE",
default_value = "10000",
action
)]
pub max_num_columns_per_table: usize,
/// Limit the number of partition fetch queries to at most the specified
/// number of queries per second.
///
/// Queries are smoothed over the full second.
#[clap(
long = "max-partition-fetch-queries-per-second",
env = "INFLUXDB_IOX_MAX_PARTITION_FETCH_QUERIES_PER_SECOND",
action
)]
pub max_partition_fetch_queries_per_second: Option<usize>,
}

View File

@ -0,0 +1,159 @@
//! Compactor-Scheduler-related configs.
/// Compaction Scheduler type.
#[derive(Debug, Default, Clone, Copy, PartialEq, clap::ValueEnum)]
pub enum CompactorSchedulerType {
/// Perform scheduling decisions locally.
#[default]
Local,
/// Perform scheduling decisions remotely.
Remote,
}
/// CLI config for compactor scheduler.
#[derive(Debug, Clone, Default, clap::Parser)]
pub struct ShardConfigForLocalScheduler {
/// Number of shards.
///
/// If this is set then the shard ID MUST also be set. If both are not provided, sharding is disabled.
/// (shard ID can be provided by the host name)
#[clap(
long = "compaction-shard-count",
env = "INFLUXDB_IOX_COMPACTION_SHARD_COUNT",
action
)]
pub shard_count: Option<usize>,
/// Shard ID.
///
/// Starts at 0, must be smaller than the number of shard.
///
/// If this is set then the shard count MUST also be set. If both are not provided, sharding is disabled.
#[clap(
long = "compaction-shard-id",
env = "INFLUXDB_IOX_COMPACTION_SHARD_ID",
requires("shard_count"),
action
)]
pub shard_id: Option<usize>,
/// Host Name
///
/// comprised of leading text (e.g. 'iox-shared-compactor-'), ending with shard_id (e.g. '0').
/// When shard_count is specified, but shard_id is not specified, the id is extracted from hostname.
#[clap(env = "HOSTNAME")]
pub hostname: Option<String>,
}
/// CLI config for partitions_source used by the scheduler.
#[derive(Debug, Clone, Default, clap::Parser)]
pub struct PartitionSourceConfigForLocalScheduler {
/// The compactor will only consider compacting partitions that
/// have new Parquet files created within this many minutes.
#[clap(
long = "compaction_partition_minute_threshold",
env = "INFLUXDB_IOX_COMPACTION_PARTITION_MINUTE_THRESHOLD",
default_value = "10",
action
)]
pub compaction_partition_minute_threshold: u64,
/// Filter partitions to the given set of IDs.
///
/// This is mostly useful for debugging.
#[clap(
long = "compaction-partition-filter",
env = "INFLUXDB_IOX_COMPACTION_PARTITION_FILTER",
action
)]
pub partition_filter: Option<Vec<i64>>,
/// Compact all partitions found in the catalog, no matter if/when
/// they received writes.
#[clap(
long = "compaction-process-all-partitions",
env = "INFLUXDB_IOX_COMPACTION_PROCESS_ALL_PARTITIONS",
default_value = "false",
action
)]
pub process_all_partitions: bool,
/// Ignores "partition marked w/ error and shall be skipped" entries in the catalog.
///
/// This is mostly useful for debugging.
#[clap(
long = "compaction-ignore-partition-skip-marker",
env = "INFLUXDB_IOX_COMPACTION_IGNORE_PARTITION_SKIP_MARKER",
action
)]
pub ignore_partition_skip_marker: bool,
}
/// CLI config for compactor scheduler.
#[derive(Debug, Clone, Default, clap::Parser)]
pub struct CompactorSchedulerConfig {
/// Scheduler type to use.
#[clap(
value_enum,
long = "compactor-scheduler",
env = "INFLUXDB_IOX_COMPACTION_SCHEDULER",
default_value = "local",
action
)]
pub compactor_scheduler_type: CompactorSchedulerType,
/// Partition source config used by the local scheduler.
#[clap(flatten)]
pub partition_source_config: PartitionSourceConfigForLocalScheduler,
/// Shard config used by the local scheduler.
#[clap(flatten)]
pub shard_config: ShardConfigForLocalScheduler,
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
use test_helpers::assert_contains;
#[test]
fn default_compactor_scheduler_type_is_local() {
let config = CompactorSchedulerConfig::try_parse_from(["my_binary"]).unwrap();
assert_eq!(
config.compactor_scheduler_type,
CompactorSchedulerType::Local
);
}
#[test]
fn can_specify_local() {
let config = CompactorSchedulerConfig::try_parse_from([
"my_binary",
"--compactor-scheduler",
"local",
])
.unwrap();
assert_eq!(
config.compactor_scheduler_type,
CompactorSchedulerType::Local
);
}
#[test]
fn any_other_scheduler_type_string_is_invalid() {
let error = CompactorSchedulerConfig::try_parse_from([
"my_binary",
"--compactor-scheduler",
"hello",
])
.unwrap_err()
.to_string();
assert_contains!(
&error,
"invalid value 'hello' for '--compactor-scheduler <COMPACTOR_SCHEDULER_TYPE>'"
);
assert_contains!(&error, "[possible values: local, remote]");
}
}

View File

@ -0,0 +1,85 @@
//! Garbage Collector configuration
use clap::Parser;
use humantime::parse_duration;
use std::{fmt::Debug, time::Duration};
/// Configuration specific to the object store garbage collector
#[derive(Debug, Clone, Parser, Copy)]
pub struct GarbageCollectorConfig {
/// If this flag is specified, don't delete the files in object storage. Only print the files
/// that would be deleted if this flag wasn't specified.
#[clap(long, env = "INFLUXDB_IOX_GC_DRY_RUN")]
pub dry_run: bool,
/// Items in the object store that are older than this duration that are not referenced in the
/// catalog will be deleted.
/// Parsed with <https://docs.rs/humantime/latest/humantime/fn.parse_duration.html>
///
/// If not specified, defaults to 14 days ago.
#[clap(
long,
default_value = "14d",
value_parser = parse_duration,
env = "INFLUXDB_IOX_GC_OBJECTSTORE_CUTOFF"
)]
pub objectstore_cutoff: Duration,
/// Number of concurrent object store deletion tasks
#[clap(
long,
default_value_t = 5,
env = "INFLUXDB_IOX_GC_OBJECTSTORE_CONCURRENT_DELETES"
)]
pub objectstore_concurrent_deletes: usize,
/// Number of minutes to sleep between iterations of the objectstore list loop.
/// This is the sleep between entirely fresh list operations.
/// Defaults to 30 minutes.
#[clap(
long,
default_value_t = 30,
env = "INFLUXDB_IOX_GC_OBJECTSTORE_SLEEP_INTERVAL_MINUTES"
)]
pub objectstore_sleep_interval_minutes: u64,
/// Number of milliseconds to sleep between listing consecutive chunks of objecstore files.
/// Object store listing is processed in batches; this is the sleep between batches.
/// Defaults to 1000 milliseconds.
#[clap(
long,
default_value_t = 1000,
env = "INFLUXDB_IOX_GC_OBJECTSTORE_SLEEP_INTERVAL_BATCH_MILLISECONDS"
)]
pub objectstore_sleep_interval_batch_milliseconds: u64,
/// Parquet file rows in the catalog flagged for deletion before this duration will be deleted.
/// Parsed with <https://docs.rs/humantime/latest/humantime/fn.parse_duration.html>
///
/// If not specified, defaults to 14 days ago.
#[clap(
long,
default_value = "14d",
value_parser = parse_duration,
env = "INFLUXDB_IOX_GC_PARQUETFILE_CUTOFF"
)]
pub parquetfile_cutoff: Duration,
/// Number of minutes to sleep between iterations of the parquet file deletion loop.
/// Defaults to 30 minutes.
#[clap(
long,
default_value_t = 30,
env = "INFLUXDB_IOX_GC_PARQUETFILE_SLEEP_INTERVAL_MINUTES"
)]
pub parquetfile_sleep_interval_minutes: u64,
/// Number of minutes to sleep between iterations of the retention code.
/// Defaults to 35 minutes to reduce incidence of it running at the same time as the parquet
/// file deleter.
#[clap(
long,
default_value_t = 35,
env = "INFLUXDB_IOX_GC_RETENTION_SLEEP_INTERVAL_MINUTES"
)]
pub retention_sleep_interval_minutes: u64,
}

49
clap_blocks/src/gossip.rs Normal file
View File

@ -0,0 +1,49 @@
//! CLI config for cluster gossip communication.
use crate::socket_addr::SocketAddr;
/// Configuration parameters for the cluster gossip communication mechanism.
#[derive(Debug, Clone, clap::Parser)]
#[allow(missing_copy_implementations)]
pub struct GossipConfig {
/// A comma-delimited set of seed gossip peer addresses.
///
/// Example: "10.0.0.1:4242,10.0.0.2:4242"
///
/// These seeds will be used to discover all other peers that talk to the
/// same seeds. Typically all nodes in the cluster should use the same set
/// of seeds.
#[clap(
long = "gossip-seed-list",
env = "INFLUXDB_IOX_GOSSIP_SEED_LIST",
required = false,
num_args=1..,
value_delimiter = ',',
requires = "gossip_bind_address", // Field name, not flag
)]
pub seed_list: Vec<String>,
/// The UDP socket address IOx will use for gossip communication between
/// peers.
///
/// Example: "0.0.0.0:4242"
///
/// If not provided, the gossip sub-system is disabled.
#[clap(
long = "gossip-bind-address",
env = "INFLUXDB_IOX_GOSSIP_BIND_ADDR",
requires = "seed_list", // Field name, not flag
action
)]
pub gossip_bind_address: Option<SocketAddr>,
}
impl GossipConfig {
/// Initialise the gossip config to be disabled.
pub fn disabled() -> Self {
Self {
seed_list: vec![],
gossip_bind_address: None,
}
}
}

View File

@ -0,0 +1,88 @@
//! CLI config for the ingester using the RPC write path
use std::{num::NonZeroUsize, path::PathBuf};
use crate::gossip::GossipConfig;
/// CLI config for the ingester using the RPC write path
#[derive(Debug, Clone, clap::Parser)]
#[allow(missing_copy_implementations)]
pub struct IngesterConfig {
/// Gossip config.
#[clap(flatten)]
pub gossip_config: GossipConfig,
/// Where this ingester instance should store its write-ahead log files. Each ingester instance
/// must have its own directory.
#[clap(long = "wal-directory", env = "INFLUXDB_IOX_WAL_DIRECTORY", action)]
pub wal_directory: PathBuf,
/// Specify the maximum allowed incoming RPC write message size sent by the
/// Router.
#[clap(
long = "rpc-write-max-incoming-bytes",
env = "INFLUXDB_IOX_RPC_WRITE_MAX_INCOMING_BYTES",
default_value = "104857600", // 100MiB
)]
pub rpc_write_max_incoming_bytes: usize,
/// The number of seconds between WAL file rotations.
#[clap(
long = "wal-rotation-period-seconds",
env = "INFLUXDB_IOX_WAL_ROTATION_PERIOD_SECONDS",
default_value = "300",
action
)]
pub wal_rotation_period_seconds: u64,
/// Sets how many queries the ingester will handle simultaneously before
/// rejecting further incoming requests.
#[clap(
long = "concurrent-query-limit",
env = "INFLUXDB_IOX_CONCURRENT_QUERY_LIMIT",
default_value = "20",
action
)]
pub concurrent_query_limit: usize,
/// The maximum number of persist tasks that can run simultaneously.
#[clap(
long = "persist-max-parallelism",
env = "INFLUXDB_IOX_PERSIST_MAX_PARALLELISM",
default_value = "5",
action
)]
pub persist_max_parallelism: usize,
/// The maximum number of persist tasks that can be queued at any one time.
///
/// Once this limit is reached, ingest is blocked until the persist backlog
/// is reduced.
#[clap(
long = "persist-queue-depth",
env = "INFLUXDB_IOX_PERSIST_QUEUE_DEPTH",
default_value = "250",
action
)]
pub persist_queue_depth: usize,
/// The limit at which a partition's estimated persistence cost causes it to
/// be queued for persistence.
#[clap(
long = "persist-hot-partition-cost",
env = "INFLUXDB_IOX_PERSIST_HOT_PARTITION_COST",
default_value = "20000000", // 20,000,000
action
)]
pub persist_hot_partition_cost: usize,
/// Limit the number of partitions that may be buffered in a single
/// namespace (across all tables) at any one time.
///
/// This limit is disabled by default.
#[clap(
long = "max-partitions-per-namespace",
env = "INFLUXDB_IOX_MAX_PARTITIONS_PER_NAMESPACE"
)]
pub max_partitions_per_namespace: Option<NonZeroUsize>,
}

View File

@ -0,0 +1,308 @@
//! Shared configuration and tests for accepting ingester addresses as arguments.
use http::uri::{InvalidUri, InvalidUriParts, Uri};
use snafu::Snafu;
use std::{fmt::Display, str::FromStr};
/// An address to an ingester's gRPC API. Create by using `IngesterAddress::from_str`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IngesterAddress {
uri: Uri,
}
/// Why a specified ingester address might be invalid
#[allow(missing_docs)]
#[derive(Snafu, Debug)]
pub enum Error {
#[snafu(context(false))]
Invalid { source: InvalidUri },
#[snafu(display("Port is required; no port found in `{value}`"))]
MissingPort { value: String },
#[snafu(context(false))]
InvalidParts { source: InvalidUriParts },
}
impl FromStr for IngesterAddress {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let uri = Uri::from_str(s)?;
if uri.port().is_none() {
return MissingPortSnafu { value: s }.fail();
}
let uri = if uri.scheme().is_none() {
Uri::from_str(&format!("http://{s}"))?
} else {
uri
};
Ok(Self { uri })
}
}
impl Display for IngesterAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.uri)
}
}
#[cfg(test)]
mod tests {
use super::*;
use clap::{error::ErrorKind, Parser};
use std::env;
use test_helpers::{assert_contains, assert_error};
/// Applications such as the router MUST have valid ingester addresses.
#[derive(Debug, Clone, clap::Parser)]
struct RouterConfig {
#[clap(
long = "ingester-addresses",
env = "TEST_INFLUXDB_IOX_INGESTER_ADDRESSES",
required = true,
num_args=1..,
value_delimiter = ','
)]
pub ingester_addresses: Vec<IngesterAddress>,
}
#[test]
fn error_if_not_specified_when_required() {
assert_error!(
RouterConfig::try_parse_from(["my_binary"]),
ref e if e.kind() == ErrorKind::MissingRequiredArgument
);
}
/// Applications such as the querier might not have any ingester addresses, but if they have
/// any, they should be valid.
#[derive(Debug, Clone, clap::Parser)]
struct QuerierConfig {
#[clap(
long = "ingester-addresses",
env = "TEST_INFLUXDB_IOX_INGESTER_ADDRESSES",
required = false,
num_args=0..,
value_delimiter = ','
)]
pub ingester_addresses: Vec<IngesterAddress>,
}
#[test]
fn empty_if_not_specified_when_optional() {
assert!(QuerierConfig::try_parse_from(["my_binary"])
.unwrap()
.ingester_addresses
.is_empty());
}
fn both_types_valid(args: &[&'static str], expected: &[&'static str]) {
let router = RouterConfig::try_parse_from(args).unwrap();
let actual: Vec<_> = router
.ingester_addresses
.iter()
.map(ToString::to_string)
.collect();
assert_eq!(actual, expected);
let querier = QuerierConfig::try_parse_from(args).unwrap();
let actual: Vec<_> = querier
.ingester_addresses
.iter()
.map(ToString::to_string)
.collect();
assert_eq!(actual, expected);
}
fn both_types_error(args: &[&'static str], expected_error_message: &'static str) {
assert_contains!(
RouterConfig::try_parse_from(args).unwrap_err().to_string(),
expected_error_message
);
assert_contains!(
QuerierConfig::try_parse_from(args).unwrap_err().to_string(),
expected_error_message
);
}
#[test]
fn accepts_one() {
let args = [
"my_binary",
"--ingester-addresses",
"http://example.com:1234",
];
let expected = ["http://example.com:1234/"];
both_types_valid(&args, &expected);
}
#[test]
fn accepts_two() {
let args = [
"my_binary",
"--ingester-addresses",
"http://example.com:1234,http://example.com:5678",
];
let expected = ["http://example.com:1234/", "http://example.com:5678/"];
both_types_valid(&args, &expected);
}
#[test]
fn rejects_any_invalid_uri() {
let args = [
"my_binary",
"--ingester-addresses",
"http://example.com:1234,", // note the trailing comma; empty URIs are invalid
];
let expected = "error: invalid value '' for '--ingester-addresses";
both_types_error(&args, expected);
}
#[test]
fn rejects_uri_without_port() {
let args = [
"my_binary",
"--ingester-addresses",
"example.com,http://example.com:1234",
];
let expected = "Port is required; no port found in `example.com`";
both_types_error(&args, expected);
}
#[test]
fn no_scheme_assumes_http() {
let args = [
"my_binary",
"--ingester-addresses",
"http://example.com:1234,somescheme://0.0.0.0:1000,127.0.0.1:8080",
];
let expected = [
"http://example.com:1234/",
"somescheme://0.0.0.0:1000/",
"http://127.0.0.1:8080/",
];
both_types_valid(&args, &expected);
}
#[test]
fn specifying_flag_multiple_times_works() {
let args = [
"my_binary",
"--ingester-addresses",
"http://example.com:1234",
"--ingester-addresses",
"somescheme://0.0.0.0:1000",
"--ingester-addresses",
"127.0.0.1:8080",
];
let expected = [
"http://example.com:1234/",
"somescheme://0.0.0.0:1000/",
"http://127.0.0.1:8080/",
];
both_types_valid(&args, &expected);
}
#[test]
fn specifying_flag_multiple_times_and_using_commas_works() {
let args = [
"my_binary",
"--ingester-addresses",
"http://example.com:1234",
"--ingester-addresses",
"somescheme://0.0.0.0:1000,127.0.0.1:8080",
];
let expected = [
"http://example.com:1234/",
"somescheme://0.0.0.0:1000/",
"http://127.0.0.1:8080/",
];
both_types_valid(&args, &expected);
}
/// Use an environment variable name not shared with any other config to avoid conflicts when
/// setting the var in tests.
/// Applications such as the router MUST have valid ingester addresses.
#[derive(Debug, Clone, clap::Parser)]
struct EnvRouterConfig {
#[clap(
long = "ingester-addresses",
env = "NO_CONFLICT_ROUTER_TEST_INFLUXDB_IOX_INGESTER_ADDRESSES",
required = true,
num_args=1..,
value_delimiter = ','
)]
pub ingester_addresses: Vec<IngesterAddress>,
}
#[test]
fn required_and_specified_via_environment_variable() {
env::set_var(
"NO_CONFLICT_ROUTER_TEST_INFLUXDB_IOX_INGESTER_ADDRESSES",
"http://example.com:1234,somescheme://0.0.0.0:1000,127.0.0.1:8080",
);
let args = ["my_binary"];
let expected = [
"http://example.com:1234/",
"somescheme://0.0.0.0:1000/",
"http://127.0.0.1:8080/",
];
let router = EnvRouterConfig::try_parse_from(args).unwrap();
let actual: Vec<_> = router
.ingester_addresses
.iter()
.map(ToString::to_string)
.collect();
assert_eq!(actual, expected);
}
/// Use an environment variable name not shared with any other config to avoid conflicts when
/// setting the var in tests.
/// Applications such as the querier might not have any ingester addresses, but if they have
/// any, they should be valid.
#[derive(Debug, Clone, clap::Parser)]
struct EnvQuerierConfig {
#[clap(
long = "ingester-addresses",
env = "NO_CONFLICT_QUERIER_TEST_INFLUXDB_IOX_INGESTER_ADDRESSES",
required = false,
num_args=0..,
value_delimiter = ','
)]
pub ingester_addresses: Vec<IngesterAddress>,
}
#[test]
fn optional_and_specified_via_environment_variable() {
env::set_var(
"NO_CONFLICT_QUERIER_TEST_INFLUXDB_IOX_INGESTER_ADDRESSES",
"http://example.com:1234,somescheme://0.0.0.0:1000,127.0.0.1:8080",
);
let args = ["my_binary"];
let expected = [
"http://example.com:1234/",
"somescheme://0.0.0.0:1000/",
"http://127.0.0.1:8080/",
];
let querier = EnvQuerierConfig::try_parse_from(args).unwrap();
let actual: Vec<_> = querier
.ingester_addresses
.iter()
.map(ToString::to_string)
.collect();
assert_eq!(actual, expected);
}
}

34
clap_blocks/src/lib.rs Normal file
View File

@ -0,0 +1,34 @@
//! Building blocks for [`clap`]-driven configs.
//!
//! They can easily be re-used using `#[clap(flatten)]`.
#![deny(rustdoc::broken_intra_doc_links, rust_2018_idioms)]
#![warn(
missing_copy_implementations,
missing_docs,
clippy::explicit_iter_loop,
// See https://github.com/influxdata/influxdb_iox/pull/1671
clippy::future_not_send,
clippy::use_self,
clippy::clone_on_ref_ptr,
clippy::todo,
clippy::dbg_macro,
unused_crate_dependencies
)]
// Workaround for "unused crate" lint false positives.
use workspace_hack as _;
pub mod catalog_dsn;
pub mod compactor;
pub mod compactor_scheduler;
pub mod garbage_collector;
pub mod gossip;
pub mod ingester;
pub mod ingester_address;
pub mod memory_size;
pub mod object_store;
pub mod querier;
pub mod router;
pub mod run_config;
pub mod single_tenant;
pub mod socket_addr;

View File

@ -0,0 +1,108 @@
//! Helper types to express memory size.
use std::{str::FromStr, sync::OnceLock};
use sysinfo::{RefreshKind, System, SystemExt};
/// Memory size.
///
/// # Parsing
/// This can be parsed from strings in one of the following formats:
///
/// - **absolute:** just use a non-negative number to specify the absolute bytes, e.g. `1024`
/// - **relative:** use percentage between 0 and 100 (both inclusive) to specify a relative amount of the totally
/// available memory size, e.g. `50%`
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct MemorySize(usize);
impl MemorySize {
/// Number of bytes.
pub fn bytes(&self) -> usize {
self.0
}
}
impl std::fmt::Debug for MemorySize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::fmt::Display for MemorySize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl FromStr for MemorySize {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.strip_suffix('%') {
Some(s) => {
let percentage = u64::from_str(s).map_err(|e| e.to_string())?;
if percentage > 100 {
return Err(format!(
"relative memory size must be in [0, 100] but is {percentage}"
));
}
let total = *TOTAL_MEM_BYTES.get_or_init(|| {
let sys = System::new_with_specifics(RefreshKind::new().with_memory());
sys.total_memory() as usize
});
let bytes = (percentage as f64 / 100f64 * total as f64).round() as usize;
Ok(Self(bytes))
}
None => {
let bytes = usize::from_str(s).map_err(|e| e.to_string())?;
Ok(Self(bytes))
}
}
}
}
/// Totally available memory size in bytes.
///
/// Keep this in a global state so that we only need to inspect the system once during IOx startup.
static TOTAL_MEM_BYTES: OnceLock<usize> = OnceLock::new();
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse() {
assert_ok("0", 0);
assert_ok("1", 1);
assert_ok("1024", 1024);
assert_ok("0%", 0);
assert_gt_zero("50%");
assert_err("-1", "invalid digit found in string");
assert_err("foo", "invalid digit found in string");
assert_err("-1%", "invalid digit found in string");
assert_err(
"101%",
"relative memory size must be in [0, 100] but is 101",
);
}
#[track_caller]
fn assert_ok(s: &'static str, expected: usize) {
let parsed: MemorySize = s.parse().unwrap();
assert_eq!(parsed.bytes(), expected);
}
#[track_caller]
fn assert_gt_zero(s: &'static str) {
let parsed: MemorySize = s.parse().unwrap();
assert!(parsed.bytes() > 0);
}
#[track_caller]
fn assert_err(s: &'static str, expected: &'static str) {
let err = MemorySize::from_str(s).unwrap_err();
assert_eq!(err, expected);
}
}

View File

@ -0,0 +1,617 @@
//! CLI handling for object store config (via CLI arguments and environment variables).
use futures::TryStreamExt;
use object_store::memory::InMemory;
use object_store::path::Path;
use object_store::throttle::ThrottledStore;
use object_store::{throttle::ThrottleConfig, DynObjectStore};
use observability_deps::tracing::{info, warn};
use snafu::{ResultExt, Snafu};
use std::sync::Arc;
use std::{fs, num::NonZeroUsize, path::PathBuf, time::Duration};
use uuid::Uuid;
#[derive(Debug, Snafu)]
#[allow(missing_docs)]
pub enum ParseError {
#[snafu(display("Unable to create database directory {:?}: {}", path, source))]
CreatingDatabaseDirectory {
path: PathBuf,
source: std::io::Error,
},
#[snafu(display("Unable to create local store {:?}: {}", path, source))]
CreateLocalFileSystem {
path: PathBuf,
source: object_store::Error,
},
#[snafu(display(
"Specified {:?} for the object store, required configuration missing for {}",
object_store,
missing
))]
MissingObjectStoreConfig {
object_store: ObjectStoreType,
missing: String,
},
// Creating a new S3 object store can fail if the region is *specified* but
// not *parseable* as a rusoto `Region`. The other object store constructors
// don't return `Result`.
#[snafu(display("Error configuring Amazon S3: {}", source))]
InvalidS3Config { source: object_store::Error },
#[snafu(display("Error configuring GCS: {}", source))]
InvalidGCSConfig { source: object_store::Error },
#[snafu(display("Error configuring Microsoft Azure: {}", source))]
InvalidAzureConfig { source: object_store::Error },
}
/// The AWS region to use for Amazon S3 based object storage if none is
/// specified.
pub const FALLBACK_AWS_REGION: &str = "us-east-1";
/// CLI config for object stores.
#[derive(Debug, Clone, clap::Parser)]
pub struct ObjectStoreConfig {
/// Which object storage to use. If not specified, defaults to memory.
///
/// Possible values (case insensitive):
///
/// * memory (default): Effectively no object persistence.
/// * memorythrottled: Like `memory` but with latency and throughput that somewhat resamble a cloud
/// object store. Useful for testing and benchmarking.
/// * file: Stores objects in the local filesystem. Must also set `--data-dir`.
/// * s3: Amazon S3. Must also set `--bucket`, `--aws-access-key-id`, `--aws-secret-access-key`, and
/// possibly `--aws-default-region`.
/// * google: Google Cloud Storage. Must also set `--bucket` and `--google-service-account`.
/// * azure: Microsoft Azure blob storage. Must also set `--bucket`, `--azure-storage-account`,
/// and `--azure-storage-access-key`.
#[clap(
value_enum,
long = "object-store",
env = "INFLUXDB_IOX_OBJECT_STORE",
ignore_case = true,
action
)]
pub object_store: Option<ObjectStoreType>,
/// Name of the bucket to use for the object store. Must also set
/// `--object-store` to a cloud object storage to have any effect.
///
/// If using Google Cloud Storage for the object store, this item as well
/// as `--google-service-account` must be set.
///
/// If using S3 for the object store, must set this item as well
/// as `--aws-access-key-id` and `--aws-secret-access-key`. Can also set
/// `--aws-default-region` if not using the fallback region.
///
/// If using Azure for the object store, set this item to the name of a
/// container you've created in the associated storage account, under
/// Blob Service > Containers. Must also set `--azure-storage-account` and
/// `--azure-storage-access-key`.
#[clap(long = "bucket", env = "INFLUXDB_IOX_BUCKET", action)]
pub bucket: Option<String>,
/// The location InfluxDB IOx will use to store files locally.
#[clap(long = "data-dir", env = "INFLUXDB_IOX_DB_DIR", action)]
pub database_directory: Option<PathBuf>,
/// When using Amazon S3 as the object store, set this to an access key that
/// has permission to read from and write to the specified S3 bucket.
///
/// Must also set `--object-store=s3`, `--bucket`, and
/// `--aws-secret-access-key`. Can also set `--aws-default-region` if not
/// using the fallback region.
///
/// Prefer the environment variable over the command line flag in shared
/// environments.
#[clap(long = "aws-access-key-id", env = "AWS_ACCESS_KEY_ID", action)]
pub aws_access_key_id: Option<String>,
/// When using Amazon S3 as the object store, set this to the secret access
/// key that goes with the specified access key ID.
///
/// Must also set `--object-store=s3`, `--bucket`, `--aws-access-key-id`.
/// Can also set `--aws-default-region` if not using the fallback region.
///
/// Prefer the environment variable over the command line flag in shared
/// environments.
#[clap(long = "aws-secret-access-key", env = "AWS_SECRET_ACCESS_KEY", action)]
pub aws_secret_access_key: Option<String>,
/// When using Amazon S3 as the object store, set this to the region
/// that goes with the specified bucket if different from the fallback
/// value.
///
/// Must also set `--object-store=s3`, `--bucket`, `--aws-access-key-id`,
/// and `--aws-secret-access-key`.
#[clap(
long = "aws-default-region",
env = "AWS_DEFAULT_REGION",
default_value = FALLBACK_AWS_REGION,
action,
)]
pub aws_default_region: String,
/// When using Amazon S3 compatibility storage service, set this to the
/// endpoint.
///
/// Must also set `--object-store=s3`, `--bucket`. Can also set `--aws-default-region`
/// if not using the fallback region.
///
/// Prefer the environment variable over the command line flag in shared
/// environments.
#[clap(long = "aws-endpoint", env = "AWS_ENDPOINT", action)]
pub aws_endpoint: Option<String>,
/// When using Amazon S3 as an object store, set this to the session token. This is handy when using a federated
/// login / SSO and you fetch credentials via the UI.
///
/// Is it assumed that the session is valid as long as the IOx server is running.
///
/// Prefer the environment variable over the command line flag in shared
/// environments.
#[clap(long = "aws-session-token", env = "AWS_SESSION_TOKEN", action)]
pub aws_session_token: Option<String>,
/// Allow unencrypted HTTP connection to AWS.
#[clap(long = "aws-allow-http", env = "AWS_ALLOW_HTTP", action)]
pub aws_allow_http: bool,
/// When using Google Cloud Storage as the object store, set this to the
/// path to the JSON file that contains the Google credentials.
///
/// Must also set `--object-store=google` and `--bucket`.
#[clap(
long = "google-service-account",
env = "GOOGLE_SERVICE_ACCOUNT",
action
)]
pub google_service_account: Option<String>,
/// When using Microsoft Azure as the object store, set this to the
/// name you see when going to All Services > Storage accounts > `[name]`.
///
/// Must also set `--object-store=azure`, `--bucket`, and
/// `--azure-storage-access-key`.
#[clap(long = "azure-storage-account", env = "AZURE_STORAGE_ACCOUNT", action)]
pub azure_storage_account: Option<String>,
/// When using Microsoft Azure as the object store, set this to one of the
/// Key values in the Storage account's Settings > Access keys.
///
/// Must also set `--object-store=azure`, `--bucket`, and
/// `--azure-storage-account`.
///
/// Prefer the environment variable over the command line flag in shared
/// environments.
#[clap(
long = "azure-storage-access-key",
env = "AZURE_STORAGE_ACCESS_KEY",
action
)]
pub azure_storage_access_key: Option<String>,
/// When using a network-based object store, limit the number of connection to this value.
#[clap(
long = "object-store-connection-limit",
env = "OBJECT_STORE_CONNECTION_LIMIT",
default_value = "16",
action
)]
pub object_store_connection_limit: NonZeroUsize,
}
impl ObjectStoreConfig {
/// Create a new instance for all-in-one mode, only allowing some arguments.
pub fn new(database_directory: Option<PathBuf>) -> Self {
match &database_directory {
Some(dir) => info!("Object store: File-based in `{}`", dir.display()),
None => info!("Object store: In-memory"),
}
let object_store = database_directory.as_ref().map(|_| ObjectStoreType::File);
Self {
aws_access_key_id: Default::default(),
aws_allow_http: Default::default(),
aws_default_region: Default::default(),
aws_endpoint: Default::default(),
aws_secret_access_key: Default::default(),
aws_session_token: Default::default(),
azure_storage_access_key: Default::default(),
azure_storage_account: Default::default(),
bucket: Default::default(),
database_directory,
google_service_account: Default::default(),
object_store,
object_store_connection_limit: NonZeroUsize::new(16).unwrap(),
}
}
}
/// Object-store type.
#[derive(Debug, Copy, Clone, PartialEq, Eq, clap::ValueEnum)]
pub enum ObjectStoreType {
/// In-memory.
Memory,
/// In-memory with additional throttling applied for testing
MemoryThrottled,
/// Filesystem.
File,
/// AWS S3.
S3,
/// GCS.
Google,
/// Azure object store.
Azure,
}
#[cfg(feature = "gcp")]
fn new_gcs(config: &ObjectStoreConfig) -> Result<Arc<DynObjectStore>, ParseError> {
use object_store::gcp::GoogleCloudStorageBuilder;
use object_store::limit::LimitStore;
info!(bucket=?config.bucket, object_store_type="GCS", "Object Store");
let mut builder = GoogleCloudStorageBuilder::new();
if let Some(bucket) = &config.bucket {
builder = builder.with_bucket_name(bucket);
}
if let Some(account) = &config.google_service_account {
builder = builder.with_service_account_path(account);
}
Ok(Arc::new(LimitStore::new(
builder.build().context(InvalidGCSConfigSnafu)?,
config.object_store_connection_limit.get(),
)))
}
#[cfg(not(feature = "gcp"))]
fn new_gcs(_: &ObjectStoreConfig) -> Result<Arc<DynObjectStore>, ParseError> {
panic!("GCS support not enabled, recompile with the gcp feature enabled")
}
#[cfg(feature = "aws")]
fn new_s3(config: &ObjectStoreConfig) -> Result<Arc<DynObjectStore>, ParseError> {
use object_store::aws::AmazonS3Builder;
use object_store::limit::LimitStore;
info!(bucket=?config.bucket, endpoint=?config.aws_endpoint, object_store_type="S3", "Object Store");
let mut builder = AmazonS3Builder::new()
.with_allow_http(config.aws_allow_http)
.with_region(&config.aws_default_region)
.with_imdsv1_fallback();
if let Some(bucket) = &config.bucket {
builder = builder.with_bucket_name(bucket);
}
if let Some(key_id) = &config.aws_access_key_id {
builder = builder.with_access_key_id(key_id);
}
if let Some(token) = &config.aws_session_token {
builder = builder.with_token(token);
}
if let Some(secret) = &config.aws_secret_access_key {
builder = builder.with_secret_access_key(secret);
}
if let Some(endpoint) = &config.aws_endpoint {
builder = builder.with_endpoint(endpoint);
}
Ok(Arc::new(LimitStore::new(
builder.build().context(InvalidS3ConfigSnafu)?,
config.object_store_connection_limit.get(),
)))
}
#[cfg(not(feature = "aws"))]
fn new_s3(_: &ObjectStoreConfig) -> Result<Arc<DynObjectStore>, ParseError> {
panic!("S3 support not enabled, recompile with the aws feature enabled")
}
#[cfg(feature = "azure")]
fn new_azure(config: &ObjectStoreConfig) -> Result<Arc<DynObjectStore>, ParseError> {
use object_store::azure::MicrosoftAzureBuilder;
use object_store::limit::LimitStore;
info!(bucket=?config.bucket, account=?config.azure_storage_account,
object_store_type="Azure", "Object Store");
let mut builder = MicrosoftAzureBuilder::new();
if let Some(bucket) = &config.bucket {
builder = builder.with_container_name(bucket);
}
if let Some(account) = &config.azure_storage_account {
builder = builder.with_account(account)
}
if let Some(key) = &config.azure_storage_access_key {
builder = builder.with_access_key(key)
}
Ok(Arc::new(LimitStore::new(
builder.build().context(InvalidAzureConfigSnafu)?,
config.object_store_connection_limit.get(),
)))
}
#[cfg(not(feature = "azure"))]
fn new_azure(_: &ObjectStoreConfig) -> Result<Arc<DynObjectStore>, ParseError> {
panic!("Azure blob storage support not enabled, recompile with the azure feature enabled")
}
/// Create config-dependant object store.
pub fn make_object_store(config: &ObjectStoreConfig) -> Result<Arc<DynObjectStore>, ParseError> {
if let Some(data_dir) = &config.database_directory {
if !matches!(&config.object_store, Some(ObjectStoreType::File)) {
warn!(?data_dir, object_store_type=?config.object_store,
"--data-dir / `INFLUXDB_IOX_DB_DIR` ignored. It only affects 'file' object stores");
}
}
match &config.object_store {
Some(ObjectStoreType::Memory) | None => {
info!(object_store_type = "Memory", "Object Store");
Ok(Arc::new(InMemory::new()))
}
Some(ObjectStoreType::MemoryThrottled) => {
let config = ThrottleConfig {
// for every call: assume a 100ms latency
wait_delete_per_call: Duration::from_millis(100),
wait_get_per_call: Duration::from_millis(100),
wait_list_per_call: Duration::from_millis(100),
wait_list_with_delimiter_per_call: Duration::from_millis(100),
wait_put_per_call: Duration::from_millis(100),
// for list operations: assume we need 1 call per 1k entries at 100ms
wait_list_per_entry: Duration::from_millis(100) / 1_000,
wait_list_with_delimiter_per_entry: Duration::from_millis(100) / 1_000,
// for upload/download: assume 1GByte/s
wait_get_per_byte: Duration::from_secs(1) / 1_000_000_000,
};
info!(?config, object_store_type = "Memory", "Object Store");
Ok(Arc::new(ThrottledStore::new(InMemory::new(), config)))
}
Some(ObjectStoreType::Google) => new_gcs(config),
Some(ObjectStoreType::S3) => new_s3(config),
Some(ObjectStoreType::Azure) => new_azure(config),
Some(ObjectStoreType::File) => match config.database_directory.as_ref() {
Some(db_dir) => {
info!(?db_dir, object_store_type = "Directory", "Object Store");
fs::create_dir_all(db_dir)
.context(CreatingDatabaseDirectorySnafu { path: db_dir })?;
let store = object_store::local::LocalFileSystem::new_with_prefix(db_dir)
.context(CreateLocalFileSystemSnafu { path: db_dir })?;
Ok(Arc::new(store))
}
None => MissingObjectStoreConfigSnafu {
object_store: ObjectStoreType::File,
missing: "data-dir",
}
.fail(),
},
}
}
#[derive(Debug, Snafu)]
#[allow(missing_docs)]
pub enum CheckError {
#[snafu(display("Cannot read from object store: {}", source))]
CannotReadObjectStore { source: object_store::Error },
}
/// Check if object store is properly configured and accepts writes and reads.
///
/// Note: This does NOT test if the object store is writable!
pub async fn check_object_store(object_store: &DynObjectStore) -> Result<(), CheckError> {
// Use some prefix that will very likely end in an empty result, so we don't pull too much actual data here.
let uuid = Uuid::new_v4().to_string();
let prefix = Path::from_iter([uuid]);
// create stream (this might fail if the store is not readable)
let mut stream = object_store
.list(Some(&prefix))
.await
.context(CannotReadObjectStoreSnafu)?;
// ... but sometimes it fails only if we use the resulting stream, so try that once
stream
.try_next()
.await
.context(CannotReadObjectStoreSnafu)?;
// store seems to be readable
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
use std::env;
use tempfile::TempDir;
#[test]
fn default_object_store_is_memory() {
let config = ObjectStoreConfig::try_parse_from(["server"]).unwrap();
let object_store = make_object_store(&config).unwrap();
assert_eq!(&object_store.to_string(), "InMemory")
}
#[test]
fn explicitly_set_object_store_to_memory() {
let config =
ObjectStoreConfig::try_parse_from(["server", "--object-store", "memory"]).unwrap();
let object_store = make_object_store(&config).unwrap();
assert_eq!(&object_store.to_string(), "InMemory")
}
#[test]
#[cfg(feature = "aws")]
fn valid_s3_config() {
let config = ObjectStoreConfig::try_parse_from([
"server",
"--object-store",
"s3",
"--bucket",
"mybucket",
"--aws-access-key-id",
"NotARealAWSAccessKey",
"--aws-secret-access-key",
"NotARealAWSSecretAccessKey",
])
.unwrap();
let object_store = make_object_store(&config).unwrap();
assert_eq!(&object_store.to_string(), "AmazonS3(mybucket)")
}
#[test]
#[cfg(feature = "aws")]
fn s3_config_missing_params() {
let mut config =
ObjectStoreConfig::try_parse_from(["server", "--object-store", "s3"]).unwrap();
// clean out eventual leaks via env variables
config.bucket = None;
let err = make_object_store(&config).unwrap_err().to_string();
assert_eq!(
err,
"Specified S3 for the object store, required configuration missing for bucket"
);
}
#[test]
#[cfg(feature = "gcp")]
fn valid_google_config() {
let config = ObjectStoreConfig::try_parse_from([
"server",
"--object-store",
"google",
"--bucket",
"mybucket",
"--google-service-account",
"~/Not/A/Real/path.json",
])
.unwrap();
let object_store = make_object_store(&config).unwrap();
assert_eq!(&object_store.to_string(), "GoogleCloudStorage(mybucket)")
}
#[test]
#[cfg(feature = "gcp")]
fn google_config_missing_params() {
let mut config =
ObjectStoreConfig::try_parse_from(["server", "--object-store", "google"]).unwrap();
// clean out eventual leaks via env variables
config.bucket = None;
let err = make_object_store(&config).unwrap_err().to_string();
assert_eq!(
err,
"Specified Google for the object store, required configuration missing for \
bucket, google-service-account"
);
}
#[test]
#[cfg(feature = "azure")]
fn valid_azure_config() {
let config = ObjectStoreConfig::try_parse_from([
"server",
"--object-store",
"azure",
"--bucket",
"mybucket",
"--azure-storage-account",
"NotARealStorageAccount",
"--azure-storage-access-key",
"NotARealKey",
])
.unwrap();
let object_store = make_object_store(&config).unwrap();
assert_eq!(&object_store.to_string(), "MicrosoftAzure(mybucket)")
}
#[test]
#[cfg(feature = "azure")]
fn azure_config_missing_params() {
let mut config =
ObjectStoreConfig::try_parse_from(["server", "--object-store", "azure"]).unwrap();
// clean out eventual leaks via env variables
config.bucket = None;
let err = make_object_store(&config).unwrap_err().to_string();
assert_eq!(
err,
"Specified Azure for the object store, required configuration missing for \
bucket, azure-storage-account, azure-storage-access-key"
);
}
#[test]
fn valid_file_config() {
let root = TempDir::new().unwrap();
let root_path = root.path().to_str().unwrap();
let config = ObjectStoreConfig::try_parse_from([
"server",
"--object-store",
"file",
"--data-dir",
root_path,
])
.unwrap();
let object_store = make_object_store(&config).unwrap().to_string();
assert!(
object_store.starts_with("LocalFileSystem"),
"{}",
object_store
)
}
#[test]
fn file_config_missing_params() {
// this test tests for failure to configure the object store because of data-dir configuration missing
// if the INFLUXDB_IOX_DB_DIR env variable is set, the test fails because the configuration is
// actually present.
env::remove_var("INFLUXDB_IOX_DB_DIR");
let config =
ObjectStoreConfig::try_parse_from(["server", "--object-store", "file"]).unwrap();
let err = make_object_store(&config).unwrap_err().to_string();
assert_eq!(
err,
"Specified File for the object store, required configuration missing for \
data-dir"
);
}
}

256
clap_blocks/src/querier.rs Normal file
View File

@ -0,0 +1,256 @@
//! Querier-related configs.
use crate::{
ingester_address::IngesterAddress,
memory_size::MemorySize,
single_tenant::{CONFIG_AUTHZ_ENV_NAME, CONFIG_AUTHZ_FLAG},
};
use std::{collections::HashMap, num::NonZeroUsize};
/// CLI config for querier configuration
#[derive(Debug, Clone, PartialEq, Eq, clap::Parser)]
pub struct QuerierConfig {
/// Addr for connection to authz
#[clap(long = CONFIG_AUTHZ_FLAG, env = CONFIG_AUTHZ_ENV_NAME)]
pub authz_address: Option<String>,
/// The number of threads to use for queries.
///
/// If not specified, defaults to the number of cores on the system
#[clap(
long = "num-query-threads",
env = "INFLUXDB_IOX_NUM_QUERY_THREADS",
action
)]
pub num_query_threads: Option<NonZeroUsize>,
/// Size of memory pool used during query exec, in bytes.
///
/// If queries attempt to allocate more than this many bytes
/// during execution, they will error with "ResourcesExhausted".
///
/// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`).
#[clap(
long = "exec-mem-pool-bytes",
env = "INFLUXDB_IOX_EXEC_MEM_POOL_BYTES",
default_value = "8589934592", // 8GB
action
)]
pub exec_mem_pool_bytes: MemorySize,
/// gRPC address for the router to talk with the ingesters. For
/// example:
///
/// "http://127.0.0.1:8083"
///
/// or
///
/// "http://10.10.10.1:8083,http://10.10.10.2:8083"
///
/// for multiple addresses.
#[clap(
long = "ingester-addresses",
env = "INFLUXDB_IOX_INGESTER_ADDRESSES",
required = false,
num_args = 0..,
value_delimiter = ','
)]
pub ingester_addresses: Vec<IngesterAddress>,
/// Size of the RAM cache used to store catalog metadata information in bytes.
///
/// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`).
#[clap(
long = "ram-pool-metadata-bytes",
env = "INFLUXDB_IOX_RAM_POOL_METADATA_BYTES",
default_value = "134217728", // 128MB
action
)]
pub ram_pool_metadata_bytes: MemorySize,
/// Size of the RAM cache used to store data in bytes.
///
/// Can be given as absolute value or in percentage of the total available memory (e.g. `10%`).
#[clap(
long = "ram-pool-data-bytes",
env = "INFLUXDB_IOX_RAM_POOL_DATA_BYTES",
default_value = "1073741824", // 1GB
action
)]
pub ram_pool_data_bytes: MemorySize,
/// Limit the number of concurrent queries.
#[clap(
long = "max-concurrent-queries",
env = "INFLUXDB_IOX_MAX_CONCURRENT_QUERIES",
default_value = "10",
action
)]
pub max_concurrent_queries: usize,
/// After how many ingester query errors should the querier enter circuit breaker mode?
///
/// The querier normally contacts the ingester for any unpersisted data during query planning.
/// However, when the ingester can not be contacted for some reason, the querier will begin
/// returning results that do not include unpersisted data and enter "circuit breaker mode"
/// to avoid continually retrying the failing connection on subsequent queries.
///
/// If circuits are open, the querier will NOT contact the ingester and no unpersisted data
/// will be presented to the user.
///
/// Circuits will switch to "half open" after some jittered timeout and the querier will try to
/// use the ingester in question again. If this succeeds, we are back to normal, otherwise it
/// will back off exponentially before trying again (and again ...).
///
/// In a production environment the `ingester_circuit_state` metric should be monitored.
#[clap(
long = "ingester-circuit-breaker-threshold",
env = "INFLUXDB_IOX_INGESTER_CIRCUIT_BREAKER_THRESHOLD",
default_value = "10",
action
)]
pub ingester_circuit_breaker_threshold: u64,
/// DataFusion config.
#[clap(
long = "datafusion-config",
env = "INFLUXDB_IOX_DATAFUSION_CONFIG",
default_value = "",
value_parser = parse_datafusion_config,
action
)]
pub datafusion_config: HashMap<String, String>,
}
fn parse_datafusion_config(
s: &str,
) -> Result<HashMap<String, String>, Box<dyn std::error::Error + Send + Sync + 'static>> {
let s = s.trim();
if s.is_empty() {
return Ok(HashMap::with_capacity(0));
}
let mut out = HashMap::new();
for part in s.split(',') {
let kv = part.trim().splitn(2, ':').collect::<Vec<_>>();
match kv.as_slice() {
[key, value] => {
let key_owned = key.trim().to_owned();
let value_owned = value.trim().to_owned();
let existed = out.insert(key_owned, value_owned).is_some();
if existed {
return Err(format!("key '{key}' passed multiple times").into());
}
}
_ => {
return Err(
format!("Invalid key value pair - expected 'KEY:VALUE' got '{s}'").into(),
);
}
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
use test_helpers::assert_contains;
#[test]
fn test_default() {
let actual = QuerierConfig::try_parse_from(["my_binary"]).unwrap();
assert_eq!(actual.num_query_threads, None);
assert!(actual.ingester_addresses.is_empty());
assert!(actual.datafusion_config.is_empty());
}
#[test]
fn test_num_threads() {
let actual =
QuerierConfig::try_parse_from(["my_binary", "--num-query-threads", "42"]).unwrap();
assert_eq!(
actual.num_query_threads,
Some(NonZeroUsize::new(42).unwrap())
);
}
#[test]
fn test_ingester_addresses_list() {
let querier = QuerierConfig::try_parse_from([
"my_binary",
"--ingester-addresses",
"http://ingester-0:8082,http://ingester-1:8082",
])
.unwrap();
let actual: Vec<_> = querier
.ingester_addresses
.iter()
.map(ToString::to_string)
.collect();
let expected = vec!["http://ingester-0:8082/", "http://ingester-1:8082/"];
assert_eq!(actual, expected);
}
#[test]
fn bad_ingester_addresses_list() {
let actual = QuerierConfig::try_parse_from([
"my_binary",
"--ingester-addresses",
"\\ingester-0:8082",
])
.unwrap_err()
.to_string();
assert_contains!(
actual,
"error: \
invalid value '\\ingester-0:8082' \
for '--ingester-addresses [<INGESTER_ADDRESSES>...]': \
Invalid: invalid uri character"
);
}
#[test]
fn test_datafusion_config() {
let actual = QuerierConfig::try_parse_from([
"my_binary",
"--datafusion-config= foo : bar , x:y:z ",
])
.unwrap();
assert_eq!(
actual.datafusion_config,
HashMap::from([
(String::from("foo"), String::from("bar")),
(String::from("x"), String::from("y:z")),
]),
);
}
#[test]
fn bad_datafusion_config() {
let actual = QuerierConfig::try_parse_from(["my_binary", "--datafusion-config=foo"])
.unwrap_err()
.to_string();
assert_contains!(
actual,
"error: invalid value 'foo' for '--datafusion-config <DATAFUSION_CONFIG>': Invalid key value pair - expected 'KEY:VALUE' got 'foo'"
);
let actual =
QuerierConfig::try_parse_from(["my_binary", "--datafusion-config=foo:bar,baz:1,foo:2"])
.unwrap_err()
.to_string();
assert_contains!(
actual,
"error: invalid value 'foo:bar,baz:1,foo:2' for '--datafusion-config <DATAFUSION_CONFIG>': key 'foo' passed multiple times"
);
}
}

149
clap_blocks/src/router.rs Normal file
View File

@ -0,0 +1,149 @@
//! CLI config for the router using the RPC write path
use crate::{
gossip::GossipConfig,
ingester_address::IngesterAddress,
single_tenant::{
CONFIG_AUTHZ_ENV_NAME, CONFIG_AUTHZ_FLAG, CONFIG_CST_ENV_NAME, CONFIG_CST_FLAG,
},
};
use std::{
num::{NonZeroUsize, ParseIntError},
time::Duration,
};
/// CLI config for the router using the RPC write path
#[derive(Debug, Clone, clap::Parser)]
#[allow(missing_copy_implementations)]
pub struct RouterConfig {
/// Gossip config.
#[clap(flatten)]
pub gossip_config: GossipConfig,
/// Addr for connection to authz
#[clap(
long = CONFIG_AUTHZ_FLAG,
env = CONFIG_AUTHZ_ENV_NAME,
requires("single_tenant_deployment"),
)]
pub authz_address: Option<String>,
/// Differential handling based upon deployment to CST vs MT.
///
/// At minimum, differs in supports of v1 endpoint. But also includes
/// differences in namespace handling, etc.
#[clap(
long = CONFIG_CST_FLAG,
env = CONFIG_CST_ENV_NAME,
default_value = "false",
requires_if("true", "authz_address")
)]
pub single_tenant_deployment: bool,
/// The maximum number of simultaneous requests the HTTP server is
/// configured to accept.
///
/// This number of requests, multiplied by the maximum request body size the
/// HTTP server is configured with gives the rough amount of memory a HTTP
/// server will use to buffer request bodies in memory.
///
/// A default maximum of 200 requests, multiplied by the default 10MiB
/// maximum for HTTP request bodies == ~2GiB.
#[clap(
long = "max-http-requests",
env = "INFLUXDB_IOX_MAX_HTTP_REQUESTS",
default_value = "200",
action
)]
pub http_request_limit: usize,
/// gRPC address for the router to talk with the ingesters. For
/// example:
///
/// "http://127.0.0.1:8083"
///
/// or
///
/// "http://10.10.10.1:8083,http://10.10.10.2:8083"
///
/// for multiple addresses.
#[clap(
long = "ingester-addresses",
env = "INFLUXDB_IOX_INGESTER_ADDRESSES",
required = true,
num_args=1..,
value_delimiter = ','
)]
pub ingester_addresses: Vec<IngesterAddress>,
/// Retention period to use when auto-creating namespaces.
/// For infinite retention, leave this unset and it will default to `None`.
/// Setting it to zero will not make it infinite.
/// Ignored if namespace-autocreation-enabled is set to false.
#[clap(
long = "new-namespace-retention-hours",
env = "INFLUXDB_IOX_NEW_NAMESPACE_RETENTION_HOURS",
action
)]
pub new_namespace_retention_hours: Option<u64>,
/// When writing data to a non-existent namespace, should the router auto-create the namespace
/// or reject the write? Set to false to disable namespace autocreation.
#[clap(
long = "namespace-autocreation-enabled",
env = "INFLUXDB_IOX_NAMESPACE_AUTOCREATION_ENABLED",
default_value = "true",
action
)]
pub namespace_autocreation_enabled: bool,
/// Specify the timeout in seconds for a single RPC write request to an
/// ingester.
#[clap(
long = "rpc-write-timeout-seconds",
env = "INFLUXDB_IOX_RPC_WRITE_TIMEOUT_SECONDS",
default_value = "3",
value_parser = parse_duration
)]
pub rpc_write_timeout_seconds: Duration,
/// Specify the maximum allowed outgoing RPC write message size when
/// communicating with the Ingester.
#[clap(
long = "rpc-write-max-outgoing-bytes",
env = "INFLUXDB_IOX_RPC_WRITE_MAX_OUTGOING_BYTES",
default_value = "104857600", // 100MiB
)]
pub rpc_write_max_outgoing_bytes: usize,
/// Enable optional replication for each RPC write.
///
/// This value specifies the total number of copies of data after
/// replication, defaulting to 1.
///
/// If the desired replication level is not achieved, a partial write error
/// will be returned to the user. The write MAY be queryable after a partial
/// write failure.
#[clap(
long = "rpc-write-replicas",
env = "INFLUXDB_IOX_RPC_WRITE_REPLICAS",
default_value = "1"
)]
pub rpc_write_replicas: NonZeroUsize,
/// Specify the maximum number of probe requests to be sent per second.
///
/// At least 20% of these requests must succeed within a second for the
/// endpoint to be considered healthy.
#[clap(
long = "rpc-write-health-num-probes",
env = "INFLUXDB_IOX_RPC_WRITE_HEALTH_NUM_PROBES",
default_value = "10"
)]
pub rpc_write_health_num_probes: u64,
}
/// Map a string containing an integer number of seconds into a [`Duration`].
fn parse_duration(input: &str) -> Result<Duration, ParseIntError> {
input.parse().map(Duration::from_secs)
}

View File

@ -0,0 +1,107 @@
//! Common config for all `run` commands.
use trace_exporters::TracingConfig;
use trogging::cli::LoggingConfig;
use crate::{object_store::ObjectStoreConfig, socket_addr::SocketAddr};
/// The default bind address for the HTTP API.
pub const DEFAULT_API_BIND_ADDR: &str = "127.0.0.1:8080";
/// The default bind address for the gRPC.
pub const DEFAULT_GRPC_BIND_ADDR: &str = "127.0.0.1:8082";
/// Common config for all `run` commands.
#[derive(Debug, Clone, clap::Parser)]
pub struct RunConfig {
/// logging options
#[clap(flatten)]
pub(crate) logging_config: LoggingConfig,
/// tracing options
#[clap(flatten)]
pub(crate) tracing_config: TracingConfig,
/// The address on which IOx will serve HTTP API requests.
#[clap(
long = "api-bind",
env = "INFLUXDB_IOX_BIND_ADDR",
default_value = DEFAULT_API_BIND_ADDR,
action,
)]
pub http_bind_address: SocketAddr,
/// The address on which IOx will serve Storage gRPC API requests.
#[clap(
long = "grpc-bind",
env = "INFLUXDB_IOX_GRPC_BIND_ADDR",
default_value = DEFAULT_GRPC_BIND_ADDR,
action,
)]
pub grpc_bind_address: SocketAddr,
/// Maximum size of HTTP requests.
#[clap(
long = "max-http-request-size",
env = "INFLUXDB_IOX_MAX_HTTP_REQUEST_SIZE",
default_value = "10485760", // 10 MiB
action,
)]
pub max_http_request_size: usize,
/// object store config
#[clap(flatten)]
pub(crate) object_store_config: ObjectStoreConfig,
}
impl RunConfig {
/// Get a reference to the run config's tracing config.
pub fn tracing_config(&self) -> &TracingConfig {
&self.tracing_config
}
/// Get a reference to the run config's object store config.
pub fn object_store_config(&self) -> &ObjectStoreConfig {
&self.object_store_config
}
/// Get a mutable reference to the run config's tracing config.
pub fn tracing_config_mut(&mut self) -> &mut TracingConfig {
&mut self.tracing_config
}
/// Get a reference to the run config's logging config.
pub fn logging_config(&self) -> &LoggingConfig {
&self.logging_config
}
/// set the http bind address
pub fn with_http_bind_address(mut self, http_bind_address: SocketAddr) -> Self {
self.http_bind_address = http_bind_address;
self
}
/// set the grpc bind address
pub fn with_grpc_bind_address(mut self, grpc_bind_address: SocketAddr) -> Self {
self.grpc_bind_address = grpc_bind_address;
self
}
/// Create a new instance for all-in-one mode, only allowing some arguments.
pub fn new(
logging_config: LoggingConfig,
tracing_config: TracingConfig,
http_bind_address: SocketAddr,
grpc_bind_address: SocketAddr,
max_http_request_size: usize,
object_store_config: ObjectStoreConfig,
) -> Self {
Self {
logging_config,
tracing_config,
http_bind_address,
grpc_bind_address,
max_http_request_size,
object_store_config,
}
}
}

View File

@ -0,0 +1,11 @@
//! CLI config for request authorization.
/// Env var providing authz address
pub const CONFIG_AUTHZ_ENV_NAME: &str = "INFLUXDB_IOX_AUTHZ_ADDR";
/// CLI flag for authz address
pub const CONFIG_AUTHZ_FLAG: &str = "authz-addr";
/// Env var for single tenancy deployments
pub const CONFIG_CST_ENV_NAME: &str = "INFLUXDB_IOX_SINGLE_TENANCY";
/// CLI flag for single tenancy deployments
pub const CONFIG_CST_FLAG: &str = "single-tenancy";

View File

@ -0,0 +1,77 @@
//! Config for socket addresses.
use std::{net::ToSocketAddrs, ops::Deref};
/// Parsable socket address.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SocketAddr(std::net::SocketAddr);
impl Deref for SocketAddr {
type Target = std::net::SocketAddr;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::fmt::Display for SocketAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::str::FromStr for SocketAddr {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_socket_addrs() {
Ok(mut addrs) => {
if let Some(addr) = addrs.next() {
Ok(Self(addr))
} else {
Err(format!("Found no addresses for '{s}'"))
}
}
Err(e) => Err(format!("Cannot parse socket address '{s}': {e}")),
}
}
}
impl From<SocketAddr> for std::net::SocketAddr {
fn from(addr: SocketAddr) -> Self {
addr.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
str::FromStr,
};
#[test]
fn test_socketaddr() {
let addr: std::net::SocketAddr = SocketAddr::from_str("127.0.0.1:1234").unwrap().into();
assert_eq!(addr, std::net::SocketAddr::from(([127, 0, 0, 1], 1234)),);
let addr: std::net::SocketAddr = SocketAddr::from_str("localhost:1234").unwrap().into();
// depending on where the test runs, localhost will either resolve to a ipv4 or
// an ipv6 addr.
match addr {
std::net::SocketAddr::V4(so) => {
assert_eq!(so, SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234))
}
std::net::SocketAddr::V6(so) => assert_eq!(
so,
SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 1234, 0, 0)
),
};
assert_eq!(
SocketAddr::from_str("!@INv_a1d(ad0/resp_!").unwrap_err(),
"Cannot parse socket address '!@INv_a1d(ad0/resp_!': invalid socket address",
);
}
}

19
client_util/Cargo.toml Normal file
View File

@ -0,0 +1,19 @@
[package]
name = "client_util"
description = "Shared code for IOx clients"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
http = "0.2.9"
reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"] }
thiserror = "1.0.48"
tonic = { workspace = true }
tower = "0.4"
workspace-hack = { version = "0.1", path = "../workspace-hack" }
[dev-dependencies]
tokio = { version = "1.32", features = ["macros", "parking_lot", "rt-multi-thread"] }
mockito = { version = "1.2", default-features = false }

View File

@ -0,0 +1,295 @@
use crate::tower::{SetRequestHeadersLayer, SetRequestHeadersService};
use http::header::HeaderName;
use http::HeaderMap;
use http::{uri::InvalidUri, HeaderValue, Uri};
use std::convert::TryInto;
use std::time::Duration;
use thiserror::Error;
use tonic::transport::{Channel, Endpoint};
use tower::make::MakeConnection;
/// The connection type used for clients. Use [`Builder`] to create
/// instances of [`Connection`] objects
#[derive(Debug, Clone)]
pub struct Connection {
grpc_connection: GrpcConnection,
http_connection: HttpConnection,
}
impl Connection {
/// Create a new Connection
fn new(grpc_connection: GrpcConnection, http_connection: HttpConnection) -> Self {
Self {
grpc_connection,
http_connection,
}
}
/// Consume `self` and return a [`GrpcConnection`] (suitable for use in
/// tonic clients)
pub fn into_grpc_connection(self) -> GrpcConnection {
self.grpc_connection
}
/// Consume `self` and return a [`HttpConnection`] (suitable for making
/// calls to /api/v2 endpoints)
pub fn into_http_connection(self) -> HttpConnection {
self.http_connection
}
}
/// The type used to make tonic (gRPC) requests
pub type GrpcConnection = SetRequestHeadersService<tonic::transport::Channel>;
/// The type used to make raw http request
#[derive(Debug, Clone)]
pub struct HttpConnection {
/// The base uri of the IOx http API endpoint
uri: Uri,
/// http client connection
http_client: reqwest::Client,
}
impl HttpConnection {
fn new(uri: Uri, http_client: reqwest::Client) -> Self {
Self { uri, http_client }
}
/// Return a reference to the underyling http client
pub fn client(&self) -> &reqwest::Client {
&self.http_client
}
/// Return a reference to the base uri of the IOx http API endpoint
pub fn uri(&self) -> &Uri {
&self.uri
}
}
/// The default User-Agent header sent by the HTTP client.
pub const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
/// The default connection timeout
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
/// The default request timeout
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
/// Errors returned by the ConnectionBuilder
#[derive(Debug, Error)]
pub enum Error {
/// Server returned an invalid argument error
#[error("Connection error: {}{}", source, details)]
TransportError {
/// underlying [`tonic::transport::Error`]
source: tonic::transport::Error,
/// stringified version of the tonic error's source
details: String,
},
/// Client received an unexpected error from the server
#[error("Invalid URI: {}", .0)]
InvalidUri(#[from] InvalidUri),
}
// Custom impl to include underlying source (not included in tonic
// transport error)
impl From<tonic::transport::Error> for Error {
fn from(source: tonic::transport::Error) -> Self {
use std::error::Error;
let details = source
.source()
.map(|e| format!(" ({e})"))
.unwrap_or_else(|| "".to_string());
Self::TransportError { source, details }
}
}
/// Result type for the ConnectionBuilder
pub type Result<T, E = Error> = std::result::Result<T, E>;
/// A builder that produces a connection that can be used with any of the gRPC
/// clients
///
/// ```no_run
/// #[tokio::main]
/// # async fn main() {
/// use client_util::connection::Builder;
/// use std::time::Duration;
///
/// let connection = Builder::new()
/// .timeout(Duration::from_secs(42))
/// .user_agent("my_awesome_client")
/// .build("http://127.0.0.1:8082/")
/// .await
/// .expect("connection must succeed");
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct Builder {
user_agent: String,
headers: Vec<(HeaderName, HeaderValue)>,
connect_timeout: Duration,
timeout: Duration,
}
impl std::default::Default for Builder {
fn default() -> Self {
Self {
user_agent: USER_AGENT.into(),
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
timeout: DEFAULT_TIMEOUT,
headers: Default::default(),
}
}
}
impl Builder {
/// Create a new default builder
pub fn new() -> Self {
Default::default()
}
/// Construct the [`Connection`] instance using the specified base URL.
pub async fn build<D>(self, dst: D) -> Result<Connection>
where
D: TryInto<Uri, Error = InvalidUri> + Send,
{
let endpoint = self.create_endpoint(dst)?;
let channel = endpoint.connect().await?;
Ok(self.compose_middleware(channel, endpoint))
}
/// Construct the [`Connection`] instance using the specified base URL and custom connector.
pub async fn build_with_connector<D, C>(self, dst: D, connector: C) -> Result<Connection>
where
D: TryInto<Uri, Error = InvalidUri> + Send,
C: MakeConnection<Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
{
let endpoint = self.create_endpoint(dst)?;
let channel = endpoint.connect_with_connector(connector).await?;
Ok(self.compose_middleware(channel, endpoint))
}
fn create_endpoint<D>(&self, dst: D) -> Result<Endpoint>
where
D: TryInto<Uri, Error = InvalidUri> + Send,
{
let endpoint = Endpoint::from(dst.try_into()?)
.user_agent(&self.user_agent)?
.connect_timeout(self.connect_timeout)
.timeout(self.timeout);
Ok(endpoint)
}
fn compose_middleware(self, channel: Channel, endpoint: Endpoint) -> Connection {
let headers_map: HeaderMap = self.headers.iter().cloned().collect();
// Compose channel with new tower middleware stack
let grpc_connection = tower::ServiceBuilder::new()
.layer(SetRequestHeadersLayer::new(self.headers))
.service(channel);
let http_client = reqwest::Client::builder()
.connection_verbose(true)
.default_headers(headers_map)
.build()
.expect("reqwest::Client should have built");
let http_connection = HttpConnection::new(endpoint.uri().clone(), http_client);
Connection::new(grpc_connection, http_connection)
}
/// Set the `User-Agent` header sent by this client.
pub fn user_agent(self, user_agent: impl Into<String>) -> Self {
Self {
user_agent: user_agent.into(),
..self
}
}
/// Sets a header to be included on all requests
pub fn header(self, header: impl Into<HeaderName>, value: impl Into<HeaderValue>) -> Self {
let mut headers = self.headers;
headers.push((header.into(), value.into()));
Self { headers, ..self }
}
/// Sets the maximum duration of time the client will wait for the IOx
/// server to accept the TCP connection before aborting the request.
///
/// Note this does not bound the request duration - see
/// [`timeout`][Self::timeout].
pub fn connect_timeout(self, timeout: Duration) -> Self {
Self {
connect_timeout: timeout,
..self
}
}
/// Bounds the total amount of time a single client HTTP request take before
/// being aborted.
///
/// This timeout includes:
///
/// - Establishing the TCP connection (see [`connect_timeout`])
/// - Sending the HTTP request
/// - Waiting for, and receiving the entire HTTP response
///
/// [`connect_timeout`]: Self::connect_timeout
pub fn timeout(self, timeout: Duration) -> Self {
Self { timeout, ..self }
}
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::Method;
#[test]
fn test_builder_cloneable() {
// Clone is used by Conductor.
fn assert_clone<T: Clone>(_t: T) {}
assert_clone(Builder::default())
}
#[tokio::test(flavor = "multi_thread")]
async fn headers_are_set() {
let mut mock_server = mockito::Server::new_async().await;
let url = mock_server.url();
let http_connection = Builder::new()
.header(
HeaderName::from_static("foo"),
HeaderValue::from_static("bar"),
)
.build(&url)
.await
.unwrap()
.into_http_connection();
let url = format!("{url}/the_api");
println!("Sending to {url}");
let m = mock_server
.mock("POST", "/the_api")
.with_status(201)
.with_body("world")
.match_header("FOO", "bar")
.create_async()
.await;
http_connection
.client()
.request(Method::POST, &url)
.send()
.await
.expect("Error making http request");
m.assert_async().await;
}
}

32
client_util/src/lib.rs Normal file
View File

@ -0,0 +1,32 @@
//! Shared InfluxDB IOx API client functionality
#![deny(
rustdoc::broken_intra_doc_links,
rustdoc::bare_urls,
rust_2018_idioms,
missing_debug_implementations,
unreachable_pub
)]
#![warn(
missing_docs,
clippy::todo,
clippy::dbg_macro,
clippy::clone_on_ref_ptr,
// See https://github.com/influxdata/influxdb_iox/pull/1671
clippy::future_not_send,
clippy::todo,
clippy::dbg_macro,
unused_crate_dependencies
)]
#![allow(clippy::missing_docs_in_private_items)]
// Workaround for "unused crate" lint false positives.
use workspace_hack as _;
/// Builder for constructing connections for use with the various gRPC clients
pub mod connection;
/// Helper to set client headers.
pub mod tower;
/// Namespace <--> org/bucket utilities
pub mod namespace_translation;

View File

@ -0,0 +1,90 @@
//! Contains logic to map namespace back/forth to org/bucket
use thiserror::Error;
/// Errors returned by namespace parsing
#[allow(missing_docs)]
#[derive(Debug, Error)]
pub enum Error {
#[error("Invalid namespace '{namespace}': {reason}")]
InvalidNamespace { namespace: String, reason: String },
}
impl Error {
fn new(namespace: impl Into<String>, reason: impl Into<String>) -> Self {
Self::InvalidNamespace {
namespace: namespace.into(),
reason: reason.into(),
}
}
}
/// Splits up the namespace name into org_id and bucket_id
pub fn split_namespace(namespace: &str) -> Result<(&str, &str), Error> {
let mut iter = namespace.split('_');
let org_id = iter.next().ok_or_else(|| Error::new(namespace, "empty"))?;
if org_id.is_empty() {
return Err(Error::new(namespace, "No org_id found"));
}
let bucket_id = iter
.next()
.ok_or_else(|| Error::new(namespace, "Could not find '_'"))?;
if bucket_id.is_empty() {
return Err(Error::new(namespace, "No bucket_id found"));
}
if iter.next().is_some() {
return Err(Error::new(namespace, "More than one '_'"));
}
Ok((org_id, bucket_id))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn split_good() {
assert_eq!(split_namespace("foo_bar").unwrap(), ("foo", "bar"));
}
#[test]
#[should_panic(expected = "No org_id found")]
fn split_bad_empty() {
split_namespace("").unwrap();
}
#[test]
#[should_panic(expected = "No org_id found")]
fn split_bad_only_underscore() {
split_namespace("_").unwrap();
}
#[test]
#[should_panic(expected = "No org_id found")]
fn split_bad_empty_org_id() {
split_namespace("_ff").unwrap();
}
#[test]
#[should_panic(expected = "No bucket_id found")]
fn split_bad_empty_bucket_id() {
split_namespace("ff_").unwrap();
}
#[test]
#[should_panic(expected = "More than one '_'")]
fn split_too_many() {
split_namespace("ff_bf_").unwrap();
}
#[test]
#[should_panic(expected = "More than one '_'")]
fn split_way_too_many() {
split_namespace("ff_bf_dfd_3_f").unwrap();
}
}

79
client_util/src/tower.rs Normal file
View File

@ -0,0 +1,79 @@
use http::header::HeaderName;
use http::{HeaderValue, Request, Response};
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};
/// `SetRequestHeadersLayer` sets the provided headers on all requests flowing through it
/// unless they're already set
#[derive(Debug, Clone)]
pub(crate) struct SetRequestHeadersLayer {
headers: Arc<Vec<(HeaderName, HeaderValue)>>,
}
impl SetRequestHeadersLayer {
pub(crate) fn new(headers: Vec<(HeaderName, HeaderValue)>) -> Self {
Self {
headers: Arc::new(headers),
}
}
}
impl<S> Layer<S> for SetRequestHeadersLayer {
type Service = SetRequestHeadersService<S>;
fn layer(&self, service: S) -> Self::Service {
SetRequestHeadersService {
service,
headers: Arc::clone(&self.headers),
}
}
}
/// SetRequestHeadersService wraps an inner tower::Service and sets the provided
/// headers on requests flowing through it
#[derive(Debug, Clone)]
pub struct SetRequestHeadersService<S> {
service: S,
headers: Arc<Vec<(HeaderName, HeaderValue)>>,
}
impl<S> SetRequestHeadersService<S> {
/// Create sevice from inner service and headers.
pub fn new(service: S, headers: Vec<(HeaderName, HeaderValue)>) -> Self {
Self {
service,
headers: Arc::new(headers),
}
}
/// De-construct service into parts.
///
/// The can be used to call [`new`](Self::new) again.
pub fn into_parts(self) -> (S, Arc<Vec<(HeaderName, HeaderValue)>>) {
let SetRequestHeadersService { service, headers } = self;
(service, headers)
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestHeadersService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
let headers = request.headers_mut();
for (name, value) in self.headers.iter() {
headers.insert(name, value.clone());
}
self.service.call(request)
}
}

44
compactor/Cargo.toml Normal file
View File

@ -0,0 +1,44 @@
[package]
name = "compactor"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
async-trait = "0.1.73"
backoff = { path = "../backoff" }
bytes = "1.5"
chrono = { version = "0.4", default-features = false }
compactor_scheduler = { path = "../compactor_scheduler" }
data_types = { path = "../data_types" }
datafusion = { workspace = true }
futures = "0.3"
generated_types = { version = "0.1.0", path = "../generated_types" }
gossip = { version = "0.1.0", path = "../gossip" }
gossip_compaction = { version = "0.1.0", path = "../gossip_compaction" }
iox_catalog = { path = "../iox_catalog" }
iox_query = { path = "../iox_query" }
iox_time = { path = "../iox_time" }
itertools = "0.11.0"
metric = { path = "../metric" }
object_store = { workspace = true }
observability_deps = { path = "../observability_deps" }
parking_lot = "0.12.1"
parquet_file = { path = "../parquet_file" }
rand = "0.8.3"
schema = { path = "../schema" }
tokio = { version = "1", features = ["macros", "rt", "sync"] }
tokio-util = { version = "0.7.9" }
trace = { version = "0.1.0", path = "../trace" }
tracker = { path = "../tracker" }
uuid = { version = "1", features = ["v4"] }
workspace-hack = { version = "0.1", path = "../workspace-hack" }
[dev-dependencies]
arrow_util = { path = "../arrow_util" }
assert_matches = "1"
compactor_test_utils = { path = "../compactor_test_utils" }
iox_tests = { path = "../iox_tests" }
test_helpers = { path = "../test_helpers" }
insta = { version = "1.32.0", features = ["yaml"] }

45
compactor/img/driver.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 210 KiB

121
compactor/src/compactor.rs Normal file
View File

@ -0,0 +1,121 @@
//! Main compactor entry point.
use std::sync::Arc;
use futures::{
future::{BoxFuture, Shared},
FutureExt, TryFutureExt,
};
use generated_types::influxdata::iox::gossip::{v1::CompactionEvent, Topic};
use gossip::{NopDispatcher, TopicInterests};
use observability_deps::tracing::{info, warn};
use tokio::task::{JoinError, JoinHandle};
use tokio_util::sync::CancellationToken;
use tracker::AsyncSemaphoreMetrics;
use crate::{
components::{
hardcoded::hardcoded_components,
report::{log_components, log_config},
},
config::Config,
driver::compact,
};
/// A [`JoinHandle`] that can be cloned
type SharedJoinHandle = Shared<BoxFuture<'static, Result<(), Arc<JoinError>>>>;
/// Convert a [`JoinHandle`] into a [`SharedJoinHandle`].
fn shared_handle(handle: JoinHandle<()>) -> SharedJoinHandle {
handle.map_err(Arc::new).boxed().shared()
}
/// Main compactor driver.
#[derive(Debug)]
pub struct Compactor {
shutdown: CancellationToken,
worker: SharedJoinHandle,
}
impl Compactor {
/// Start compactor.
pub async fn start(config: Config) -> Self {
info!("compactor starting");
log_config(&config);
let shutdown = CancellationToken::new();
let shutdown_captured = shutdown.clone();
let components = hardcoded_components(&config);
log_components(&components);
let semaphore_metrics = Arc::new(AsyncSemaphoreMetrics::new(
&config.metric_registry,
&[("semaphore", "job")],
));
let df_semaphore = Arc::new(semaphore_metrics.new_semaphore(config.df_concurrency.get()));
// Initialise the gossip subsystem, if configured.
let gossip = match config.gossip_bind_address {
Some(bind) => {
// Initialise the gossip subsystem.
let handle = gossip::Builder::<_, Topic>::new(
config.gossip_seeds,
NopDispatcher,
Arc::clone(&config.metric_registry),
)
// Configure the compactor to subscribe to no topics - it
// currently only sends events.
.with_topic_filter(TopicInterests::default())
.bind(bind)
.await
.expect("failed to start gossip reactor");
let event_tx =
gossip_compaction::tx::CompactionEventTx::<CompactionEvent>::new(handle);
Some(Arc::new(event_tx))
}
None => None,
};
let worker = tokio::spawn(async move {
tokio::select! {
_ = shutdown_captured.cancelled() => {}
_ = async {
compact(
config.trace_collector,
config.partition_concurrency,
config.partition_timeout,
Arc::clone(&df_semaphore),
&components,
gossip,
).await;
info!("compactor done");
} => {}
}
});
let worker = shared_handle(worker);
Self { shutdown, worker }
}
/// Trigger shutdown. You should [join](Self::join) afterwards.
pub fn shutdown(&self) {
info!("compactor shutting down");
self.shutdown.cancel();
}
/// Wait until the compactor finishes.
pub async fn join(&self) -> Result<(), Arc<JoinError>> {
self.worker.clone().await
}
}
impl Drop for Compactor {
fn drop(&mut self) {
if self.worker.clone().now_or_never().is_none() {
warn!("Compactor was not shut down properly");
}
}
}

View File

@ -0,0 +1,33 @@
use std::fmt::Display;
use super::{ChangedFilesFilter, SavedParquetFileState};
use async_trait::async_trait;
use observability_deps::tracing::info;
#[derive(Debug, Default, Copy, Clone)]
pub struct LoggingChangedFiles {}
impl LoggingChangedFiles {
pub fn new() -> Self {
Self {}
}
}
impl Display for LoggingChangedFiles {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "logging_changed_files")
}
}
#[async_trait]
impl ChangedFilesFilter for LoggingChangedFiles {
async fn apply(&self, old: &SavedParquetFileState, new: &SavedParquetFileState) -> bool {
if old.existing_files_modified(new) {
let modified_ids_and_levels = old.modified_ids_and_levels(new);
info!(?modified_ids_and_levels, "Concurrent modification detected");
}
false // we're ignoring the return value anyway for the moment
}
}

View File

@ -0,0 +1,213 @@
use std::{
collections::HashSet,
fmt::{Debug, Display},
};
use async_trait::async_trait;
use data_types::{CompactionLevel, ParquetFile, ParquetFileId};
pub mod logging;
/// Returns `true` if the files in the saved state have been changed according to the current state.
#[async_trait]
pub trait ChangedFilesFilter: Debug + Display + Send + Sync {
/// Return `true` if some other process modified the files in `old` such that they don't appear or appear with a
/// different compaction level than `new`, and thus we should stop compacting.
async fn apply(&self, old: &SavedParquetFileState, new: &SavedParquetFileState) -> bool;
}
/// Saved snapshot of a partition's Parquet files' IDs and compaction levels. Save this state at the beginning of a
/// compaction operation, then just before committing ask for the catalog state again. If the ID+compaction level pairs
/// in the initial saved state still appear in the latest catalog state (disregarding any new files that may appear in
/// the latest catalog state) we assume no other compactor instance has compacted the relevant files and this compactor
/// instance should commit its work. If any old ID+compaction level pairs are missing from the latest catalog state
/// (and thus show up in a set difference operation of `old - current`), throw away the work and do not commit as the
/// relevant Parquet files have been changed by some other process while this compactor instance was working.
#[derive(Debug, Clone)]
pub struct SavedParquetFileState {
ids_and_levels: HashSet<(ParquetFileId, CompactionLevel)>,
}
impl<'a, T> From<T> for SavedParquetFileState
where
T: IntoIterator<Item = &'a ParquetFile>,
{
fn from(parquet_files: T) -> Self {
let ids_and_levels = parquet_files
.into_iter()
.map(|pf| (pf.id, pf.compaction_level))
.collect();
Self { ids_and_levels }
}
}
impl SavedParquetFileState {
fn missing<'a>(
&'a self,
new: &'a Self,
) -> impl Iterator<Item = &'a (ParquetFileId, CompactionLevel)> {
let old = self;
old.ids_and_levels.difference(&new.ids_and_levels)
}
pub fn existing_files_modified(&self, new: &Self) -> bool {
let mut missing = self.missing(new);
// If there are any `(ParquetFileId, CompactionLevel)` pairs in `self` that are not present in `new`, that
// means some files were marked to delete or had their compaction level changed by some other process.
missing.next().is_some()
}
pub fn modified_ids_and_levels(&self, new: &Self) -> Vec<(ParquetFileId, CompactionLevel)> {
self.missing(new).cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use iox_tests::ParquetFileBuilder;
#[test]
fn saved_state_sorts_by_parquet_file_id() {
let pf_id1_level_0 = ParquetFileBuilder::new(1)
.with_compaction_level(CompactionLevel::Initial)
.build();
let pf_id2_level_2 = ParquetFileBuilder::new(2)
.with_compaction_level(CompactionLevel::Final)
.build();
let pf_id3_level_1 = ParquetFileBuilder::new(3)
.with_compaction_level(CompactionLevel::FileNonOverlapped)
.build();
let saved_state_1 =
SavedParquetFileState::from([&pf_id1_level_0, &pf_id2_level_2, &pf_id3_level_1]);
let saved_state_2 =
SavedParquetFileState::from([&pf_id3_level_1, &pf_id1_level_0, &pf_id2_level_2]);
assert!(!saved_state_1.existing_files_modified(&saved_state_2));
assert!(saved_state_1
.modified_ids_and_levels(&saved_state_2)
.is_empty());
}
#[test]
fn both_empty_parquet_files() {
let saved_state_1 = SavedParquetFileState::from([]);
let saved_state_2 = SavedParquetFileState::from([]);
assert!(!saved_state_1.existing_files_modified(&saved_state_2));
assert!(saved_state_1
.modified_ids_and_levels(&saved_state_2)
.is_empty());
}
#[test]
fn missing_files_indicates_modifications() {
let pf_id1_level_0 = ParquetFileBuilder::new(1)
.with_compaction_level(CompactionLevel::Initial)
.build();
let saved_state_1 = SavedParquetFileState::from([&pf_id1_level_0]);
let saved_state_2 = SavedParquetFileState::from([]);
assert!(saved_state_1.existing_files_modified(&saved_state_2));
assert_eq!(
saved_state_1.modified_ids_and_levels(&saved_state_2),
&[(ParquetFileId::new(1), CompactionLevel::Initial)]
);
}
#[test]
fn disregard_new_files() {
let pf_id1_level_0 = ParquetFileBuilder::new(1)
.with_compaction_level(CompactionLevel::Initial)
.build();
// New files of any level don't affect whether the old saved state is considered modified
let pf_id2_level_2 = ParquetFileBuilder::new(2)
.with_compaction_level(CompactionLevel::Final)
.build();
let pf_id3_level_1 = ParquetFileBuilder::new(3)
.with_compaction_level(CompactionLevel::FileNonOverlapped)
.build();
let pf_id4_level_0 = ParquetFileBuilder::new(4)
.with_compaction_level(CompactionLevel::Initial)
.build();
let saved_state_1 = SavedParquetFileState::from([&pf_id1_level_0]);
let saved_state_2 = SavedParquetFileState::from([&pf_id1_level_0, &pf_id2_level_2]);
assert!(!saved_state_1.existing_files_modified(&saved_state_2));
assert!(saved_state_1
.modified_ids_and_levels(&saved_state_2)
.is_empty());
let saved_state_2 = SavedParquetFileState::from([&pf_id1_level_0, &pf_id3_level_1]);
assert!(!saved_state_1.existing_files_modified(&saved_state_2));
assert!(saved_state_1
.modified_ids_and_levels(&saved_state_2)
.is_empty());
let saved_state_2 = SavedParquetFileState::from([&pf_id1_level_0, &pf_id4_level_0]);
assert!(!saved_state_1.existing_files_modified(&saved_state_2));
assert!(saved_state_1
.modified_ids_and_levels(&saved_state_2)
.is_empty());
let saved_state_2 = SavedParquetFileState::from([
&pf_id1_level_0,
&pf_id2_level_2,
&pf_id4_level_0,
&pf_id4_level_0,
]);
assert!(!saved_state_1.existing_files_modified(&saved_state_2));
assert!(saved_state_1
.modified_ids_and_levels(&saved_state_2)
.is_empty());
}
#[test]
fn changed_compaction_level_indicates_modification() {
let pf_id1_level_0 = ParquetFileBuilder::new(1)
.with_compaction_level(CompactionLevel::Initial)
.build();
let pf_id1_level_1 = ParquetFileBuilder::new(1)
.with_compaction_level(CompactionLevel::FileNonOverlapped)
.build();
let pf_id2_level_2 = ParquetFileBuilder::new(2)
.with_compaction_level(CompactionLevel::Final)
.build();
let saved_state_1 = SavedParquetFileState::from([&pf_id1_level_0, &pf_id2_level_2]);
let saved_state_2 = SavedParquetFileState::from([&pf_id1_level_1, &pf_id2_level_2]);
assert!(saved_state_1.existing_files_modified(&saved_state_2));
assert_eq!(
saved_state_1.modified_ids_and_levels(&saved_state_2),
&[(ParquetFileId::new(1), CompactionLevel::Initial)]
);
}
#[test]
fn same_number_of_files_different_ids_indicates_modification() {
let pf_id1_level_0 = ParquetFileBuilder::new(1)
.with_compaction_level(CompactionLevel::Initial)
.build();
let pf_id2_level_0 = ParquetFileBuilder::new(2)
.with_compaction_level(CompactionLevel::Initial)
.build();
let pf_id3_level_2 = ParquetFileBuilder::new(3)
.with_compaction_level(CompactionLevel::Final)
.build();
let saved_state_1 = SavedParquetFileState::from([&pf_id1_level_0, &pf_id3_level_2]);
let saved_state_2 = SavedParquetFileState::from([&pf_id2_level_0, &pf_id3_level_2]);
assert!(saved_state_1.existing_files_modified(&saved_state_2));
assert_eq!(
saved_state_1.modified_ids_and_levels(&saved_state_2),
&[(ParquetFileId::new(1), CompactionLevel::Initial)]
);
}
}

View File

@ -0,0 +1,46 @@
use std::{fmt::Display, sync::Arc};
use async_trait::async_trait;
use backoff::{Backoff, BackoffConfig};
use data_types::{Column, TableId};
use iox_catalog::interface::Catalog;
use super::ColumnsSource;
#[derive(Debug)]
pub struct CatalogColumnsSource {
backoff_config: BackoffConfig,
catalog: Arc<dyn Catalog>,
}
impl CatalogColumnsSource {
pub fn new(backoff_config: BackoffConfig, catalog: Arc<dyn Catalog>) -> Self {
Self {
backoff_config,
catalog,
}
}
}
impl Display for CatalogColumnsSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "catalog")
}
}
#[async_trait]
impl ColumnsSource for CatalogColumnsSource {
async fn fetch(&self, table: TableId) -> Vec<Column> {
Backoff::new(&self.backoff_config)
.retry_all_errors("table_of_given_table_id", || async {
self.catalog
.repositories()
.await
.columns()
.list_by_table_id(table)
.await
})
.await
.expect("retry forever")
}
}

View File

@ -0,0 +1,71 @@
use std::{collections::HashMap, fmt::Display};
use async_trait::async_trait;
use data_types::{Column, TableId};
use super::ColumnsSource;
#[derive(Debug)]
pub struct MockColumnsSource {
tables: HashMap<TableId, Vec<Column>>,
}
impl MockColumnsSource {
#[allow(dead_code)] // not used anywhere
pub fn new(tables: HashMap<TableId, Vec<Column>>) -> Self {
Self { tables }
}
}
impl Display for MockColumnsSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "mock")
}
}
#[async_trait]
impl ColumnsSource for MockColumnsSource {
async fn fetch(&self, table: TableId) -> Vec<Column> {
self.tables.get(&table).cloned().unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use data_types::ColumnType;
use iox_tests::{ColumnBuilder, TableBuilder};
use super::*;
#[test]
fn test_display() {
assert_eq!(
MockColumnsSource::new(HashMap::default()).to_string(),
"mock",
)
}
#[tokio::test]
async fn test_fetch() {
// // t_1 has one column and t_2 has no column
let t1 = TableBuilder::new(1).with_name("table1").build();
let t1_c1 = ColumnBuilder::new(1, t1.id.get())
.with_name("time")
.with_column_type(ColumnType::Time)
.build();
let t2 = TableBuilder::new(2).with_name("table2").build();
let tables = HashMap::from([(t1.id, vec![t1_c1.clone()]), (t2.id, vec![])]);
let source = MockColumnsSource::new(tables);
// different tables
assert_eq!(source.fetch(t1.id).await, vec![t1_c1.clone()],);
assert_eq!(source.fetch(t2.id).await, vec![]);
// fetching does not drain
assert_eq!(source.fetch(t1.id).await, vec![t1_c1],);
// unknown table => empty result
assert_eq!(source.fetch(TableId::new(3)).await, vec![]);
}
}

View File

@ -0,0 +1,15 @@
use std::fmt::{Debug, Display};
use async_trait::async_trait;
use data_types::{Column, TableId};
pub mod catalog;
pub mod mock;
#[async_trait]
pub trait ColumnsSource: Debug + Display + Send + Sync {
/// Get Columns for a given table
///
/// This method performs retries.
async fn fetch(&self, table: TableId) -> Vec<Column>;
}

View File

@ -0,0 +1,51 @@
use std::sync::Arc;
use compactor_scheduler::{
CommitUpdate, CompactionJob, CompactionJobStatus, CompactionJobStatusResponse,
CompactionJobStatusVariant, Scheduler,
};
use data_types::{CompactionLevel, ParquetFile, ParquetFileId, ParquetFileParams};
#[derive(Debug)]
pub struct CommitToScheduler {
scheduler: Arc<dyn Scheduler>,
}
impl CommitToScheduler {
pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
Self { scheduler }
}
pub async fn commit(
&self,
job: CompactionJob,
delete: &[ParquetFile],
upgrade: &[ParquetFile],
create: &[ParquetFileParams],
target_level: CompactionLevel,
) -> Result<Vec<ParquetFileId>, crate::DynError> {
match self
.scheduler
.update_job_status(CompactionJobStatus {
job: job.clone(),
status: CompactionJobStatusVariant::Update(CommitUpdate::new(
job.partition_id,
delete.into(),
upgrade.into(),
create.into(),
target_level,
)),
})
.await?
{
CompactionJobStatusResponse::CreatedParquetFiles(ids) => Ok(ids),
CompactionJobStatusResponse::Ack => unreachable!("scheduler should not ack"),
}
}
}
impl std::fmt::Display for CommitToScheduler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CommitToScheduler")
}
}

View File

@ -0,0 +1,161 @@
use std::{collections::HashSet, fmt::Display, sync::Arc};
use async_trait::async_trait;
use compactor_scheduler::{
CompactionJob, CompactionJobStatus, CompactionJobStatusResponse, CompactionJobStatusVariant,
ErrorKind as SchedulerErrorKind, Scheduler,
};
use crate::error::{DynError, ErrorKind, ErrorKindExt};
use super::CompactionJobDoneSink;
#[derive(Debug)]
pub struct ErrorKindCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
kind: HashSet<ErrorKind>,
inner: T,
scheduler: Arc<dyn Scheduler>,
}
impl<T> ErrorKindCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
pub fn new(inner: T, kind: HashSet<ErrorKind>, scheduler: Arc<dyn Scheduler>) -> Self {
Self {
kind,
inner,
scheduler,
}
}
}
impl<T> Display for ErrorKindCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut kinds = self.kind.iter().copied().collect::<Vec<_>>();
kinds.sort();
write!(f, "kind({:?}, {})", kinds, self.inner)
}
}
#[async_trait]
impl<T> CompactionJobDoneSink for ErrorKindCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
async fn record(&self, job: CompactionJob, res: Result<(), DynError>) -> Result<(), DynError> {
match res {
Ok(()) => self.inner.record(job, Ok(())).await,
Err(e) if self.kind.contains(&e.classify()) => {
let scheduler_error = match SchedulerErrorKind::from(e.classify()) {
SchedulerErrorKind::OutOfMemory => SchedulerErrorKind::OutOfMemory,
SchedulerErrorKind::ObjectStore => SchedulerErrorKind::ObjectStore,
SchedulerErrorKind::Timeout => SchedulerErrorKind::Timeout,
SchedulerErrorKind::Unknown(_) => SchedulerErrorKind::Unknown(e.to_string()),
};
match self
.scheduler
.update_job_status(CompactionJobStatus {
job: job.clone(),
status: CompactionJobStatusVariant::Error(scheduler_error),
})
.await?
{
CompactionJobStatusResponse::Ack => {}
CompactionJobStatusResponse::CreatedParquetFiles(_) => {
unreachable!("scheduler should not created parquet files")
}
}
self.inner.record(job, Err(e)).await
}
Err(e) => {
// contract of this abstraction,
// where we do not pass to `self.inner` if not in `self.kind`
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use compactor_scheduler::create_test_scheduler;
use data_types::PartitionId;
use datafusion::error::DataFusionError;
use iox_tests::TestCatalog;
use iox_time::{MockProvider, Time};
use object_store::Error as ObjectStoreError;
use super::{super::mock::MockCompactionJobDoneSink, *};
#[test]
fn test_display() {
let sink = ErrorKindCompactionJobDoneSinkWrapper::new(
MockCompactionJobDoneSink::new(),
HashSet::from([ErrorKind::ObjectStore, ErrorKind::OutOfMemory]),
create_test_scheduler(
TestCatalog::new().catalog(),
Arc::new(MockProvider::new(Time::MIN)),
None,
),
);
assert_eq!(sink.to_string(), "kind([ObjectStore, OutOfMemory], mock)");
}
#[tokio::test]
async fn test_record() {
let inner = Arc::new(MockCompactionJobDoneSink::new());
let sink = ErrorKindCompactionJobDoneSinkWrapper::new(
Arc::clone(&inner),
HashSet::from([ErrorKind::ObjectStore, ErrorKind::OutOfMemory]),
create_test_scheduler(
TestCatalog::new().catalog(),
Arc::new(MockProvider::new(Time::MIN)),
None,
),
);
let cj_1 = CompactionJob::new(PartitionId::new(1));
let cj_2 = CompactionJob::new(PartitionId::new(2));
let cj_3 = CompactionJob::new(PartitionId::new(3));
let cj_4 = CompactionJob::new(PartitionId::new(4));
sink.record(
cj_1.clone(),
Err(Box::new(ObjectStoreError::NotImplemented)),
)
.await
.expect("record failed");
sink.record(
cj_2.clone(),
Err(Box::new(DataFusionError::ResourcesExhausted(String::from(
"foo",
)))),
)
.await
.expect("record failed");
sink.record(cj_3, Err("foo".into())).await.unwrap_err();
sink.record(cj_4.clone(), Ok(()))
.await
.expect("record failed");
assert_eq!(
inner.results(),
HashMap::from([
(cj_1, Err(String::from("Operation not yet implemented.")),),
(cj_2, Err(String::from("Resources exhausted: foo")),),
(cj_4, Ok(()),),
]),
);
}
}

View File

@ -0,0 +1,126 @@
use std::fmt::Display;
use async_trait::async_trait;
use compactor_scheduler::CompactionJob;
use observability_deps::tracing::{error, info};
use crate::error::{DynError, ErrorKindExt};
use super::CompactionJobDoneSink;
#[derive(Debug)]
pub struct LoggingCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
inner: T,
}
impl<T> LoggingCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
pub fn new(inner: T) -> Self {
Self { inner }
}
}
impl<T> Display for LoggingCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "logging({})", self.inner)
}
}
#[async_trait]
impl<T> CompactionJobDoneSink for LoggingCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
async fn record(&self, job: CompactionJob, res: Result<(), DynError>) -> Result<(), DynError> {
match &res {
Ok(()) => {
info!(
partition_id = job.partition_id.get(),
job_uuid = job.uuid().to_string(),
"Finished compaction job",
);
}
Err(e) => {
// log compactor errors, classified by compactor ErrorKind
error!(
%e,
kind=e.classify().name(),
partition_id = job.partition_id.get(),
job_uuid = job.uuid().to_string(),
"Error while compacting partition",
);
}
}
self.inner.record(job, res).await
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use data_types::PartitionId;
use object_store::Error as ObjectStoreError;
use test_helpers::tracing::TracingCapture;
use super::{super::mock::MockCompactionJobDoneSink, *};
#[test]
fn test_display() {
let sink = LoggingCompactionJobDoneSinkWrapper::new(MockCompactionJobDoneSink::new());
assert_eq!(sink.to_string(), "logging(mock)");
}
#[tokio::test]
async fn test_record() {
let inner = Arc::new(MockCompactionJobDoneSink::new());
let sink = LoggingCompactionJobDoneSinkWrapper::new(Arc::clone(&inner));
let capture = TracingCapture::new();
let cj_1 = CompactionJob::new(PartitionId::new(1));
let cj_2 = CompactionJob::new(PartitionId::new(2));
let cj_3 = CompactionJob::new(PartitionId::new(3));
sink.record(cj_1.clone(), Err("msg 1".into()))
.await
.expect("record failed");
sink.record(cj_2.clone(), Err("msg 2".into()))
.await
.expect("record failed");
sink.record(
cj_1.clone(),
Err(Box::new(ObjectStoreError::NotImplemented)),
)
.await
.expect("record failed");
sink.record(cj_3.clone(), Ok(()))
.await
.expect("record failed");
assert_eq!(
capture.to_string(),
format!("level = ERROR; message = Error while compacting partition; e = msg 1; kind = \"unknown\"; partition_id = 1; job_uuid = {:?}; \n\
level = ERROR; message = Error while compacting partition; e = msg 2; kind = \"unknown\"; partition_id = 2; job_uuid = {:?}; \n\
level = ERROR; message = Error while compacting partition; e = Operation not yet implemented.; kind = \"object_store\"; partition_id = 1; job_uuid = {:?}; \n\
level = INFO; message = Finished compaction job; partition_id = 3; job_uuid = {:?}; ", cj_1.uuid().to_string(), cj_2.uuid().to_string(), cj_1.uuid().to_string(), cj_3.uuid().to_string()),
);
assert_eq!(
inner.results(),
HashMap::from([
(cj_1, Err(String::from("Operation not yet implemented.")),),
(cj_2, Err(String::from("msg 2"))),
(cj_3, Ok(())),
]),
);
}
}

View File

@ -0,0 +1,164 @@
use std::{collections::HashMap, fmt::Display};
use async_trait::async_trait;
use compactor_scheduler::CompactionJob;
use metric::{Registry, U64Counter};
use crate::error::{DynError, ErrorKind, ErrorKindExt};
use super::CompactionJobDoneSink;
const METRIC_NAME_PARTITION_COMPLETE_COUNT: &str = "iox_compactor_partition_complete_count";
#[derive(Debug)]
pub struct MetricsCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
ok_counter: U64Counter,
error_counter: HashMap<ErrorKind, U64Counter>,
inner: T,
}
impl<T> MetricsCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
pub fn new(inner: T, registry: &Registry) -> Self {
let metric = registry.register_metric::<U64Counter>(
METRIC_NAME_PARTITION_COMPLETE_COUNT,
"Number of completed partitions",
);
let ok_counter = metric.recorder(&[("result", "ok")]);
let error_counter = ErrorKind::variants()
.iter()
.map(|kind| {
(
*kind,
metric.recorder(&[("result", "error"), ("kind", kind.name())]),
)
})
.collect();
Self {
ok_counter,
error_counter,
inner,
}
}
}
impl<T> Display for MetricsCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "metrics({})", self.inner)
}
}
#[async_trait]
impl<T> CompactionJobDoneSink for MetricsCompactionJobDoneSinkWrapper<T>
where
T: CompactionJobDoneSink,
{
async fn record(&self, job: CompactionJob, res: Result<(), DynError>) -> Result<(), DynError> {
match &res {
Ok(()) => {
self.ok_counter.inc(1);
}
Err(e) => {
// classify and track counts of compactor ErrorKind
let kind = e.classify();
self.error_counter
.get(&kind)
.expect("all kinds constructed")
.inc(1);
}
}
self.inner.record(job, res).await
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use data_types::PartitionId;
use metric::{assert_counter, Attributes};
use object_store::Error as ObjectStoreError;
use super::{super::mock::MockCompactionJobDoneSink, *};
#[test]
fn test_display() {
let registry = Registry::new();
let sink =
MetricsCompactionJobDoneSinkWrapper::new(MockCompactionJobDoneSink::new(), &registry);
assert_eq!(sink.to_string(), "metrics(mock)");
}
#[tokio::test]
async fn test_record() {
let registry = Registry::new();
let inner = Arc::new(MockCompactionJobDoneSink::new());
let sink = MetricsCompactionJobDoneSinkWrapper::new(Arc::clone(&inner), &registry);
assert_ok_counter(&registry, 0);
assert_error_counter(&registry, "unknown", 0);
assert_error_counter(&registry, "object_store", 0);
let cj_1 = CompactionJob::new(PartitionId::new(1));
let cj_2 = CompactionJob::new(PartitionId::new(2));
let cj_3 = CompactionJob::new(PartitionId::new(3));
sink.record(cj_1.clone(), Err("msg 1".into()))
.await
.expect("record failed");
sink.record(cj_2.clone(), Err("msg 2".into()))
.await
.expect("record failed");
sink.record(
cj_1.clone(),
Err(Box::new(ObjectStoreError::NotImplemented)),
)
.await
.expect("record failed");
sink.record(cj_3.clone(), Ok(()))
.await
.expect("record failed");
assert_ok_counter(&registry, 1);
assert_error_counter(&registry, "unknown", 2);
assert_error_counter(&registry, "object_store", 1);
assert_eq!(
inner.results(),
HashMap::from([
(cj_1, Err(String::from("Operation not yet implemented.")),),
(cj_2, Err(String::from("msg 2"))),
(cj_3, Ok(())),
]),
);
}
fn assert_ok_counter(registry: &Registry, value: u64) {
assert_counter!(
registry,
U64Counter,
METRIC_NAME_PARTITION_COMPLETE_COUNT,
labels = Attributes::from(&[("result", "ok")]),
value = value,
);
}
fn assert_error_counter(registry: &Registry, kind: &'static str, value: u64) {
assert_counter!(
registry,
U64Counter,
METRIC_NAME_PARTITION_COMPLETE_COUNT,
labels = Attributes::from(&[("result", "error"), ("kind", kind)]),
value = value,
);
}
}

Some files were not shown because too many files have changed in this diff Show More