Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions der/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,29 @@
// TODO: fix needless_question_mark in the derive crate
#![allow(clippy::bool_assert_comparison, clippy::needless_question_mark)]

#[derive(Debug)]
#[allow(dead_code)]
pub struct CustomError(der::Error);

impl From<der::Error> for CustomError {
fn from(value: der::Error) -> Self {
Self(value)
}
}

impl From<std::convert::Infallible> for CustomError {
fn from(_value: std::convert::Infallible) -> Self {
unreachable!()
}
}

/// Custom derive test cases for the `Choice` macro.
mod choice {
use super::CustomError;

/// `Choice` with `EXPLICIT` tagging.
mod explicit {
use super::CustomError;
use der::{
asn1::{GeneralizedTime, UtcTime},
Choice, Decode, Encode, SliceWriter,
Expand Down Expand Up @@ -50,6 +69,13 @@ mod choice {
}
}

#[derive(Choice)]
#[asn1(error = CustomError)]
pub enum WithCustomError {
#[asn1(type = "GeneralizedTime")]
Foo(GeneralizedTime),
}

const UTC_TIMESTAMP_DER: &[u8] = &hex!("17 0d 39 31 30 35 30 36 32 33 34 35 34 30 5a");
const GENERAL_TIMESTAMP_DER: &[u8] =
&hex!("18 0f 31 39 39 31 30 35 30 36 32 33 34 35 34 30 5a");
Expand All @@ -61,6 +87,10 @@ mod choice {

let general_time = Time::from_der(GENERAL_TIMESTAMP_DER).unwrap();
assert_eq!(general_time.to_unix_duration().as_secs(), 673573540);

let WithCustomError::Foo(with_custom_error) =
WithCustomError::from_der(GENERAL_TIMESTAMP_DER).unwrap();
assert_eq!(with_custom_error.to_unix_duration().as_secs(), 673573540);
}

#[test]
Expand Down Expand Up @@ -154,6 +184,7 @@ mod choice {

/// Custom derive test cases for the `Enumerated` macro.
mod enumerated {
use super::CustomError;
use der::{Decode, Encode, Enumerated, SliceWriter};
use hex_literal::hex;

Expand All @@ -176,13 +207,24 @@ mod enumerated {
const UNSPECIFIED_DER: &[u8] = &hex!("0a 01 00");
const KEY_COMPROMISE_DER: &[u8] = &hex!("0a 01 01");

#[derive(Enumerated, Copy, Clone, Eq, PartialEq, Debug)]
#[asn1(error = CustomError)]
#[repr(u32)]
pub enum EnumWithCustomError {
Unspecified = 0,
Specified = 1,
}

#[test]
fn decode() {
let unspecified = CrlReason::from_der(UNSPECIFIED_DER).unwrap();
assert_eq!(CrlReason::Unspecified, unspecified);

let key_compromise = CrlReason::from_der(KEY_COMPROMISE_DER).unwrap();
assert_eq!(CrlReason::KeyCompromise, key_compromise);

let custom_error_enum = EnumWithCustomError::from_der(UNSPECIFIED_DER).unwrap();
assert_eq!(custom_error_enum, EnumWithCustomError::Unspecified);
}

#[test]
Expand All @@ -202,6 +244,7 @@ mod enumerated {
/// Custom derive test cases for the `Sequence` macro.
#[cfg(feature = "oid")]
mod sequence {
use super::CustomError;
use core::marker::PhantomData;
use der::{
asn1::{AnyRef, ObjectIdentifier, SetOf},
Expand Down Expand Up @@ -383,6 +426,12 @@ mod sequence {
pub typed_context_specific_optional: Option<&'a [u8]>,
}

#[derive(Sequence)]
#[asn1(error = CustomError)]
pub struct TypeWithCustomError {
pub simple: bool,
}

#[test]
fn idp_test() {
let idp = IssuingDistributionPointExample::from_der(&hex!("30038101FF")).unwrap();
Expand Down Expand Up @@ -444,6 +493,9 @@ mod sequence {
PRIME256V1_OID,
ObjectIdentifier::try_from(algorithm_identifier.parameters.unwrap()).unwrap()
);

let t = TypeWithCustomError::from_der(&hex!("30030101FF")).unwrap();
assert!(t.simple);
}

#[test]
Expand Down
72 changes: 52 additions & 20 deletions der_derive/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,33 @@

use crate::{Asn1Type, Tag, TagMode, TagNumber};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::{quote, ToTokens};
use std::{fmt::Debug, str::FromStr};
use syn::punctuated::Punctuated;
use syn::{parse::Parse, parse::ParseStream, Attribute, Ident, LitStr, Path, Token};

/// Error type used by the structure
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub(crate) enum ErrorType {
/// Represents the ::der::Error type
#[default]
Der,
/// Represents an error designed by Path
Custom(Path),
}

impl ToTokens for ErrorType {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
Self::Der => {
let err = quote! { ::der::Error };
err.to_tokens(tokens)
}
Self::Custom(path) => path.to_tokens(tokens),
}
}
}

/// Attribute name.
pub(crate) const ATTR_NAME: &str = "asn1";

Expand All @@ -18,37 +40,47 @@ pub(crate) struct TypeAttrs {
///
/// The default value is `EXPLICIT`.
pub tag_mode: TagMode,
pub error: ErrorType,
}

impl TypeAttrs {
/// Parse attributes from a struct field or enum variant.
pub fn parse(attrs: &[Attribute]) -> syn::Result<Self> {
let mut tag_mode = None;
let mut error = None;

let mut parsed_attrs = Vec::new();
AttrNameValue::from_attributes(attrs, &mut parsed_attrs)?;

for attr in parsed_attrs {
// `tag_mode = "..."` attribute
let mode = attr.parse_value("tag_mode")?.ok_or_else(|| {
syn::Error::new_spanned(
&attr.name,
"invalid `asn1` attribute (valid options are `tag_mode`)",
)
})?;

if tag_mode.is_some() {
return Err(syn::Error::new_spanned(
&attr.name,
"duplicate ASN.1 `tag_mode` attribute",
));
attrs.iter().try_for_each(|attr| {
if !attr.path().is_ident(ATTR_NAME) {
return Ok(());
}

tag_mode = Some(mode);
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("tag_mode") {
if tag_mode.is_some() {
abort!(attr, "duplicate ASN.1 `tag_mode` attribute");
}

tag_mode = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("error") {
if error.is_some() {
abort!(attr, "duplicate ASN.1 `error` attribute");
}

error = Some(ErrorType::Custom(meta.value()?.parse()?));
} else {
return Err(syn::Error::new_spanned(
attr,
"invalid `asn1` attribute (valid options are `tag_mode` and `error`)",
));
}

Ok(())
})
})?;

Ok(Self {
tag_mode: tag_mode.unwrap_or_default(),
error: error.unwrap_or_default(),
})
}
}
Expand Down
27 changes: 18 additions & 9 deletions der_derive/src/choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
mod variant;

use self::variant::ChoiceVariant;
use crate::{default_lifetime, TypeAttrs};
use crate::{default_lifetime, ErrorType, TypeAttrs};
use proc_macro2::TokenStream;
use quote::quote;
use quote::{quote, ToTokens};
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};

/// Derive the `Choice` trait for an enum.
Expand All @@ -20,6 +20,9 @@ pub(crate) struct DeriveChoice {

/// Variants of this `Choice`.
variants: Vec<ChoiceVariant>,

/// Error type for `DecodeValue` implementation.
error: ErrorType,
}

impl DeriveChoice {
Expand All @@ -44,6 +47,7 @@ impl DeriveChoice {
ident: input.ident,
generics: input.generics.clone(),
variants,
error: type_attrs.error.clone(),
})
}

Expand Down Expand Up @@ -84,6 +88,8 @@ impl DeriveChoice {
tagged_body.push(variant.to_tagged_tokens());
}

let error = self.error.to_token_stream();

quote! {
impl #impl_generics ::der::Choice<#lifetime> for #ident #ty_generics #where_clause {
fn can_decode(tag: ::der::Tag) -> bool {
Expand All @@ -92,17 +98,20 @@ impl DeriveChoice {
}

impl #impl_generics ::der::Decode<#lifetime> for #ident #ty_generics #where_clause {
type Error = ::der::Error;
type Error = #error;

fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::core::result::Result<Self, #error> {
use der::Reader as _;
match ::der::Tag::peek(reader)? {
#(#decode_body)*
actual => Err(der::ErrorKind::TagUnexpected {
expected: None,
actual
}
.into()),
actual => Err(::der::Error::new(
::der::ErrorKind::TagUnexpected {
expected: None,
actual
},
reader.position()
).into()
),
}
}
}
Expand Down
51 changes: 34 additions & 17 deletions der_derive/src/enumerated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
//! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to
//! enum variants.

use crate::attributes::AttrNameValue;
use crate::{default_lifetime, ATTR_NAME};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant};
use quote::{quote, ToTokens};
use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, LitStr, Path, Variant};

/// Valid options for the `#[repr]` attribute on `Enumerated` types.
const REPR_TYPES: &[&str] = &["u8", "u16", "u32"];
Expand All @@ -24,6 +23,9 @@ pub(crate) struct DeriveEnumerated {

/// Variants of this enum.
variants: Vec<EnumeratedVariant>,

/// Error type for `DecodeValue` implementation.
error: Option<Path>,
}

impl DeriveEnumerated {
Expand All @@ -40,22 +42,30 @@ impl DeriveEnumerated {
// Reject `asn1` attributes, parse the `repr` attribute
let mut repr: Option<Ident> = None;
let mut integer = false;
let mut error: Option<Path> = None;

for attr in &input.attrs {
if attr.path().is_ident(ATTR_NAME) {
let kvs = match AttrNameValue::parse_attribute(attr) {
Ok(kvs) => kvs,
Err(e) => abort!(attr, e),
};
for anv in kvs {
if anv.name.is_ident("type") {
match anv.value.value().as_str() {
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("type") {
let value: LitStr = meta.value()?.parse()?;
match value.value().as_str() {
"ENUMERATED" => integer = false,
"INTEGER" => integer = true,
s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")),
s => abort!(value, format_args!("`type = \"{s}\"` is unsupported")),
}
} else if meta.path.is_ident("error") {
let path: Path = meta.value()?.parse()?;
error = Some(path);
} else {
return Err(syn::Error::new_spanned(
&meta.path,
"invalid `asn1` attribute (valid options are `type` and `error`)",
));
}
}

Ok(())
})?;
} else if attr.path().is_ident("repr") {
if repr.is_some() {
abort!(
Expand Down Expand Up @@ -97,6 +107,7 @@ impl DeriveEnumerated {
})?,
variants,
integer,
error,
})
}

Expand All @@ -115,14 +126,20 @@ impl DeriveEnumerated {
try_from_body.push(variant.to_try_from_tokens());
}

let error = self
.error
.as_ref()
.map(ToTokens::to_token_stream)
.unwrap_or_else(|| quote! { ::der::Error });

quote! {
impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
type Error = ::der::Error;
type Error = #error;

fn decode_value<R: ::der::Reader<#default_lifetime>>(
reader: &mut R,
header: ::der::Header
) -> ::der::Result<Self> {
) -> ::core::result::Result<Self, #error> {
<#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
}
}
Expand All @@ -142,12 +159,12 @@ impl DeriveEnumerated {
}

impl TryFrom<#repr> for #ident {
type Error = ::der::Error;
type Error = #error;

fn try_from(n: #repr) -> ::der::Result<Self> {
fn try_from(n: #repr) -> ::core::result::Result<Self, #error> {
match n {
#(#try_from_body)*
_ => Err(#tag.value_error())
_ => Err(#tag.value_error().into())
}
}
}
Expand Down
Loading