From c427f848486ecf8d26b3c05a75e64e56887ed4c7 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Fri, 24 Sep 2021 14:20:50 +0100 Subject: [PATCH] feat: correct pbjson serialization of Timestamp and Duration (#2620) Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com> --- Cargo.lock | 1 + google_types/Cargo.toml | 3 + google_types/build.rs | 1 + google_types/src/duration.rs | 159 ++++++++++++++++++++++++++++++++++ google_types/src/lib.rs | 45 +--------- google_types/src/timestamp.rs | 73 ++++++++++++++++ pbjson_build/src/lib.rs | 86 ++++++++++-------- 7 files changed, 292 insertions(+), 76 deletions(-) create mode 100644 google_types/src/duration.rs create mode 100644 google_types/src/timestamp.rs diff --git a/Cargo.lock b/Cargo.lock index 67cfc55ea5..cfe6333681 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1340,6 +1340,7 @@ dependencies = [ "prost", "prost-build", "serde", + "serde_json", ] [[package]] diff --git a/google_types/Cargo.toml b/google_types/Cargo.toml index 9ef46210ef..bf15dc7134 100644 --- a/google_types/Cargo.toml +++ b/google_types/Cargo.toml @@ -12,6 +12,9 @@ pbjson = { path = "../pbjson" } prost = "0.8" serde = { version = "1.0", features = ["derive"] } +[dev-dependencies] +serde_json = "1.0" + [build-dependencies] # In alphabetical order prost-build = "0.8" pbjson_build = { path = "../pbjson_build" } diff --git a/google_types/build.rs b/google_types/build.rs index 28818c1e0d..696e648171 100644 --- a/google_types/build.rs +++ b/google_types/build.rs @@ -28,6 +28,7 @@ fn main() -> Result<()> { let descriptor_set = std::fs::read(descriptor_path)?; pbjson_build::Builder::new() .register_descriptors(&descriptor_set)? + .exclude([".google.protobuf.Duration", ".google.protobuf.Timestamp"]) .build(&[".google"])?; Ok(()) diff --git a/google_types/src/duration.rs b/google_types/src/duration.rs new file mode 100644 index 0000000000..1506e4e31a --- /dev/null +++ b/google_types/src/duration.rs @@ -0,0 +1,159 @@ +use crate::protobuf::Duration; +use serde::{Deserialize, Serialize, Serializer}; +use std::convert::{TryFrom, TryInto}; + +impl TryFrom for std::time::Duration { + type Error = std::num::TryFromIntError; + + fn try_from(value: Duration) -> Result { + Ok(std::time::Duration::new( + value.seconds.try_into()?, + value.nanos.try_into()?, + )) + } +} + +impl From for Duration { + fn from(value: std::time::Duration) -> Self { + Self { + seconds: value.as_secs() as _, + nanos: value.subsec_nanos() as _, + } + } +} + +impl Serialize for Duration { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + if self.seconds != 0 && self.nanos != 0 && (self.nanos < 0) != (self.seconds < 0) { + return Err(serde::ser::Error::custom("Duration has inconsistent signs")); + } + + let mut s = if self.seconds == 0 { + if self.nanos < 0 { + "-0".to_string() + } else { + "0".to_string() + } + } else { + self.seconds.to_string() + }; + + if self.nanos != 0 { + s.push('.'); + let f = match split_nanos(self.nanos.abs() as u32) { + (millis, 0, 0) => format!("{:03}", millis), + (millis, micros, 0) => format!("{:03}{:03}", millis, micros), + (millis, micros, nanos) => format!("{:03}{:03}{:03}", millis, micros, nanos), + }; + s.push_str(&f); + } + + s.push('s'); + serializer.serialize_str(&s) + } +} + +impl<'de> serde::Deserialize<'de> for Duration { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s: &str = Deserialize::deserialize(deserializer)?; + let s = s + .strip_suffix('s') + .ok_or_else(|| serde::de::Error::custom("missing 's' suffix"))?; + let secs: f64 = s.parse().map_err(serde::de::Error::custom)?; + + if secs < 0. { + let negated = std::time::Duration::from_secs_f64(-secs); + Ok(Self { + seconds: -(negated.as_secs() as i64), + nanos: -(negated.subsec_nanos() as i32), + }) + } else { + Ok(std::time::Duration::from_secs_f64(secs).into()) + } + } +} + +/// Splits nanoseconds into whole milliseconds, microseconds, and nanoseconds +fn split_nanos(mut nanos: u32) -> (u32, u32, u32) { + let millis = nanos / 1_000_000; + nanos -= millis * 1_000_000; + let micros = nanos / 1_000; + nanos -= micros * 1_000; + (millis, micros, nanos) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_duration() { + let verify = |duration: &Duration, expected: &str| { + assert_eq!(serde_json::to_string(duration).unwrap().as_str(), expected); + assert_eq!( + &serde_json::from_str::(expected).unwrap(), + duration + ) + }; + + let duration = Duration { + seconds: 0, + nanos: 0, + }; + verify(&duration, "\"0s\""); + + let duration = Duration { + seconds: 0, + nanos: 123, + }; + verify(&duration, "\"0.000000123s\""); + + let duration = Duration { + seconds: 0, + nanos: 123456, + }; + verify(&duration, "\"0.000123456s\""); + + let duration = Duration { + seconds: 0, + nanos: 123456789, + }; + verify(&duration, "\"0.123456789s\""); + + let duration = Duration { + seconds: 0, + nanos: -67088, + }; + verify(&duration, "\"-0.000067088s\""); + + let duration = Duration { + seconds: 121, + nanos: 3454, + }; + verify(&duration, "\"121.000003454s\""); + + let duration = Duration { + seconds: -90, + nanos: -2456301, + }; + verify(&duration, "\"-90.002456301s\""); + + let duration = Duration { + seconds: -90, + nanos: 234, + }; + serde_json::to_string(&duration).unwrap_err(); + + let duration = Duration { + seconds: 90, + nanos: -234, + }; + serde_json::to_string(&duration).unwrap_err(); + } +} diff --git a/google_types/src/lib.rs b/google_types/src/lib.rs index 078916f103..e805653f6d 100644 --- a/google_types/src/lib.rs +++ b/google_types/src/lib.rs @@ -13,52 +13,13 @@ mod pb { pub mod google { pub mod protobuf { - use chrono::{NaiveDateTime, Utc}; - use std::convert::{TryFrom, TryInto}; - include!(concat!(env!("OUT_DIR"), "/google.protobuf.rs")); include!(concat!(env!("OUT_DIR"), "/google.protobuf.serde.rs")); - - impl TryFrom for std::time::Duration { - type Error = std::num::TryFromIntError; - - fn try_from(value: Duration) -> Result { - Ok(std::time::Duration::new( - value.seconds.try_into()?, - value.nanos.try_into()?, - )) - } - } - - impl From for Duration { - fn from(value: std::time::Duration) -> Self { - Self { - seconds: value.as_secs() as _, - nanos: value.subsec_nanos() as _, - } - } - } - - impl TryFrom for chrono::DateTime { - type Error = std::num::TryFromIntError; - fn try_from(value: Timestamp) -> Result { - let Timestamp { seconds, nanos } = value; - - let dt = NaiveDateTime::from_timestamp(seconds, nanos.try_into()?); - Ok(chrono::DateTime::::from_utc(dt, Utc)) - } - } - - impl From> for Timestamp { - fn from(value: chrono::DateTime) -> Self { - Self { - seconds: value.timestamp(), - nanos: value.timestamp_subsec_nanos() as i32, - } - } - } } } } +mod duration; +mod timestamp; + pub use pb::google::*; diff --git a/google_types/src/timestamp.rs b/google_types/src/timestamp.rs new file mode 100644 index 0000000000..4b2acdb25d --- /dev/null +++ b/google_types/src/timestamp.rs @@ -0,0 +1,73 @@ +use crate::protobuf::Timestamp; +use chrono::{DateTime, NaiveDateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::convert::{TryFrom, TryInto}; + +impl TryFrom for chrono::DateTime { + type Error = std::num::TryFromIntError; + fn try_from(value: Timestamp) -> Result { + let Timestamp { seconds, nanos } = value; + + let dt = NaiveDateTime::from_timestamp(seconds, nanos.try_into()?); + Ok(DateTime::::from_utc(dt, Utc)) + } +} + +impl From> for Timestamp { + fn from(value: DateTime) -> Self { + Self { + seconds: value.timestamp(), + nanos: value.timestamp_subsec_nanos() as i32, + } + } +} + +impl Serialize for Timestamp { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let t: DateTime = self.clone().try_into().map_err(serde::ser::Error::custom)?; + serializer.serialize_str(t.to_rfc3339().as_str()) + } +} + +impl<'de> serde::Deserialize<'de> for Timestamp { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s: &str = Deserialize::deserialize(deserializer)?; + let d = DateTime::parse_from_rfc3339(s).map_err(serde::de::Error::custom)?; + let d: DateTime = d.into(); + Ok(d.into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{FixedOffset, TimeZone}; + use serde::de::value::{BorrowedStrDeserializer, Error}; + + #[test] + fn test_date() { + let datetime = FixedOffset::east(5 * 3600) + .ymd(2016, 11, 8) + .and_hms(21, 7, 9); + let encoded = datetime.to_rfc3339(); + assert_eq!(&encoded, "2016-11-08T21:07:09+05:00"); + + let utc: DateTime = datetime.into(); + let utc_encoded = utc.to_rfc3339(); + assert_eq!(&utc_encoded, "2016-11-08T16:07:09+00:00"); + + let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded); + let a: Timestamp = Timestamp::deserialize(deserializer).unwrap(); + assert_eq!(a.seconds, utc.timestamp()); + assert_eq!(a.nanos, utc.timestamp_subsec_nanos() as i32); + + let encoded = serde_json::to_string(&a).unwrap(); + assert_eq!(encoded, format!("\"{}\"", utc_encoded)); + } +} diff --git a/pbjson_build/src/lib.rs b/pbjson_build/src/lib.rs index b985473ac4..a5c521fa34 100644 --- a/pbjson_build/src/lib.rs +++ b/pbjson_build/src/lib.rs @@ -21,6 +21,7 @@ mod message; #[derive(Debug, Default)] pub struct Builder { descriptors: descriptor::DescriptorSet, + exclude: Vec, out_dir: Option, } @@ -29,6 +30,7 @@ impl Builder { pub fn new() -> Self { Self { descriptors: DescriptorSet::new(), + exclude: Default::default(), out_dir: None, } } @@ -39,6 +41,15 @@ impl Builder { Ok(self) } + /// Don't generate code for the following type prefixes + pub fn exclude, I: IntoIterator>( + &mut self, + prefixes: I, + ) -> &mut Self { + self.exclude.extend(prefixes.into_iter().map(Into::into)); + self + } + /// Generates code for all registered types where `prefixes` contains a prefix of /// the fully-qualified path of the type pub fn build>(&mut self, prefixes: &[S]) -> Result<()> { @@ -63,51 +74,58 @@ impl Builder { Ok(BufWriter::new(file)) }; - let writers = generate(&self.descriptors, prefixes, write_factory)?; + let writers = self.generate(prefixes, write_factory)?; for (_, mut writer) in writers { writer.flush()?; } Ok(()) } -} -fn generate, W: Write, F: FnMut(&Package) -> Result>( - descriptors: &DescriptorSet, - prefixes: &[S], - mut write_factory: F, -) -> Result> { - let config = Config { - extern_types: Default::default(), - }; - - let iter = descriptors.iter().filter(move |(t, _)| { - prefixes - .iter() - .any(|prefix| t.matches_prefix(prefix.as_ref())) - }); - - // Exploit the fact descriptors is ordered to group together types from the same package - let mut ret: Vec<(Package, W)> = Vec::new(); - for (type_path, descriptor) in iter { - let writer = match ret.last_mut() { - Some((package, writer)) if package == type_path.package() => writer, - _ => { - let package = type_path.package(); - ret.push((package.clone(), write_factory(package)?)); - &mut ret.last_mut().unwrap().1 - } + fn generate, W: Write, F: FnMut(&Package) -> Result>( + &self, + prefixes: &[S], + mut write_factory: F, + ) -> Result> { + let config = Config { + extern_types: Default::default(), }; - match descriptor { - Descriptor::Enum(descriptor) => generate_enum(&config, type_path, descriptor, writer)?, - Descriptor::Message(descriptor) => { - if let Some(message) = resolve_message(descriptors, descriptor) { - generate_message(&config, &message, writer)? + let iter = self.descriptors.iter().filter(move |(t, _)| { + let exclude = self + .exclude + .iter() + .any(|prefix| t.matches_prefix(prefix.as_ref())); + let include = prefixes + .iter() + .any(|prefix| t.matches_prefix(prefix.as_ref())); + include && !exclude + }); + + // Exploit the fact descriptors is ordered to group together types from the same package + let mut ret: Vec<(Package, W)> = Vec::new(); + for (type_path, descriptor) in iter { + let writer = match ret.last_mut() { + Some((package, writer)) if package == type_path.package() => writer, + _ => { + let package = type_path.package(); + ret.push((package.clone(), write_factory(package)?)); + &mut ret.last_mut().unwrap().1 + } + }; + + match descriptor { + Descriptor::Enum(descriptor) => { + generate_enum(&config, type_path, descriptor, writer)? + } + Descriptor::Message(descriptor) => { + if let Some(message) = resolve_message(&self.descriptors, descriptor) { + generate_message(&config, &message, writer)? + } } } } - } - Ok(ret) + Ok(ret) + } }