diff --git a/rust/derive/tests/base.rs b/rust/derive/tests/base.rs index cbb40ed..9d1a8c4 100644 --- a/rust/derive/tests/base.rs +++ b/rust/derive/tests/base.rs @@ -194,7 +194,7 @@ fn enum_custom_tags() -> common::Result { Five {}, } - impl StrictSerialize for Assoc {} + impl StrictSerialize<256> for Assoc {} assert_eq!(Assoc::ALL_VARIANTS, &[ (0, "one"), @@ -205,13 +205,13 @@ fn enum_custom_tags() -> common::Result { ]); let assoc = Assoc::Two(0, 1, 2); - assert_eq!(assoc.to_strict_serialized::<256>().unwrap().as_slice(), &[2, 0, 1, 0, 2, 0, 0, 0]); + assert_eq!(assoc.to_strict_vec().unwrap().as_slice(), &[2, 0, 1, 0, 2, 0, 0, 0]); let assoc = Assoc::One { hash: [0u8; 32], ord: 0, }; - assert_eq!(assoc.to_strict_serialized::<256>().unwrap().as_slice(), &[0u8; 34]); + assert_eq!(assoc.to_strict_vec().unwrap().as_slice(), &[0u8; 34]); Ok(()) } diff --git a/rust/derive/tests/type.rs b/rust/derive/tests/type.rs index 0c4ec4a..4cbaf1e 100644 --- a/rust/derive/tests/type.rs +++ b/rust/derive/tests/type.rs @@ -137,7 +137,7 @@ fn skip_field() -> common::Result { must_camelize: 2, wrong_name: 3, }; - assert_eq!(val.to_strict_serialized::<{ usize::MAX }>().unwrap().as_slice(), &[2]); + assert_eq!(val.to_strict_vec().unwrap().as_slice(), &[2]); let val = Struct { must_camelize: 2, wrong_name: 0, diff --git a/rust/src/traits.rs b/rust/src/traits.rs index 8433bda..a1f50c7 100644 --- a/rust/src/traits.rs +++ b/rust/src/traits.rs @@ -381,34 +381,54 @@ impl StrictDecode for PhantomData { fn strict_decode(_reader: &mut impl TypedRead) -> Result { Ok(default!()) } } -// TODO: Provide max length as a trait-level const -pub trait StrictSerialize: StrictEncode { - fn strict_serialized_len(&self) -> io::Result { +pub trait StrictSerialize: StrictEncode { + fn strict_serialize(&self, write: impl io::Write) -> Result<(), io::Error> { + let writer = StreamWriter::new::(write); + self.strict_write(writer) + } + + fn to_strict_vec(&self) -> Result, 0, MAX_SERIALIZED_LEN>, SerializeError> { + let ast_data = StrictWriter::in_memory::(); + let data = self.strict_encode(ast_data)?.unbox().unconfine(); + Confined::, 0, MAX_SERIALIZED_LEN>::try_from(data).map_err(SerializeError::from) + } + + //#[cfg(feature = "std")] + fn strict_serialize_to_path( + &self, + path: impl AsRef, + overwrite: bool, + ) -> Result<(), SerializeError> { + let file = if overwrite { fs::File::create(path)? } else { fs::File::create_new(path)? }; + self.strict_serialize(file)?; + Ok(()) + } + + #[deprecated(since = "3.0.0", note = "use `StrictEncode::strict_len` instead")] + fn strict_serialized_len(&self) -> io::Result + where Self: StrictSerialize { let counter = StrictWriter::counter::(); Ok(self.strict_encode(counter)?.unbox().unconfine().count) } + #[deprecated(since = "3.0.0", note = "use `to_strict_vec` instead")] fn to_strict_serialized( &self, - ) -> Result, 0, MAX>, SerializeError> { - let ast_data = StrictWriter::in_memory::(); - let data = self.strict_encode(ast_data)?.unbox().unconfine(); - Confined::, 0, MAX>::try_from(data).map_err(SerializeError::from) - } - - fn strict_serialize(&self, write: impl io::Write) -> Result<(), io::Error> { - let writer = StreamWriter::new::(write); - self.strict_write(writer) + ) -> Result, 0, MAX>, SerializeError> + where Self: StrictSerialize { + self.to_strict_vec() } + #[deprecated(since = "3.0.0", note = "use `strict_serialize_to_path` instead")] fn strict_serialize_to_file( &self, path: impl AsRef, - ) -> Result<(), SerializeError> { + ) -> Result<(), SerializeError> + where + Self: StrictSerialize, + { let file = fs::File::create(path)?; - // TODO: Do FileWriter - let file = StrictWriter::with(StreamWriter::new::(file)); - self.strict_encode(file)?; + StrictSerialize::::strict_serialize(self, file)?; Ok(()) } }