feat: add pbjson bytes support (#2560)

* feat: add pbjson bytes support

* chore: fix lint

* chore: review feedback

Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
pull/24376/head
Raphael Taylor-Davies 2021-09-16 17:46:12 +01:00 committed by GitHub
parent 315cbb8105
commit f34eab70b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 188 additions and 69 deletions

2
Cargo.lock generated
View File

@ -2902,6 +2902,8 @@ checksum = "acbf547ad0c65e31259204bd90935776d1c693cec2f4ff7abb7a1bbbd40dfe58"
name = "pbjson"
version = "0.1.0"
dependencies = [
"base64 0.13.0",
"bytes",
"serde",
]

View File

@ -8,5 +8,7 @@ description = "Utilities for pbjson converion"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
base64 = "0.13"
[dev-dependencies]
bytes = "1.0"

View File

@ -9,12 +9,17 @@
#[doc(hidden)]
pub mod private {
/// Re-export base64
pub use base64;
use serde::Deserialize;
use std::str::FromStr;
/// Used to parse a number from either a string or its raw representation
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
pub struct NumberDeserialize<T>(pub T);
#[derive(serde::Deserialize)]
#[derive(Deserialize)]
#[serde(untagged)]
enum Content<'a, T> {
Str(&'a str),
@ -26,7 +31,6 @@ pub mod private {
T: FromStr + serde::Deserialize<'de>,
<T as FromStr>::Err: std::error::Error,
{
#[allow(deprecated)]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
@ -38,4 +42,41 @@ pub mod private {
}))
}
}
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
pub struct BytesDeserialize<T>(pub T);
impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
where
T: From<Vec<u8>>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: &str = Deserialize::deserialize(deserializer)?;
let decoded = base64::decode(s).map_err(serde::de::Error::custom)?;
Ok(Self(decoded.into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use serde::de::value::{BorrowedStrDeserializer, Error};
#[test]
fn test_bytes() {
let raw = vec![2, 5, 62, 2, 5, 7, 8, 43, 5, 8, 4, 23, 5, 7, 7, 3, 2, 5, 196];
let encoded = base64::encode(&raw);
let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
assert_eq!(raw.as_slice(), &a);
assert_eq!(raw.as_slice(), &b);
}
}
}

View File

@ -198,28 +198,14 @@ fn write_serialize_variable<W: Write>(
writer: &mut W,
) -> Result<()> {
match &field.field_type {
FieldType::Scalar(ScalarType::I64) | FieldType::Scalar(ScalarType::U64) => {
match field.field_modifier {
FieldModifier::Repeated => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", &{}.iter().map(ToString::to_string).collect::<Vec<_>>())?;",
Indent(indent),
field.json_name(),
variable.raw
)
}
_ => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", {}.to_string().as_str())?;",
Indent(indent),
field.json_name(),
variable.raw
)
}
}
}
FieldType::Scalar(scalar) => write_serialize_scalar_variable(
indent,
*scalar,
field.field_modifier,
variable,
field.json_name(),
writer,
),
FieldType::Enum(path) => {
write!(writer, "{}let v = ", Indent(indent))?;
match field.field_modifier {
@ -301,6 +287,52 @@ fn write_serialize_variable<W: Write>(
}
}
fn write_serialize_scalar_variable<W: Write>(
indent: usize,
scalar: ScalarType,
field_modifier: FieldModifier,
variable: Variable<'_>,
json_name: String,
writer: &mut W,
) -> Result<()> {
let conversion = match scalar {
ScalarType::I64 | ScalarType::U64 => "ToString::to_string",
ScalarType::Bytes => "pbjson::private::base64::encode",
_ => {
return writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", {})?;",
Indent(indent),
json_name,
variable.as_ref
)
}
};
match field_modifier {
FieldModifier::Repeated => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", &{}.iter().map({}).collect::<Vec<_>>())?;",
Indent(indent),
json_name,
variable.raw,
conversion
)
}
_ => {
writeln!(
writer,
"{}struct_ser.serialize_field(\"{}\", {}(&{}).as_str())?;",
Indent(indent),
json_name,
conversion,
variable.raw,
)
}
}
}
fn write_serialize_field<W: Write>(
config: &Config,
indent: usize,
@ -641,33 +673,8 @@ fn write_deserialize_field<W: Write>(
}
match &field.field_type {
FieldType::Scalar(scalar) if scalar.is_numeric() => {
writeln!(writer)?;
match field.field_modifier {
FieldModifier::Repeated => {
writeln!(
writer,
"{}map.next_value::<Vec<::pbjson::private::NumberDeserialize<{}>>>()?",
Indent(indent + 2),
scalar.rust_type()
)?;
writeln!(
writer,
"{}.into_iter().map(|x| x.0).collect()",
Indent(indent + 3)
)?;
}
_ => {
writeln!(
writer,
"{}map.next_value::<::pbjson::private::NumberDeserialize<{}>>()?.0",
Indent(indent + 2),
scalar.rust_type()
)?;
}
}
write!(writer, "{}", Indent(indent + 1))?;
FieldType::Scalar(scalar) => {
write_encode_scalar_field(indent + 1, *scalar, field.field_modifier, writer)?;
}
FieldType::Enum(path) => match field.field_modifier {
FieldModifier::Repeated => {
@ -693,8 +700,12 @@ fn write_deserialize_field<W: Write>(
Indent(indent + 2),
)?;
let map_k = match key.is_numeric() {
true => {
let map_k = match key {
ScalarType::Bytes => {
// https://github.com/tokio-rs/prost/issues/531
panic!("bytes are not currently supported as map keys")
}
_ if key.is_numeric() => {
write!(
writer,
"::pbjson::private::NumberDeserialize<{}>",
@ -702,7 +713,7 @@ fn write_deserialize_field<W: Write>(
)?;
"k.0"
}
false => {
_ => {
write!(writer, "_")?;
"k"
}
@ -717,6 +728,10 @@ fn write_deserialize_field<W: Write>(
)?;
"v.0"
}
FieldType::Scalar(ScalarType::Bytes) => {
// https://github.com/tokio-rs/prost/issues/531
panic!("bytes are not currently supported as map values")
}
FieldType::Enum(path) => {
write!(writer, "{}", config.rust_type(path))?;
"v as i32"
@ -752,3 +767,43 @@ fn write_deserialize_field<W: Write>(
writeln!(writer, ");")?;
writeln!(writer, "{}}}", Indent(indent))
}
fn write_encode_scalar_field<W: Write>(
indent: usize,
scalar: ScalarType,
field_modifier: FieldModifier,
writer: &mut W,
) -> Result<()> {
let deserializer = match scalar {
ScalarType::Bytes => "BytesDeserialize",
_ if scalar.is_numeric() => "NumberDeserialize",
_ => return write!(writer, "map.next_value()?",),
};
writeln!(writer)?;
match field_modifier {
FieldModifier::Repeated => {
writeln!(
writer,
"{}map.next_value::<Vec<::pbjson::private::{}<_>>>()?",
Indent(indent + 1),
deserializer
)?;
writeln!(
writer,
"{}.into_iter().map(|x| x.0).collect()",
Indent(indent + 2)
)?;
}
_ => {
writeln!(
writer,
"{}map.next_value::<::pbjson::private::{}<_>>()?.0",
Indent(indent + 1),
deserializer
)?;
}
}
write!(writer, "{}", Indent(indent))
}

View File

@ -12,7 +12,7 @@ use prost_types::{
use crate::descriptor::{Descriptor, DescriptorSet, MessageDescriptor, Syntax, TypeName, TypePath};
use crate::escape::escape_ident;
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub enum ScalarType {
F64,
F32,
@ -61,7 +61,7 @@ pub enum FieldType {
Map(ScalarType, Box<FieldType>),
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub enum FieldModifier {
Required,
Optional,

View File

@ -60,11 +60,13 @@ message KitchenSink {
Empty one_of_message = 30;
}
// Bytes support is currently broken
// bytes bytes = 31;
// optional bytes optional_bytes = 32;
// map<string, bytes> bytes_dict = 35;
bytes bytes = 31;
optional bytes optional_bytes = 32;
repeated bytes repeated_bytes = 33;
string string = 33;
optional string optional_string = 34;
// Bytes support is currently broken - https://github.com/tokio-rs/prost/issues/531
// map<string, bytes> bytes_dict = 34;
string string = 35;
optional string optional_string = 36;
}

View File

@ -196,13 +196,30 @@ mod tests {
decoded.repeated_value = Default::default();
verify_decode(&decoded, "{}");
// Bytes support currently broken
// decoded.bytes = prost::bytes::Bytes::from_static(b"kjkjkj");
// verify(&decoded, r#"{"bytes":"a2pramtqCg=="}"#);
//
// decoded.repeated_value = Default::default();
// verify_decode(&decoded, "{}");
//
decoded.bytes = prost::bytes::Bytes::from_static(b"kjkjkj");
verify(&decoded, r#"{"bytes":"a2pramtq"}"#);
decoded.bytes = Default::default();
verify_decode(&decoded, "{}");
decoded.optional_bytes = Some(prost::bytes::Bytes::from_static(b"kjkjkj"));
verify(&decoded, r#"{"optionalBytes":"a2pramtq"}"#);
decoded.optional_bytes = Some(Default::default());
verify(&decoded, r#"{"optionalBytes":""}"#);
decoded.optional_bytes = None;
verify_decode(&decoded, "{}");
decoded.repeated_bytes = vec![
prost::bytes::Bytes::from_static(b"sdfsd"),
prost::bytes::Bytes::from_static(b"fghfg"),
];
verify(&decoded, r#"{"repeatedBytes":["c2Rmc2Q=","ZmdoZmc="]}"#);
decoded.repeated_bytes = Default::default();
verify_decode(&decoded, "{}");
// decoded.bytes_dict.insert(
// "test".to_string(),
// prost::bytes::Bytes::from_static(b"asdf"),