feat: add pbjson support (#2468)

* feat: add pbjson support

* chore: fix test
pull/24376/head
Raphael Taylor-Davies 2021-09-16 08:33:27 +01:00 committed by GitHub
parent 185f45f56b
commit 1d55d9a1b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 2189 additions and 21 deletions

37
Cargo.lock generated
View File

@ -1318,13 +1318,13 @@ dependencies = [
"data_types",
"futures",
"google_types",
"influxdb_line_protocol",
"num_cpus",
"observability_deps",
"pbjson",
"pbjson_build",
"proc-macro2",
"prost",
"prost-build",
"prost-types",
"regex",
"serde",
"serde_json",
@ -1392,6 +1392,8 @@ version = "0.1.0"
dependencies = [
"bytes",
"chrono",
"pbjson",
"pbjson_build",
"prost",
"prost-build",
"serde",
@ -2896,6 +2898,37 @@ version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf547ad0c65e31259204bd90935776d1c693cec2f4ff7abb7a1bbbd40dfe58"
[[package]]
name = "pbjson"
version = "0.1.0"
dependencies = [
"serde",
]
[[package]]
name = "pbjson_build"
version = "0.1.0"
dependencies = [
"heck",
"itertools 0.10.1",
"pbjson_test",
"prost",
"prost-types",
"tempfile",
]
[[package]]
name = "pbjson_test"
version = "0.1.0"
dependencies = [
"pbjson",
"pbjson_build",
"prost",
"prost-build",
"serde",
"serde_json",
]
[[package]]
name = "peeking_take_while"
version = "0.1.2"

View File

@ -57,6 +57,9 @@ members = [
"observability_deps",
"packers",
"panic_logging",
"pbjson",
"pbjson_build",
"pbjson_test",
"persistence_windows",
"predicate",
"query",

View File

@ -7,16 +7,12 @@ edition = "2018"
[dependencies] # In alphabetical order
bytes = { version = "1.0", features = ["serde"] }
data_types = { path = "../data_types" }
# See docs/regenerating_flatbuffers.md about updating generated code when updating the
# version of the flatbuffers crate
#flatbuffers = "2"
futures = "0.3"
google_types = { path = "../google_types" }
influxdb_line_protocol = { path = "../influxdb_line_protocol" }
observability_deps = { path = "../observability_deps" }
num_cpus = "1.13.0"
observability_deps = { path = "../observability_deps" }
pbjson = { path = "../pbjson" }
prost = "0.8"
prost-types = "0.8"
regex = "1.4"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.67"
@ -31,3 +27,4 @@ chrono = { version = "0.4", features = ["serde"] }
proc-macro2 = "=1.0.27"
tonic-build = "0.5"
prost-build = "0.8"
pbjson_build = { path = "../pbjson_build" }

View File

@ -64,12 +64,6 @@ fn generate_grpc_types(root: &Path) -> Result<()> {
config
.compile_well_known_types()
.disable_comments(&[".google"])
// approximates jsonpb. This is still not enough to deal with the special cases like Any
// Tracking issue for proper jsonpb support in prost: https://github.com/danburkert/prost/issues/277
.type_attribute(
".",
"#[derive(serde::Serialize,serde::Deserialize)] #[serde(rename_all = \"camelCase\")]",
)
.extern_path(".google.protobuf", "::google_types::protobuf")
.bytes(&[".influxdata.iox.catalog.v1.AddParquet.metadata"]);
@ -79,5 +73,11 @@ fn generate_grpc_types(root: &Path) -> Result<()> {
.format(true)
.compile_with_config(config, &proto_files, &[root.into()])?;
let descriptor_set = std::fs::read(descriptor_path)?;
pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)?
.build(&[".influxdata", ".google.longrunning", ".google.rpc"])?;
Ok(())
}

View File

@ -5,10 +5,12 @@ pub use google_types::*;
pub mod rpc {
include!(concat!(env!("OUT_DIR"), "/google.rpc.rs"));
include!(concat!(env!("OUT_DIR"), "/google.rpc.serde.rs"));
}
pub mod longrunning {
include!(concat!(env!("OUT_DIR"), "/google.longrunning.rs"));
include!(concat!(env!("OUT_DIR"), "/google.longrunning.serde.rs"));
impl Operation {
/// Return the IOx operation `id`. This `id` can

View File

@ -10,6 +10,10 @@ pub mod influxdata {
pub mod platform {
pub mod storage {
include!(concat!(env!("OUT_DIR"), "/influxdata.platform.storage.rs"));
include!(concat!(
env!("OUT_DIR"),
"/influxdata.platform.storage.serde.rs"
));
// Can't implement `Default` because `prost::Message` implements `Default`
impl TimestampRange {
@ -27,6 +31,10 @@ pub mod influxdata {
pub mod catalog {
pub mod v1 {
include!(concat!(env!("OUT_DIR"), "/influxdata.iox.catalog.v1.rs"));
include!(concat!(
env!("OUT_DIR"),
"/influxdata.iox.catalog.v1.serde.rs"
));
}
}
@ -37,12 +45,20 @@ pub mod influxdata {
"influxdata.iox.management.v1.OperationMetadata";
include!(concat!(env!("OUT_DIR"), "/influxdata.iox.management.v1.rs"));
include!(concat!(
env!("OUT_DIR"),
"/influxdata.iox.management.v1.serde.rs"
));
}
}
pub mod write {
pub mod v1 {
include!(concat!(env!("OUT_DIR"), "/influxdata.iox.write.v1.rs"));
include!(concat!(
env!("OUT_DIR"),
"/influxdata.iox.write.v1.serde.rs"
));
}
}
}
@ -50,6 +66,7 @@ pub mod influxdata {
pub mod pbdata {
pub mod v1 {
include!(concat!(env!("OUT_DIR"), "/influxdata.pbdata.v1.rs"));
include!(concat!(env!("OUT_DIR"), "/influxdata.pbdata.v1.serde.rs"));
}
}
}

View File

@ -8,8 +8,10 @@ edition = "2018"
[dependencies] # In alphabetical order
bytes = { version = "1.0", features = ["serde"] }
chrono = "0.4"
pbjson = { path = "../pbjson" }
prost = "0.8"
serde = { version = "1.0", features = ["derive"] }
[build-dependencies] # In alphabetical order
prost-build = "0.8"
pbjson_build = { path = "../pbjson_build" }

View File

@ -1,6 +1,7 @@
//! Compiles Protocol Buffers and FlatBuffers schema definitions into
//! native Rust types.
use std::env;
use std::path::PathBuf;
type Error = Box<dyn std::error::Error>;
@ -16,16 +17,18 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed={}", proto_file.display());
}
let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin");
prost_build::Config::new()
.file_descriptor_set_path(&descriptor_path)
.compile_well_known_types()
.disable_comments(&["."])
// approximates jsonpb. This is still not enough to deal with the special cases like Any.
.type_attribute(
".google",
"#[derive(serde::Serialize,serde::Deserialize)] #[serde(rename_all = \"camelCase\")]",
)
.bytes(&[".google"])
.compile_protos(&proto_files, &[root])?;
let descriptor_set = std::fs::read(descriptor_path)?;
pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)?
.build(&[".google"])?;
Ok(())
}

View File

@ -17,6 +17,7 @@ mod pb {
use std::convert::{TryFrom, TryInto};
include!(concat!(env!("OUT_DIR"), "/google.protobuf.rs"));
include!(concat!(env!("OUT_DIR"), "/google.protobuf.serde.rs"));
impl TryFrom<Duration> for std::time::Duration {
type Error = std::num::TryFromIntError;

12
pbjson/Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "pbjson"
version = "0.1.0"
authors = ["Raphael Taylor-Davies <r.taylordavies@googlemail.com>"]
edition = "2018"
description = "Utilities for pbjson converion"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
[dev-dependencies]

41
pbjson/src/lib.rs Normal file
View File

@ -0,0 +1,41 @@
#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
#![warn(
missing_debug_implementations,
clippy::explicit_iter_loop,
clippy::use_self,
clippy::clone_on_ref_ptr,
clippy::future_not_send
)]
#[doc(hidden)]
pub mod private {
use std::str::FromStr;
/// Used to parse a number from either a string or its raw representation
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
pub struct NumberDeserialize<T>(pub T);
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Content<'a, T> {
Str(&'a str),
Number(T),
}
impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
where
T: FromStr + serde::Deserialize<'de>,
<T as FromStr>::Err: std::error::Error,
{
#[allow(deprecated)]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let content = Content::deserialize(deserializer)?;
Ok(Self(match content {
Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
Content::Number(v) => v,
}))
}
}
}

17
pbjson_build/Cargo.toml Normal file
View File

@ -0,0 +1,17 @@
[package]
name = "pbjson_build"
version = "0.1.0"
authors = ["Raphael Taylor-Davies <r.taylordavies@googlemail.com>"]
edition = "2018"
description = "Generates Serialize and Deserialize implementations for prost message types"
[dependencies]
heck = "0.3"
prost = "0.8"
prost-types = "0.8"
itertools = "0.10"
[dev-dependencies]
tempfile = "3.1"
pbjson_test = { path = "../pbjson_test" }

View File

@ -0,0 +1,260 @@
//! This module contains code to parse and extract the protobuf descriptor
//! format for use by the rest of the codebase
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::fmt::{Display, Formatter};
use std::io::{Error, ErrorKind, Result};
use itertools::{EitherOrBoth, Itertools};
use prost_types::{
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
FileDescriptorSet, MessageOptions, OneofDescriptorProto,
};
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct Package(String);
impl Display for Package {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl Package {
pub fn new(s: impl Into<String>) -> Self {
let s = s.into();
assert!(
!s.starts_with('.'),
"package cannot start with \'.\', got \"{}\"",
s
);
Self(s)
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct TypeName(String);
impl Display for TypeName {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl TypeName {
pub fn new(s: impl Into<String>) -> Self {
let s = s.into();
assert!(
!s.contains('.'),
"type name cannot contain \'.\', got \"{}\"",
s
);
Self(s)
}
pub fn to_snake_case(&self) -> String {
use heck::SnakeCase;
self.0.to_snake_case()
}
pub fn to_camel_case(&self) -> String {
use heck::CamelCase;
self.0.to_camel_case()
}
pub fn to_shouty_snake_case(&self) -> String {
use heck::ShoutySnakeCase;
self.0.to_shouty_snake_case()
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct TypePath {
package: Package,
path: Vec<TypeName>,
}
impl Display for TypePath {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.package.fmt(f)?;
for element in &self.path {
write!(f, ".{}", element)?;
}
Ok(())
}
}
impl TypePath {
pub fn new(package: Package) -> Self {
Self {
package,
path: Default::default(),
}
}
pub fn package(&self) -> &Package {
&self.package
}
pub fn path(&self) -> &[TypeName] {
self.path.as_slice()
}
pub fn child(&self, name: TypeName) -> Self {
let path = self
.path
.iter()
.cloned()
.chain(std::iter::once(name))
.collect();
Self {
package: self.package.clone(),
path,
}
}
pub fn matches_prefix(&self, prefix: &str) -> bool {
let prefix = match prefix.strip_prefix('.') {
Some(prefix) => prefix,
None => return false,
};
if prefix.len() <= self.package.0.len() {
return self.package.0.starts_with(prefix);
}
match prefix.strip_prefix(&self.package.0) {
Some(prefix) => {
let split = prefix.split('.').skip(1);
for zipped in self.path.iter().zip_longest(split) {
match zipped {
EitherOrBoth::Both(a, b) if a.0.as_str() == b => continue,
EitherOrBoth::Left(_) => return true,
_ => return false,
}
}
true
}
None => false,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DescriptorSet {
descriptors: BTreeMap<TypePath, Descriptor>,
}
impl DescriptorSet {
pub fn new() -> Self {
Self::default()
}
pub fn register_encoded(&mut self, encoded: &[u8]) -> Result<()> {
let descriptors: FileDescriptorSet =
prost::Message::decode(encoded).map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
for file in descriptors.file {
let syntax = match file.syntax.as_deref() {
None | Some("proto2") => Syntax::Proto2,
Some("proto3") => Syntax::Proto3,
Some(s) => panic!("unknown syntax: {}", s),
};
let package = Package::new(file.package.expect("expected package"));
let path = TypePath::new(package);
for descriptor in file.message_type {
self.register_message(&path, descriptor, syntax)
}
for descriptor in file.enum_type {
self.register_enum(&path, descriptor)
}
}
Ok(())
}
pub fn iter(&self) -> impl Iterator<Item = (&TypePath, &Descriptor)> {
self.descriptors.iter()
}
fn register_message(&mut self, path: &TypePath, descriptor: DescriptorProto, syntax: Syntax) {
let name = TypeName::new(descriptor.name.expect("expected name"));
let child_path = path.child(name);
for child_descriptor in descriptor.enum_type {
self.register_enum(&child_path, child_descriptor)
}
for child_descriptor in descriptor.nested_type {
self.register_message(&child_path, child_descriptor, syntax)
}
self.register_descriptor(
child_path.clone(),
Descriptor::Message(MessageDescriptor {
path: child_path,
options: descriptor.options,
one_of: descriptor.oneof_decl,
fields: descriptor.field,
syntax,
}),
);
}
fn register_enum(&mut self, path: &TypePath, descriptor: EnumDescriptorProto) {
let name = TypeName::new(descriptor.name.expect("expected name"));
self.register_descriptor(
path.child(name),
Descriptor::Enum(EnumDescriptor {
values: descriptor.value,
}),
);
}
fn register_descriptor(&mut self, path: TypePath, descriptor: Descriptor) {
match self.descriptors.entry(path) {
Entry::Occupied(o) => panic!("descriptor already registered for {}", o.key()),
Entry::Vacant(v) => v.insert(descriptor),
};
}
}
#[derive(Debug, Clone, Copy)]
pub enum Syntax {
Proto2,
Proto3,
}
#[derive(Debug, Clone)]
pub enum Descriptor {
Enum(EnumDescriptor),
Message(MessageDescriptor),
}
#[derive(Debug, Clone)]
pub struct EnumDescriptor {
pub values: Vec<EnumValueDescriptorProto>,
}
#[derive(Debug, Clone)]
pub struct MessageDescriptor {
pub path: TypePath,
pub options: Option<MessageOptions>,
pub one_of: Vec<OneofDescriptorProto>,
pub fields: Vec<FieldDescriptorProto>,
pub syntax: Syntax,
}
impl MessageDescriptor {
/// Whether this is an auto-generated type for the map field
pub fn is_map(&self) -> bool {
self.options
.as_ref()
.and_then(|options| options.map_entry)
.unwrap_or(false)
}
}

View File

@ -0,0 +1,26 @@
///! Contains code to escape strings to avoid collisions with reserved Rust keywords
pub fn escape_ident(mut ident: String) -> String {
// Copied from prost-build::ident
//
// Use a raw identifier if the identifier matches a Rust keyword:
// https://doc.rust-lang.org/reference/keywords.html.
match ident.as_str() {
// 2015 strict keywords.
| "as" | "break" | "const" | "continue" | "else" | "enum" | "false"
| "fn" | "for" | "if" | "impl" | "in" | "let" | "loop" | "match" | "mod" | "move" | "mut"
| "pub" | "ref" | "return" | "static" | "struct" | "trait" | "true"
| "type" | "unsafe" | "use" | "where" | "while"
// 2018 strict keywords.
| "dyn"
// 2015 reserved keywords.
| "abstract" | "become" | "box" | "do" | "final" | "macro" | "override" | "priv" | "typeof"
| "unsized" | "virtual" | "yield"
// 2018 reserved keywords.
| "async" | "await" | "try" => ident.insert_str(0, "r#"),
// the following keywords are not supported as raw identifiers and are therefore suffixed with an underscore.
"self" | "super" | "extern" | "crate" => ident += "_",
_ => (),
};
ident
}

View File

@ -0,0 +1,129 @@
//! This module contains the actual code generation logic
use std::collections::BTreeMap;
use std::fmt::{Display, Formatter};
use std::io::{Result, Write};
use crate::descriptor::TypePath;
mod enumeration;
mod message;
pub use enumeration::generate_enum;
pub use message::generate_message;
#[derive(Debug, Clone, Copy)]
struct Indent(usize);
impl Display for Indent {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
for _ in 0..self.0 {
write!(f, " ")?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct Config {
pub extern_types: BTreeMap<TypePath, String>,
}
impl Config {
fn rust_type(&self, path: &TypePath) -> String {
if let Some(t) = self.extern_types.get(path) {
return t.clone();
}
let mut ret = String::new();
let path = path.path();
assert!(!path.is_empty(), "path cannot be empty");
for i in &path[..(path.len() - 1)] {
ret.push_str(i.to_snake_case().as_str());
ret.push_str("::");
}
ret.push_str(path.last().unwrap().to_camel_case().as_str());
ret
}
fn rust_variant(&self, enumeration: &TypePath, variant: &str) -> String {
use heck::CamelCase;
assert!(
variant
.chars()
.all(|c| matches!(c, '0'..='9' | 'A'..='Z' | '_')),
"illegal variant - {}",
variant
);
// TODO: Config to disable stripping prefix
let enumeration_name = enumeration.path().last().unwrap().to_shouty_snake_case();
let variant = match variant.strip_prefix(&enumeration_name) {
Some("") => variant,
Some(stripped) => stripped,
None => variant,
};
variant.to_camel_case()
}
}
fn write_fields_array<'a, W: Write, I: Iterator<Item = &'a str>>(
writer: &mut W,
indent: usize,
variants: I,
) -> Result<()> {
writeln!(writer, "{}const FIELDS: &[&str] = &[", Indent(indent))?;
for name in variants {
writeln!(writer, "{}\"{}\",", Indent(indent + 1), name)?;
}
writeln!(writer, "{}];", Indent(indent))?;
writeln!(writer)
}
fn write_serialize_start<W: Write>(indent: usize, rust_type: &str, writer: &mut W) -> Result<()> {
writeln!(
writer,
r#"{indent}impl serde::Serialize for {rust_type} {{
{indent} #[allow(deprecated)]
{indent} fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
{indent} where
{indent} S: serde::Serializer,
{indent} {{"#,
indent = Indent(indent),
rust_type = rust_type
)
}
fn write_serialize_end<W: Write>(indent: usize, writer: &mut W) -> Result<()> {
writeln!(
writer,
r#"{indent} }}
{indent}}}"#,
indent = Indent(indent),
)
}
fn write_deserialize_start<W: Write>(indent: usize, rust_type: &str, writer: &mut W) -> Result<()> {
writeln!(
writer,
r#"{indent}impl<'de> serde::Deserialize<'de> for {rust_type} {{
{indent} #[allow(deprecated)]
{indent} fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
{indent} where
{indent} D: serde::Deserializer<'de>,
{indent} {{"#,
indent = Indent(indent),
rust_type = rust_type
)
}
fn write_deserialize_end<W: Write>(indent: usize, writer: &mut W) -> Result<()> {
writeln!(
writer,
r#"{indent} }}
{indent}}}"#,
indent = Indent(indent),
)
}

View File

@ -0,0 +1,138 @@
//! This module contains the code to generate Serialize and Deserialize
//! implementations for enumeration type
//!
//! An enumeration should be decode-able from the full string variant name
//! or its integer tag number, and should encode to the string representation
use super::{
write_deserialize_end, write_deserialize_start, write_serialize_end, write_serialize_start,
Config, Indent,
};
use crate::descriptor::{EnumDescriptor, TypePath};
use crate::generator::write_fields_array;
use std::io::{Result, Write};
pub fn generate_enum<W: Write>(
config: &Config,
path: &TypePath,
descriptor: &EnumDescriptor,
writer: &mut W,
) -> Result<()> {
let rust_type = config.rust_type(path);
let variants: Vec<_> = descriptor
.values
.iter()
.map(|variant| {
let variant_name = variant.name.clone().unwrap();
let rust_variant = config.rust_variant(path, &variant_name);
(variant_name, rust_variant)
})
.collect();
// Generate Serialize
write_serialize_start(0, &rust_type, writer)?;
writeln!(writer, "{}let variant = match self {{", Indent(2))?;
for (variant_name, rust_variant) in &variants {
writeln!(
writer,
"{}Self::{} => \"{}\",",
Indent(3),
rust_variant,
variant_name
)?;
}
writeln!(writer, "{}}};", Indent(2))?;
writeln!(writer, "{}serializer.serialize_str(variant)", Indent(2))?;
write_serialize_end(0, writer)?;
// Generate Deserialize
write_deserialize_start(0, &rust_type, writer)?;
write_fields_array(writer, 2, variants.iter().map(|(name, _)| name.as_str()))?;
write_visitor(writer, 2, &rust_type, &variants)?;
// Use deserialize_any to allow users to provide integers or strings
writeln!(
writer,
"{}deserializer.deserialize_any(GeneratedVisitor)",
Indent(2)
)?;
write_deserialize_end(0, writer)?;
Ok(())
}
fn write_visitor<W: Write>(
writer: &mut W,
indent: usize,
rust_type: &str,
variants: &[(String, String)],
) -> Result<()> {
// Protobuf supports deserialization of enumerations both from string and integer values
writeln!(
writer,
r#"{indent}struct GeneratedVisitor;
{indent}impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {{
{indent} type Value = {rust_type};
{indent} fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
{indent} write!(formatter, "expected one of: {{:?}}", &FIELDS)
{indent} }}
{indent} fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
{indent} where
{indent} E: serde::de::Error,
{indent} {{
{indent} use std::convert::TryFrom;
{indent} i32::try_from(v)
{indent} .ok()
{indent} .and_then({rust_type}::from_i32)
{indent} .ok_or_else(|| {{
{indent} serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self)
{indent} }})
{indent} }}
{indent} fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
{indent} where
{indent} E: serde::de::Error,
{indent} {{
{indent} use std::convert::TryFrom;
{indent} i32::try_from(v)
{indent} .ok()
{indent} .and_then({rust_type}::from_i32)
{indent} .ok_or_else(|| {{
{indent} serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self)
{indent} }})
{indent} }}
{indent} fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
{indent} where
{indent} E: serde::de::Error,
{indent} {{"#,
indent = Indent(indent),
rust_type = rust_type,
)?;
writeln!(writer, "{}match value {{", Indent(indent + 2))?;
for (variant_name, rust_variant) in variants {
writeln!(
writer,
"{}\"{}\" => Ok({}::{}),",
Indent(indent + 3),
variant_name,
rust_type,
rust_variant
)?;
}
writeln!(
writer,
"{indent}_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),",
indent = Indent(indent + 3)
)?;
writeln!(writer, "{}}}", Indent(indent + 2))?;
writeln!(writer, "{}}}", Indent(indent + 1))?;
writeln!(writer, "{}}}", Indent(indent))
}

View File

@ -0,0 +1,754 @@
//! This module contains the code to generate Serialize and Deserialize
//! implementations for message types
//!
//! The implementation follows the proto3 [JSON mapping][1] with the default options
//!
//! Importantly:
//! - numeric types can be decoded from either a string or number
//! - 32-bit integers and floats are encoded as numbers
//! - 64-bit integers are encoded as strings
//! - repeated fields are encoded as arrays
//! - bytes are base64 encoded (NOT CURRENTLY SUPPORTED)
//! - messages and maps are encoded as objects
//! - fields are lowerCamelCase except where overridden by the proto definition
//! - default values are not emitted on encode
//! - unrecognised fields error on decode
//!
//! Note: This will not generate code to correctly serialize/deserialize well-known-types
//! such as google.protobuf.Any, google.protobuf.Duration, etc... conversions for these
//! special-cased messages will need to be manually implemented. Once done so, however,
//! any messages containing these types will serialize/deserialize correctly
//!
//! [1]: https://developers.google.com/protocol-buffers/docs/proto3#json
use std::io::{Result, Write};
use crate::message::{Field, FieldModifier, FieldType, Message, OneOf, ScalarType};
use super::{
write_deserialize_end, write_deserialize_start, write_serialize_end, write_serialize_start,
Config, Indent,
};
use crate::descriptor::TypePath;
use crate::generator::write_fields_array;
pub fn generate_message<W: Write>(
config: &Config,
message: &Message,
writer: &mut W,
) -> Result<()> {
let rust_type = config.rust_type(&message.path);
// Generate Serialize
write_serialize_start(0, &rust_type, writer)?;
write_message_serialize(config, 2, message, writer)?;
write_serialize_end(0, writer)?;
// Generate Deserialize
write_deserialize_start(0, &rust_type, writer)?;
write_deserialize_message(config, 2, message, &rust_type, writer)?;
write_deserialize_end(0, writer)?;
Ok(())
}
fn write_field_empty_predicate<W: Write>(member: &Field, writer: &mut W) -> Result<()> {
match (&member.field_type, &member.field_modifier) {
(_, FieldModifier::Required) => unreachable!(),
(_, FieldModifier::Repeated)
| (FieldType::Map(_, _), _)
| (FieldType::Scalar(ScalarType::String), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::Bytes), FieldModifier::UseDefault) => {
write!(writer, "!self.{}.is_empty()", member.rust_field_name())
}
(_, FieldModifier::Optional) | (FieldType::Message(_), _) => {
write!(writer, "self.{}.is_some()", member.rust_field_name())
}
(FieldType::Scalar(ScalarType::F64), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::F32), FieldModifier::UseDefault) => {
write!(writer, "self.{} != 0.", member.rust_field_name())
}
(FieldType::Scalar(ScalarType::Bool), FieldModifier::UseDefault) => {
write!(writer, "self.{}", member.rust_field_name())
}
(FieldType::Enum(_), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::I64), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::I32), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::U32), FieldModifier::UseDefault)
| (FieldType::Scalar(ScalarType::U64), FieldModifier::UseDefault) => {
write!(writer, "self.{} != 0", member.rust_field_name())
}
}
}
fn write_message_serialize<W: Write>(
config: &Config,
indent: usize,
message: &Message,
writer: &mut W,
) -> Result<()> {
write_struct_serialize_start(indent, message, writer)?;
for field in &message.fields {
write_serialize_field(config, indent, field, writer)?;
}
for one_of in &message.one_ofs {
write_serialize_one_of(indent, config, one_of, writer)?;
}
write_struct_serialize_end(indent, writer)
}
fn write_struct_serialize_start<W: Write>(
indent: usize,
message: &Message,
writer: &mut W,
) -> Result<()> {
writeln!(writer, "{}use serde::ser::SerializeStruct;", Indent(indent))?;
let required_len = message
.fields
.iter()
.filter(|member| member.field_modifier.is_required())
.count();
if required_len != message.fields.len() || !message.one_ofs.is_empty() {
writeln!(writer, "{}let mut len = {};", Indent(indent), required_len)?;
} else {
writeln!(writer, "{}let len = {};", Indent(indent), required_len)?;
}
for field in &message.fields {
if field.field_modifier.is_required() {
continue;
}
write!(writer, "{}if ", Indent(indent))?;
write_field_empty_predicate(field, writer)?;
writeln!(writer, " {{")?;
writeln!(writer, "{}len += 1;", Indent(indent + 1))?;
writeln!(writer, "{}}}", Indent(indent))?;
}
for one_of in &message.one_ofs {
writeln!(
writer,
"{}if self.{}.is_some() {{",
Indent(indent),
one_of.rust_field_name()
)?;
writeln!(writer, "{}len += 1;", Indent(indent + 1))?;
writeln!(writer, "{}}}", Indent(indent))?;
}
if !message.fields.is_empty() || !message.one_ofs.is_empty() {
writeln!(
writer,
"{}let mut struct_ser = serializer.serialize_struct(\"{}\", len)?;",
Indent(indent),
message.path
)?;
} else {
writeln!(
writer,
"{}let struct_ser = serializer.serialize_struct(\"{}\", len)?;",
Indent(indent),
message.path
)?;
}
Ok(())
}
fn write_struct_serialize_end<W: Write>(indent: usize, writer: &mut W) -> Result<()> {
writeln!(writer, "{}struct_ser.end()", Indent(indent))
}
fn write_decode_variant<W: Write>(
config: &Config,
indent: usize,
value: &str,
path: &TypePath,
writer: &mut W,
) -> Result<()> {
writeln!(writer, "{}::from_i32({})", config.rust_type(path), value)?;
write!(
writer,
"{}.ok_or_else(|| serde::ser::Error::custom(format!(\"Invalid variant {{}}\", {})))",
Indent(indent),
value
)
}
/// Depending on the type of the field different ways of accessing field's value
/// are needed - this allows decoupling the type serialization logic from the logic
/// that manipulates its container e.g. Vec, Option, HashMap
struct Variable<'a> {
/// A reference to the field's value
as_ref: &'a str,
/// The field's value
as_unref: &'a str,
/// The field without any leading "&" or "*"
raw: &'a str,
}
fn write_serialize_variable<W: Write>(
config: &Config,
indent: usize,
field: &Field,
variable: Variable<'_>,
writer: &mut W,
) -> Result<()> {
match &field.field_type {
FieldType::Scalar(ScalarType::I64) | FieldType::Scalar(ScalarType::U64) => {
match field.field_modifier {
FieldModifier::Repeated => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", &{}.iter().map(ToString::to_string).collect::<Vec<_>>())?;",
Indent(indent),
field.json_name(),
variable.raw
)
}
_ => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", {}.to_string().as_str())?;",
Indent(indent),
field.json_name(),
variable.raw
)
}
}
}
FieldType::Enum(path) => {
write!(writer, "{}let v = ", Indent(indent))?;
match field.field_modifier {
FieldModifier::Repeated => {
writeln!(writer, "{}.iter().cloned().map(|v| {{", variable.raw)?;
write!(writer, "{}", Indent(indent + 1))?;
write_decode_variant(config, indent + 2, "v", path, writer)?;
writeln!(writer)?;
write!(
writer,
"{}}}).collect::<Result<Vec<_>, _>>()",
Indent(indent + 1)
)
}
_ => write_decode_variant(config, indent + 1, variable.as_unref, path, writer),
}?;
writeln!(writer, "?;")?;
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", &v)?;",
Indent(indent),
field.json_name()
)
}
FieldType::Map(_, value_type)
if matches!(
value_type.as_ref(),
FieldType::Scalar(ScalarType::I64)
| FieldType::Scalar(ScalarType::U64)
| FieldType::Enum(_)
) =>
{
writeln!(
writer,
"{}let v: std::collections::HashMap<_, _> = {}.iter()",
Indent(indent),
variable.raw
)?;
match value_type.as_ref() {
FieldType::Scalar(ScalarType::I64) | FieldType::Scalar(ScalarType::U64) => {
writeln!(
writer,
"{}.map(|(k, v)| (k, v.to_string())).collect();",
Indent(indent + 1)
)?;
}
FieldType::Enum(path) => {
writeln!(writer, "{}.map(|(k, v)| {{", Indent(indent + 1))?;
write!(writer, "{}let v = ", Indent(indent + 2))?;
write_decode_variant(config, indent + 3, "*v", path, writer)?;
writeln!(writer, "?;")?;
writeln!(writer, "{}Ok((k, v))", Indent(indent + 2))?;
writeln!(
writer,
"{}}}).collect::<Result<_,_>>()?;",
Indent(indent + 1)
)?;
}
_ => unreachable!(),
}
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", &v)?;",
Indent(indent),
field.json_name()
)
}
_ => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", {})?;",
Indent(indent),
field.json_name(),
variable.as_ref
)
}
}
}
fn write_serialize_field<W: Write>(
config: &Config,
indent: usize,
field: &Field,
writer: &mut W,
) -> Result<()> {
let as_ref = format!("&self.{}", field.rust_field_name());
let variable = Variable {
as_ref: as_ref.as_str(),
as_unref: &as_ref.as_str()[1..],
raw: &as_ref.as_str()[1..],
};
match &field.field_modifier {
FieldModifier::Required => {
write_serialize_variable(config, indent, field, variable, writer)?;
}
FieldModifier::Optional => {
writeln!(
writer,
"{}if let Some(v) = {}.as_ref() {{",
Indent(indent),
variable.as_unref
)?;
let variable = Variable {
as_ref: "v",
as_unref: "*v",
raw: "v",
};
write_serialize_variable(config, indent + 1, field, variable, writer)?;
writeln!(writer, "{}}}", Indent(indent))?;
}
FieldModifier::Repeated | FieldModifier::UseDefault => {
write!(writer, "{}if ", Indent(indent))?;
write_field_empty_predicate(field, writer)?;
writeln!(writer, " {{")?;
write_serialize_variable(config, indent + 1, field, variable, writer)?;
writeln!(writer, "{}}}", Indent(indent))?;
}
}
Ok(())
}
fn write_serialize_one_of<W: Write>(
indent: usize,
config: &Config,
one_of: &OneOf,
writer: &mut W,
) -> Result<()> {
writeln!(
writer,
"{}if let Some(v) = self.{}.as_ref() {{",
Indent(indent),
one_of.rust_field_name()
)?;
writeln!(writer, "{}match v {{", Indent(indent + 1))?;
for field in &one_of.fields {
writeln!(
writer,
"{}{}::{}(v) => {{",
Indent(indent + 2),
config.rust_type(&one_of.path),
field.rust_type_name(),
)?;
let variable = Variable {
as_ref: "v",
as_unref: "*v",
raw: "v",
};
write_serialize_variable(config, indent + 3, field, variable, writer)?;
writeln!(writer, "{}}}", Indent(indent + 2))?;
}
writeln!(writer, "{}}}", Indent(indent + 1),)?;
writeln!(writer, "{}}}", Indent(indent))
}
fn write_deserialize_message<W: Write>(
config: &Config,
indent: usize,
message: &Message,
rust_type: &str,
writer: &mut W,
) -> Result<()> {
write_deserialize_field_name(2, message, writer)?;
writeln!(writer, "{}struct GeneratedVisitor;", Indent(indent))?;
writeln!(
writer,
r#"{indent}impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {{
{indent} type Value = {rust_type};
{indent} fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
{indent} formatter.write_str("struct {name}")
{indent} }}
{indent} fn visit_map<V>(self, mut map: V) -> Result<{rust_type}, V::Error>
{indent} where
{indent} V: serde::de::MapAccess<'de>,
{indent} {{"#,
indent = Indent(indent),
name = message.path,
rust_type = rust_type,
)?;
for field in &message.fields {
writeln!(
writer,
"{}let mut {} = None;",
Indent(indent + 2),
field.rust_field_name(),
)?;
}
for one_of in &message.one_ofs {
writeln!(
writer,
"{}let mut {} = None;",
Indent(indent + 2),
one_of.rust_field_name(),
)?;
}
if !message.fields.is_empty() || !message.one_ofs.is_empty() {
writeln!(
writer,
"{}while let Some(k) = map.next_key()? {{",
Indent(indent + 2)
)?;
writeln!(writer, "{}match k {{", Indent(indent + 3))?;
for field in &message.fields {
write_deserialize_field(config, indent + 4, field, None, writer)?;
}
for one_of in &message.one_ofs {
for field in &one_of.fields {
write_deserialize_field(config, indent + 4, field, Some(one_of), writer)?;
}
}
writeln!(writer, "{}}}", Indent(indent + 3))?;
writeln!(writer, "{}}}", Indent(indent + 2))?;
} else {
writeln!(
writer,
"{}while map.next_key::<GeneratedField>()?.is_some() {{}}",
Indent(indent + 2)
)?;
}
writeln!(writer, "{}Ok({} {{", Indent(indent + 2), rust_type)?;
for field in &message.fields {
match field.field_modifier {
FieldModifier::Required => {
writeln!(
writer,
"{indent}{field}: {field}.ok_or_else(|| serde::de::Error::missing_field(\"{json_name}\"))?,",
indent=Indent(indent + 3),
field= field.rust_field_name(),
json_name= field.json_name()
)?;
}
FieldModifier::UseDefault | FieldModifier::Repeated => {
// Note: this currently does not hydrate optional proto2 fields with defaults
writeln!(
writer,
"{indent}{field}: {field}.unwrap_or_default(),",
indent = Indent(indent + 3),
field = field.rust_field_name()
)?;
}
_ => {
writeln!(
writer,
"{indent}{field},",
indent = Indent(indent + 3),
field = field.rust_field_name()
)?;
}
}
}
for one_of in &message.one_ofs {
writeln!(
writer,
"{indent}{field},",
indent = Indent(indent + 3),
field = one_of.rust_field_name(),
)?;
}
writeln!(writer, "{}}})", Indent(indent + 2))?;
writeln!(writer, "{}}}", Indent(indent + 1))?;
writeln!(writer, "{}}}", Indent(indent))?;
writeln!(
writer,
"{}deserializer.deserialize_struct(\"{}\", FIELDS, GeneratedVisitor)",
Indent(indent),
message.path
)
}
fn write_deserialize_field_name<W: Write>(
indent: usize,
message: &Message,
writer: &mut W,
) -> Result<()> {
let fields: Vec<_> = message
.all_fields()
.map(|field| (field.json_name(), field.rust_type_name()))
.collect();
write_fields_array(writer, indent, fields.iter().map(|(name, _)| name.as_str()))?;
write_fields_enum(writer, indent, fields.iter().map(|(_, name)| name.as_str()))?;
writeln!(
writer,
r#"{indent}impl<'de> serde::Deserialize<'de> for GeneratedField {{
{indent} fn deserialize<D>(deserializer: D) -> Result<GeneratedField, D::Error>
{indent} where
{indent} D: serde::Deserializer<'de>,
{indent} {{
{indent} struct GeneratedVisitor;
{indent} impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {{
{indent} type Value = GeneratedField;
{indent} fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
{indent} write!(formatter, "expected one of: {{:?}}", &FIELDS)
{indent} }}
{indent} fn visit_str<E>(self, value: &str) -> Result<GeneratedField, E>
{indent} where
{indent} E: serde::de::Error,
{indent} {{"#,
indent = Indent(indent)
)?;
if !fields.is_empty() {
writeln!(writer, "{}match value {{", Indent(indent + 4))?;
for (json_name, type_name) in &fields {
writeln!(
writer,
"{}\"{}\" => Ok(GeneratedField::{}),",
Indent(indent + 5),
json_name,
type_name
)?;
}
writeln!(
writer,
"{}_ => Err(serde::de::Error::unknown_field(value, FIELDS)),",
Indent(indent + 5)
)?;
writeln!(writer, "{}}}", Indent(indent + 4))?;
} else {
writeln!(
writer,
"{}Err(serde::de::Error::unknown_field(value, FIELDS))",
Indent(indent + 4)
)?;
}
writeln!(
writer,
r#"{indent} }}
{indent} }}
{indent} deserializer.deserialize_identifier(GeneratedVisitor)
{indent} }}
{indent}}}"#,
indent = Indent(indent)
)
}
fn write_fields_enum<'a, W: Write, I: Iterator<Item = &'a str>>(
writer: &mut W,
indent: usize,
fields: I,
) -> Result<()> {
writeln!(
writer,
"{}#[allow(clippy::enum_variant_names)]",
Indent(indent)
)?;
writeln!(writer, "{}enum GeneratedField {{", Indent(indent))?;
for type_name in fields {
writeln!(writer, "{}{},", Indent(indent + 1), type_name)?;
}
writeln!(writer, "{}}}", Indent(indent))
}
fn write_deserialize_field<W: Write>(
config: &Config,
indent: usize,
field: &Field,
one_of: Option<&OneOf>,
writer: &mut W,
) -> Result<()> {
let field_name = match one_of {
Some(one_of) => one_of.rust_field_name(),
None => field.rust_field_name(),
};
let json_name = field.json_name();
writeln!(
writer,
"{}GeneratedField::{} => {{",
Indent(indent),
field.rust_type_name()
)?;
writeln!(
writer,
"{}if {}.is_some() {{",
Indent(indent + 1),
field_name
)?;
// Note: this will report duplicate field if multiple value are specified for a one of
writeln!(
writer,
"{}return Err(serde::de::Error::duplicate_field(\"{}\"));",
Indent(indent + 2),
json_name
)?;
writeln!(writer, "{}}}", Indent(indent + 1))?;
write!(writer, "{}{} = Some(", Indent(indent + 1), field_name)?;
if let Some(one_of) = one_of {
write!(
writer,
"{}::{}(",
config.rust_type(&one_of.path),
field.rust_type_name()
)?;
}
match &field.field_type {
FieldType::Scalar(scalar) if scalar.is_numeric() => {
writeln!(writer)?;
match field.field_modifier {
FieldModifier::Repeated => {
writeln!(
writer,
"{}map.next_value::<Vec<::pbjson::private::NumberDeserialize<{}>>>()?",
Indent(indent + 2),
scalar.rust_type()
)?;
writeln!(
writer,
"{}.into_iter().map(|x| x.0).collect()",
Indent(indent + 3)
)?;
}
_ => {
writeln!(
writer,
"{}map.next_value::<::pbjson::private::NumberDeserialize<{}>>()?.0",
Indent(indent + 2),
scalar.rust_type()
)?;
}
}
write!(writer, "{}", Indent(indent + 1))?;
}
FieldType::Enum(path) => match field.field_modifier {
FieldModifier::Repeated => {
write!(
writer,
"map.next_value::<Vec<{}>>()?.into_iter().map(|x| x as i32).collect()",
config.rust_type(path)
)?;
}
_ => {
write!(
writer,
"map.next_value::<{}>()? as i32",
config.rust_type(path)
)?;
}
},
FieldType::Map(key, value) => {
writeln!(writer)?;
write!(
writer,
"{}map.next_value::<std::collections::HashMap<",
Indent(indent + 2),
)?;
let map_k = match key.is_numeric() {
true => {
write!(
writer,
"::pbjson::private::NumberDeserialize<{}>",
key.rust_type()
)?;
"k.0"
}
false => {
write!(writer, "_")?;
"k"
}
};
write!(writer, ", ")?;
let map_v = match value.as_ref() {
FieldType::Scalar(scalar) if scalar.is_numeric() => {
write!(
writer,
"::pbjson::private::NumberDeserialize<{}>",
scalar.rust_type()
)?;
"v.0"
}
FieldType::Enum(path) => {
write!(writer, "{}", config.rust_type(path))?;
"v as i32"
}
FieldType::Map(_, _) => panic!("protobuf disallows nested maps"),
_ => {
write!(writer, "_")?;
"v"
}
};
writeln!(writer, ">>()?")?;
if map_k != "k" || map_v != "v" {
writeln!(
writer,
"{}.into_iter().map(|(k,v)| ({}, {})).collect()",
Indent(indent + 3),
map_k,
map_v,
)?;
}
write!(writer, "{}", Indent(indent + 1))?;
}
_ => {
write!(writer, "map.next_value()?",)?;
}
};
if one_of.is_some() {
write!(writer, ")")?;
}
writeln!(writer, ");")?;
writeln!(writer, "{}}}", Indent(indent))
}

113
pbjson_build/src/lib.rs Normal file
View File

@ -0,0 +1,113 @@
#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
#![warn(
missing_debug_implementations,
clippy::explicit_iter_loop,
clippy::use_self,
clippy::clone_on_ref_ptr,
clippy::future_not_send
)]
use crate::descriptor::{Descriptor, DescriptorSet, Package};
use crate::generator::{generate_enum, generate_message, Config};
use crate::message::resolve_message;
use std::io::{BufWriter, Error, ErrorKind, Result, Write};
use std::path::PathBuf;
mod descriptor;
mod escape;
mod generator;
mod message;
#[derive(Debug, Default)]
pub struct Builder {
descriptors: descriptor::DescriptorSet,
out_dir: Option<PathBuf>,
}
impl Builder {
/// Create a new `Builder`
pub fn new() -> Self {
Self {
descriptors: DescriptorSet::new(),
out_dir: None,
}
}
/// Register an encoded `FileDescriptorSet` with this `Builder`
pub fn register_descriptors(&mut self, descriptors: &[u8]) -> Result<&mut Self> {
self.descriptors.register_encoded(descriptors)?;
Ok(self)
}
/// Generates code for all registered types where `prefixes` contains a prefix of
/// the fully-qualified path of the type
pub fn build<S: AsRef<str>>(&mut self, prefixes: &[S]) -> Result<()> {
let mut output: PathBuf = self.out_dir.clone().map(Ok).unwrap_or_else(|| {
std::env::var_os("OUT_DIR")
.ok_or_else(|| {
Error::new(ErrorKind::Other, "OUT_DIR environment variable is not set")
})
.map(Into::into)
})?;
output.push("FILENAME");
let write_factory = move |package: &Package| {
output.set_file_name(format!("{}.serde.rs", package));
let file = std::fs::OpenOptions::new()
.write(true)
.truncate(true)
.create(true)
.open(&output)?;
Ok(BufWriter::new(file))
};
let writers = generate(&self.descriptors, prefixes, write_factory)?;
for (_, mut writer) in writers {
writer.flush()?;
}
Ok(())
}
}
fn generate<S: AsRef<str>, W: Write, F: FnMut(&Package) -> Result<W>>(
descriptors: &DescriptorSet,
prefixes: &[S],
mut write_factory: F,
) -> Result<Vec<(Package, W)>> {
let config = Config {
extern_types: Default::default(),
};
let iter = descriptors.iter().filter(move |(t, _)| {
prefixes
.iter()
.any(|prefix| t.matches_prefix(prefix.as_ref()))
});
// Exploit the fact descriptors is ordered to group together types from the same package
let mut ret: Vec<(Package, W)> = Vec::new();
for (type_path, descriptor) in iter {
let writer = match ret.last_mut() {
Some((package, writer)) if package == type_path.package() => writer,
_ => {
let package = type_path.package();
ret.push((package.clone(), write_factory(package)?));
&mut ret.last_mut().unwrap().1
}
};
match descriptor {
Descriptor::Enum(descriptor) => generate_enum(&config, type_path, descriptor, writer)?,
Descriptor::Message(descriptor) => {
if let Some(message) = resolve_message(descriptors, descriptor) {
generate_message(&config, &message, writer)?
}
}
}
}
Ok(ret)
}

275
pbjson_build/src/message.rs Normal file
View File

@ -0,0 +1,275 @@
//! The raw descriptor format is not very easy to work with, a fact not aided
//! by prost making almost all members of proto2 syntax message optional
//!
//! This module therefore extracts a slightly less obtuse representation of a
//! message that can be used by the code generation logic
use prost_types::{
field_descriptor_proto::{Label, Type},
FieldDescriptorProto,
};
use crate::descriptor::{Descriptor, DescriptorSet, MessageDescriptor, Syntax, TypeName, TypePath};
use crate::escape::escape_ident;
#[derive(Debug, Clone)]
pub enum ScalarType {
F64,
F32,
I32,
I64,
U32,
U64,
Bool,
String,
Bytes,
}
impl ScalarType {
pub fn rust_type(&self) -> &'static str {
match self {
ScalarType::F64 => "f64",
ScalarType::F32 => "f32",
ScalarType::I32 => "i32",
ScalarType::I64 => "i64",
ScalarType::U32 => "u32",
ScalarType::U64 => "u64",
ScalarType::Bool => "bool",
ScalarType::String => "String",
ScalarType::Bytes => "Vec<u8>",
}
}
pub fn is_numeric(&self) -> bool {
matches!(
self,
ScalarType::F64
| ScalarType::F32
| ScalarType::I32
| ScalarType::I64
| ScalarType::U32
| ScalarType::U64
)
}
}
#[derive(Debug, Clone)]
pub enum FieldType {
Scalar(ScalarType),
Enum(TypePath),
Message(TypePath),
Map(ScalarType, Box<FieldType>),
}
#[derive(Debug, Clone)]
pub enum FieldModifier {
Required,
Optional,
UseDefault,
Repeated,
}
impl FieldModifier {
pub fn is_required(&self) -> bool {
matches!(self, Self::Required)
}
}
#[derive(Debug, Clone)]
pub struct Field {
pub name: String,
pub json_name: Option<String>,
pub field_modifier: FieldModifier,
pub field_type: FieldType,
}
impl Field {
pub fn rust_type_name(&self) -> String {
use heck::CamelCase;
self.name.to_camel_case()
}
pub fn rust_field_name(&self) -> String {
use heck::SnakeCase;
escape_ident(self.name.to_snake_case())
}
pub fn json_name(&self) -> String {
use heck::MixedCase;
self.json_name
.clone()
.unwrap_or_else(|| self.name.to_mixed_case())
}
}
#[derive(Debug, Clone)]
pub struct OneOf {
pub name: String,
pub path: TypePath,
pub fields: Vec<Field>,
}
impl OneOf {
pub fn rust_field_name(&self) -> String {
use heck::SnakeCase;
escape_ident(self.name.to_snake_case())
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub path: TypePath,
pub fields: Vec<Field>,
pub one_ofs: Vec<OneOf>,
}
impl Message {
pub fn all_fields(&self) -> impl Iterator<Item = &Field> + '_ {
self.fields
.iter()
.chain(self.one_ofs.iter().flat_map(|one_of| one_of.fields.iter()))
}
}
/// Resolve the provided message descriptor into a slightly less obtuse representation
///
/// Returns None if the provided provided message is auto-generated
pub fn resolve_message(
descriptors: &DescriptorSet,
message: &MessageDescriptor,
) -> Option<Message> {
if message.is_map() {
return None;
}
let mut fields = Vec::new();
let mut one_of_fields = vec![Vec::new(); message.one_of.len()];
for field in &message.fields {
let field_type = field_type(descriptors, field);
let field_modifier = field_modifier(message, field, &field_type);
let resolved = Field {
name: field.name.clone().expect("expected field to have name"),
json_name: field.json_name.clone(),
field_type,
field_modifier,
};
// Treat synthetic one-of as normal
let proto3_optional = field.proto3_optional.unwrap_or(false);
match (field.oneof_index, proto3_optional) {
(Some(idx), false) => one_of_fields[idx as usize].push(resolved),
_ => fields.push(resolved),
}
}
let mut one_ofs = Vec::new();
for (fields, descriptor) in one_of_fields.into_iter().zip(&message.one_of) {
// Might be empty in the event of a synthetic one-of
if !fields.is_empty() {
let name = descriptor.name.clone().expect("oneof with no name");
let path = message.path.child(TypeName::new(&name));
one_ofs.push(OneOf { name, path, fields })
}
}
Some(Message {
path: message.path.clone(),
fields,
one_ofs,
})
}
fn field_modifier(
message: &MessageDescriptor,
field: &FieldDescriptorProto,
field_type: &FieldType,
) -> FieldModifier {
let label = Label::from_i32(field.label.expect("expected label")).expect("valid label");
if field.proto3_optional.unwrap_or(false) {
assert_eq!(label, Label::Optional);
return FieldModifier::Optional;
}
if field.oneof_index.is_some() {
assert_eq!(label, Label::Optional);
return FieldModifier::Optional;
}
if matches!(field_type, FieldType::Map(_, _)) {
assert_eq!(label, Label::Repeated);
return FieldModifier::Repeated;
}
match label {
Label::Optional => match message.syntax {
Syntax::Proto2 => FieldModifier::Optional,
Syntax::Proto3 => match field_type {
FieldType::Message(_) => FieldModifier::Optional,
_ => FieldModifier::UseDefault,
},
},
Label::Required => FieldModifier::Required,
Label::Repeated => FieldModifier::Repeated,
}
}
fn field_type(descriptors: &DescriptorSet, field: &FieldDescriptorProto) -> FieldType {
match field.type_name.as_ref() {
Some(type_name) => resolve_type(descriptors, type_name.as_str()),
None => {
let scalar =
match Type::from_i32(field.r#type.expect("expected type")).expect("valid type") {
Type::Double => ScalarType::F64,
Type::Float => ScalarType::F32,
Type::Int64 | Type::Sfixed64 | Type::Sint64 => ScalarType::I64,
Type::Int32 | Type::Sfixed32 | Type::Sint32 => ScalarType::I32,
Type::Uint64 | Type::Fixed64 => ScalarType::U64,
Type::Uint32 | Type::Fixed32 => ScalarType::U32,
Type::Bool => ScalarType::Bool,
Type::String => ScalarType::String,
Type::Bytes => ScalarType::Bytes,
Type::Message | Type::Enum | Type::Group => panic!("no type name specified"),
};
FieldType::Scalar(scalar)
}
}
}
fn resolve_type(descriptors: &DescriptorSet, type_name: &str) -> FieldType {
assert!(
type_name.starts_with('.'),
"pbjson does not currently support resolving relative types"
);
let maybe_descriptor = descriptors
.iter()
.find(|(path, _)| path.matches_prefix(type_name));
match maybe_descriptor {
Some((path, Descriptor::Enum(_))) => FieldType::Enum(path.clone()),
Some((path, Descriptor::Message(descriptor))) => match descriptor.is_map() {
true => {
assert_eq!(descriptor.fields.len(), 2, "expected map to have 2 fields");
let key = &descriptor.fields[0];
let value = &descriptor.fields[1];
assert_eq!("key", key.name());
assert_eq!("value", value.name());
let key_type = match field_type(descriptors, key) {
FieldType::Scalar(scalar) => scalar,
_ => panic!("non scalar map key"),
};
let value_type = field_type(descriptors, value);
FieldType::Map(key_type, Box::new(value_type))
}
// Note: This may actually be a group but it is non-trivial to detect this,
// they're deprecated, and pbjson doesn't need to be able to distinguish
false => FieldType::Message(path.clone()),
},
None => panic!("failed to resolve type: {}", type_name),
}
}

18
pbjson_test/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "pbjson_test"
version = "0.1.0"
authors = ["Raphael Taylor-Davies <r.taylordavies@googlemail.com>"]
edition = "2018"
description = "Test resources for pbjson converion"
[dependencies]
prost = "0.8"
pbjson = { path = "../pbjson" }
serde = { version = "1.0", features = ["derive"] }
[dev-dependencies]
serde_json = "1.0"
[build-dependencies]
prost-build = "0.8"
pbjson_build = { path = "../pbjson_build" }

33
pbjson_test/build.rs Normal file
View File

@ -0,0 +1,33 @@
//! Compiles Protocol Buffers definitions into native Rust types
use std::env;
use std::path::PathBuf;
type Error = Box<dyn std::error::Error>;
type Result<T, E = Error> = std::result::Result<T, E>;
fn main() -> Result<()> {
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("protos");
let proto_files = vec![root.join("syntax3.proto")];
// Tell cargo to recompile if any of these proto files are changed
for proto_file in &proto_files {
println!("cargo:rerun-if-changed={}", proto_file.display());
}
let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin");
prost_build::Config::new()
.file_descriptor_set_path(&descriptor_path)
.compile_well_known_types()
.disable_comments(&["."])
.bytes(&[".test"])
.compile_protos(&proto_files, &[root])?;
let descriptor_set = std::fs::read(descriptor_path)?;
pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)?
.build(&[".test"])?;
Ok(())
}

View File

@ -0,0 +1,70 @@
syntax = "proto3";
package test.syntax3;
message Empty {}
message KitchenSink {
// Standard enum
enum Value {
VALUE_UNKNOWN = 0;
VALUE_A = 45;
VALUE_B = 63;
}
// An enumeration without prefixed variants
enum Prefix {
UNKNOWN = 0;
A = 66;
B = 20;
}
int32 i32 = 1;
optional int32 optional_i32 = 2;
repeated int32 repeated_i32 = 3;
uint32 u32 = 4;
optional uint32 optional_u32 = 5;
repeated uint32 repeated_u32 = 6;
int64 i64 = 7;
optional int64 optional_i64 = 8;
repeated int64 repeated_i64 = 9;
uint64 u64 = 10;
optional uint64 optional_u64 = 11;
repeated uint64 repeated_u64 = 12;
Value value = 13;
optional Value optional_value = 14;
repeated Value repeated_value = 15;
Prefix prefix = 16;
Empty empty = 17;
map<string, string> string_dict = 18;
map<string, Empty> message_dict = 19;
map<string, Prefix> enum_dict = 20;
map<int64, Prefix> int64_dict = 21;
map<int32, Prefix> int32_dict = 22;
map<int32, uint64> integer_dict = 23;
bool bool = 24;
optional bool optional_bool = 25;
repeated bool repeated_bool = 26;
oneof one_of {
int32 one_of_i32 = 27;
bool one_of_bool = 28;
Value one_of_value = 29;
Empty one_of_message = 30;
}
// Bytes support is currently broken
// bytes bytes = 31;
// optional bytes optional_bytes = 32;
// map<string, bytes> bytes_dict = 35;
string string = 33;
optional string optional_string = 34;
}

224
pbjson_test/src/lib.rs Normal file
View File

@ -0,0 +1,224 @@
include!(concat!(env!("OUT_DIR"), "/test.syntax3.rs"));
include!(concat!(env!("OUT_DIR"), "/test.syntax3.serde.rs"));
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty() {
let message = Empty {};
let encoded = serde_json::to_string(&message).unwrap();
let _decoded: Empty = serde_json::from_str(&encoded).unwrap();
let err = serde_json::from_str::<Empty>("343").unwrap_err();
assert_eq!(
err.to_string().as_str(),
"invalid type: integer `343`, expected struct test.syntax3.Empty at line 1 column 3"
);
let err = serde_json::from_str::<Empty>("{\"foo\": \"bar\"}").unwrap_err();
assert_eq!(
err.to_string().as_str(),
"unknown field `foo`, there are no fields at line 1 column 6"
);
}
#[test]
fn test_kitchen_sink() {
let mut decoded: KitchenSink = serde_json::from_str("{}").unwrap();
let verify_encode = |decoded: &KitchenSink, expected: &str| {
assert_eq!(serde_json::to_string(&decoded).unwrap().as_str(), expected);
};
let verify_decode = |decoded: &KitchenSink, expected: &str| {
assert_eq!(decoded, &serde_json::from_str(expected).unwrap());
};
let verify = |decoded: &KitchenSink, expected: &str| {
verify_encode(decoded, expected);
verify_decode(decoded, expected);
};
verify(&decoded, "{}");
decoded.i32 = 24;
verify(&decoded, r#"{"i32":24}"#);
decoded.i32 = 0;
verify_decode(&decoded, "{}");
// Explicit optional fields can distinguish between no value and default value
decoded.optional_i32 = Some(2);
verify(&decoded, r#"{"optionalI32":2}"#);
decoded.optional_i32 = Some(0);
verify(&decoded, r#"{"optionalI32":0}"#);
// Can also decode from string
verify_decode(&decoded, r#"{"optionalI32":"0"}"#);
decoded.optional_i32 = None;
verify_decode(&decoded, "{}");
// 64-bit integers are encoded as strings
decoded.i64 = 123125;
verify(&decoded, r#"{"i64":"123125"}"#);
decoded.i64 = 0;
verify_decode(&decoded, "{}");
decoded.optional_i64 = Some(532);
verify(&decoded, r#"{"optionalI64":"532"}"#);
decoded.optional_i64 = Some(0);
verify(&decoded, r#"{"optionalI64":"0"}"#);
// Can also decode from non-string
verify_decode(&decoded, r#"{"optionalI64":0}"#);
decoded.optional_i64 = None;
verify_decode(&decoded, "{}");
decoded.u64 = 34346;
decoded.u32 = 567094456;
decoded.optional_u32 = Some(0);
decoded.optional_u64 = Some(3);
verify(
&decoded,
r#"{"u32":567094456,"optionalU32":0,"u64":"34346","optionalU64":"3"}"#,
);
decoded.u64 = 0;
decoded.u32 = 0;
decoded.optional_u32 = None;
decoded.optional_u64 = None;
verify_decode(&decoded, "{}");
decoded.repeated_i32 = vec![0, 23, 5, 6, 2, 34];
verify(&decoded, r#"{"repeatedI32":[0,23,5,6,2,34]}"#);
// Can also mix in some strings
verify_decode(&decoded, r#"{"repeatedI32":[0,"23",5,6,"2",34]}"#);
decoded.repeated_i32 = vec![];
verify_decode(&decoded, "{}");
decoded.repeated_u64 = vec![0, 532, 2];
verify(&decoded, r#"{"repeatedU64":["0","532","2"]}"#);
// Can also mix in some non-strings
verify_decode(&decoded, r#"{"repeatedU64":["0",532,"2"]}"#);
decoded.repeated_u64 = vec![];
verify_decode(&decoded, "{}");
// Enumerations should be encoded as strings
decoded.value = kitchen_sink::Value::A as i32;
verify(&decoded, r#"{"value":"VALUE_A"}"#);
// Can also use variant number
verify_decode(&decoded, r#"{"value":45}"#);
decoded.value = kitchen_sink::Value::Unknown as i32;
verify_decode(&decoded, "{}");
decoded.optional_value = Some(kitchen_sink::Value::Unknown as i32);
verify(&decoded, r#"{"optionalValue":"VALUE_UNKNOWN"}"#);
// Can also use variant number
verify_decode(&decoded, r#"{"optionalValue":0}"#);
decoded.optional_value = None;
verify_decode(&decoded, "{}");
decoded
.string_dict
.insert("foo".to_string(), "bar".to_string());
verify(&decoded, r#"{"stringDict":{"foo":"bar"}}"#);
decoded.string_dict = Default::default();
verify_decode(&decoded, "{}");
decoded
.int32_dict
.insert(343, kitchen_sink::Prefix::A as i32);
// Dictionary keys should always be strings
// Enum dictionary values should be encoded as strings
verify(&decoded, r#"{"int32Dict":{"343":"A"}}"#);
// Enum dictionary values can be decoded from integers
verify_decode(&decoded, r#"{"int32Dict":{"343":66}}"#);
decoded.int32_dict = Default::default();
verify_decode(&decoded, "{}");
// 64-bit dictionary values should be encoded as strings
decoded.integer_dict.insert(12, 13);
verify(&decoded, r#"{"integerDict":{"12":"13"}}"#);
// 64-bit dictionary values can be decoded from numeric types
verify_decode(&decoded, r#"{"integerDict":{"12":13}}"#);
decoded.integer_dict = Default::default();
verify_decode(&decoded, "{}");
decoded.one_of = Some(kitchen_sink::OneOf::OneOfI32(0));
verify(&decoded, r#"{"oneOfI32":0}"#);
// Can also specify string
verify_decode(&decoded, r#"{"oneOfI32":"0"}"#);
decoded.one_of = Some(kitchen_sink::OneOf::OneOfI32(12));
verify(&decoded, r#"{"oneOfI32":12}"#);
decoded.one_of = Some(kitchen_sink::OneOf::OneOfBool(false));
verify(&decoded, r#"{"oneOfBool":false}"#);
decoded.one_of = Some(kitchen_sink::OneOf::OneOfBool(true));
verify(&decoded, r#"{"oneOfBool":true}"#);
decoded.one_of = Some(kitchen_sink::OneOf::OneOfValue(
kitchen_sink::Value::B as i32,
));
verify(&decoded, r#"{"oneOfValue":"VALUE_B"}"#);
// Can also specify enum variant
verify_decode(&decoded, r#"{"oneOfValue":63}"#);
decoded.one_of = None;
verify_decode(&decoded, "{}");
decoded.repeated_value = vec![
kitchen_sink::Value::B as i32,
kitchen_sink::Value::B as i32,
kitchen_sink::Value::A as i32,
];
verify(
&decoded,
r#"{"repeatedValue":["VALUE_B","VALUE_B","VALUE_A"]}"#,
);
verify_decode(&decoded, r#"{"repeatedValue":[63,"VALUE_B","VALUE_A"]}"#);
decoded.repeated_value = Default::default();
verify_decode(&decoded, "{}");
// Bytes support currently broken
// decoded.bytes = prost::bytes::Bytes::from_static(b"kjkjkj");
// verify(&decoded, r#"{"bytes":"a2pramtqCg=="}"#);
//
// decoded.repeated_value = Default::default();
// verify_decode(&decoded, "{}");
//
// decoded.bytes_dict.insert(
// "test".to_string(),
// prost::bytes::Bytes::from_static(b"asdf"),
// );
// verify(&decoded, r#"{"bytesDict":{"test":"YXNkZgo="}}"#);
//
// decoded.bytes_dict = Default::default();
// verify_decode(&decoded, "{}");
decoded.string = "test".to_string();
verify(&decoded, r#"{"string":"test"}"#);
decoded.string = Default::default();
verify_decode(&decoded, "{}");
decoded.optional_string = Some(String::new());
verify(&decoded, r#"{"optionalString":""}"#);
}
}

View File

@ -112,7 +112,7 @@ async fn test_create_database() {
.and(predicate::str::contains(format!(r#""name": "{}"#, db)))
// validate the defaults have been set reasonably
.and(predicate::str::contains("%Y-%m-%d %H:00:00"))
.and(predicate::str::contains(r#""bufferSizeHard": 104857600"#))
.and(predicate::str::contains(r#""bufferSizeHard": "104857600""#))
.and(predicate::str::contains("lifecycleRules")),
);
}
@ -147,7 +147,7 @@ async fn test_create_database_size() {
.assert()
.success()
.stdout(
predicate::str::contains(r#""bufferSizeHard": 1000"#)
predicate::str::contains(r#""bufferSizeHard": "1000""#)
.and(predicate::str::contains("lifecycleRules")),
);
}