Merge branch 'main' into ntran/chunk_types

pull/24376/head
kodiakhq[bot] 2021-09-28 18:52:01 +00:00 committed by GitHub
commit 0beab3e7e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 105 additions and 2811 deletions

52
Cargo.lock generated
View File

@ -1259,11 +1259,11 @@ dependencies = [
"bytes",
"chrono",
"data_types",
"google_types",
"num_cpus",
"observability_deps",
"pbjson",
"pbjson_build",
"pbjson-build",
"pbjson-types",
"proc-macro2",
"prost",
"prost-build",
@ -1329,20 +1329,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]]
name = "google_types"
version = "0.1.0"
dependencies = [
"bytes",
"chrono",
"pbjson",
"pbjson_build",
"prost",
"prost-build",
"serde",
"serde_json",
]
[[package]]
name = "grpc-router"
version = "0.1.0"
@ -1700,7 +1686,6 @@ dependencies = [
"rdkafka",
"read_buffer",
"reqwest",
"routerify",
"rustyline",
"serde",
"serde_json",
@ -2714,7 +2699,6 @@ dependencies = [
"datafusion_util",
"futures",
"generated_types",
"google_types",
"internal_types",
"iox_object_store",
"metric",
@ -2723,6 +2707,7 @@ dependencies = [
"parking_lot",
"parquet",
"parquet-format",
"pbjson-types",
"persistence_windows",
"predicate",
"prost",
@ -2745,34 +2730,38 @@ checksum = "acbf547ad0c65e31259204bd90935776d1c693cec2f4ff7abb7a1bbbd40dfe58"
[[package]]
name = "pbjson"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbde1ff33a3aac655835d29a136e02935d4fa15f969268dbc97936159eb81d98"
dependencies = [
"base64 0.13.0",
"bytes",
"serde",
]
[[package]]
name = "pbjson_build"
name = "pbjson-build"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64eb98c480df4f6ea0f4befb7b6b61be8de3098337016acec79d3dd7d6d219ff"
dependencies = [
"heck",
"itertools",
"pbjson_test",
"prost",
"prost-types",
"tempfile",
]
[[package]]
name = "pbjson_test"
name = "pbjson-types"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d20cbdd9993508f3c3420a6133bddd267c540fbe23821887dc5af1d53a0d52"
dependencies = [
"bytes",
"chrono",
"pbjson",
"pbjson_build",
"pbjson-build",
"prost",
"prost-build",
"serde",
"serde_json",
]
[[package]]
@ -3510,19 +3499,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "routerify"
version = "2.0.0-beta-2"
source = "git+https://github.com/influxdata/routerify?rev=274e250#274e250e556b968ad02259282c0433e445db4200"
dependencies = [
"http",
"hyper",
"lazy_static",
"percent-encoding",
"regex",
"thiserror",
]
[[package]]
name = "rusoto_core"
version = "0.47.0"

View File

@ -39,7 +39,6 @@ members = [
"datafusion_util",
"entry",
"generated_types",
"google_types",
"influxdb2_client",
"influxdb_iox_client",
"influxdb_line_protocol",
@ -57,9 +56,6 @@ members = [
"observability_deps",
"packers",
"panic_logging",
"pbjson",
"pbjson_build",
"pbjson_test",
"persistence_windows",
"predicate",
"query",
@ -142,8 +138,6 @@ prettytable-rs = "0.8"
pprof = { version = "^0.5", default-features = false, features = ["flamegraph", "protobuf"], optional = true }
prost = "0.8"
rustyline = { version = "9.0", default-features = false }
# Forked to upgrade hyper and tokio
routerify = { git = "https://github.com/influxdata/routerify", rev = "274e250" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.67"
serde_urlencoded = "0.7.0"
@ -200,3 +194,4 @@ aws = ["object_store/aws"] # Optional AWS / S3 object store support
# Cargo cannot currently implement mutually exclusive features so let's force every build
# to pick either heappy or jemalloc_replacing_malloc feature at least until we figure out something better.
jemalloc_replacing_malloc = ["tikv-jemalloc-sys"]

View File

@ -7,9 +7,9 @@ edition = "2018"
[dependencies] # In alphabetical order
bytes = "1.0"
data_types = { path = "../data_types" }
google_types = { path = "../google_types" }
observability_deps = { path = "../observability_deps" }
pbjson = { path = "../pbjson" }
pbjson = "0.1"
pbjson-types = "0.1"
prost = "0.8"
regex = "1.4"
serde = { version = "1.0", features = ["derive"] }
@ -25,4 +25,4 @@ num_cpus = "1.13.0"
proc-macro2 = "=1.0.27"
tonic-build = "0.5"
prost-build = "0.8"
pbjson_build = { path = "../pbjson_build" }
pbjson-build = "0.1"

View File

@ -64,7 +64,7 @@ fn generate_grpc_types(root: &Path) -> Result<()> {
config
.compile_well_known_types()
.disable_comments(&[".google"])
.extern_path(".google.protobuf", "::google_types::protobuf")
.extern_path(".google.protobuf", "::pbjson_types")
.bytes(&[".influxdata.iox.catalog.v1.AddParquet.metadata"])
.btree_map(&[
".influxdata.iox.catalog.v1.DatabaseCheckpoint.sequencer_numbers",

View File

@ -74,19 +74,18 @@ impl TryFrom<management::Chunk> for ChunkSummary {
type Error = FieldViolation;
fn try_from(proto: management::Chunk) -> Result<Self, Self::Error> {
let convert_timestamp = |t: google_types::protobuf::Timestamp, field: &'static str| {
let convert_timestamp = |t: pbjson_types::Timestamp, field: &'static str| {
t.try_into().map_err(|_| FieldViolation {
field: field.to_string(),
description: "Timestamp must be positive".to_string(),
})
};
let timestamp = |t: Option<google_types::protobuf::Timestamp>, field: &'static str| {
let timestamp = |t: Option<pbjson_types::Timestamp>, field: &'static str| {
t.map(|t| convert_timestamp(t, field)).transpose()
};
let required_timestamp = |t: Option<google_types::protobuf::Timestamp>,
field: &'static str| {
let required_timestamp = |t: Option<pbjson_types::Timestamp>, field: &'static str| {
t.ok_or_else(|| FieldViolation {
field: field.to_string(),
description: "Timestamp is required".to_string(),
@ -186,7 +185,7 @@ mod test {
time_of_first_write: Some(now.into()),
time_of_last_write: Some(now.into()),
time_closed: None,
time_of_last_access: Some(google_types::protobuf::Timestamp {
time_of_last_access: Some(pbjson_types::Timestamp {
seconds: 50,
nanos: 7,
}),
@ -250,7 +249,7 @@ mod test {
time_of_first_write: Some(now.into()),
time_of_last_write: Some(now.into()),
time_closed: None,
time_of_last_access: Some(google_types::protobuf::Timestamp {
time_of_last_access: Some(pbjson_types::Timestamp {
seconds: 12,
nanos: 100_007,
}),

View File

@ -129,7 +129,7 @@ mod tests {
management::lifecycle_rules::MaxActiveCompactionsCfg::MaxActiveCompactions(8),
),
catalog_transactions_until_checkpoint: 10,
catalog_transaction_prune_age: Some(google_types::protobuf::Duration {
catalog_transaction_prune_age: Some(pbjson_types::Duration {
seconds: 11,
nanos: 22,
}),

View File

@ -1,7 +1,9 @@
//! Protobuf types for errors from the google standards and
//! conversions to `tonic::Status`
pub use google_types::*;
pub mod protobuf {
pub use pbjson_types::*;
}
pub mod rpc {
include!(concat!(env!("OUT_DIR"), "/google.rpc.rs"));

View File

@ -1,20 +0,0 @@
[package]
name = "google_types"
version = "0.1.0"
authors = ["Raphael Taylor-Davies <r.taylordavies@googlemail.com>"]
description = "Standard Protobuf definitions - extracted into separate crate to workaround https://github.com/hyperium/tonic/issues/521"
edition = "2018"
[dependencies] # In alphabetical order
bytes = "1.0"
chrono = "0.4"
pbjson = { path = "../pbjson" }
prost = "0.8"
serde = { version = "1.0", features = ["derive"] }
[dev-dependencies]
serde_json = "1.0"
[build-dependencies] # In alphabetical order
prost-build = "0.8"
pbjson_build = { path = "../pbjson_build" }

View File

@ -1,35 +0,0 @@
//! Compiles Protocol Buffers and FlatBuffers schema definitions into
//! native Rust types.
use std::env;
use std::path::PathBuf;
type Error = Box<dyn std::error::Error>;
type Result<T, E = Error> = std::result::Result<T, E>;
fn main() -> Result<()> {
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("protos");
let proto_files = vec![root.join("google/protobuf/types.proto")];
// Tell cargo to recompile if any of these proto files are changed
for proto_file in &proto_files {
println!("cargo:rerun-if-changed={}", proto_file.display());
}
let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin");
prost_build::Config::new()
.file_descriptor_set_path(&descriptor_path)
.compile_well_known_types()
.disable_comments(&["."])
.bytes(&[".google"])
.compile_protos(&proto_files, &[root])?;
let descriptor_set = std::fs::read(descriptor_path)?;
pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)?
.exclude([".google.protobuf.Duration", ".google.protobuf.Timestamp"])
.build(&[".google"])?;
Ok(())
}

View File

@ -1,15 +0,0 @@
syntax = "proto3";
package google.protobuf;
import "google/protobuf/any.proto";
import "google/protobuf/api.proto";
import "google/protobuf/descriptor.proto";
import "google/protobuf/duration.proto";
import "google/protobuf/empty.proto";
import "google/protobuf/field_mask.proto";
import "google/protobuf/source_context.proto";
import "google/protobuf/struct.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/type.proto";
import "google/protobuf/wrappers.proto";

View File

@ -1,159 +0,0 @@
use crate::protobuf::Duration;
use serde::{Deserialize, Serialize, Serializer};
use std::convert::{TryFrom, TryInto};
impl TryFrom<Duration> for std::time::Duration {
type Error = std::num::TryFromIntError;
fn try_from(value: Duration) -> Result<Self, Self::Error> {
Ok(std::time::Duration::new(
value.seconds.try_into()?,
value.nanos.try_into()?,
))
}
}
impl From<std::time::Duration> for Duration {
fn from(value: std::time::Duration) -> Self {
Self {
seconds: value.as_secs() as _,
nanos: value.subsec_nanos() as _,
}
}
}
impl Serialize for Duration {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if self.seconds != 0 && self.nanos != 0 && (self.nanos < 0) != (self.seconds < 0) {
return Err(serde::ser::Error::custom("Duration has inconsistent signs"));
}
let mut s = if self.seconds == 0 {
if self.nanos < 0 {
"-0".to_string()
} else {
"0".to_string()
}
} else {
self.seconds.to_string()
};
if self.nanos != 0 {
s.push('.');
let f = match split_nanos(self.nanos.abs() as u32) {
(millis, 0, 0) => format!("{:03}", millis),
(millis, micros, 0) => format!("{:03}{:03}", millis, micros),
(millis, micros, nanos) => format!("{:03}{:03}{:03}", millis, micros, nanos),
};
s.push_str(&f);
}
s.push('s');
serializer.serialize_str(&s)
}
}
impl<'de> serde::Deserialize<'de> for Duration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: &str = Deserialize::deserialize(deserializer)?;
let s = s
.strip_suffix('s')
.ok_or_else(|| serde::de::Error::custom("missing 's' suffix"))?;
let secs: f64 = s.parse().map_err(serde::de::Error::custom)?;
if secs < 0. {
let negated = std::time::Duration::from_secs_f64(-secs);
Ok(Self {
seconds: -(negated.as_secs() as i64),
nanos: -(negated.subsec_nanos() as i32),
})
} else {
Ok(std::time::Duration::from_secs_f64(secs).into())
}
}
}
/// Splits nanoseconds into whole milliseconds, microseconds, and nanoseconds
fn split_nanos(mut nanos: u32) -> (u32, u32, u32) {
let millis = nanos / 1_000_000;
nanos -= millis * 1_000_000;
let micros = nanos / 1_000;
nanos -= micros * 1_000;
(millis, micros, nanos)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_duration() {
let verify = |duration: &Duration, expected: &str| {
assert_eq!(serde_json::to_string(duration).unwrap().as_str(), expected);
assert_eq!(
&serde_json::from_str::<Duration>(expected).unwrap(),
duration
)
};
let duration = Duration {
seconds: 0,
nanos: 0,
};
verify(&duration, "\"0s\"");
let duration = Duration {
seconds: 0,
nanos: 123,
};
verify(&duration, "\"0.000000123s\"");
let duration = Duration {
seconds: 0,
nanos: 123456,
};
verify(&duration, "\"0.000123456s\"");
let duration = Duration {
seconds: 0,
nanos: 123456789,
};
verify(&duration, "\"0.123456789s\"");
let duration = Duration {
seconds: 0,
nanos: -67088,
};
verify(&duration, "\"-0.000067088s\"");
let duration = Duration {
seconds: 121,
nanos: 3454,
};
verify(&duration, "\"121.000003454s\"");
let duration = Duration {
seconds: -90,
nanos: -2456301,
};
verify(&duration, "\"-90.002456301s\"");
let duration = Duration {
seconds: -90,
nanos: 234,
};
serde_json::to_string(&duration).unwrap_err();
let duration = Duration {
seconds: 90,
nanos: -234,
};
serde_json::to_string(&duration).unwrap_err();
}
}

View File

@ -1,25 +0,0 @@
// This crate deliberately does not use the same linting rules as the other
// crates because of all the generated code it contains that we don't have much
// control over.
#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
#![allow(
unused_imports,
clippy::redundant_static_lifetimes,
clippy::redundant_closure,
clippy::redundant_field_names,
clippy::clone_on_ref_ptr
)]
mod pb {
pub mod google {
pub mod protobuf {
include!(concat!(env!("OUT_DIR"), "/google.protobuf.rs"));
include!(concat!(env!("OUT_DIR"), "/google.protobuf.serde.rs"));
}
}
}
mod duration;
mod timestamp;
pub use pb::google::*;

View File

@ -1,73 +0,0 @@
use crate::protobuf::Timestamp;
use chrono::{DateTime, NaiveDateTime, Utc};
use serde::{Deserialize, Serialize};
use std::convert::{TryFrom, TryInto};
impl TryFrom<Timestamp> for chrono::DateTime<Utc> {
type Error = std::num::TryFromIntError;
fn try_from(value: Timestamp) -> Result<Self, Self::Error> {
let Timestamp { seconds, nanos } = value;
let dt = NaiveDateTime::from_timestamp(seconds, nanos.try_into()?);
Ok(DateTime::<Utc>::from_utc(dt, Utc))
}
}
impl From<DateTime<Utc>> for Timestamp {
fn from(value: DateTime<Utc>) -> Self {
Self {
seconds: value.timestamp(),
nanos: value.timestamp_subsec_nanos() as i32,
}
}
}
impl Serialize for Timestamp {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let t: DateTime<Utc> = self.clone().try_into().map_err(serde::ser::Error::custom)?;
serializer.serialize_str(t.to_rfc3339().as_str())
}
}
impl<'de> serde::Deserialize<'de> for Timestamp {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: &str = Deserialize::deserialize(deserializer)?;
let d = DateTime::parse_from_rfc3339(s).map_err(serde::de::Error::custom)?;
let d: DateTime<Utc> = d.into();
Ok(d.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{FixedOffset, TimeZone};
use serde::de::value::{BorrowedStrDeserializer, Error};
#[test]
fn test_date() {
let datetime = FixedOffset::east(5 * 3600)
.ymd(2016, 11, 8)
.and_hms(21, 7, 9);
let encoded = datetime.to_rfc3339();
assert_eq!(&encoded, "2016-11-08T21:07:09+05:00");
let utc: DateTime<Utc> = datetime.into();
let utc_encoded = utc.to_rfc3339();
assert_eq!(&utc_encoded, "2016-11-08T16:07:09+00:00");
let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
let a: Timestamp = Timestamp::deserialize(deserializer).unwrap();
assert_eq!(a.seconds, utc.timestamp());
assert_eq!(a.nanos, utc.timestamp_subsec_nanos() as i32);
let encoded = serde_json::to_string(&a).unwrap();
assert_eq!(encoded, format!("\"{}\"", utc_encoded));
}
}

View File

@ -14,7 +14,6 @@ datafusion = { path = "../datafusion" }
datafusion_util = { path = "../datafusion_util" }
futures = "0.3.7"
generated_types = { path = "../generated_types" }
google_types = { path = "../google_types" }
internal_types = { path = "../internal_types" }
iox_object_store = { path = "../iox_object_store" }
metric = { path = "../metric" }
@ -23,6 +22,7 @@ observability_deps = { path = "../observability_deps" }
parquet = "5.0"
parquet-format = "2.6"
parking_lot = "0.11.1"
pbjson-types = "0.1"
persistence_windows = { path = "../persistence_windows" }
predicate = { path = "../predicate" }
prost = "0.8"

View File

@ -441,7 +441,7 @@ impl IoxMetadata {
}
fn decode_timestamp_from_field(
value: Option<google_types::protobuf::Timestamp>,
value: Option<pbjson_types::Timestamp>,
field: &'static str,
) -> Result<DateTime<Utc>> {
value

View File

@ -1,14 +0,0 @@
[package]
name = "pbjson"
version = "0.1.0"
authors = ["Raphael Taylor-Davies <r.taylordavies@googlemail.com>"]
edition = "2018"
description = "Utilities for pbjson converion"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
base64 = "0.13"
[dev-dependencies]
bytes = "1.0"

View File

@ -1,82 +0,0 @@
#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
#![warn(
missing_debug_implementations,
clippy::explicit_iter_loop,
clippy::use_self,
clippy::clone_on_ref_ptr,
clippy::future_not_send
)]
#[doc(hidden)]
pub mod private {
/// Re-export base64
pub use base64;
use serde::Deserialize;
use std::str::FromStr;
/// Used to parse a number from either a string or its raw representation
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
pub struct NumberDeserialize<T>(pub T);
#[derive(Deserialize)]
#[serde(untagged)]
enum Content<'a, T> {
Str(&'a str),
Number(T),
}
impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
where
T: FromStr + serde::Deserialize<'de>,
<T as FromStr>::Err: std::error::Error,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let content = Content::deserialize(deserializer)?;
Ok(Self(match content {
Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
Content::Number(v) => v,
}))
}
}
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
pub struct BytesDeserialize<T>(pub T);
impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
where
T: From<Vec<u8>>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: &str = Deserialize::deserialize(deserializer)?;
let decoded = base64::decode(s).map_err(serde::de::Error::custom)?;
Ok(Self(decoded.into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use serde::de::value::{BorrowedStrDeserializer, Error};
#[test]
fn test_bytes() {
let raw = vec![2, 5, 62, 2, 5, 7, 8, 43, 5, 8, 4, 23, 5, 7, 7, 3, 2, 5, 196];
let encoded = base64::encode(&raw);
let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
assert_eq!(raw.as_slice(), &a);
assert_eq!(raw.as_slice(), &b);
}
}
}

View File

@ -1,17 +0,0 @@
[package]
name = "pbjson_build"
version = "0.1.0"
authors = ["Raphael Taylor-Davies <r.taylordavies@googlemail.com>"]
edition = "2018"
description = "Generates Serialize and Deserialize implementations for prost message types"
[dependencies]
heck = "0.3"
prost = "0.8"
prost-types = "0.8"
itertools = "0.10"
[dev-dependencies]
tempfile = "3.1"
pbjson_test = { path = "../pbjson_test" }

View File

@ -1,260 +0,0 @@
//! This module contains code to parse and extract the protobuf descriptor
//! format for use by the rest of the codebase
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::fmt::{Display, Formatter};
use std::io::{Error, ErrorKind, Result};
use itertools::{EitherOrBoth, Itertools};
use prost_types::{
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
FileDescriptorSet, MessageOptions, OneofDescriptorProto,
};
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct Package(String);
impl Display for Package {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl Package {
pub fn new(s: impl Into<String>) -> Self {
let s = s.into();
assert!(
!s.starts_with('.'),
"package cannot start with \'.\', got \"{}\"",
s
);
Self(s)
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct TypeName(String);
impl Display for TypeName {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl TypeName {
pub fn new(s: impl Into<String>) -> Self {
let s = s.into();
assert!(
!s.contains('.'),
"type name cannot contain \'.\', got \"{}\"",
s
);
Self(s)
}
pub fn to_snake_case(&self) -> String {
use heck::SnakeCase;
self.0.to_snake_case()
}
pub fn to_camel_case(&self) -> String {
use heck::CamelCase;
self.0.to_camel_case()
}
pub fn to_shouty_snake_case(&self) -> String {
use heck::ShoutySnakeCase;
self.0.to_shouty_snake_case()
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct TypePath {
package: Package,
path: Vec<TypeName>,
}
impl Display for TypePath {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.package.fmt(f)?;
for element in &self.path {
write!(f, ".{}", element)?;
}
Ok(())
}
}
impl TypePath {
pub fn new(package: Package) -> Self {
Self {
package,
path: Default::default(),
}
}
pub fn package(&self) -> &Package {
&self.package
}
pub fn path(&self) -> &[TypeName] {
self.path.as_slice()
}
pub fn child(&self, name: TypeName) -> Self {
let path = self
.path
.iter()
.cloned()
.chain(std::iter::once(name))
.collect();
Self {
package: self.package.clone(),
path,
}
}
pub fn matches_prefix(&self, prefix: &str) -> bool {
let prefix = match prefix.strip_prefix('.') {
Some(prefix) => prefix,
None => return false,
};
if prefix.len() <= self.package.0.len() {
return self.package.0.starts_with(prefix);
}
match prefix.strip_prefix(&self.package.0) {
Some(prefix) => {
let split = prefix.split('.').skip(1);
for zipped in self.path.iter().zip_longest(split) {
match zipped {
EitherOrBoth::Both(a, b) if a.0.as_str() == b => continue,
EitherOrBoth::Left(_) => return true,
_ => return false,
}
}
true
}
None => false,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DescriptorSet {
descriptors: BTreeMap<TypePath, Descriptor>,
}
impl DescriptorSet {
pub fn new() -> Self {
Self::default()
}
pub fn register_encoded(&mut self, encoded: &[u8]) -> Result<()> {
let descriptors: FileDescriptorSet =
prost::Message::decode(encoded).map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
for file in descriptors.file {
let syntax = match file.syntax.as_deref() {
None | Some("proto2") => Syntax::Proto2,
Some("proto3") => Syntax::Proto3,
Some(s) => panic!("unknown syntax: {}", s),
};
let package = Package::new(file.package.expect("expected package"));
let path = TypePath::new(package);
for descriptor in file.message_type {
self.register_message(&path, descriptor, syntax)
}
for descriptor in file.enum_type {
self.register_enum(&path, descriptor)
}
}
Ok(())
}
pub fn iter(&self) -> impl Iterator<Item = (&TypePath, &Descriptor)> {
self.descriptors.iter()
}
fn register_message(&mut self, path: &TypePath, descriptor: DescriptorProto, syntax: Syntax) {
let name = TypeName::new(descriptor.name.expect("expected name"));
let child_path = path.child(name);
for child_descriptor in descriptor.enum_type {
self.register_enum(&child_path, child_descriptor)
}
for child_descriptor in descriptor.nested_type {
self.register_message(&child_path, child_descriptor, syntax)
}
self.register_descriptor(
child_path.clone(),
Descriptor::Message(MessageDescriptor {
path: child_path,
options: descriptor.options,
one_of: descriptor.oneof_decl,
fields: descriptor.field,
syntax,
}),
);
}
fn register_enum(&mut self, path: &TypePath, descriptor: EnumDescriptorProto) {
let name = TypeName::new(descriptor.name.expect("expected name"));
self.register_descriptor(
path.child(name),
Descriptor::Enum(EnumDescriptor {
values: descriptor.value,
}),
);
}
fn register_descriptor(&mut self, path: TypePath, descriptor: Descriptor) {
match self.descriptors.entry(path) {
Entry::Occupied(o) => panic!("descriptor already registered for {}", o.key()),
Entry::Vacant(v) => v.insert(descriptor),
};
}
}
#[derive(Debug, Clone, Copy)]
pub enum Syntax {
Proto2,
Proto3,
}
#[derive(Debug, Clone)]
pub enum Descriptor {
Enum(EnumDescriptor),
Message(MessageDescriptor),
}
#[derive(Debug, Clone)]
pub struct EnumDescriptor {
pub values: Vec<EnumValueDescriptorProto>,
}
#[derive(Debug, Clone)]
pub struct MessageDescriptor {
pub path: TypePath,
pub options: Option<MessageOptions>,
pub one_of: Vec<OneofDescriptorProto>,
pub fields: Vec<FieldDescriptorProto>,
pub syntax: Syntax,
}
impl MessageDescriptor {
/// Whether this is an auto-generated type for the map field
pub fn is_map(&self) -> bool {
self.options
.as_ref()
.and_then(|options| options.map_entry)
.unwrap_or(false)
}
}

View File

@ -1,26 +0,0 @@
///! Contains code to escape strings to avoid collisions with reserved Rust keywords
pub fn escape_ident(mut ident: String) -> String {
// Copied from prost-build::ident
//
// Use a raw identifier if the identifier matches a Rust keyword:
// https://doc.rust-lang.org/reference/keywords.html.
match ident.as_str() {
// 2015 strict keywords.
| "as" | "break" | "const" | "continue" | "else" | "enum" | "false"
| "fn" | "for" | "if" | "impl" | "in" | "let" | "loop" | "match" | "mod" | "move" | "mut"
| "pub" | "ref" | "return" | "static" | "struct" | "trait" | "true"
| "type" | "unsafe" | "use" | "where" | "while"
// 2018 strict keywords.
| "dyn"
// 2015 reserved keywords.
| "abstract" | "become" | "box" | "do" | "final" | "macro" | "override" | "priv" | "typeof"
| "unsized" | "virtual" | "yield"
// 2018 reserved keywords.
| "async" | "await" | "try" => ident.insert_str(0, "r#"),
// the following keywords are not supported as raw identifiers and are therefore suffixed with an underscore.
"self" | "super" | "extern" | "crate" => ident += "_",
_ => (),
};
ident
}

View File

@ -1,129 +0,0 @@
//! This module contains the actual code generation logic
use std::collections::BTreeMap;
use std::fmt::{Display, Formatter};
use std::io::{Result, Write};
use crate::descriptor::TypePath;
mod enumeration;
mod message;
pub use enumeration::generate_enum;
pub use message::generate_message;
#[derive(Debug, Clone, Copy)]
struct Indent(usize);
impl Display for Indent {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
for _ in 0..self.0 {
write!(f, " ")?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct Config {
pub extern_types: BTreeMap<TypePath, String>,
}
impl Config {
fn rust_type(&self, path: &TypePath) -> String {
if let Some(t) = self.extern_types.get(path) {
return t.clone();
}
let mut ret = String::new();
let path = path.path();
assert!(!path.is_empty(), "path cannot be empty");
for i in &path[..(path.len() - 1)] {
ret.push_str(i.to_snake_case().as_str());
ret.push_str("::");
}
ret.push_str(path.last().unwrap().to_camel_case().as_str());
ret
}
fn rust_variant(&self, enumeration: &TypePath, variant: &str) -> String {
use heck::CamelCase;
assert!(
variant
.chars()
.all(|c| matches!(c, '0'..='9' | 'A'..='Z' | '_')),
"illegal variant - {}",
variant
);
// TODO: Config to disable stripping prefix
let enumeration_name = enumeration.path().last().unwrap().to_shouty_snake_case();
let variant = match variant.strip_prefix(&enumeration_name) {
Some("") => variant,
Some(stripped) => stripped,
None => variant,
};
variant.to_camel_case()
}
}
fn write_fields_array<'a, W: Write, I: Iterator<Item = &'a str>>(
writer: &mut W,
indent: usize,
variants: I,
) -> Result<()> {
writeln!(writer, "{}const FIELDS: &[&str] = &[", Indent(indent))?;
for name in variants {
writeln!(writer, "{}\"{}\",", Indent(indent + 1), name)?;
}
writeln!(writer, "{}];", Indent(indent))?;
writeln!(writer)
}
fn write_serialize_start<W: Write>(indent: usize, rust_type: &str, writer: &mut W) -> Result<()> {
writeln!(
writer,
r#"{indent}impl serde::Serialize for {rust_type} {{
{indent} #[allow(deprecated)]
{indent} fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
{indent} where
{indent} S: serde::Serializer,
{indent} {{"#,
indent = Indent(indent),
rust_type = rust_type
)
}
fn write_serialize_end<W: Write>(indent: usize, writer: &mut W) -> Result<()> {
writeln!(
writer,
r#"{indent} }}
{indent}}}"#,
indent = Indent(indent),
)
}
fn write_deserialize_start<W: Write>(indent: usize, rust_type: &str, writer: &mut W) -> Result<()> {
writeln!(
writer,
r#"{indent}impl<'de> serde::Deserialize<'de> for {rust_type} {{
{indent} #[allow(deprecated)]
{indent} fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
{indent} where
{indent} D: serde::Deserializer<'de>,
{indent} {{"#,
indent = Indent(indent),
rust_type = rust_type
)
}
fn write_deserialize_end<W: Write>(indent: usize, writer: &mut W) -> Result<()> {
writeln!(
writer,
r#"{indent} }}
{indent}}}"#,
indent = Indent(indent),
)
}

View File

@ -1,138 +0,0 @@
//! This module contains the code to generate Serialize and Deserialize
//! implementations for enumeration type
//!
//! An enumeration should be decode-able from the full string variant name
//! or its integer tag number, and should encode to the string representation
use super::{
write_deserialize_end, write_deserialize_start, write_serialize_end, write_serialize_start,
Config, Indent,
};
use crate::descriptor::{EnumDescriptor, TypePath};
use crate::generator::write_fields_array;
use std::io::{Result, Write};
pub fn generate_enum<W: Write>(
config: &Config,
path: &TypePath,
descriptor: &EnumDescriptor,
writer: &mut W,
) -> Result<()> {
let rust_type = config.rust_type(path);
let variants: Vec<_> = descriptor
.values
.iter()
.map(|variant| {
let variant_name = variant.name.clone().unwrap();
let rust_variant = config.rust_variant(path, &variant_name);
(variant_name, rust_variant)
})
.collect();
// Generate Serialize
write_serialize_start(0, &rust_type, writer)?;
writeln!(writer, "{}let variant = match self {{", Indent(2))?;
for (variant_name, rust_variant) in &variants {
writeln!(
writer,
"{}Self::{} => \"{}\",",
Indent(3),
rust_variant,
variant_name
)?;
}
writeln!(writer, "{}}};", Indent(2))?;
writeln!(writer, "{}serializer.serialize_str(variant)", Indent(2))?;
write_serialize_end(0, writer)?;
// Generate Deserialize
write_deserialize_start(0, &rust_type, writer)?;
write_fields_array(writer, 2, variants.iter().map(|(name, _)| name.as_str()))?;
write_visitor(writer, 2, &rust_type, &variants)?;
// Use deserialize_any to allow users to provide integers or strings
writeln!(
writer,
"{}deserializer.deserialize_any(GeneratedVisitor)",
Indent(2)
)?;
write_deserialize_end(0, writer)?;
Ok(())
}
fn write_visitor<W: Write>(
writer: &mut W,
indent: usize,
rust_type: &str,
variants: &[(String, String)],
) -> Result<()> {
// Protobuf supports deserialization of enumerations both from string and integer values
writeln!(
writer,
r#"{indent}struct GeneratedVisitor;
{indent}impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {{
{indent} type Value = {rust_type};
{indent} fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
{indent} write!(formatter, "expected one of: {{:?}}", &FIELDS)
{indent} }}
{indent} fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
{indent} where
{indent} E: serde::de::Error,
{indent} {{
{indent} use std::convert::TryFrom;
{indent} i32::try_from(v)
{indent} .ok()
{indent} .and_then({rust_type}::from_i32)
{indent} .ok_or_else(|| {{
{indent} serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self)
{indent} }})
{indent} }}
{indent} fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
{indent} where
{indent} E: serde::de::Error,
{indent} {{
{indent} use std::convert::TryFrom;
{indent} i32::try_from(v)
{indent} .ok()
{indent} .and_then({rust_type}::from_i32)
{indent} .ok_or_else(|| {{
{indent} serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self)
{indent} }})
{indent} }}
{indent} fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
{indent} where
{indent} E: serde::de::Error,
{indent} {{"#,
indent = Indent(indent),
rust_type = rust_type,
)?;
writeln!(writer, "{}match value {{", Indent(indent + 2))?;
for (variant_name, rust_variant) in variants {
writeln!(
writer,
"{}\"{}\" => Ok({}::{}),",
Indent(indent + 3),
variant_name,
rust_type,
rust_variant
)?;
}
writeln!(
writer,
"{indent}_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),",
indent = Indent(indent + 3)
)?;
writeln!(writer, "{}}}", Indent(indent + 2))?;
writeln!(writer, "{}}}", Indent(indent + 1))?;
writeln!(writer, "{}}}", Indent(indent))
}

View File

@ -1,809 +0,0 @@
//! This module contains the code to generate Serialize and Deserialize
//! implementations for message types
//!
//! The implementation follows the proto3 [JSON mapping][1] with the default options
//!
//! Importantly:
//! - numeric types can be decoded from either a string or number
//! - 32-bit integers and floats are encoded as numbers
//! - 64-bit integers are encoded as strings
//! - repeated fields are encoded as arrays
//! - bytes are base64 encoded (NOT CURRENTLY SUPPORTED)
//! - messages and maps are encoded as objects
//! - fields are lowerCamelCase except where overridden by the proto definition
//! - default values are not emitted on encode
//! - unrecognised fields error on decode
//!
//! Note: This will not generate code to correctly serialize/deserialize well-known-types
//! such as google.protobuf.Any, google.protobuf.Duration, etc... conversions for these
//! special-cased messages will need to be manually implemented. Once done so, however,
//! any messages containing these types will serialize/deserialize correctly
//!
//! [1]: https://developers.google.com/protocol-buffers/docs/proto3#json
use std::io::{Result, Write};
use crate::message::{Field, FieldModifier, FieldType, Message, OneOf, ScalarType};
use super::{
write_deserialize_end, write_deserialize_start, write_serialize_end, write_serialize_start,
Config, Indent,
};
use crate::descriptor::TypePath;
use crate::generator::write_fields_array;
pub fn generate_message<W: Write>(
config: &Config,
message: &Message,
writer: &mut W,
) -> Result<()> {
let rust_type = config.rust_type(&message.path);
// Generate Serialize
write_serialize_start(0, &rust_type, writer)?;
write_message_serialize(config, 2, message, writer)?;
write_serialize_end(0, writer)?;
// Generate Deserialize
write_deserialize_start(0, &rust_type, writer)?;
write_deserialize_message(config, 2, message, &rust_type, writer)?;
write_deserialize_end(0, writer)?;
Ok(())
}
fn write_field_empty_predicate<W: Write>(member: &Field, writer: &mut W) -> Result<()> {
match (&member.field_type, &member.field_modifier) {
(_, FieldModifier::Required) => unreachable!(),
(_, FieldModifier::Repeated)
| (FieldType::Map(_, _), _)
| (FieldType::Scalar(ScalarType::String), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::Bytes), FieldModifier::UseDefault) => {
write!(writer, "!self.{}.is_empty()", member.rust_field_name())
}
(_, FieldModifier::Optional) | (FieldType::Message(_), _) => {
write!(writer, "self.{}.is_some()", member.rust_field_name())
}
(FieldType::Scalar(ScalarType::F64), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::F32), FieldModifier::UseDefault) => {
write!(writer, "self.{} != 0.", member.rust_field_name())
}
(FieldType::Scalar(ScalarType::Bool), FieldModifier::UseDefault) => {
write!(writer, "self.{}", member.rust_field_name())
}
(FieldType::Enum(_), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::I64), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::I32), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::U32), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::U64), FieldModifier::UseDefault) => {
write!(writer, "self.{} != 0", member.rust_field_name())
}
}
}
fn write_message_serialize<W: Write>(
config: &Config,
indent: usize,
message: &Message,
writer: &mut W,
) -> Result<()> {
write_struct_serialize_start(indent, message, writer)?;
for field in &message.fields {
write_serialize_field(config, indent, field, writer)?;
}
for one_of in &message.one_ofs {
write_serialize_one_of(indent, config, one_of, writer)?;
}
write_struct_serialize_end(indent, writer)
}
fn write_struct_serialize_start<W: Write>(
indent: usize,
message: &Message,
writer: &mut W,
) -> Result<()> {
writeln!(writer, "{}use serde::ser::SerializeStruct;", Indent(indent))?;
let required_len = message
.fields
.iter()
.filter(|member| member.field_modifier.is_required())
.count();
if required_len != message.fields.len() || !message.one_ofs.is_empty() {
writeln!(writer, "{}let mut len = {};", Indent(indent), required_len)?;
} else {
writeln!(writer, "{}let len = {};", Indent(indent), required_len)?;
}
for field in &message.fields {
if field.field_modifier.is_required() {
continue;
}
write!(writer, "{}if ", Indent(indent))?;
write_field_empty_predicate(field, writer)?;
writeln!(writer, " {{")?;
writeln!(writer, "{}len += 1;", Indent(indent + 1))?;
writeln!(writer, "{}}}", Indent(indent))?;
}
for one_of in &message.one_ofs {
writeln!(
writer,
"{}if self.{}.is_some() {{",
Indent(indent),
one_of.rust_field_name()
)?;
writeln!(writer, "{}len += 1;", Indent(indent + 1))?;
writeln!(writer, "{}}}", Indent(indent))?;
}
if !message.fields.is_empty() || !message.one_ofs.is_empty() {
writeln!(
writer,
"{}let mut struct_ser = serializer.serialize_struct(\"{}\", len)?;",
Indent(indent),
message.path
)?;
} else {
writeln!(
writer,
"{}let struct_ser = serializer.serialize_struct(\"{}\", len)?;",
Indent(indent),
message.path
)?;
}
Ok(())
}
fn write_struct_serialize_end<W: Write>(indent: usize, writer: &mut W) -> Result<()> {
writeln!(writer, "{}struct_ser.end()", Indent(indent))
}
fn write_decode_variant<W: Write>(
config: &Config,
indent: usize,
value: &str,
path: &TypePath,
writer: &mut W,
) -> Result<()> {
writeln!(writer, "{}::from_i32({})", config.rust_type(path), value)?;
write!(
writer,
"{}.ok_or_else(|| serde::ser::Error::custom(format!(\"Invalid variant {{}}\", {})))",
Indent(indent),
value
)
}
/// Depending on the type of the field different ways of accessing field's value
/// are needed - this allows decoupling the type serialization logic from the logic
/// that manipulates its container e.g. Vec, Option, HashMap
struct Variable<'a> {
/// A reference to the field's value
as_ref: &'a str,
/// The field's value
as_unref: &'a str,
/// The field without any leading "&" or "*"
raw: &'a str,
}
fn write_serialize_variable<W: Write>(
config: &Config,
indent: usize,
field: &Field,
variable: Variable<'_>,
writer: &mut W,
) -> Result<()> {
match &field.field_type {
FieldType::Scalar(scalar) => write_serialize_scalar_variable(
indent,
*scalar,
field.field_modifier,
variable,
field.json_name(),
writer,
),
FieldType::Enum(path) => {
write!(writer, "{}let v = ", Indent(indent))?;
match field.field_modifier {
FieldModifier::Repeated => {
writeln!(writer, "{}.iter().cloned().map(|v| {{", variable.raw)?;
write!(writer, "{}", Indent(indent + 1))?;
write_decode_variant(config, indent + 2, "v", path, writer)?;
writeln!(writer)?;
write!(
writer,
"{}}}).collect::<Result<Vec<_>, _>>()",
Indent(indent + 1)
)
}
_ => write_decode_variant(config, indent + 1, variable.as_unref, path, writer),
}?;
writeln!(writer, "?;")?;
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", &v)?;",
Indent(indent),
field.json_name()
)
}
FieldType::Map(_, value_type)
if matches!(
value_type.as_ref(),
FieldType::Scalar(ScalarType::I64)
| FieldType::Scalar(ScalarType::U64)
| FieldType::Enum(_)
) =>
{
writeln!(
writer,
"{}let v: std::collections::HashMap<_, _> = {}.iter()",
Indent(indent),
variable.raw
)?;
match value_type.as_ref() {
FieldType::Scalar(ScalarType::I64) | FieldType::Scalar(ScalarType::U64) => {
writeln!(
writer,
"{}.map(|(k, v)| (k, v.to_string())).collect();",
Indent(indent + 1)
)?;
}
FieldType::Enum(path) => {
writeln!(writer, "{}.map(|(k, v)| {{", Indent(indent + 1))?;
write!(writer, "{}let v = ", Indent(indent + 2))?;
write_decode_variant(config, indent + 3, "*v", path, writer)?;
writeln!(writer, "?;")?;
writeln!(writer, "{}Ok((k, v))", Indent(indent + 2))?;
writeln!(
writer,
"{}}}).collect::<Result<_,_>>()?;",
Indent(indent + 1)
)?;
}
_ => unreachable!(),
}
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", &v)?;",
Indent(indent),
field.json_name()
)
}
_ => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", {})?;",
Indent(indent),
field.json_name(),
variable.as_ref
)
}
}
}
fn write_serialize_scalar_variable<W: Write>(
indent: usize,
scalar: ScalarType,
field_modifier: FieldModifier,
variable: Variable<'_>,
json_name: String,
writer: &mut W,
) -> Result<()> {
let conversion = match scalar {
ScalarType::I64 | ScalarType::U64 => "ToString::to_string",
ScalarType::Bytes => "pbjson::private::base64::encode",
_ => {
return writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", {})?;",
Indent(indent),
json_name,
variable.as_ref
)
}
};
match field_modifier {
FieldModifier::Repeated => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", &{}.iter().map({}).collect::<Vec<_>>())?;",
Indent(indent),
json_name,
variable.raw,
conversion
)
}
_ => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", {}(&{}).as_str())?;",
Indent(indent),
json_name,
conversion,
variable.raw,
)
}
}
}
fn write_serialize_field<W: Write>(
config: &Config,
indent: usize,
field: &Field,
writer: &mut W,
) -> Result<()> {
let as_ref = format!("&self.{}", field.rust_field_name());
let variable = Variable {
as_ref: as_ref.as_str(),
as_unref: &as_ref.as_str()[1..],
raw: &as_ref.as_str()[1..],
};
match &field.field_modifier {
FieldModifier::Required => {
write_serialize_variable(config, indent, field, variable, writer)?;
}
FieldModifier::Optional => {
writeln!(
writer,
"{}if let Some(v) = {}.as_ref() {{",
Indent(indent),
variable.as_unref
)?;
let variable = Variable {
as_ref: "v",
as_unref: "*v",
raw: "v",
};
write_serialize_variable(config, indent + 1, field, variable, writer)?;
writeln!(writer, "{}}}", Indent(indent))?;
}
FieldModifier::Repeated | FieldModifier::UseDefault => {
write!(writer, "{}if ", Indent(indent))?;
write_field_empty_predicate(field, writer)?;
writeln!(writer, " {{")?;
write_serialize_variable(config, indent + 1, field, variable, writer)?;
writeln!(writer, "{}}}", Indent(indent))?;
}
}
Ok(())
}
fn write_serialize_one_of<W: Write>(
indent: usize,
config: &Config,
one_of: &OneOf,
writer: &mut W,
) -> Result<()> {
writeln!(
writer,
"{}if let Some(v) = self.{}.as_ref() {{",
Indent(indent),
one_of.rust_field_name()
)?;
writeln!(writer, "{}match v {{", Indent(indent + 1))?;
for field in &one_of.fields {
writeln!(
writer,
"{}{}::{}(v) => {{",
Indent(indent + 2),
config.rust_type(&one_of.path),
field.rust_type_name(),
)?;
let variable = Variable {
as_ref: "v",
as_unref: "*v",
raw: "v",
};
write_serialize_variable(config, indent + 3, field, variable, writer)?;
writeln!(writer, "{}}}", Indent(indent + 2))?;
}
writeln!(writer, "{}}}", Indent(indent + 1),)?;
writeln!(writer, "{}}}", Indent(indent))
}
fn write_deserialize_message<W: Write>(
config: &Config,
indent: usize,
message: &Message,
rust_type: &str,
writer: &mut W,
) -> Result<()> {
write_deserialize_field_name(2, message, writer)?;
writeln!(writer, "{}struct GeneratedVisitor;", Indent(indent))?;
writeln!(
writer,
r#"{indent}impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {{
{indent} type Value = {rust_type};
{indent} fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
{indent} formatter.write_str("struct {name}")
{indent} }}
{indent} fn visit_map<V>(self, mut map: V) -> Result<{rust_type}, V::Error>
{indent} where
{indent} V: serde::de::MapAccess<'de>,
{indent} {{"#,
indent = Indent(indent),
name = message.path,
rust_type = rust_type,
)?;
for field in &message.fields {
writeln!(
writer,
"{}let mut {} = None;",
Indent(indent + 2),
field.rust_field_name(),
)?;
}
for one_of in &message.one_ofs {
writeln!(
writer,
"{}let mut {} = None;",
Indent(indent + 2),
one_of.rust_field_name(),
)?;
}
if !message.fields.is_empty() || !message.one_ofs.is_empty() {
writeln!(
writer,
"{}while let Some(k) = map.next_key()? {{",
Indent(indent + 2)
)?;
writeln!(writer, "{}match k {{", Indent(indent + 3))?;
for field in &message.fields {
write_deserialize_field(config, indent + 4, field, None, writer)?;
}
for one_of in &message.one_ofs {
for field in &one_of.fields {
write_deserialize_field(config, indent + 4, field, Some(one_of), writer)?;
}
}
writeln!(writer, "{}}}", Indent(indent + 3))?;
writeln!(writer, "{}}}", Indent(indent + 2))?;
} else {
writeln!(
writer,
"{}while map.next_key::<GeneratedField>()?.is_some() {{}}",
Indent(indent + 2)
)?;
}
writeln!(writer, "{}Ok({} {{", Indent(indent + 2), rust_type)?;
for field in &message.fields {
match field.field_modifier {
FieldModifier::Required => {
writeln!(
writer,
"{indent}{field}: {field}.ok_or_else(|| serde::de::Error::missing_field(\"{json_name}\"))?,",
indent=Indent(indent + 3),
field= field.rust_field_name(),
json_name= field.json_name()
)?;
}
FieldModifier::UseDefault | FieldModifier::Repeated => {
// Note: this currently does not hydrate optional proto2 fields with defaults
writeln!(
writer,
"{indent}{field}: {field}.unwrap_or_default(),",
indent = Indent(indent + 3),
field = field.rust_field_name()
)?;
}
_ => {
writeln!(
writer,
"{indent}{field},",
indent = Indent(indent + 3),
field = field.rust_field_name()
)?;
}
}
}
for one_of in &message.one_ofs {
writeln!(
writer,
"{indent}{field},",
indent = Indent(indent + 3),
field = one_of.rust_field_name(),
)?;
}
writeln!(writer, "{}}})", Indent(indent + 2))?;
writeln!(writer, "{}}}", Indent(indent + 1))?;
writeln!(writer, "{}}}", Indent(indent))?;
writeln!(
writer,
"{}deserializer.deserialize_struct(\"{}\", FIELDS, GeneratedVisitor)",
Indent(indent),
message.path
)
}
fn write_deserialize_field_name<W: Write>(
indent: usize,
message: &Message,
writer: &mut W,
) -> Result<()> {
let fields: Vec<_> = message
.all_fields()
.map(|field| (field.json_name(), field.rust_type_name()))
.collect();
write_fields_array(writer, indent, fields.iter().map(|(name, _)| name.as_str()))?;
write_fields_enum(writer, indent, fields.iter().map(|(_, name)| name.as_str()))?;
writeln!(
writer,
r#"{indent}impl<'de> serde::Deserialize<'de> for GeneratedField {{
{indent} fn deserialize<D>(deserializer: D) -> Result<GeneratedField, D::Error>
{indent} where
{indent} D: serde::Deserializer<'de>,
{indent} {{
{indent} struct GeneratedVisitor;
{indent} impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {{
{indent} type Value = GeneratedField;
{indent} fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
{indent} write!(formatter, "expected one of: {{:?}}", &FIELDS)
{indent} }}
{indent} fn visit_str<E>(self, value: &str) -> Result<GeneratedField, E>
{indent} where
{indent} E: serde::de::Error,
{indent} {{"#,
indent = Indent(indent)
)?;
if !fields.is_empty() {
writeln!(writer, "{}match value {{", Indent(indent + 4))?;
for (json_name, type_name) in &fields {
writeln!(
writer,
"{}\"{}\" => Ok(GeneratedField::{}),",
Indent(indent + 5),
json_name,
type_name
)?;
}
writeln!(
writer,
"{}_ => Err(serde::de::Error::unknown_field(value, FIELDS)),",
Indent(indent + 5)
)?;
writeln!(writer, "{}}}", Indent(indent + 4))?;
} else {
writeln!(
writer,
"{}Err(serde::de::Error::unknown_field(value, FIELDS))",
Indent(indent + 4)
)?;
}
writeln!(
writer,
r#"{indent} }}
{indent} }}
{indent} deserializer.deserialize_identifier(GeneratedVisitor)
{indent} }}
{indent}}}"#,
indent = Indent(indent)
)
}
fn write_fields_enum<'a, W: Write, I: Iterator<Item = &'a str>>(
writer: &mut W,
indent: usize,
fields: I,
) -> Result<()> {
writeln!(
writer,
"{}#[allow(clippy::enum_variant_names)]",
Indent(indent)
)?;
writeln!(writer, "{}enum GeneratedField {{", Indent(indent))?;
for type_name in fields {
writeln!(writer, "{}{},", Indent(indent + 1), type_name)?;
}
writeln!(writer, "{}}}", Indent(indent))
}
fn write_deserialize_field<W: Write>(
config: &Config,
indent: usize,
field: &Field,
one_of: Option<&OneOf>,
writer: &mut W,
) -> Result<()> {
let field_name = match one_of {
Some(one_of) => one_of.rust_field_name(),
None => field.rust_field_name(),
};
let json_name = field.json_name();
writeln!(
writer,
"{}GeneratedField::{} => {{",
Indent(indent),
field.rust_type_name()
)?;
writeln!(
writer,
"{}if {}.is_some() {{",
Indent(indent + 1),
field_name
)?;
// Note: this will report duplicate field if multiple value are specified for a one of
writeln!(
writer,
"{}return Err(serde::de::Error::duplicate_field(\"{}\"));",
Indent(indent + 2),
json_name
)?;
writeln!(writer, "{}}}", Indent(indent + 1))?;
write!(writer, "{}{} = Some(", Indent(indent + 1), field_name)?;
if let Some(one_of) = one_of {
write!(
writer,
"{}::{}(",
config.rust_type(&one_of.path),
field.rust_type_name()
)?;
}
match &field.field_type {
FieldType::Scalar(scalar) => {
write_encode_scalar_field(indent + 1, *scalar, field.field_modifier, writer)?;
}
FieldType::Enum(path) => match field.field_modifier {
FieldModifier::Repeated => {
write!(
writer,
"map.next_value::<Vec<{}>>()?.into_iter().map(|x| x as i32).collect()",
config.rust_type(path)
)?;
}
_ => {
write!(
writer,
"map.next_value::<{}>()? as i32",
config.rust_type(path)
)?;
}
},
FieldType::Map(key, value) => {
writeln!(writer)?;
write!(
writer,
"{}map.next_value::<std::collections::HashMap<",
Indent(indent + 2),
)?;
let map_k = match key {
ScalarType::Bytes => {
// https://github.com/tokio-rs/prost/issues/531
panic!("bytes are not currently supported as map keys")
}
_ if key.is_numeric() => {
write!(
writer,
"::pbjson::private::NumberDeserialize<{}>",
key.rust_type()
)?;
"k.0"
}
_ => {
write!(writer, "_")?;
"k"
}
};
write!(writer, ", ")?;
let map_v = match value.as_ref() {
FieldType::Scalar(scalar) if scalar.is_numeric() => {
write!(
writer,
"::pbjson::private::NumberDeserialize<{}>",
scalar.rust_type()
)?;
"v.0"
}
FieldType::Scalar(ScalarType::Bytes) => {
// https://github.com/tokio-rs/prost/issues/531
panic!("bytes are not currently supported as map values")
}
FieldType::Enum(path) => {
write!(writer, "{}", config.rust_type(path))?;
"v as i32"
}
FieldType::Map(_, _) => panic!("protobuf disallows nested maps"),
_ => {
write!(writer, "_")?;
"v"
}
};
writeln!(writer, ">>()?")?;
if map_k != "k" || map_v != "v" {
writeln!(
writer,
"{}.into_iter().map(|(k,v)| ({}, {})).collect()",
Indent(indent + 3),
map_k,
map_v,
)?;
}
write!(writer, "{}", Indent(indent + 1))?;
}
_ => {
write!(writer, "map.next_value()?",)?;
}
};
if one_of.is_some() {
write!(writer, ")")?;
}
writeln!(writer, ");")?;
writeln!(writer, "{}}}", Indent(indent))
}
fn write_encode_scalar_field<W: Write>(
indent: usize,
scalar: ScalarType,
field_modifier: FieldModifier,
writer: &mut W,
) -> Result<()> {
let deserializer = match scalar {
ScalarType::Bytes => "BytesDeserialize",
_ if scalar.is_numeric() => "NumberDeserialize",
_ => return write!(writer, "map.next_value()?",),
};
writeln!(writer)?;
match field_modifier {
FieldModifier::Repeated => {
writeln!(
writer,
"{}map.next_value::<Vec<::pbjson::private::{}<_>>>()?",
Indent(indent + 1),
deserializer
)?;
writeln!(
writer,
"{}.into_iter().map(|x| x.0).collect()",
Indent(indent + 2)
)?;
}
_ => {
writeln!(
writer,
"{}map.next_value::<::pbjson::private::{}<_>>()?.0",
Indent(indent + 1),
deserializer
)?;
}
}
write!(writer, "{}", Indent(indent))
}

View File

@ -1,131 +0,0 @@
#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
#![warn(
missing_debug_implementations,
clippy::explicit_iter_loop,
clippy::use_self,
clippy::clone_on_ref_ptr,
clippy::future_not_send
)]
use crate::descriptor::{Descriptor, DescriptorSet, Package};
use crate::generator::{generate_enum, generate_message, Config};
use crate::message::resolve_message;
use std::io::{BufWriter, Error, ErrorKind, Result, Write};
use std::path::PathBuf;
mod descriptor;
mod escape;
mod generator;
mod message;
#[derive(Debug, Default)]
pub struct Builder {
descriptors: descriptor::DescriptorSet,
exclude: Vec<String>,
out_dir: Option<PathBuf>,
}
impl Builder {
/// Create a new `Builder`
pub fn new() -> Self {
Self {
descriptors: DescriptorSet::new(),
exclude: Default::default(),
out_dir: None,
}
}
/// Register an encoded `FileDescriptorSet` with this `Builder`
pub fn register_descriptors(&mut self, descriptors: &[u8]) -> Result<&mut Self> {
self.descriptors.register_encoded(descriptors)?;
Ok(self)
}
/// Don't generate code for the following type prefixes
pub fn exclude<S: Into<String>, I: IntoIterator<Item = S>>(
&mut self,
prefixes: I,
) -> &mut Self {
self.exclude.extend(prefixes.into_iter().map(Into::into));
self
}
/// Generates code for all registered types where `prefixes` contains a prefix of
/// the fully-qualified path of the type
pub fn build<S: AsRef<str>>(&mut self, prefixes: &[S]) -> Result<()> {
let mut output: PathBuf = self.out_dir.clone().map(Ok).unwrap_or_else(|| {
std::env::var_os("OUT_DIR")
.ok_or_else(|| {
Error::new(ErrorKind::Other, "OUT_DIR environment variable is not set")
})
.map(Into::into)
})?;
output.push("FILENAME");
let write_factory = move |package: &Package| {
output.set_file_name(format!("{}.serde.rs", package));
let file = std::fs::OpenOptions::new()
.write(true)
.truncate(true)
.create(true)
.open(&output)?;
Ok(BufWriter::new(file))
};
let writers = self.generate(prefixes, write_factory)?;
for (_, mut writer) in writers {
writer.flush()?;
}
Ok(())
}
fn generate<S: AsRef<str>, W: Write, F: FnMut(&Package) -> Result<W>>(
&self,
prefixes: &[S],
mut write_factory: F,
) -> Result<Vec<(Package, W)>> {
let config = Config {
extern_types: Default::default(),
};
let iter = self.descriptors.iter().filter(move |(t, _)| {
let exclude = self
.exclude
.iter()
.any(|prefix| t.matches_prefix(prefix.as_ref()));
let include = prefixes
.iter()
.any(|prefix| t.matches_prefix(prefix.as_ref()));
include && !exclude
});
// Exploit the fact descriptors is ordered to group together types from the same package
let mut ret: Vec<(Package, W)> = Vec::new();
for (type_path, descriptor) in iter {
let writer = match ret.last_mut() {
Some((package, writer)) if package == type_path.package() => writer,
_ => {
let package = type_path.package();
ret.push((package.clone(), write_factory(package)?));
&mut ret.last_mut().unwrap().1
}
};
match descriptor {
Descriptor::Enum(descriptor) => {
generate_enum(&config, type_path, descriptor, writer)?
}
Descriptor::Message(descriptor) => {
if let Some(message) = resolve_message(&self.descriptors, descriptor) {
generate_message(&config, &message, writer)?
}
}
}
}
Ok(ret)
}
}

View File

@ -1,275 +0,0 @@
//! The raw descriptor format is not very easy to work with, a fact not aided
//! by prost making almost all members of proto2 syntax message optional
//!
//! This module therefore extracts a slightly less obtuse representation of a
//! message that can be used by the code generation logic
use prost_types::{
field_descriptor_proto::{Label, Type},
FieldDescriptorProto,
};
use crate::descriptor::{Descriptor, DescriptorSet, MessageDescriptor, Syntax, TypeName, TypePath};
use crate::escape::escape_ident;
#[derive(Debug, Clone, Copy)]
pub enum ScalarType {
F64,
F32,
I32,
I64,
U32,
U64,
Bool,
String,
Bytes,
}
impl ScalarType {
pub fn rust_type(&self) -> &'static str {
match self {
ScalarType::F64 => "f64",
ScalarType::F32 => "f32",
ScalarType::I32 => "i32",
ScalarType::I64 => "i64",
ScalarType::U32 => "u32",
ScalarType::U64 => "u64",
ScalarType::Bool => "bool",
ScalarType::String => "String",
ScalarType::Bytes => "Vec<u8>",
}
}
pub fn is_numeric(&self) -> bool {
matches!(
self,
ScalarType::F64
| ScalarType::F32
| ScalarType::I32
| ScalarType::I64
| ScalarType::U32
| ScalarType::U64
)
}
}
#[derive(Debug, Clone)]
pub enum FieldType {
Scalar(ScalarType),
Enum(TypePath),
Message(TypePath),
Map(ScalarType, Box<FieldType>),
}
#[derive(Debug, Clone, Copy)]
pub enum FieldModifier {
Required,
Optional,
UseDefault,
Repeated,
}
impl FieldModifier {
pub fn is_required(&self) -> bool {
matches!(self, Self::Required)
}
}
#[derive(Debug, Clone)]
pub struct Field {
pub name: String,
pub json_name: Option<String>,
pub field_modifier: FieldModifier,
pub field_type: FieldType,
}
impl Field {
pub fn rust_type_name(&self) -> String {
use heck::CamelCase;
self.name.to_camel_case()
}
pub fn rust_field_name(&self) -> String {
use heck::SnakeCase;
escape_ident(self.name.to_snake_case())
}
pub fn json_name(&self) -> String {
use heck::MixedCase;
self.json_name
.clone()
.unwrap_or_else(|| self.name.to_mixed_case())
}
}
#[derive(Debug, Clone)]
pub struct OneOf {
pub name: String,
pub path: TypePath,
pub fields: Vec<Field>,
}
impl OneOf {
pub fn rust_field_name(&self) -> String {
use heck::SnakeCase;
escape_ident(self.name.to_snake_case())
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub path: TypePath,
pub fields: Vec<Field>,
pub one_ofs: Vec<OneOf>,
}
impl Message {
pub fn all_fields(&self) -> impl Iterator<Item = &Field> + '_ {
self.fields
.iter()
.chain(self.one_ofs.iter().flat_map(|one_of| one_of.fields.iter()))
}
}
/// Resolve the provided message descriptor into a slightly less obtuse representation
///
/// Returns None if the provided provided message is auto-generated
pub fn resolve_message(
descriptors: &DescriptorSet,
message: &MessageDescriptor,
) -> Option<Message> {
if message.is_map() {
return None;
}
let mut fields = Vec::new();
let mut one_of_fields = vec![Vec::new(); message.one_of.len()];
for field in &message.fields {
let field_type = field_type(descriptors, field);
let field_modifier = field_modifier(message, field, &field_type);
let resolved = Field {
name: field.name.clone().expect("expected field to have name"),
json_name: field.json_name.clone(),
field_type,
field_modifier,
};
// Treat synthetic one-of as normal
let proto3_optional = field.proto3_optional.unwrap_or(false);
match (field.oneof_index, proto3_optional) {
(Some(idx), false) => one_of_fields[idx as usize].push(resolved),
_ => fields.push(resolved),
}
}
let mut one_ofs = Vec::new();
for (fields, descriptor) in one_of_fields.into_iter().zip(&message.one_of) {
// Might be empty in the event of a synthetic one-of
if !fields.is_empty() {
let name = descriptor.name.clone().expect("oneof with no name");
let path = message.path.child(TypeName::new(&name));
one_ofs.push(OneOf { name, path, fields })
}
}
Some(Message {
path: message.path.clone(),
fields,
one_ofs,
})
}
fn field_modifier(
message: &MessageDescriptor,
field: &FieldDescriptorProto,
field_type: &FieldType,
) -> FieldModifier {
let label = Label::from_i32(field.label.expect("expected label")).expect("valid label");
if field.proto3_optional.unwrap_or(false) {
assert_eq!(label, Label::Optional);
return FieldModifier::Optional;
}
if field.oneof_index.is_some() {
assert_eq!(label, Label::Optional);
return FieldModifier::Optional;
}
if matches!(field_type, FieldType::Map(_, _)) {
assert_eq!(label, Label::Repeated);
return FieldModifier::Repeated;
}
match label {
Label::Optional => match message.syntax {
Syntax::Proto2 => FieldModifier::Optional,
Syntax::Proto3 => match field_type {
FieldType::Message(_) => FieldModifier::Optional,
_ => FieldModifier::UseDefault,
},
},
Label::Required => FieldModifier::Required,
Label::Repeated => FieldModifier::Repeated,
}
}
fn field_type(descriptors: &DescriptorSet, field: &FieldDescriptorProto) -> FieldType {
match field.type_name.as_ref() {
Some(type_name) => resolve_type(descriptors, type_name.as_str()),
None => {
let scalar =
match Type::from_i32(field.r#type.expect("expected type")).expect("valid type") {
Type::Double => ScalarType::F64,
Type::Float => ScalarType::F32,
Type::Int64 | Type::Sfixed64 | Type::Sint64 => ScalarType::I64,
Type::Int32 | Type::Sfixed32 | Type::Sint32 => ScalarType::I32,
Type::Uint64 | Type::Fixed64 => ScalarType::U64,
Type::Uint32 | Type::Fixed32 => ScalarType::U32,
Type::Bool => ScalarType::Bool,
Type::String => ScalarType::String,
Type::Bytes => ScalarType::Bytes,
Type::Message | Type::Enum | Type::Group => panic!("no type name specified"),
};
FieldType::Scalar(scalar)
}
}
}
fn resolve_type(descriptors: &DescriptorSet, type_name: &str) -> FieldType {
assert!(
type_name.starts_with('.'),
"pbjson does not currently support resolving relative types"
);
let maybe_descriptor = descriptors
.iter()
.find(|(path, _)| path.matches_prefix(type_name));
match maybe_descriptor {
Some((path, Descriptor::Enum(_))) => FieldType::Enum(path.clone()),
Some((path, Descriptor::Message(descriptor))) => match descriptor.is_map() {
true => {
assert_eq!(descriptor.fields.len(), 2, "expected map to have 2 fields");
let key = &descriptor.fields[0];
let value = &descriptor.fields[1];
assert_eq!("key", key.name());
assert_eq!("value", value.name());
let key_type = match field_type(descriptors, key) {
FieldType::Scalar(scalar) => scalar,
_ => panic!("non scalar map key"),
};
let value_type = field_type(descriptors, value);
FieldType::Map(key_type, Box::new(value_type))
}
// Note: This may actually be a group but it is non-trivial to detect this,
// they're deprecated, and pbjson doesn't need to be able to distinguish
false => FieldType::Message(path.clone()),
},
None => panic!("failed to resolve type: {}", type_name),
}
}

View File

@ -1,18 +0,0 @@
[package]
name = "pbjson_test"
version = "0.1.0"
authors = ["Raphael Taylor-Davies <r.taylordavies@googlemail.com>"]
edition = "2018"
description = "Test resources for pbjson converion"
[dependencies]
prost = "0.8"
pbjson = { path = "../pbjson" }
serde = { version = "1.0", features = ["derive"] }
[dev-dependencies]
serde_json = "1.0"
[build-dependencies]
prost-build = "0.8"
pbjson_build = { path = "../pbjson_build" }

View File

@ -1,33 +0,0 @@
//! Compiles Protocol Buffers definitions into native Rust types
use std::env;
use std::path::PathBuf;
type Error = Box<dyn std::error::Error>;
type Result<T, E = Error> = std::result::Result<T, E>;
fn main() -> Result<()> {
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("protos");
let proto_files = vec![root.join("syntax3.proto")];
// Tell cargo to recompile if any of these proto files are changed
for proto_file in &proto_files {
println!("cargo:rerun-if-changed={}", proto_file.display());
}
let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin");
prost_build::Config::new()
.file_descriptor_set_path(&descriptor_path)
.compile_well_known_types()
.disable_comments(&["."])
.bytes(&[".test"])
.compile_protos(&proto_files, &[root])?;
let descriptor_set = std::fs::read(descriptor_path)?;
pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)?
.build(&[".test"])?;
Ok(())
}

View File

@ -1,72 +0,0 @@
syntax = "proto3";
package test.syntax3;
message Empty {}
message KitchenSink {
// Standard enum
enum Value {
VALUE_UNKNOWN = 0;
VALUE_A = 45;
VALUE_B = 63;
}
// An enumeration without prefixed variants
enum Prefix {
UNKNOWN = 0;
A = 66;
B = 20;
}
int32 i32 = 1;
optional int32 optional_i32 = 2;
repeated int32 repeated_i32 = 3;
uint32 u32 = 4;
optional uint32 optional_u32 = 5;
repeated uint32 repeated_u32 = 6;
int64 i64 = 7;
optional int64 optional_i64 = 8;
repeated int64 repeated_i64 = 9;
uint64 u64 = 10;
optional uint64 optional_u64 = 11;
repeated uint64 repeated_u64 = 12;
Value value = 13;
optional Value optional_value = 14;
repeated Value repeated_value = 15;
Prefix prefix = 16;
Empty empty = 17;
map<string, string> string_dict = 18;
map<string, Empty> message_dict = 19;
map<string, Prefix> enum_dict = 20;
map<int64, Prefix> int64_dict = 21;
map<int32, Prefix> int32_dict = 22;
map<int32, uint64> integer_dict = 23;
bool bool = 24;
optional bool optional_bool = 25;
repeated bool repeated_bool = 26;
oneof one_of {
int32 one_of_i32 = 27;
bool one_of_bool = 28;
Value one_of_value = 29;
Empty one_of_message = 30;
}
bytes bytes = 31;
optional bytes optional_bytes = 32;
repeated bytes repeated_bytes = 33;
// Bytes support is currently broken - https://github.com/tokio-rs/prost/issues/531
// map<string, bytes> bytes_dict = 34;
string string = 35;
optional string optional_string = 36;
}

View File

@ -1,241 +0,0 @@
include!(concat!(env!("OUT_DIR"), "/test.syntax3.rs"));
include!(concat!(env!("OUT_DIR"), "/test.syntax3.serde.rs"));
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty() {
let message = Empty {};
let encoded = serde_json::to_string(&message).unwrap();
let _decoded: Empty = serde_json::from_str(&encoded).unwrap();
let err = serde_json::from_str::<Empty>("343").unwrap_err();
assert_eq!(
err.to_string().as_str(),
"invalid type: integer `343`, expected struct test.syntax3.Empty at line 1 column 3"
);
let err = serde_json::from_str::<Empty>("{\"foo\": \"bar\"}").unwrap_err();
assert_eq!(
err.to_string().as_str(),
"unknown field `foo`, there are no fields at line 1 column 6"
);
}
#[test]
fn test_kitchen_sink() {
let mut decoded: KitchenSink = serde_json::from_str("{}").unwrap();
let verify_encode = |decoded: &KitchenSink, expected: &str| {
assert_eq!(serde_json::to_string(&decoded).unwrap().as_str(), expected);
};
let verify_decode = |decoded: &KitchenSink, expected: &str| {
assert_eq!(decoded, &serde_json::from_str(expected).unwrap());
};
let verify = |decoded: &KitchenSink, expected: &str| {
verify_encode(decoded, expected);
verify_decode(decoded, expected);
};
verify(&decoded, "{}");
decoded.i32 = 24;
verify(&decoded, r#"{"i32":24}"#);
decoded.i32 = 0;
verify_decode(&decoded, "{}");
// Explicit optional fields can distinguish between no value and default value
decoded.optional_i32 = Some(2);
verify(&decoded, r#"{"optionalI32":2}"#);
decoded.optional_i32 = Some(0);
verify(&decoded, r#"{"optionalI32":0}"#);
// Can also decode from string
verify_decode(&decoded, r#"{"optionalI32":"0"}"#);
decoded.optional_i32 = None;
verify_decode(&decoded, "{}");
// 64-bit integers are encoded as strings
decoded.i64 = 123125;
verify(&decoded, r#"{"i64":"123125"}"#);
decoded.i64 = 0;
verify_decode(&decoded, "{}");
decoded.optional_i64 = Some(532);
verify(&decoded, r#"{"optionalI64":"532"}"#);
decoded.optional_i64 = Some(0);
verify(&decoded, r#"{"optionalI64":"0"}"#);
// Can also decode from non-string
verify_decode(&decoded, r#"{"optionalI64":0}"#);
decoded.optional_i64 = None;
verify_decode(&decoded, "{}");
decoded.u64 = 34346;
decoded.u32 = 567094456;
decoded.optional_u32 = Some(0);
decoded.optional_u64 = Some(3);
verify(
&decoded,
r#"{"u32":567094456,"optionalU32":0,"u64":"34346","optionalU64":"3"}"#,
);
decoded.u64 = 0;
decoded.u32 = 0;
decoded.optional_u32 = None;
decoded.optional_u64 = None;
verify_decode(&decoded, "{}");
decoded.repeated_i32 = vec![0, 23, 5, 6, 2, 34];
verify(&decoded, r#"{"repeatedI32":[0,23,5,6,2,34]}"#);
// Can also mix in some strings
verify_decode(&decoded, r#"{"repeatedI32":[0,"23",5,6,"2",34]}"#);
decoded.repeated_i32 = vec![];
verify_decode(&decoded, "{}");
decoded.repeated_u64 = vec![0, 532, 2];
verify(&decoded, r#"{"repeatedU64":["0","532","2"]}"#);
// Can also mix in some non-strings
verify_decode(&decoded, r#"{"repeatedU64":["0",532,"2"]}"#);
decoded.repeated_u64 = vec![];
verify_decode(&decoded, "{}");
// Enumerations should be encoded as strings
decoded.value = kitchen_sink::Value::A as i32;
verify(&decoded, r#"{"value":"VALUE_A"}"#);
// Can also use variant number
verify_decode(&decoded, r#"{"value":45}"#);
decoded.value = kitchen_sink::Value::Unknown as i32;
verify_decode(&decoded, "{}");
decoded.optional_value = Some(kitchen_sink::Value::Unknown as i32);
verify(&decoded, r#"{"optionalValue":"VALUE_UNKNOWN"}"#);
// Can also use variant number
verify_decode(&decoded, r#"{"optionalValue":0}"#);
decoded.optional_value = None;
verify_decode(&decoded, "{}");
decoded
.string_dict
.insert("foo".to_string(), "bar".to_string());
verify(&decoded, r#"{"stringDict":{"foo":"bar"}}"#);
decoded.string_dict = Default::default();
verify_decode(&decoded, "{}");
decoded
.int32_dict
.insert(343, kitchen_sink::Prefix::A as i32);
// Dictionary keys should always be strings
// Enum dictionary values should be encoded as strings
verify(&decoded, r#"{"int32Dict":{"343":"A"}}"#);
// Enum dictionary values can be decoded from integers
verify_decode(&decoded, r#"{"int32Dict":{"343":66}}"#);
decoded.int32_dict = Default::default();
verify_decode(&decoded, "{}");
// 64-bit dictionary values should be encoded as strings
decoded.integer_dict.insert(12, 13);
verify(&decoded, r#"{"integerDict":{"12":"13"}}"#);
// 64-bit dictionary values can be decoded from numeric types
verify_decode(&decoded, r#"{"integerDict":{"12":13}}"#);
decoded.integer_dict = Default::default();
verify_decode(&decoded, "{}");
decoded.one_of = Some(kitchen_sink::OneOf::OneOfI32(0));
verify(&decoded, r#"{"oneOfI32":0}"#);
// Can also specify string
verify_decode(&decoded, r#"{"oneOfI32":"0"}"#);
decoded.one_of = Some(kitchen_sink::OneOf::OneOfI32(12));
verify(&decoded, r#"{"oneOfI32":12}"#);
decoded.one_of = Some(kitchen_sink::OneOf::OneOfBool(false));
verify(&decoded, r#"{"oneOfBool":false}"#);
decoded.one_of = Some(kitchen_sink::OneOf::OneOfBool(true));
verify(&decoded, r#"{"oneOfBool":true}"#);
decoded.one_of = Some(kitchen_sink::OneOf::OneOfValue(
kitchen_sink::Value::B as i32,
));
verify(&decoded, r#"{"oneOfValue":"VALUE_B"}"#);
// Can also specify enum variant
verify_decode(&decoded, r#"{"oneOfValue":63}"#);
decoded.one_of = None;
verify_decode(&decoded, "{}");
decoded.repeated_value = vec![
kitchen_sink::Value::B as i32,
kitchen_sink::Value::B as i32,
kitchen_sink::Value::A as i32,
];
verify(
&decoded,
r#"{"repeatedValue":["VALUE_B","VALUE_B","VALUE_A"]}"#,
);
verify_decode(&decoded, r#"{"repeatedValue":[63,"VALUE_B","VALUE_A"]}"#);
decoded.repeated_value = Default::default();
verify_decode(&decoded, "{}");
decoded.bytes = prost::bytes::Bytes::from_static(b"kjkjkj");
verify(&decoded, r#"{"bytes":"a2pramtq"}"#);
decoded.bytes = Default::default();
verify_decode(&decoded, "{}");
decoded.optional_bytes = Some(prost::bytes::Bytes::from_static(b"kjkjkj"));
verify(&decoded, r#"{"optionalBytes":"a2pramtq"}"#);
decoded.optional_bytes = Some(Default::default());
verify(&decoded, r#"{"optionalBytes":""}"#);
decoded.optional_bytes = None;
verify_decode(&decoded, "{}");
decoded.repeated_bytes = vec![
prost::bytes::Bytes::from_static(b"sdfsd"),
prost::bytes::Bytes::from_static(b"fghfg"),
];
verify(&decoded, r#"{"repeatedBytes":["c2Rmc2Q=","ZmdoZmc="]}"#);
decoded.repeated_bytes = Default::default();
verify_decode(&decoded, "{}");
// decoded.bytes_dict.insert(
// "test".to_string(),
// prost::bytes::Bytes::from_static(b"asdf"),
// );
// verify(&decoded, r#"{"bytesDict":{"test":"YXNkZgo="}}"#);
//
// decoded.bytes_dict = Default::default();
// verify_decode(&decoded, "{}");
decoded.string = "test".to_string();
verify(&decoded, r#"{"string":"test"}"#);
decoded.string = Default::default();
verify_decode(&decoded, "{}");
decoded.optional_string = Some(String::new());
verify(&decoded, r#"{"optionalString":""}"#);
}
}

View File

@ -16,8 +16,6 @@ mod heappy;
#[cfg(feature = "pprof")]
mod pprof;
mod tower;
mod metrics;
// Influx crates
@ -38,13 +36,13 @@ use futures::{self, StreamExt};
use http::header::{CONTENT_ENCODING, CONTENT_TYPE};
use hyper::{http::HeaderValue, Body, Method, Request, Response, StatusCode};
use observability_deps::tracing::{debug, error};
use routerify::{prelude::*, Middleware, RequestInfo, Router, RouterError};
use serde::Deserialize;
use snafu::{OptionExt, ResultExt, Snafu};
use trace_http::ctx::TraceHeaderParser;
use crate::influxdb_ioxd::http::metrics::LineProtocolMetrics;
use hyper::server::conn::AddrIncoming;
use hyper::server::conn::{AddrIncoming, AddrStream};
use std::convert::Infallible;
use std::num::NonZeroI32;
use std::{
fmt::Debug,
@ -52,7 +50,9 @@ use std::{
sync::Arc,
};
use tokio_util::sync::CancellationToken;
use tower::Layer;
use trace::TraceCollector;
use trace_http::tower::TraceLayer;
/// Constants used in API error codes.
///
@ -350,79 +350,56 @@ impl From<server::Error> for ApplicationError {
}
}
#[derive(Debug)]
struct Server<M>
where
M: ConnectionManager + Send + Sync + Debug + 'static,
{
application: Arc<ApplicationState>,
app_server: Arc<AppServer<M>>,
lp_metrics: Arc<LineProtocolMetrics>,
max_request_size: usize,
}
fn router<M>(
application: Arc<ApplicationState>,
app_server: Arc<AppServer<M>>,
max_request_size: usize,
) -> Router<Body, ApplicationError>
async fn route_request<M>(
server: Arc<Server<M>>,
mut req: Request<Body>,
) -> Result<Response<Body>, Infallible>
where
M: ConnectionManager + Send + Sync + Debug + 'static,
{
let server = Server {
app_server,
max_request_size,
lp_metrics: Arc::new(LineProtocolMetrics::new(
application.metric_registry().as_ref(),
)),
};
// we don't need the authorization header and we don't want to accidentally log it.
req.headers_mut().remove("authorization");
debug!(request = ?req,"Processing request");
// Create a router and specify the the handlers.
Router::builder()
.data(server)
.data(application)
.middleware(Middleware::pre(|mut req| async move {
// we don't need the authorization header and we don't want to accidentally log it.
req.headers_mut().remove("authorization");
debug!(request = ?req,"Processing request");
Ok(req)
}))
.middleware(Middleware::post(|res| async move {
debug!(response = ?res, "Successfully processed request");
Ok(res)
})) // this endpoint is for API backward compatibility with InfluxDB 2.x
.post("/api/v2/write", write::<M>)
.get("/health", health::<M>)
.get("/metrics", handle_metrics::<M>)
.get("/api/v3/query", query::<M>)
.get("/debug/pprof", pprof_home::<M>)
.get("/debug/pprof/profile", pprof_profile::<M>)
.get("/debug/pprof/allocs", pprof_heappy_profile::<M>)
// Specify the error handler to handle any errors caused by
// a route or any middleware.
.err_handler_with_info(error_handler)
.build()
.unwrap()
}
// The API-global error handler, handles ApplicationErrors originating from
// individual routes and middlewares, along with errors from the router itself
async fn error_handler(err: RouterError<ApplicationError>, req: RequestInfo) -> Response<Body> {
let method = req.method().clone();
let uri = req.uri().clone();
let span_id = req.headers().get("x-b3-spanid");
let content_length = req.headers().get("content-length");
error!(error = ?err, error_message = ?err.to_string(), method = ?method, uri = ?uri, ?span_id, ?content_length, "Error while handling request");
let content_length = req.headers().get("content-length").cloned();
match err {
RouterError::HandleRequest(e, _)
| RouterError::HandlePreMiddlewareRequest(e)
| RouterError::HandlePostMiddlewareWithInfoRequest(e)
| RouterError::HandlePostMiddlewareWithoutInfoRequest(e) => e.response(),
_ => {
let json = serde_json::json!({"error": err.to_string()}).to_string();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(json))
.unwrap()
let response = match (method.clone(), uri.path()) {
(Method::GET, "/health") => health(),
(Method::GET, "/metrics") => handle_metrics(server.application.as_ref()),
(Method::POST, "/api/v2/write") => write(req, server.as_ref()).await,
(Method::GET, "/api/v3/query") => query(req, server.as_ref()).await,
(Method::GET, "/debug/pprof") => pprof_home(req).await,
(Method::GET, "/debug/pprof/profile") => pprof_profile(req).await,
(Method::GET, "/debug/pprof/allocs") => pprof_heappy_profile(req).await,
(method, path) => Err(ApplicationError::RouteNotFound {
method,
path: path.to_string(),
}),
};
// TODO: Move logging to TraceLayer
match response {
Ok(response) => {
debug!(?response, "Successfully processed request");
Ok(response)
}
Err(error) => {
error!(%error, %method, %uri, ?content_length, "Error while handling request");
Ok(error.response())
}
}
}
@ -486,7 +463,10 @@ async fn parse_body(req: hyper::Request<Body>, max_size: usize) -> Result<Bytes,
}
}
async fn write<M>(req: Request<Body>) -> Result<Response<Body>, ApplicationError>
async fn write<M>(
req: Request<Body>,
server: &Server<M>,
) -> Result<Response<Body>, ApplicationError>
where
M: ConnectionManager + Send + Sync + Debug + 'static,
{
@ -494,7 +474,8 @@ where
app_server: server,
lp_metrics,
max_request_size,
} = req.data::<Server<M>>().expect("server state");
..
} = server;
let max_request_size = *max_request_size;
let server = Arc::clone(server);
@ -575,8 +556,9 @@ fn default_format() -> String {
async fn query<M: ConnectionManager + Send + Sync + Debug + 'static>(
req: Request<Body>,
server: &Server<M>,
) -> Result<Response<Body>, ApplicationError> {
let server = Arc::clone(&req.data::<Server<M>>().expect("server state").app_server);
let server = &server.app_server;
let uri_query = req.uri().query().context(ExpectedQueryString {})?;
@ -617,20 +599,12 @@ async fn query<M: ConnectionManager + Send + Sync + Debug + 'static>(
Ok(response)
}
async fn health<M: ConnectionManager + Send + Sync + Debug + 'static>(
_req: Request<Body>,
) -> Result<Response<Body>, ApplicationError> {
fn health() -> Result<Response<Body>, ApplicationError> {
let response_body = "OK";
Ok(Response::new(Body::from(response_body.to_string())))
}
async fn handle_metrics<M: ConnectionManager + Send + Sync + Debug + 'static>(
req: Request<Body>,
) -> Result<Response<Body>, ApplicationError> {
let application = req
.data::<Arc<ApplicationState>>()
.expect("application state");
fn handle_metrics(application: &ApplicationState) -> Result<Response<Body>, ApplicationError> {
let mut body: Vec<u8> = Default::default();
let mut reporter = metric_exporters::PrometheusTextEncoder::new(&mut body);
application.metric_registry().report(&mut reporter);
@ -638,18 +612,7 @@ async fn handle_metrics<M: ConnectionManager + Send + Sync + Debug + 'static>(
Ok(Response::new(Body::from(body)))
}
#[derive(Deserialize, Debug)]
/// Arguments in the query string of the request to /snapshot
struct SnapshotInfo {
org: String,
bucket: String,
partition: String,
table_name: String,
}
async fn pprof_home<M: ConnectionManager + Send + Sync + Debug + 'static>(
req: Request<Body>,
) -> Result<Response<Body>, ApplicationError> {
async fn pprof_home(req: Request<Body>) -> Result<Response<Body>, ApplicationError> {
let default_host = HeaderValue::from_static("localhost");
let host = req
.headers()
@ -715,9 +678,7 @@ impl PProfAllocsArgs {
}
#[cfg(feature = "pprof")]
async fn pprof_profile<M: ConnectionManager + Send + Sync + Debug + 'static>(
req: Request<Body>,
) -> Result<Response<Body>, ApplicationError> {
async fn pprof_profile(req: Request<Body>) -> Result<Response<Body>, ApplicationError> {
use ::pprof::protos::Message;
let query_string = req.uri().query().unwrap_or_default();
let query: PProfArgs =
@ -758,17 +719,13 @@ async fn pprof_profile<M: ConnectionManager + Send + Sync + Debug + 'static>(
}
#[cfg(not(feature = "pprof"))]
async fn pprof_profile<M: ConnectionManager + Send + Sync + Debug + 'static>(
_req: Request<Body>,
) -> Result<Response<Body>, ApplicationError> {
async fn pprof_profile(_req: Request<Body>) -> Result<Response<Body>, ApplicationError> {
PProfIsNotCompiled {}.fail()
}
// If heappy support is enabled, call it
#[cfg(feature = "heappy")]
async fn pprof_heappy_profile<M: ConnectionManager + Send + Sync + Debug + 'static>(
req: Request<Body>,
) -> Result<Response<Body>, ApplicationError> {
async fn pprof_heappy_profile(req: Request<Body>) -> Result<Response<Body>, ApplicationError> {
let query_string = req.uri().query().unwrap_or_default();
let query: PProfAllocsArgs =
serde_urlencoded::from_str(query_string).context(InvalidQueryString { query_string })?;
@ -802,16 +759,14 @@ async fn pprof_heappy_profile<M: ConnectionManager + Send + Sync + Debug + 'stat
// Return error if heappy not enabled
#[cfg(not(feature = "heappy"))]
async fn pprof_heappy_profile<M: ConnectionManager + Send + Sync + Debug + 'static>(
_req: Request<Body>,
) -> Result<Response<Body>, ApplicationError> {
async fn pprof_heappy_profile(_req: Request<Body>) -> Result<Response<Body>, ApplicationError> {
HeappyIsNotCompiled {}.fail()
}
pub async fn serve<M>(
addr: AddrIncoming,
application: Arc<ApplicationState>,
server: Arc<AppServer<M>>,
app_server: Arc<AppServer<M>>,
shutdown: CancellationToken,
max_request_size: usize,
trace_header_parser: TraceHeaderParser,
@ -821,16 +776,29 @@ where
M: ConnectionManager + Send + Sync + Debug + 'static,
{
let metric_registry = Arc::clone(application.metric_registry());
let router = router(application, server, max_request_size);
let new_service = tower::MakeService::new(
router,
trace_header_parser,
trace_collector,
metric_registry,
);
let trace_layer = TraceLayer::new(trace_header_parser, metric_registry, trace_collector, false);
let lp_metrics = Arc::new(LineProtocolMetrics::new(
application.metric_registry().as_ref(),
));
let server = Arc::new(Server {
application,
app_server,
lp_metrics,
max_request_size,
});
hyper::Server::builder(addr)
.serve(new_service)
.serve(hyper::service::make_service_fn(|_conn: &AddrStream| {
let server = Arc::clone(&server);
let service = hyper::service::service_fn(move |request: Request<_>| {
route_request(Arc::clone(&server), request)
});
let service = trace_layer.layer(service);
futures::future::ready(Ok::<_, Infallible>(service))
}))
.with_graceful_shutdown(shutdown.cancelled())
.await
}

View File

@ -1,74 +0,0 @@
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::future::BoxFuture;
use futures::{ready, FutureExt};
use hyper::server::conn::AddrStream;
use hyper::Body;
use routerify::{RequestService, Router, RouterService};
use trace::TraceCollector;
use trace_http::ctx::TraceHeaderParser;
use trace_http::tower::{TraceLayer, TraceService};
use super::ApplicationError;
use tower::Layer;
/// `MakeService` can be thought of as a hyper-compatible connection factory
///
/// Specifically it implements the necessary trait to be used with `hyper::server::Builder::serve`
pub struct MakeService {
inner: RouterService<Body, ApplicationError>,
trace_layer: trace_http::tower::TraceLayer,
}
impl MakeService {
pub fn new(
router: Router<Body, ApplicationError>,
trace_header_parser: TraceHeaderParser,
collector: Option<Arc<dyn TraceCollector>>,
metric_registry: Arc<metric::Registry>,
) -> Self {
Self {
inner: RouterService::new(router).unwrap(),
trace_layer: TraceLayer::new(trace_header_parser, metric_registry, collector, false),
}
}
}
impl tower::Service<&AddrStream> for MakeService {
type Response = Service;
type Error = Infallible;
type Future = MakeServiceFuture;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, conn: &AddrStream) -> Self::Future {
MakeServiceFuture {
inner: self.inner.call(conn),
trace_layer: self.trace_layer.clone(),
}
}
}
/// A future produced by `MakeService` that resolves to a `Service`
pub struct MakeServiceFuture {
inner: BoxFuture<'static, Result<RequestService<Body, ApplicationError>, Infallible>>,
trace_layer: trace_http::tower::TraceLayer,
}
impl Future for MakeServiceFuture {
type Output = Result<Service, Infallible>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let maybe_service = ready!(self.inner.poll_unpin(cx));
Poll::Ready(maybe_service.map(|service| self.trace_layer.layer(service)))
}
}
pub type Service = TraceService<RequestService<Body, ApplicationError>>;