diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index acea0a635..6dc6927f9 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -479,6 +479,15 @@ dependencies = [ "url", ] +[[package]] +name = "bhttp" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d305a54bcb99974213b4c78a486c34091e83c5d6d6572f7f4331c904ea9d127" +dependencies = [ + "thiserror 2.0.18", +] + [[package]] name = "bip39" version = "2.2.2" @@ -2761,11 +2770,12 @@ checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" name = "payjoin" version = "0.25.0" dependencies = [ - "bhttp", + "bhttp 0.7.2", "bitcoin 0.32.8", "bitcoin-hpke", "bitcoin-ohttp", "bitcoin_uri", + "getrandom 0.3.4", "http", "once_cell", "payjoin-test-utils", @@ -2850,7 +2860,7 @@ dependencies = [ "anyhow", "axum", "axum-server", - "bhttp", + "bhttp 0.6.1", "bitcoin 0.32.8", "bitcoin-ohttp", "byteorder", diff --git a/Cargo-recent.lock b/Cargo-recent.lock index 2f6135d72..63f4bb9e2 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -479,6 +479,15 @@ dependencies = [ "url", ] +[[package]] +name = "bhttp" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d305a54bcb99974213b4c78a486c34091e83c5d6d6572f7f4331c904ea9d127" +dependencies = [ + "thiserror 2.0.18", +] + [[package]] name = "bip39" version = "2.2.0" @@ -1665,21 +1674,21 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "js-sys", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasip2", "wasm-bindgen", ] @@ -1869,12 +1878,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -2203,7 +2211,7 @@ version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", "libc", ] @@ -2430,7 +2438,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "windows-sys 0.59.0", ] @@ -2729,11 +2737,12 @@ checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd" name = "payjoin" version = "0.25.0" dependencies = [ - "bhttp", + "bhttp 0.7.2", "bitcoin 0.32.8", "bitcoin-hpke", "bitcoin-ohttp", "bitcoin_uri", + "getrandom 0.3.4", "http", "once_cell", "payjoin-test-utils", @@ -2818,7 +2827,7 @@ dependencies = [ "anyhow", "axum", "axum-server", - "bhttp", + "bhttp 0.6.1", "bitcoin 0.32.8", "bitcoin-ohttp", "byteorder", @@ -3134,7 +3143,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" dependencies = [ "bytes", - "getrandom 0.3.3", + "getrandom 0.3.4", "lru-slab", "rand 0.9.1", "ring", @@ -3255,7 +3264,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", ] [[package]] @@ -4100,7 +4109,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", - "getrandom 0.3.3", + "getrandom 0.3.4", "once_cell", "rustix 1.0.7", "windows-sys 0.59.0", @@ -4927,7 +4936,7 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", "js-sys", "rand 0.9.1", "wasm-bindgen", @@ -4967,12 +4976,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] -name = "wasi" -version = "0.14.2+wasi-0.2.4" +name = "wasip2" +version = "1.0.2+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] @@ -5397,13 +5406,10 @@ dependencies = [ ] [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" -dependencies = [ - "bitflags 2.6.0", -] +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" [[package]] name = "writeable" diff --git a/payjoin-ffi/src/receive/mod.rs b/payjoin-ffi/src/receive/mod.rs index 3625362ee..29b260a06 100644 --- a/payjoin-ffi/src/receive/mod.rs +++ b/payjoin-ffi/src/receive/mod.rs @@ -206,7 +206,11 @@ impl From for ReceiverSessionHistory { #[uniffi::export] impl ReceiverSessionHistory { /// Receiver session Payjoin URI - pub fn pj_uri(&self) -> Arc { Arc::new(self.0.pj_uri().into()) } + pub fn pj_uri(&self) -> Arc { + // SAFETY: pj_uri() returns PjUri<'static> — label and message are None + let uri: payjoin::PjUri<'static> = unsafe { core::mem::transmute(self.0.pj_uri()) }; + Arc::new(uri.into()) + } /// Fallback transaction from the session if present pub fn fallback_tx(&self) -> Option> { diff --git a/payjoin/src/bech32.rs b/payjoin/src/bech32.rs index 0b7d80012..617309201 100644 --- a/payjoin/src/bech32.rs +++ b/payjoin/src/bech32.rs @@ -1,3 +1,8 @@ +#[cfg(feature = "alloc")] +use alloc::string::String; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; + use bitcoin::bech32::primitives::decode::{CheckedHrpstring, CheckedHrpstringError}; use bitcoin::bech32::{self, EncodeError, Hrp, NoChecksum}; diff --git a/payjoin/src/core/error.rs b/payjoin/src/core/error.rs index 8c74f54e5..04e7140ec 100644 --- a/payjoin/src/core/error.rs +++ b/payjoin/src/core/error.rs @@ -1,6 +1,29 @@ -use std::fmt::Debug; -use std::{error, fmt}; +#[cfg(not(feature = "std"))] +use alloc::boxed::Box; +#[cfg(not(feature = "std"))] +use core::error; +use core::fmt::{self, Debug}; +#[cfg(feature = "std")] +use std::error; +#[derive(Debug)] +pub struct StdRequiredError; + +impl fmt::Display for StdRequiredError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "std is required for this operation") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for StdRequiredError {} + +#[cfg(not(feature = "std"))] +impl core::error::Error for StdRequiredError {} + +impl ImplementationError { + pub fn std_required() -> Self { ImplementationError(Box::new(StdRequiredError)) } +} #[derive(Debug)] pub struct ImplementationError(Box); @@ -11,7 +34,7 @@ impl ImplementationError { } impl fmt::Display for ImplementationError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { std::fmt::Display::fmt(&self.0, f) } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self.0, f) } } impl error::Error for ImplementationError { @@ -34,16 +57,17 @@ impl From<&str> for ImplementationError { ImplementationError::from(error) } } + /// Errors that can occur when replaying a session event log #[cfg(feature = "v2")] #[derive(Debug)] pub struct ReplayError(InternalReplayError); #[cfg(feature = "v2")] -impl std::fmt::Display +impl fmt::Display for ReplayError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use InternalReplayError::*; match &self.0 { NoEvents => write!(f, "No events found in session"), @@ -56,8 +80,9 @@ impl std::fmt::Display } } } + #[cfg(feature = "v2")] -impl std::error::Error +impl error::Error for ReplayError { } @@ -71,6 +96,7 @@ impl From { /// No events in the event log NoEvents, diff --git a/payjoin/src/core/hpke.rs b/payjoin/src/core/hpke.rs index aa12756af..5a4ab521c 100644 --- a/payjoin/src/core/hpke.rs +++ b/payjoin/src/core/hpke.rs @@ -1,6 +1,11 @@ +#![cfg(any(feature = "v2", feature = "v2-ohttp"))] +use alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use core::error; use core::fmt; +use core::ops::Deref; +#[cfg(feature = "std")] use std::error; -use std::ops::Deref; use bitcoin::key::constants::{ELLSWIFT_ENCODING_SIZE, PUBLIC_KEY_SIZE}; use bitcoin::secp256k1; @@ -195,13 +200,13 @@ pub fn decrypt_message_a( message_a: &[u8], receiver_sk: &HpkeSecretKey, ) -> Result<(Vec, HpkePublicKey), HpkeError> { - use std::io::{Cursor, Read}; + if message_a.len() < ELLSWIFT_ENCODING_SIZE { + return Err(HpkeError::PayloadTooShort); + } - let mut cursor = Cursor::new(message_a); + let (enc_part, ciphertext_part) = message_a.split_at(ELLSWIFT_ENCODING_SIZE); - let mut enc_bytes = [0u8; ELLSWIFT_ENCODING_SIZE]; - cursor.read_exact(&mut enc_bytes).map_err(|_| HpkeError::PayloadTooShort)?; - let enc = encapped_key_from_ellswift_bytes(&enc_bytes)?; + let enc = encapped_key_from_ellswift_bytes(enc_part)?; let mut decryption_ctx = hpke::setup_receiver::< ChaCha20Poly1305, @@ -209,15 +214,16 @@ pub fn decrypt_message_a( SecpK256HkdfSha256, >(&OpModeR::Base, &receiver_sk.0, &enc, INFO_A)?; - let mut ciphertext = Vec::new(); - cursor.read_to_end(&mut ciphertext).map_err(|_| HpkeError::PayloadTooShort)?; - let plaintext = decryption_ctx.open(&ciphertext, &[])?; + let plaintext = decryption_ctx.open(ciphertext_part, &[])?; - let reply_pk = pubkey_from_compressed_bytes(&plaintext[..PUBLIC_KEY_SIZE])?; + if plaintext.len() < PUBLIC_KEY_SIZE { + return Err(HpkeError::PayloadTooShort); + } - let body = &plaintext[PUBLIC_KEY_SIZE..]; + let reply_pk = pubkey_from_compressed_bytes(&plaintext[..PUBLIC_KEY_SIZE])?; + let body = plaintext[PUBLIC_KEY_SIZE..].to_vec(); - Ok((body.to_vec(), reply_pk)) + Ok((body, reply_pk)) } /// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error @@ -243,6 +249,8 @@ pub fn encrypt_message_b( Ok(message_b) } +#[cfg(feature = "std")] +#[allow(dead_code)] pub fn decrypt_message_b( message_b: &[u8], receiver_pk: HpkePublicKey, @@ -581,4 +589,60 @@ mod test { check_uniformity(messages_a); check_uniformity(messages_b); } + + #[test] + fn decrypt_message_a_payload_too_short() { + let receiver_keypair = HpkeKeyPair::gen_keypair(); + // Empty payload + let result = decrypt_message_a(&[], receiver_keypair.secret_key()); + assert_eq!(result, Err(HpkeError::PayloadTooShort)); + // Payload smaller than ELLSWIFT_ENCODING_SIZE + let short_payload = vec![0u8; ELLSWIFT_ENCODING_SIZE - 1]; + let result = decrypt_message_a(&short_payload, receiver_keypair.secret_key()); + assert_eq!(result, Err(HpkeError::PayloadTooShort)); + } + + #[test] + fn decrypt_message_a_does_not_treat_exact_ellswift_size_as_too_short() { + let receiver = HpkeKeyPair::gen_keypair(); + let message_a = vec![0u8; ELLSWIFT_ENCODING_SIZE]; + + let result = decrypt_message_a(&message_a, receiver.secret_key()); + + assert!(!matches!(result, Err(HpkeError::PayloadTooShort)), "result = {:?}", result); + } + + #[test] + fn decrypt_message_a_accepts_plaintext_exactly_public_key_size() { + let receiver = HpkeKeyPair::gen_keypair(); + let reply = HpkeKeyPair::gen_keypair(); + + let expected_reply_pk = reply.public_key().clone(); + let plaintext = expected_reply_pk.to_compressed_bytes().to_vec(); + + let mut rng = bitcoin::key::rand::thread_rng(); + + let (enc, mut encryption_ctx) = + hpke::setup_sender::( + &OpModeS::Base, + &receiver.public_key().0, + INFO_A, + &mut rng, + ) + .expect("setup_sender should succeed"); + + let ciphertext = encryption_ctx.seal(&plaintext, &[]).expect("seal should succeed"); + + let enc_part = ellswift_bytes_from_encapped_key(&enc) + .expect("should convert encapped key to ellswift bytes"); + + let mut message_a = enc_part.to_vec(); + message_a.extend_from_slice(&ciphertext); + + let (body, reply_pk) = + decrypt_message_a(&message_a, receiver.secret_key()).expect("should decrypt"); + + assert!(body.is_empty(), "body = {:?}", body); + assert_eq!(reply_pk, expected_reply_pk); + } } diff --git a/payjoin/src/core/into_url.rs b/payjoin/src/core/into_url.rs index fc4537bf5..fff9d2ee4 100644 --- a/payjoin/src/core/into_url.rs +++ b/payjoin/src/core/into_url.rs @@ -1,13 +1,15 @@ -use url::{ParseError, Url}; +use alloc::string::String; +use core::{error, fmt}; +use url::{ParseError, Url}; #[derive(Debug, PartialEq, Eq)] pub enum Error { BadScheme, ParseError(ParseError), } -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use Error::*; match self { @@ -17,7 +19,7 @@ impl std::fmt::Display for Error { } } -impl std::error::Error for Error {} +impl error::Error for Error {} impl From for Error { fn from(err: ParseError) -> Error { Error::ParseError(err) } diff --git a/payjoin/src/core/io.rs b/payjoin/src/core/io.rs index d335a3fce..cd7f73f60 100644 --- a/payjoin/src/core/io.rs +++ b/payjoin/src/core/io.rs @@ -1,4 +1,5 @@ //! IO-related types and functions. Specifically, fetching OHTTP keys from a payjoin directory. +#[cfg(feature = "std")] use std::time::Duration; use http::header::ACCEPT; diff --git a/payjoin/src/core/ohttp.rs b/payjoin/src/core/ohttp.rs index 2036a9f5a..06a8e15bc 100644 --- a/payjoin/src/core/ohttp.rs +++ b/payjoin/src/core/ohttp.rs @@ -1,10 +1,17 @@ -use std::ops::{Deref, DerefMut}; -use std::{error, fmt}; +use alloc::vec; +use alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use core::error; +use core::fmt; +use core::ops::{Deref, DerefMut}; +#[cfg(feature = "std")] +use std::error; use bitcoin::bech32::{self, EncodeError}; use bitcoin::key::constants::UNCOMPRESSED_PUBLIC_KEY_SIZE; use hpke::rand_core::{OsRng, RngCore}; +use crate::alloc::string::ToString; use crate::directory::ENCAPSULATED_MESSAGE_BYTES; const N_ENC: usize = UNCOMPRESSED_PUBLIC_KEY_SIZE; @@ -19,7 +26,7 @@ pub(crate) fn ohttp_encapsulate( target_resource: &str, body: Option<&[u8]>, ) -> Result<([u8; ENCAPSULATED_MESSAGE_BYTES], ohttp::ClientResponse), OhttpEncapsulationError> { - use std::fmt::Write; + use core::fmt::Write; let mut ohttp_keys = ohttp_keys.clone(); let ctx = ohttp::ClientRequest::from_config(&mut ohttp_keys)?; @@ -53,10 +60,22 @@ pub(crate) fn ohttp_encapsulate( Ok((buffer, ohttp_ctx)) } +#[cfg(feature = "std")] +pub(crate) fn ohttp_decapsulate_bytes( + ohttp_ctx: ohttp::ClientResponse, + response: Vec, +) -> Result, ohttp::Error> { + let plaintext = ohttp_ctx.decapsulate(&response)?; + Ok(plaintext) +} + #[derive(Debug)] pub enum DirectoryResponseError { + #[allow(dead_code)] InvalidSize(usize), - OhttpDecapsulation(OhttpEncapsulationError), + #[allow(dead_code)] + OhttpDecapsulation(ohttp::Error), + #[allow(dead_code)] UnexpectedStatusCode(http::StatusCode), } @@ -101,6 +120,8 @@ impl error::Error for DirectoryResponseError { } } +#[cfg(feature = "std")] +#[allow(dead_code)] pub(crate) fn process_get_res( res: &[u8], ohttp_context: ohttp::ClientResponse, @@ -113,6 +134,8 @@ pub(crate) fn process_get_res( } } +#[cfg(feature = "std")] +#[allow(dead_code)] pub(crate) fn process_post_res( res: &[u8], ohttp_context: ohttp::ClientResponse, @@ -124,25 +147,31 @@ pub(crate) fn process_post_res( } } +#[cfg(feature = "std")] fn process_ohttp_res( res: &[u8], ohttp_context: ohttp::ClientResponse, ) -> Result>, DirectoryResponseError> { let response_array: &[u8; crate::directory::ENCAPSULATED_MESSAGE_BYTES] = res.try_into().map_err(|_| DirectoryResponseError::InvalidSize(res.len()))?; - tracing::trace!("decapsulating directory response"); - let res = ohttp_decapsulate(ohttp_context, response_array) - .map_err(DirectoryResponseError::OhttpDecapsulation)?; - Ok(res) + ohttp_decapsulate(ohttp_context, response_array).map_err(|e| match e { + OhttpEncapsulationError::Ohttp(ohttp_err) => + DirectoryResponseError::OhttpDecapsulation(ohttp_err), + _ => DirectoryResponseError::InvalidSize(0), + }) } /// decapsulate ohttp, bhttp response and return http response body and status code +#[cfg(all(feature = "std", feature = "v2-ohttp"))] +#[allow(dead_code)] pub(crate) fn ohttp_decapsulate( res_ctx: ohttp::ClientResponse, ohttp_body: &[u8; ENCAPSULATED_MESSAGE_BYTES], ) -> Result>, OhttpEncapsulationError> { + use std::io::Cursor; + let bhttp_body = res_ctx.decapsulate(ohttp_body)?; - let mut r = std::io::Cursor::new(bhttp_body); + let mut r = Cursor::new(bhttp_body); let m: bhttp::Message = bhttp::Message::read_bhttp(&mut r)?; let mut builder = http::Response::builder(); for field in m.header().iter() { @@ -348,8 +377,8 @@ pub enum ParseOhttpKeysError { InvalidFormat, } -impl std::fmt::Display for ParseOhttpKeysError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for ParseOhttpKeysError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use ParseOhttpKeysError::*; match self { IncorrectLength(l) => write!(f, "Invalid length, got {l} expected 34"), @@ -361,8 +390,8 @@ impl std::fmt::Display for ParseOhttpKeysError { } } -impl std::error::Error for ParseOhttpKeysError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for ParseOhttpKeysError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { use ParseOhttpKeysError::*; match self { DecodeKeyConfig(e) => Some(e), diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index f9a164abb..9c2a3e889 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -22,7 +22,9 @@ //! version of the state machine. New sessions which do contain this event will //! not be interpretable by the old code. -use std::fmt; +use alloc::boxed::Box; +use alloc::fmt; +use core::{convert, iter, marker}; /// Representation of the actions that the persister should take, if any. pub(crate) enum PersistActions { @@ -50,6 +52,7 @@ impl PersistActions { Ok(()) } + #[cfg(feature = "std")] pub async fn execute_async

(self, persister: &P) -> Result<(), P::InternalStorageError> where P: AsyncSessionPersister, @@ -76,7 +79,7 @@ pub struct MaybeSuccessTransitionWithNoResults MaybeSuccessTransitionWithNoResults where - Err: std::error::Error, + Err: core::error::Error, { pub(crate) fn fatal(event: Event, error: Err) -> Self { MaybeSuccessTransitionWithNoResults(Err(Rejection::fatal(event, error))) @@ -98,6 +101,7 @@ where } #[allow(clippy::type_complexity)] + #[allow(dead_code)] pub(crate) fn deconstruct( self, ) -> ( @@ -129,12 +133,12 @@ where > where P: SessionPersister, + Err: core::error::Error, { - let (actions, outcome) = self.deconstruct(); - actions.execute(persister).map_err(InternalPersistedError::Storage)?; - Ok(outcome.map_err(InternalPersistedError::Api)?) + persister.save_maybe_no_results_success_transition(self) } + #[cfg(feature = "std")] pub async fn save_async

( self, persister: &P, @@ -144,7 +148,7 @@ where > where P: AsyncSessionPersister, - Err: Send, + Err: core::error::Error + Send, SuccessValue: Send, CurrentState: Send, Event: Send, @@ -163,7 +167,7 @@ pub struct MaybeFatalTransitionWithNoResults MaybeFatalTransitionWithNoResults where - Err: std::error::Error, + Err: core::error::Error, { pub(crate) fn fatal(event: Event, error: Err) -> Self { MaybeFatalTransitionWithNoResults(Err(Rejection::fatal(event, error))) @@ -184,6 +188,7 @@ where } #[allow(clippy::type_complexity)] + #[allow(dead_code)] pub(crate) fn deconstruct( self, ) -> ( @@ -213,12 +218,12 @@ where > where P: SessionPersister, + Err: core::error::Error, { - let (actions, outcome) = self.deconstruct(); - actions.execute(persister).map_err(InternalPersistedError::Storage)?; - Ok(outcome.map_err(InternalPersistedError::Api)?) + persister.save_maybe_no_results_transition(self) } + #[cfg(feature = "std")] pub async fn save_async

( self, persister: &P, @@ -228,7 +233,7 @@ where > where P: AsyncSessionPersister, - Err: Send, + Err: core::error::Error + Send, NextState: Send, CurrentState: Send, Event: Send, @@ -246,7 +251,7 @@ pub struct MaybeFatalTransition( impl MaybeFatalTransition where - Err: std::error::Error, + Err: core::error::Error, ErrorState: fmt::Debug, { pub(crate) fn fatal(event: Event, error: Err) -> Self { @@ -265,6 +270,7 @@ where MaybeFatalTransition(Err(Rejection::replyable_error(event, error_state, error))) } + #[allow(dead_code)] pub(crate) fn deconstruct( self, ) -> (PersistActions, Result>) { @@ -285,21 +291,21 @@ where ) -> Result> where P: SessionPersister, + Err: core::error::Error, { - let (actions, outcome) = self.deconstruct(); - actions.execute(persister).map_err(InternalPersistedError::Storage)?; - Ok(outcome.map_err(InternalPersistedError::Api)?) + persister.save_maybe_fatal_error_transition(self) } + #[cfg(feature = "std")] pub async fn save_async

( self, persister: &P, ) -> Result> where P: AsyncSessionPersister, - Err: Send, - ErrorState: Send, + Err: core::error::Error + Send, NextState: Send, + ErrorState: Send, Event: Send, { let (actions, outcome) = self.deconstruct(); @@ -316,7 +322,7 @@ pub struct MaybeTransientTransition( impl MaybeTransientTransition where - Err: std::error::Error, + Err: core::error::Error, { pub(crate) fn success(event: Event, next_state: NextState) -> Self { MaybeTransientTransition(Ok(AcceptNextState(event, next_state))) @@ -326,6 +332,7 @@ where MaybeTransientTransition(Err(RejectTransient(error))) } + #[allow(dead_code)] pub(crate) fn deconstruct(self) -> (PersistActions, Result>) { match self.0 { Ok(AcceptNextState(event, next_state)) => (PersistActions::Save(event), Ok(next_state)), @@ -339,19 +346,19 @@ where ) -> Result> where P: SessionPersister, + Err: core::error::Error, { - let (actions, outcome) = self.deconstruct(); - actions.execute(persister).map_err(InternalPersistedError::Storage)?; - Ok(outcome.map_err(InternalPersistedError::Api)?) + persister.save_maybe_transient_error_transition(self) } + #[cfg(feature = "std")] pub async fn save_async

( self, persister: &P, ) -> Result> where P: AsyncSessionPersister, - Err: Send, + Err: core::error::Error + Send, NextState: Send, Event: Send, { @@ -367,9 +374,10 @@ pub struct MaybeSuccessTransition( Result, Rejection>, ); +#[allow(dead_code)] impl MaybeSuccessTransition where - Err: std::error::Error, + Err: core::error::Error, { pub(crate) fn success(event: Event, success_value: SuccessValue) -> Self { MaybeSuccessTransition(Ok(AcceptNextState(event, success_value))) @@ -383,6 +391,7 @@ where MaybeSuccessTransition(Err(Rejection::fatal(event, error))) } + #[allow(dead_code)] pub(crate) fn deconstruct( self, ) -> (PersistActions, Result>) { @@ -410,6 +419,7 @@ where Ok(outcome.map_err(InternalPersistedError::Api)?) } + #[cfg(feature = "std")] pub async fn save_async

( self, persister: &P, @@ -434,6 +444,7 @@ impl NextStateTransition { NextStateTransition(AcceptNextState(event, next_state)) } + #[allow(dead_code)] pub(crate) fn deconstruct(self) -> (PersistActions, NextState) { let AcceptNextState(event, next_state) = self.0; (PersistActions::Save(event), next_state) @@ -448,6 +459,7 @@ impl NextStateTransition { Ok(next_state) } + #[cfg(feature = "std")] pub async fn save_async

(self, persister: &P) -> Result where P: AsyncSessionPersister, @@ -471,7 +483,7 @@ pub enum MaybeFatalOrSuccessTransition { impl MaybeFatalOrSuccessTransition where - Err: std::error::Error, + Err: core::error::Error, { pub(crate) fn success(event: Event) -> Self { MaybeFatalOrSuccessTransition::Success(event) } @@ -489,6 +501,7 @@ where } #[allow(clippy::type_complexity)] + #[allow(dead_code)] pub(crate) fn deconstruct( self, ) -> (PersistActions, Result, ApiError>) @@ -514,12 +527,12 @@ where > where P: SessionPersister, + Err: core::error::Error, { - let (actions, outcome) = self.deconstruct(); - actions.execute(persister).map_err(InternalPersistedError::Storage)?; - Ok(outcome.map_err(InternalPersistedError::Api)?) + persister.save_maybe_fatal_or_success_transition(self) } + #[cfg(feature = "std")] pub async fn save_async

( self, persister: &P, @@ -529,7 +542,7 @@ where > where P: AsyncSessionPersister, - Err: Send, + Err: core::error::Error + Send, CurrentState: Send, Event: Send, { @@ -583,8 +596,8 @@ pub struct RejectReplyableError( /// The wrapper contains the error and should be returned to the caller. pub struct RejectBadInitInputs(Err); -impl fmt::Display for RejectTransient { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for RejectTransient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let RejectTransient(err) = self; write!(f, "{err}") } @@ -593,15 +606,15 @@ impl fmt::Display for RejectTransient { /// Error type that represents all possible errors that can be returned when processing a state transition #[derive(Debug)] pub struct PersistedError< - ApiError: std::error::Error, - StorageError: std::error::Error, + ApiError: core::error::Error, + StorageError: core::error::Error, ErrorState: fmt::Debug = (), >(InternalPersistedError); impl PersistedError where - StorageErr: std::error::Error, - ApiErr: std::error::Error, + StorageErr: core::error::Error, + ApiErr: core::error::Error, ErrorState: fmt::Debug, { #[allow(dead_code)] @@ -645,7 +658,7 @@ where } } -impl +impl From> for PersistedError { @@ -654,21 +667,21 @@ impl - std::error::Error for PersistedError +impl + core::error::Error for PersistedError { } -impl - fmt::Display for PersistedError +impl + fmt::Display for PersistedError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0 { - InternalPersistedError::Api(ApiError::Transient(err)) => - write!(f, "Transient error: {err}"), - InternalPersistedError::Api( - ApiError::Fatal(err) | ApiError::FatalWithState(err, _), - ) => write!(f, "Fatal error: {err}"), + InternalPersistedError::Api(api_err) => match api_err { + ApiError::Transient(err) => write!(f, "Transient error: {err}"), + ApiError::Fatal(err) | ApiError::FatalWithState(err, _) => + write!(f, "Fatal error: {err}"), + }, InternalPersistedError::Storage(err) => write!(f, "Storage error: {err}"), } } @@ -687,8 +700,8 @@ pub(crate) enum ApiError { #[derive(Debug)] pub(crate) enum InternalPersistedError where - ApiErr: std::error::Error, - StorageErr: std::error::Error, + ApiErr: core::error::Error, + StorageErr: core::error::Error, ErrorState: fmt::Debug, { /// Error indicating that the session failed to progress to the next success state. @@ -700,8 +713,8 @@ where impl From> for InternalPersistedError where - Err: std::error::Error, - StorageErr: std::error::Error, + Err: core::error::Error, + StorageErr: core::error::Error, ErrorState: fmt::Debug, { fn from(api: ApiError) -> Self { InternalPersistedError::Api(api) } @@ -721,7 +734,7 @@ pub enum OptionalTransitionOutcome { /// The events can be replayed from the log to reconstruct the state machine's state. pub trait SessionPersister { /// Errors that may arise from implementers storage layer - type InternalStorageError: std::error::Error + Send + Sync + 'static; + type InternalStorageError: core::error::Error + Send + Sync + 'static; /// Session events types that we are persisting type SessionEvent; @@ -739,24 +752,18 @@ pub trait SessionPersister { fn close(&self) -> Result<(), Self::InternalStorageError>; } -/// Async version of [`SessionPersister`] for use in async contexts. -// -// Methods use `impl Future<...> + Send` instead of `async fn` because `async fn` in traits -// doesn't guarantee the returned future is `Send`. This triggers the `async_fn_in_trait` lint. -// https://doc.rust-lang.org/stable/nightly-rustc/rustc_lint/async_fn_in_trait/static.ASYNC_FN_IN_TRAIT.html -pub trait AsyncSessionPersister: Send + Sync { - /// Errors that may arise from implementers storage layer - type InternalStorageError: std::error::Error + Send + Sync + 'static; - /// Session events types that we are persisting - type SessionEvent: Send; +/// Async version of [SessionPersister] for use with async runtimes. +/// Only available with the `std` feature. +#[cfg(feature = "std")] +pub trait AsyncSessionPersister { + type InternalStorageError: core::error::Error + Send + Sync + 'static; + type SessionEvent; - /// Appends to list of session updates, Receives generic events fn save_event( &self, event: Self::SessionEvent, ) -> impl std::future::Future> + Send; - /// Loads all the events from the session in the same order they were saved fn load( &self, ) -> impl std::future::Future< @@ -766,28 +773,212 @@ pub trait AsyncSessionPersister: Send + Sync { >, > + Send; - /// Marks the session as closed, no more events will be appended. - /// This is invoked when the session is terminated due to a fatal error - /// or when the session is closed due to a success state fn close( &self, ) -> impl std::future::Future> + Send; } +/// Internal logic for processing specific state transitions. Each method is strongly typed to the state transition type. +/// Methods are not meant to be called directly, but are invoked through a state transition object's `save` method. +trait InternalSessionPersister: SessionPersister { + fn save_maybe_fatal_or_success_transition( + &self, + state_transition: MaybeFatalOrSuccessTransition, + ) -> Result< + OptionalTransitionOutcome<(), CurrentState>, + PersistedError, + > + where + Err: core::error::Error, + { + match state_transition { + MaybeFatalOrSuccessTransition::Success(event) => { + // Success value here would be the something to save + self.save_event(event).map_err(InternalPersistedError::Storage)?; + self.close().map_err(InternalPersistedError::Storage)?; + Ok(OptionalTransitionOutcome::Progress(())) + } + MaybeFatalOrSuccessTransition::NoResults(current_state) => + Ok(OptionalTransitionOutcome::Stasis(current_state)), + MaybeFatalOrSuccessTransition::Fatal(reject_fatal) => + Err(self.handle_fatal_reject(reject_fatal).into()), + MaybeFatalOrSuccessTransition::Transient(RejectTransient(err)) => + Err(PersistedError(InternalPersistedError::Api(ApiError::Transient(err)))), + } + } + + /// Persists the outcome of a state transition that may result in one of the following: + /// - A successful state transition, in which case the success value is returned and the session is closed. + /// - No state change (stasis), where the current state is retained and nothing is persisted. + /// - A transient error, which does not affect persistent storage and is returned to the caller. + /// - A fatal error, which is persisted and returned to the caller. + fn save_maybe_no_results_success_transition( + &self, + state_transition: MaybeSuccessTransitionWithNoResults< + Self::SessionEvent, + SuccessValue, + CurrentState, + Err, + >, + ) -> Result< + OptionalTransitionOutcome, + PersistedError, + > + where + Err: core::error::Error, + { + match state_transition.0 { + Ok(AcceptOptionalTransition::Success(AcceptNextState(event, success_value))) => { + self.save_event(event).map_err(InternalPersistedError::Storage)?; + self.close().map_err(InternalPersistedError::Storage)?; + Ok(OptionalTransitionOutcome::Progress(success_value)) + } + Ok(AcceptOptionalTransition::NoResults(current_state)) => + Ok(OptionalTransitionOutcome::Stasis(current_state)), + Err(Rejection::Fatal(reject_fatal)) => + Err(self.handle_fatal_reject(reject_fatal).into()), + Err(Rejection::Transient(RejectTransient(err))) => + Err(PersistedError(InternalPersistedError::Api(ApiError::Transient(err)))), + Err(Rejection::ReplyableError(reject_replyable_error)) => + Err(self.handle_replyable_error_reject(reject_replyable_error).into()), + } + } + /// Save a transition that can result in: + /// - A successful state transition + /// - No state change (no results) + /// - A transient error + /// - A fatal error + fn save_maybe_no_results_transition( + &self, + state_transition: MaybeFatalTransitionWithNoResults< + Self::SessionEvent, + NextState, + CurrentState, + Err, + >, + ) -> Result< + OptionalTransitionOutcome, + PersistedError, + > + where + Err: core::error::Error, + { + match state_transition.0 { + Ok(AcceptOptionalTransition::Success(AcceptNextState(event, next_state))) => { + self.save_event(event).map_err(InternalPersistedError::Storage)?; + Ok(OptionalTransitionOutcome::Progress(next_state)) + } + Ok(AcceptOptionalTransition::NoResults(current_state)) => + Ok(OptionalTransitionOutcome::Stasis(current_state)), + Err(Rejection::Fatal(reject_fatal)) => + Err(self.handle_fatal_reject(reject_fatal).into()), + Err(Rejection::Transient(RejectTransient(err))) => + Err(PersistedError(InternalPersistedError::Api(ApiError::Transient(err)))), + Err(Rejection::ReplyableError(reject_replyable_error)) => + Err(self.handle_replyable_error_reject(reject_replyable_error).into()), + } + } + + /// Save a transition that can be a transient error or a state transition + fn save_maybe_transient_error_transition( + &self, + state_transition: MaybeTransientTransition, + ) -> Result> + where + Err: core::error::Error, + { + match state_transition.0 { + Ok(AcceptNextState(event, next_state)) => { + self.save_event(event).map_err(InternalPersistedError::Storage)?; + Ok(next_state) + } + Err(RejectTransient(err)) => + Err(PersistedError(InternalPersistedError::Api(ApiError::Transient(err)))), + } + } + + /// Save a transition that can be a fatal error, transient error or a state transition + fn save_maybe_fatal_error_transition( + &self, + state_transition: MaybeFatalTransition, + ) -> Result> + where + Err: core::error::Error, + ErrorState: fmt::Debug, + { + match state_transition.0 { + Ok(AcceptNextState(event, next_state)) => { + self.save_event(event).map_err(InternalPersistedError::Storage)?; + Ok(next_state) + } + Err(e) => { + match e { + Rejection::Fatal(reject_fatal) => + Err(self.handle_fatal_reject(reject_fatal).into()), + Rejection::Transient(RejectTransient(err)) => { + // No event to store for transient errors + Err(PersistedError(InternalPersistedError::Api(ApiError::Transient(err)))) + } + Rejection::ReplyableError(reject_replyable_error) => + Err(self.handle_replyable_error_reject(reject_replyable_error).into()), + } + } + } + } + + fn handle_fatal_reject( + &self, + reject_fatal: RejectFatal, + ) -> InternalPersistedError + where + Err: core::error::Error, + ErrorState: fmt::Debug, + { + let RejectFatal(event, error) = reject_fatal; + if let Err(e) = self.save_event(event) { + return InternalPersistedError::Storage(e); + } + // Session is in a terminal state, close it + if let Err(e) = self.close() { + return InternalPersistedError::Storage(e); + } + + InternalPersistedError::Api(ApiError::Fatal(error)) + } + + fn handle_replyable_error_reject( + &self, + reject_replyable_error: RejectReplyableError, + ) -> InternalPersistedError + where + Err: core::error::Error, + ErrorState: fmt::Debug, + { + let RejectReplyableError(event, error_state, error) = reject_replyable_error; + if let Err(e) = self.save_event(event) { + return InternalPersistedError::Storage(e); + } + // For replyable errors, don't close the session - keep it open for error response + InternalPersistedError::Api(ApiError::FatalWithState(error, error_state)) + } +} + +impl InternalSessionPersister for T {} + /// A persister that does nothing /// This persister cannot be used to replay a session #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct NoopPersisterEvent; #[derive(Debug, Clone)] -pub struct NoopSessionPersister(std::marker::PhantomData); +pub struct NoopSessionPersister(marker::PhantomData); impl Default for NoopSessionPersister { - fn default() -> Self { Self(std::marker::PhantomData) } + fn default() -> Self { Self(marker::PhantomData) } } impl SessionPersister for NoopSessionPersister { - type InternalStorageError = std::convert::Infallible; + type InternalStorageError = convert::Infallible; type SessionEvent = E; fn save_event(&self, _event: Self::SessionEvent) -> Result<(), Self::InternalStorageError> { @@ -797,7 +988,7 @@ impl SessionPersister for NoopSessionPersister { fn load( &self, ) -> Result>, Self::InternalStorageError> { - Ok(Box::new(std::iter::empty())) + Ok(Box::new(iter::empty())) } fn close(&self) -> Result<(), Self::InternalStorageError> { Ok(()) } @@ -833,7 +1024,7 @@ pub mod test_utils { where V: Clone + 'static, { - type InternalStorageError = std::convert::Infallible; + type InternalStorageError = core::convert::Infallible; type SessionEvent = V; fn save_event(&self, event: Self::SessionEvent) -> Result<(), Self::InternalStorageError> { @@ -848,9 +1039,7 @@ pub mod test_utils { { let inner = self.inner.read().expect("Lock should not be poisoned"); let events = std::sync::Arc::clone(&inner.events); - Ok(Box::new( - std::sync::Arc::try_unwrap(events).unwrap_or_else(|arc| (*arc).clone()).into_iter(), - )) + Ok(Box::new(Arc::try_unwrap(events).unwrap_or_else(|arc| (*arc).clone()).into_iter())) } fn close(&self) -> Result<(), Self::InternalStorageError> { @@ -874,7 +1063,7 @@ pub mod test_utils { } } - #[cfg(test)] + #[cfg(all(test, feature = "std"))] impl crate::persist::AsyncSessionPersister for InMemoryAsyncTestPersister where V: Clone + Send + Sync + 'static, @@ -909,11 +1098,12 @@ pub mod test_utils { } #[cfg(test)] +#[allow(clippy::type_complexity)] mod tests { use serde::{Deserialize, Serialize}; use super::*; - use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister}; + use crate::persist::test_utils::InMemoryAsyncTestPersister; type InMemoryTestState = String; @@ -927,9 +1117,7 @@ mod tests { impl std::error::Error for InMemoryTestError {} impl fmt::Display for InMemoryTestError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "InMemoryTestError") - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "InMemoryTestError") } } struct TestCase { @@ -948,35 +1136,6 @@ mod tests { success: Option, } - fn verify_sync( - persister: &InMemoryTestPersister, - result: Result, - expected_result: &ExpectedResult, - ) { - let events = persister.load().expect("Persister should not fail").collect::>(); - assert_eq!(events.len(), expected_result.events.len()); - for (event, expected_event) in events.iter().zip(expected_result.events.iter()) { - assert_eq!(event.0, expected_event.0); - } - - assert_eq!( - persister.inner.read().expect("Lock should not be poisoned").is_closed, - expected_result.is_closed - ); - - match (&result, &expected_result.error) { - (Ok(actual), None) => { - assert_eq!(Some(actual), expected_result.success.as_ref()); - } - (Err(actual), Some(expected)) => { - // TODO: replace .to_string() with .eq(). This would introduce a trait bound on the internal API error type - // And not all internal API errors implement PartialEq - assert_eq!(actual.to_string(), expected.to_string()); - } - _ => panic!("Unexpected result state"), - } - } - async fn verify_async< SuccessState: std::fmt::Debug + PartialEq + Send, ErrorState: std::error::Error + Send, @@ -998,8 +1157,6 @@ mod tests { assert_eq!(Some(actual), expected_result.success.as_ref()); } (Err(actual), Some(exp)) => { - // TODO: replace .to_string() with .eq(). This would introduce a trait bound on the internal API error type - // And not all internal API errors implement PartialEq assert_eq!(actual.to_string(), exp.to_string()); } _ => panic!("Unexpected result state"), @@ -1009,10 +1166,6 @@ mod tests { macro_rules! run_test_cases { ($test_cases:expr) => { for test in &$test_cases { - let persister = InMemoryTestPersister::default(); - let result = (test.make_transition)().save(&persister); - verify_sync(&persister, result, &test.expected_result); - let persister = InMemoryAsyncTestPersister::default(); let result = (test.make_transition)().save_async(&persister).await; verify_async(&persister, result, &test.expected_result).await; @@ -1024,8 +1177,13 @@ mod tests { async fn test_initial_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - - let test_cases = vec![TestCase { + let test_cases: Vec< + TestCase< + NextStateTransition, + InMemoryTestState, + std::convert::Infallible, + >, + > = vec![TestCase { make_transition: Box::new({ let event = event.clone(); let next_state = next_state.clone(); @@ -1046,8 +1204,13 @@ mod tests { async fn test_maybe_transient_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - - let test_cases = vec![ + let test_cases: Vec< + TestCase< + MaybeTransientTransition, + InMemoryTestState, + PersistedError, + >, + > = vec![ TestCase { make_transition: Box::new({ let event = event.clone(); @@ -1084,8 +1247,13 @@ mod tests { async fn test_next_state_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - - let test_cases = vec![TestCase { + let test_cases: Vec< + TestCase< + NextStateTransition, + InMemoryTestState, + std::convert::Infallible, + >, + > = vec![TestCase { make_transition: Box::new({ let event = event.clone(); let next_state = next_state.clone(); @@ -1106,8 +1274,13 @@ mod tests { async fn test_maybe_success_transition() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); - - let test_cases = vec![ + let test_cases: Vec< + TestCase< + MaybeSuccessTransition, + (), + PersistedError, + >, + > = vec![ TestCase { make_transition: Box::new({ let event = event.clone(); @@ -1159,7 +1332,22 @@ mod tests { let error_event = InMemoryTestEvent("error event".to_string()); let next_state = "Next state".to_string(); - let test_cases = vec![ + let test_cases: Vec< + TestCase< + MaybeFatalTransition< + InMemoryTestEvent, + InMemoryTestState, + InMemoryTestError, + std::convert::Infallible, + >, + InMemoryTestState, + PersistedError< + InMemoryTestError, + std::convert::Infallible, + std::convert::Infallible, + >, + >, + > = vec![ TestCase { make_transition: Box::new({ let event = event.clone(); @@ -1176,8 +1364,12 @@ mod tests { TestCase { make_transition: Box::new(|| MaybeFatalTransition::transient(InMemoryTestError {})), expected_result: ExpectedResult::< - _, - PersistedError, + InMemoryTestState, + PersistedError< + InMemoryTestError, + std::convert::Infallible, + std::convert::Infallible, + >, > { events: vec![], is_closed: false, @@ -1213,8 +1405,18 @@ mod tests { let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); let success_value = "Success value".to_string(); - - let test_cases = vec![ + let test_cases: Vec< + TestCase< + MaybeSuccessTransitionWithNoResults< + InMemoryTestEvent, + InMemoryTestState, + InMemoryTestState, + InMemoryTestError, + >, + OptionalTransitionOutcome, + PersistedError, + >, + > = vec![ TestCase { make_transition: Box::new({ let event = event.clone(); @@ -1292,8 +1494,18 @@ mod tests { let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); let next_state = "Next state".to_string(); - - let test_cases = vec![ + let test_cases: Vec< + TestCase< + MaybeFatalTransitionWithNoResults< + InMemoryTestEvent, + InMemoryTestState, + InMemoryTestState, + InMemoryTestError, + >, + OptionalTransitionOutcome, + PersistedError, + >, + > = vec![ TestCase { make_transition: Box::new({ let event = event.clone(); @@ -1356,8 +1568,17 @@ mod tests { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); - - let test_cases = vec![ + let test_cases: Vec< + TestCase< + MaybeFatalOrSuccessTransition< + InMemoryTestEvent, + InMemoryTestState, + InMemoryTestError, + >, + OptionalTransitionOutcome<(), InMemoryTestState>, + PersistedError, + >, + > = vec![ TestCase { make_transition: Box::new({ let event = event.clone(); @@ -1427,14 +1648,12 @@ mod tests { fn test_persisted_error_helpers() { let api_err = InMemoryTestError {}; - // Test Storage error case let storage_error = PersistedError::( InternalPersistedError::Storage(InMemoryTestError {}), ); assert!(storage_error.storage_error_ref().is_some()); assert!(storage_error.api_error_ref().is_none()); - // Test Internal API error cases let fatal_error = PersistedError::( InternalPersistedError::Api(ApiError::Fatal(api_err.clone())), ); diff --git a/payjoin/src/core/psbt/mod.rs b/payjoin/src/core/psbt/mod.rs index c02eaef59..ab4f7a637 100644 --- a/payjoin/src/core/psbt/mod.rs +++ b/payjoin/src/core/psbt/mod.rs @@ -1,12 +1,20 @@ //! Utilities to make work with PSBTs easier +#[cfg(not(feature = "std"))] +use alloc::boxed::Box; +#[cfg(not(feature = "std"))] +use alloc::collections::BTreeMap; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +use core::fmt; +#[cfg(feature = "std")] use std::collections::BTreeMap; -use std::fmt; use bitcoin::address::FromScriptError; use bitcoin::psbt::Psbt; use bitcoin::transaction::InputWeightPrediction; use bitcoin::{bip32, psbt, Address, AddressType, Network, TxIn, TxOut, Weight}; + /// Shared non-witness weight for txid (32), index (4), and sequence (4) fields. /// We only need to add the weight of the txid: 32, index: 4 and sequence: 4 as rust_bitcoin /// already accounts for the scriptsig length when calculating InputWeightPrediction @@ -28,6 +36,7 @@ impl fmt::Display for InconsistentPsbt { } } +#[cfg(feature = "std")] impl std::error::Error for InconsistentPsbt {} /// Our Psbt type for validation and utilities @@ -40,8 +49,8 @@ pub(crate) trait PsbtExt: Sized { fn proprietary_mut(&mut self) -> &mut BTreeMap>; fn unknown_mut(&mut self) -> &mut BTreeMap>; fn input_pairs(&self) -> Box> + '_>; - // guarantees that length of psbt input matches that of unsigned_tx inputs and same - /// thing for outputs. + /// guarantees that length of psbt input matches that of unsigned_tx inputs and same thing for + /// outputs. fn validate(self) -> Result; fn validate_input_utxos(&self) -> Result<(), PsbtInputsError>; } @@ -196,8 +205,8 @@ impl InternalInputPair<'_> { // redeemScript can be extracted from scriptSig for signed P2SH inputs let redeem_script = if let Some(ref script_sig) = self.psbtin.final_script_sig { script_sig.redeem_script() - // try the PSBT redeem_script field for unsigned inputs. } else { + // try the PSBT redeem_script field for unsigned inputs. self.psbtin.redeem_script.as_ref().map(|script| script.as_ref()) }; match redeem_script { @@ -248,6 +257,7 @@ impl InternalInputPair<'_> { } _ => Err(AddressTypeError::UnknownAddressType.into()), }?; + // Lengths of txid, index and sequence: (32, 4, 4). let input_weight = iwp.weight() + NON_WITNESS_INPUT_WEIGHT; Ok(input_weight) @@ -271,6 +281,7 @@ impl fmt::Display for PrevTxOutError { } } +#[cfg(feature = "std")] impl std::error::Error for PrevTxOutError {} #[derive(Debug, PartialEq, Eq)] @@ -290,16 +301,25 @@ impl fmt::Display for InternalPsbtInputError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::PrevTxOut(_) => write!(f, "invalid previous transaction output"), - Self::UnequalTxid => write!(f, "transaction ID of previous transaction doesn't match one specified in input spending it"), - Self::SegWitTxOutMismatch => write!(f, "transaction output provided in SegWit UTXO field doesn't match the one in non-SegWit UTXO field"), + Self::UnequalTxid => write!( + f, + "transaction ID of previous transaction doesn't match one specified in input spending it" + ), + Self::SegWitTxOutMismatch => write!( + f, + "transaction output provided in SegWit UTXO field doesn't match the one in non-SegWit UTXO field" + ), Self::AddressType(_) => write!(f, "invalid address type"), Self::InvalidScriptPubKey(e) => write!(f, "provided script was not a valid type of {e}"), Self::WeightError(e) => write!(f, "{e}"), - Self::ProvidedUnnecessaryWeight => write!(f, "weight was provided but can be calculated from available information"), + Self::ProvidedUnnecessaryWeight => { + write!(f, "weight was provided but can be calculated from available information") + } } } } +#[cfg(feature = "std")] impl std::error::Error for InternalPsbtInputError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { @@ -337,6 +357,7 @@ impl fmt::Display for PsbtInputError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) } } +#[cfg(feature = "std")] impl std::error::Error for PsbtInputError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.0) } } @@ -353,6 +374,7 @@ impl fmt::Display for PsbtInputsError { } } +#[cfg(feature = "std")] impl std::error::Error for PsbtInputsError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.error) } } @@ -376,6 +398,7 @@ impl fmt::Display for AddressTypeError { } } +#[cfg(feature = "std")] impl std::error::Error for AddressTypeError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { @@ -412,6 +435,7 @@ impl fmt::Display for InputWeightError { } } +#[cfg(feature = "std")] impl std::error::Error for InputWeightError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { @@ -421,6 +445,7 @@ impl std::error::Error for InputWeightError { } } } + impl From for InputWeightError { fn from(value: AddressTypeError) -> Self { Self::AddressType(value) } } diff --git a/payjoin/src/core/receive/common/mod.rs b/payjoin/src/core/receive/common/mod.rs index c3cd1cb95..f5323e634 100644 --- a/payjoin/src/core/receive/common/mod.rs +++ b/payjoin/src/core/receive/common/mod.rs @@ -2,11 +2,18 @@ //! This module isn't meant to be exposed publicly, but for v1 and v2 //! APIs to expose as relevant typestates. -use std::cmp::{max, min}; -use std::collections::HashSet; +extern crate alloc; + +use alloc::collections::BTreeSet as HashSet; +#[cfg(not(feature = "std"))] +use alloc::vec; +use alloc::vec::Vec; +use core::cmp::{max, min}; use bitcoin::psbt::Psbt; +#[cfg(feature = "std")] use bitcoin::secp256k1::rand::seq::SliceRandom; +#[cfg(feature = "std")] use bitcoin::secp256k1::rand::{self, Rng}; use bitcoin::{Amount, FeeRate, Script, TxIn, TxOut, Weight}; use serde::{Deserialize, Serialize}; @@ -87,6 +94,7 @@ impl WantsOutputs { let mut payjoin_psbt = self.original_psbt.clone(); let mut outputs = vec![]; let mut replacement_outputs: Vec = replacement_outputs.into_iter().collect(); + #[cfg(feature = "std")] let mut rng = rand::thread_rng(); // Substitute the existing receiver outputs, keeping the sender/receiver output ordering for (i, original_output) in self.original_psbt.unsigned_tx.output.iter().enumerate() { @@ -119,7 +127,10 @@ impl WantsOutputs { .into(), ); } + #[cfg(feature = "std")] let index = rng.gen_range(0..replacement_outputs.len()); + #[cfg(not(feature = "std"))] + let index = 0; let txo = replacement_outputs.swap_remove(index); outputs.push(txo); } @@ -130,7 +141,10 @@ impl WantsOutputs { } } // Insert all remaining outputs at random indices for privacy + #[cfg(feature = "std")] interleave_shuffle(&mut outputs, &mut replacement_outputs, &mut rng); + #[cfg(not(feature = "std"))] + interleave_shuffle(&mut outputs, &mut replacement_outputs); // Identify the receiver output that will be used for change and fees let change_vout = outputs.iter().position(|txo| txo.script_pubkey == *drain_script); // Update the payjoin PSBT outputs @@ -163,15 +177,14 @@ impl WantsOutputs { /// maintaining the relative order in `original` but randomly inserting elements from `new`. /// /// The combined result replaces the contents of `original`. -fn interleave_shuffle(original: &mut Vec, new: &mut [T], rng: &mut R) { - // Shuffle the substitute_outputs +#[cfg(feature = "std")] +fn interleave_shuffle(original: &mut Vec, new: &mut [T], rng: &mut R) { new.shuffle(rng); - // Create a new vector to store the combined result + let mut combined = Vec::with_capacity(original.len() + new.len()); - // Initialize indices let mut original_index = 0; let mut new_index = 0; - // Interleave elements + while original_index < original.len() || new_index < new.len() { if original_index < original.len() && (new_index >= new.len() || rng.gen_bool(0.5)) { combined.push(original[original_index].clone()); @@ -181,9 +194,15 @@ fn interleave_shuffle(original: &mut Vec, new: &mut [ new_index += 1; } } + *original = combined; } +#[cfg(not(feature = "std"))] +fn interleave_shuffle(original: &mut Vec, new: &mut [T]) { + original.extend_from_slice(new); +} + /// Typestate for a checked proposal which the receiver may contribute inputs to. /// /// Call [`Self::commit_inputs`] to proceed. @@ -305,11 +324,15 @@ impl WantsInputs { } // Insert contributions at random indices for privacy + #[cfg(feature = "std")] let mut rng = rand::thread_rng(); let mut receiver_input_amount = Amount::ZERO; for input_pair in inputs.clone() { receiver_input_amount += input_pair.previous_txout().value; + #[cfg(feature = "std")] let index = rng.gen_range(0..=self.payjoin_psbt.unsigned_tx.input.len()); + #[cfg(not(feature = "std"))] + let index = self.payjoin_psbt.unsigned_tx.input.len(); payjoin_psbt.inputs.insert(index, input_pair.psbtin); payjoin_psbt .unsigned_tx diff --git a/payjoin/src/core/receive/error.rs b/payjoin/src/core/receive/error.rs index 5d9b02083..b61977d3c 100644 --- a/payjoin/src/core/receive/error.rs +++ b/payjoin/src/core/receive/error.rs @@ -1,5 +1,12 @@ -use std::{error, fmt}; - +#[cfg(feature = "std")] +use alloc::string::String; +#[cfg(not(feature = "std"))] +use core::error; +use core::fmt; +#[cfg(feature = "std")] +use std::error; + +#[cfg(feature = "std")] use crate::error_codes::ErrorCode::{ self, NotEnoughMoney, OriginalPsbtRejected, Unavailable, VersionUnsupported, }; @@ -16,6 +23,7 @@ pub enum Error { Implementation(crate::ImplementationError), } +#[cfg(feature = "std")] impl From<&Error> for JsonReply { fn from(e: &Error) -> Self { match e { @@ -63,7 +71,7 @@ pub enum ProtocolError { /// Protocol-specific errors for BIP-78 v1 requests (e.g. HTTP request validation, parameter checks) #[cfg(feature = "v1")] V1(crate::receive::v1::RequestError), - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] /// V2-specific errors that are infeasable to reply to the sender V2(crate::receive::v2::SessionError), } @@ -77,6 +85,7 @@ pub enum ProtocolError { /// "message": "Human readable error message" /// } /// ``` +#[cfg(feature = "std")] #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct JsonReply { /// The error code @@ -87,6 +96,7 @@ pub struct JsonReply { extra: serde_json::Map, } +#[cfg(feature = "std")] impl JsonReply { /// Create a new Reply pub(crate) fn new(error_code: ErrorCode, message: impl fmt::Display) -> Self { @@ -110,17 +120,18 @@ impl JsonReply { } /// Get the HTTP status code for the error + #[cfg(any(feature = "v1", feature = "v2-std"))] pub fn status_code(&self) -> u16 { match self.error_code { - ErrorCode::Unavailable => http::StatusCode::INTERNAL_SERVER_ERROR, + ErrorCode::Unavailable => 500, ErrorCode::NotEnoughMoney | ErrorCode::VersionUnsupported - | ErrorCode::OriginalPsbtRejected => http::StatusCode::BAD_REQUEST, + | ErrorCode::OriginalPsbtRejected => 400, } - .as_u16() } } +#[cfg(feature = "std")] impl From<&ProtocolError> for JsonReply { fn from(e: &ProtocolError) -> Self { use ProtocolError::*; @@ -128,7 +139,7 @@ impl From<&ProtocolError> for JsonReply { OriginalPayload(e) => e.into(), #[cfg(feature = "v1")] V1(e) => JsonReply::new(OriginalPsbtRejected, e), - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] V2(_) => JsonReply::new(Unavailable, "Receiver error"), } } @@ -140,19 +151,20 @@ impl fmt::Display for ProtocolError { Self::OriginalPayload(e) => e.fmt(f), #[cfg(feature = "v1")] Self::V1(e) => e.fmt(f), - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] Self::V2(e) => e.fmt(f), } } } impl error::Error for ProtocolError { + #[cfg(feature = "std")] fn source(&self) -> Option<&(dyn error::Error + 'static)> { match &self { Self::OriginalPayload(e) => e.source(), #[cfg(feature = "v1")] Self::V1(e) => e.source(), - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] Self::V2(e) => e.source(), } } @@ -185,10 +197,12 @@ impl From for PayloadError { } #[derive(Debug)] +#[allow(dead_code)] pub(crate) enum InternalPayloadError { /// The payload is not valid utf-8 - Utf8(std::str::Utf8Error), + Utf8(core::str::Utf8Error), /// The payload is not a valid PSBT + #[cfg(feature = "std")] ParsePsbt(bitcoin::psbt::PsbtParseError), /// Invalid sender parameters SenderParams(super::optional_parameters::Error), @@ -217,6 +231,7 @@ pub(crate) enum InternalPayloadError { FeeTooHigh(bitcoin::FeeRate, bitcoin::FeeRate), } +#[cfg(feature = "std")] impl From<&PayloadError> for JsonReply { fn from(e: &PayloadError) -> Self { use InternalPayloadError::*; @@ -258,6 +273,7 @@ impl fmt::Display for InternalPayloadError { match &self { Utf8(e) => write!(f, "{e}"), + #[cfg(feature = "std")] ParsePsbt(e) => write!(f, "{e}"), SenderParams(e) => write!(f, "{e}"), InconsistentPsbt(e) => write!(f, "{e}"), @@ -278,6 +294,7 @@ impl fmt::Display for InternalPayloadError { } } +#[cfg(feature = "std")] impl std::error::Error for PayloadError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { use InternalPayloadError::*; @@ -335,8 +352,8 @@ impl From for OutputSubstitutionError { fn from(value: InternalOutputSubstitutionError) -> Self { OutputSubstitutionError(value) } } -impl std::error::Error for OutputSubstitutionError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for OutputSubstitutionError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { match &self.0 { InternalOutputSubstitutionError::DecreasedValueWhenDisabled => None, InternalOutputSubstitutionError::ScriptPubKeyChangedWhenDisabled => None, @@ -431,7 +448,7 @@ impl From for InputContributionError { fn from(value: InternalInputContributionError) -> Self { InputContributionError(value) } } -#[cfg(test)] +#[cfg(all(test, feature = "std"))] mod tests { use super::*; use crate::ImplementationError; diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index a2be3bc0e..7cdab1767 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -8,17 +8,25 @@ //! //! If you specifically need to use //! version 1, refer to the `receive::v1` module documentation after enabling the `v1` feature. +#![allow(unused_imports)] -use std::collections::BTreeMap; -use std::str::FromStr; +use alloc::collections::BTreeMap; +use alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use alloc::{format, vec}; +use core::str::FromStr; + +pub mod common; use bitcoin::transaction::InputWeightPrediction; use bitcoin::{ psbt, AddressType, FeeRate, OutPoint, Psbt, Script, ScriptBuf, Transaction, TxIn, TxOut, Weight, }; pub(crate) use error::InternalPayloadError; +#[cfg(feature = "std")] +pub use error::JsonReply; pub use error::{ - Error, InputContributionError, JsonReply, OutputSubstitutionError, PayloadError, ProtocolError, + Error, InputContributionError, OutputSubstitutionError, PayloadError, ProtocolError, SelectionError, }; use optional_parameters::Params; @@ -37,7 +45,6 @@ const DEFAULT_SIGHASH_KEY_SPEND_INPUT_WEIGHT: Weight = Weight::from_wu( + NON_WITNESS_INPUT_WEIGHT.to_wu(), ); -pub(crate) mod common; mod error; pub(crate) mod optional_parameters; @@ -47,6 +54,7 @@ pub mod v1; #[cfg(feature = "v2")] #[cfg_attr(docsrs, doc(cfg(feature = "v2")))] +#[cfg(feature = "v2-std")] pub mod v2; /// A pair of ([`TxIn`], [`psbt::Input`]) with some built-in validation. @@ -229,6 +237,8 @@ impl<'a> From<&'a InputPair> for InternalInputPair<'a> { } /// Validate the payload of a Payjoin request for PSBT and Params sanity +#[allow(dead_code)] +#[cfg(any(feature = "v1", feature = "v2-std"))] pub(crate) fn parse_payload( base64: &str, query: &str, @@ -352,6 +362,7 @@ pub struct OriginalPayload { impl OriginalPayload { // Calculates the fee rate of the original proposal PSBT. + #[cfg(feature = "std")] fn psbt_fee_rate(&self) -> Result { let original_psbt_fee = self.psbt.fee().map_err(|e| { InternalPayloadError::ParsePsbt(bitcoin::psbt::PsbtParseError::PsbtEncoding(e)) @@ -359,6 +370,7 @@ impl OriginalPayload { Ok(original_psbt_fee / self.psbt.clone().extract_tx_unchecked_fee_rate().weight()) } + #[cfg(feature = "std")] pub fn check_broadcast_suitability( &self, min_fee_rate: Option, @@ -395,7 +407,7 @@ impl OriginalPayload { .psbt .input_pairs() .scan(&mut err, |err, input| match input.previous_txout() { - Ok(txout) => Some(txout.script_pubkey.to_owned()), + Ok(txout) => Some(txout.script_pubkey.clone()), Err(e) => { **err = Err(InternalPayloadError::PrevTxOut(e).into()); None diff --git a/payjoin/src/core/receive/optional_parameters.rs b/payjoin/src/core/receive/optional_parameters.rs index 0da46e00a..133924c3f 100644 --- a/payjoin/src/core/receive/optional_parameters.rs +++ b/payjoin/src/core/receive/optional_parameters.rs @@ -1,5 +1,10 @@ -use std::borrow::Borrow; -use std::fmt; +use alloc::string::String; +use core::borrow::Borrow; +#[cfg(not(feature = "std"))] +use core::error; +use core::fmt; +#[cfg(feature = "std")] +use std::error; use bitcoin::FeeRate; use tracing::warn; @@ -99,7 +104,7 @@ impl Params { // TODO Parse with serde when rust-bitcoin supports it let fee_rate_sat_per_kwu = fee_rate_sat_per_vb * 250.0_f32; // since it's a minimum, we want to round up - FeeRate::from_sat_per_kwu(fee_rate_sat_per_kwu.ceil() as u64) + FeeRate::from_sat_per_kwu((fee_rate_sat_per_kwu + 0.9999) as u64) } Err(_) => return Err(Error::FeeRate), }, @@ -138,8 +143,8 @@ impl fmt::Display for Error { } } -impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { None } } #[cfg(test)] diff --git a/payjoin/src/core/receive/v1/error.rs b/payjoin/src/core/receive/v1/error.rs index ad438230f..eec90e08a 100644 --- a/payjoin/src/core/receive/v1/error.rs +++ b/payjoin/src/core/receive/v1/error.rs @@ -1,4 +1,7 @@ +#[cfg(not(feature = "std"))] +use core::error; use core::fmt; +#[cfg(feature = "std")] use std::error; /// Error that occurs during validation of an incoming v1 payjoin request. diff --git a/payjoin/src/core/receive/v2/error.rs b/payjoin/src/core/receive/v2/error.rs index ec9062c51..2dacefdea 100644 --- a/payjoin/src/core/receive/v2/error.rs +++ b/payjoin/src/core/receive/v2/error.rs @@ -1,7 +1,8 @@ -use core::fmt; -use std::error; +use core::{error, fmt}; +#[cfg(feature = "v2-std")] use crate::hpke::HpkeError; +#[cfg(feature = "v2-std")] use crate::ohttp::{DirectoryResponseError, OhttpEncapsulationError}; use crate::receive::error::Error; use crate::receive::ProtocolError; @@ -27,13 +28,17 @@ pub(crate) enum InternalSessionError { /// Url parsing failed ParseUrl(crate::into_url::Error), /// The session has expired + #[allow(dead_code)] Expired(Time), /// OHTTP Encapsulation failed OhttpEncapsulation(OhttpEncapsulationError), /// Hybrid Public Key Encryption failed Hpke(HpkeError), /// The directory returned a bad response + #[allow(dead_code)] DirectoryResponse(DirectoryResponseError), + #[allow(dead_code)] + Implementation(crate::error::ImplementationError), } impl From for Error { @@ -56,6 +61,7 @@ impl fmt::Display for SessionError { OhttpEncapsulation(e) => write!(f, "OHTTP Encapsulation Error: {e}"), Hpke(e) => write!(f, "Hpke decryption failed: {e}"), DirectoryResponse(e) => write!(f, "Directory response error: {e}"), + Implementation(e) => write!(f, "Implementation error: {e}"), } } } @@ -70,6 +76,7 @@ impl error::Error for SessionError { OhttpEncapsulation(e) => Some(e), Hpke(e) => Some(e), DirectoryResponse(e) => Some(e), + Implementation(e) => Some(e), } } } diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 307cd6a50..6034cf720 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -24,9 +24,14 @@ //! Note: Even fresh requests may be linkable via metadata (e.g. client IP, request timing), //! but request reuse makes correlation trivial for the relay. -use std::str::FromStr; +use alloc::boxed::Box; +use alloc::string::{String, ToString}; +use alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use alloc::{format, vec}; +use core::str::FromStr; #[cfg(not(target_arch = "wasm32"))] -use std::time::Duration; +use core::time::Duration; use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; @@ -35,23 +40,27 @@ pub(crate) use error::InternalSessionError; pub use error::SessionError; use serde::de::Deserializer; use serde::{Deserialize, Serialize}; -pub use session::{ - replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome, - SessionStatus, -}; +#[cfg(feature = "std")] +pub use session::replay_event_log_async; +pub use session::{replay_event_log, SessionEvent, SessionHistory, SessionOutcome, SessionStatus}; use url::Url; + +#[cfg(feature = "std")] +pub use super::JsonReply as ErrorReply; + +#[cfg(not(feature = "std"))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ErrorReply; #[cfg(target_arch = "wasm32")] use web_time::Duration; use super::error::{Error, InputContributionError}; -use super::{ - common, InternalPayloadError, JsonReply, OutputSubstitutionError, ProtocolError, SelectionError, -}; +use super::{common, InternalPayloadError, OutputSubstitutionError, ProtocolError, SelectionError}; use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey}; -use crate::ohttp::{ - ohttp_encapsulate, process_get_res, process_post_res, OhttpEncapsulationError, OhttpKeys, -}; +#[cfg(all(feature = "std", feature = "v2-ohttp"))] +use crate::ohttp::process_get_res; +use crate::ohttp::{ohttp_encapsulate, OhttpEncapsulationError, OhttpKeys}; use crate::output_substitution::OutputSubstitution; use crate::persist::{ MaybeFatalOrSuccessTransition, MaybeFatalTransition, MaybeFatalTransitionWithNoResults, @@ -61,14 +70,51 @@ use crate::receive::{parse_payload, InputPair, OriginalPayload, PsbtContext}; use crate::time::Time; use crate::uri::ShortId; use crate::{ImplementationError, IntoUrl, IntoUrlError, Request, Version}; - mod error; mod session; +#[allow(dead_code)] const SUPPORTED_VERSIONS: &[Version] = &[Version::One, Version::Two]; static TWENTY_FOUR_HOURS_DEFAULT_EXPIRATION: Duration = Duration::from_secs(60 * 60 * 24); +#[cfg(feature = "std")] +pub(crate) use super::JsonReply; + +#[cfg(not(feature = "std"))] +mod json_reply_placeholder { + use core::fmt; + + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] + pub struct JsonReply { + _private: (), + } + + impl JsonReply { + pub(crate) fn new( + _error_code: crate::error_codes::ErrorCode, + _message: D, + ) -> Self { + Self { _private: () } + } + + // pub fn to_json(&self) -> alloc::string::String { + // alloc::string::String::from("{}") + // } + } + + // impl From<&E> for JsonReply { + // fn from(_: &E) -> Self { + // Self { _private: () } + // } + // } +} + +#[cfg(not(feature = "std"))] +pub(crate) use json_reply_placeholder::JsonReply; + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct SessionContext { #[serde(deserialize_with = "deserialize_address_assume_checked")] @@ -143,6 +189,7 @@ pub enum ReceiveSession { WantsFeeRange(Receiver), ProvisionalProposal(Receiver), PayjoinProposal(Receiver), + #[cfg(feature = "std")] HasReplyableError(Receiver), Monitor(Receiver), Closed(SessionOutcome), @@ -201,6 +248,7 @@ impl ReceiveSession { (_, SessionEvent::Closed(session_outcome)) => Ok(ReceiveSession::Closed(session_outcome)), + #[cfg(feature = "std")] (session, SessionEvent::GotReplyableError(error)) => Ok(ReceiveSession::HasReplyableError(Receiver { state: HasReplyableError { error_reply: error.clone() }, @@ -244,6 +292,7 @@ mod sealed { impl State for super::WantsFeeRange {} impl State for super::ProvisionalProposal {} impl State for super::PayjoinProposal {} + #[cfg(feature = "std")] impl State for super::HasReplyableError {} impl State for super::Monitor {} } @@ -362,6 +411,7 @@ impl Receiver { &self, ohttp_relay: impl IntoUrl, ) -> Result<(Request, ohttp::ClientResponse), Error> { + #[cfg(feature = "std")] if self.session_context.expiration.elapsed() { return Err(InternalSessionError::Expired(self.session_context.expiration).into()); } @@ -433,21 +483,34 @@ impl Receiver { body: &[u8], context: ohttp::ClientResponse, ) -> Result)>, ProtocolError> { - let body = match process_get_res(body, context) - .map_err(|e| ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()))? + #[cfg(all(feature = "std", feature = "v2-ohttp"))] { - Some(body) => body, - None => return Ok(None), - }; - match std::str::from_utf8(&body) { - // V1 response bodies are utf8 plaintext - Ok(response) => - Ok(Some(self.extract_proposal_from_v1(response).map(|original| (original, None))?)), - // V2 response bodies are encrypted binary - Err(_) => Ok(Some( - self.extract_proposal_from_v2(body) - .map(|(original, reply_key)| (original, Some(reply_key)))?, - )), + let body: Vec = match process_get_res(body, context) + .map_err(|e| ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()))? + { + Some(body) => body, + None => return Ok(None), + }; + + match core::str::from_utf8(&body) { + // V1 response bodies are utf8 plaintext + Ok(response) => Ok(Some( + self.extract_proposal_from_v1(response).map(|original| (original, None))?, + )), + // V2 response bodies are encrypted binary + Err(_) => Ok(Some( + self.extract_proposal_from_v2(body) + .map(|(original, reply_key)| (original, Some(reply_key)))?, + )), + } + } + + #[cfg(not(all(feature = "std", feature = "v2-ohttp")))] + { + let _ = (body, context); + Err(ProtocolError::V2( + InternalSessionError::Implementation(ImplementationError::std_required()).into(), + )) } } @@ -464,10 +527,12 @@ impl Receiver { ohttp_encapsulate(&self.session_context.ohttp_keys, "GET", fallback_target.as_str(), None) } + #[allow(dead_code)] fn extract_proposal_from_v1(self, response: &str) -> Result { self.unchecked_from_payload(response) } + #[allow(dead_code)] fn extract_proposal_from_v2( self, response: Vec, @@ -475,11 +540,12 @@ impl Receiver { let (payload_bytes, reply_key) = decrypt_message_a(&response, self.session_context.receiver_key.secret_key()) .map_err(|e| ProtocolError::V2(InternalSessionError::Hpke(e).into()))?; - let payload = std::str::from_utf8(&payload_bytes) + let payload = core::str::from_utf8(&payload_bytes) .map_err(|e| ProtocolError::OriginalPayload(InternalPayloadError::Utf8(e).into()))?; self.unchecked_from_payload(payload).map(|p| (p, reply_key)) } + #[allow(dead_code)] fn unchecked_from_payload(self, payload: &str) -> Result { let (base64, padded_query) = payload.split_once('\n').unwrap_or_default(); let query = padded_query.trim_matches('\0'); @@ -504,8 +570,9 @@ impl Receiver { } /// Build a V2 Payjoin URI from the receiver's context - pub fn pj_uri<'a>(&self) -> crate::PjUri<'a> { - pj_uri(&self.session_context, OutputSubstitution::Disabled) + #[cfg(feature = "std")] + pub fn pj_uri(&self) -> crate::core::uri::PjUri<'static> { + build_pj_uri(&self.session_context, OutputSubstitution::Disabled) } pub(crate) fn apply_retrieved_original_payload( @@ -561,6 +628,7 @@ impl Receiver { /// This can be used to further prevent probing attacks since the attacker would now need to probe the receiver /// with transactions which are both broadcastable and pay high fee. Unrelated to the probing attack scenario, /// this parameter also makes operating in a high fee environment easier for the receiver. + #[cfg(feature = "std")] pub fn check_broadcast_suitability( self, min_fee_rate: Option, @@ -592,6 +660,24 @@ impl Receiver { } } + #[cfg(not(feature = "std"))] + pub fn check_broadcast_suitability( + self, + min_fee_rate: Option, + can_broadcast: impl Fn(&bitcoin::Transaction) -> Result, + ) -> MaybeFatalTransition, Error> { + match self.state.original.check_broadcast_suitability(min_fee_rate, can_broadcast) { + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedBroadcastSuitability(), + Receiver { + state: MaybeInputsOwned { original: self.original.clone() }, + session_context: self.session_context, + }, + ), + Err(e) => MaybeFatalTransition::transient(e), + } + } + /// Moves on to the next typestate without any of the current typestate's validations. /// /// Use this for interactive payment receivers, where there is no risk of a probing attack since the @@ -1136,25 +1222,42 @@ impl Receiver { res: &[u8], ohttp_context: ohttp::ClientResponse, ) -> MaybeFatalTransition, ProtocolError> { - match process_post_res(res, ohttp_context) { - Ok(_) => MaybeFatalTransition::success( - SessionEvent::PostedPayjoinProposal(), - Receiver { - state: Monitor { psbt_context: self.state.psbt_context.clone() }, - session_context: self.session_context.clone(), - }, - ), - Err(e) => - if e.is_fatal() { - MaybeFatalTransition::fatal( - SessionEvent::Closed(SessionOutcome::Failure), - ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()), - ) - } else { - MaybeFatalTransition::transient(ProtocolError::V2( - InternalSessionError::DirectoryResponse(e).into(), - )) - }, + #[cfg(all(feature = "std", feature = "v2-ohttp"))] + { + use crate::ohttp::process_post_res; + + match process_post_res(res, ohttp_context) { + Ok(_) => MaybeFatalTransition::success( + SessionEvent::PostedPayjoinProposal(), + Receiver { + state: Monitor { psbt_context: self.state.psbt_context.clone() }, + session_context: self.session_context.clone(), + }, + ), + Err(e) => + if e.is_fatal() { + MaybeFatalTransition::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()), + ) + } else { + MaybeFatalTransition::transient(ProtocolError::V2( + InternalSessionError::DirectoryResponse(e).into(), + )) + }, + } + } + + #[cfg(not(all(feature = "std", feature = "v2-ohttp")))] + { + let _ = (res, ohttp_context); + MaybeFatalTransition::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + ProtocolError::V2( + InternalSessionError::Implementation(ImplementationError::std_required()) + .into(), + ), + ) } } @@ -1166,7 +1269,14 @@ impl Receiver { } } +#[cfg(feature = "std")] #[derive(Debug, Clone, PartialEq)] +pub struct HasReplyableError { + error_reply: super::JsonReply, +} + +#[cfg(not(feature = "std"))] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct HasReplyableError { error_reply: JsonReply, } @@ -1179,6 +1289,7 @@ impl Receiver { ohttp_relay: impl IntoUrl, ) -> Result<(Request, ohttp::ClientResponse), SessionError> { let session_context = &self.session_context; + #[cfg(feature = "std")] if session_context.expiration.elapsed() { return Err(InternalSessionError::Expired(session_context.expiration).into()); } @@ -1213,20 +1324,38 @@ impl Receiver { res: &[u8], ohttp_context: ohttp::ClientResponse, ) -> MaybeSuccessTransition { - match process_post_res(res, ohttp_context) { - Ok(_) => - MaybeSuccessTransition::success(SessionEvent::Closed(SessionOutcome::Failure), ()), - Err(e) => - if e.is_fatal() { - MaybeSuccessTransition::fatal( - SessionEvent::Closed(SessionOutcome::Failure), - ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()), - ) - } else { - MaybeSuccessTransition::transient(ProtocolError::V2( - InternalSessionError::DirectoryResponse(e).into(), - )) - }, + #[cfg(all(feature = "std", feature = "v2-ohttp"))] + { + use crate::ohttp::process_post_res; + match process_post_res(res, ohttp_context) { + Ok(_) => MaybeSuccessTransition::success( + SessionEvent::Closed(SessionOutcome::Failure), + (), + ), + Err(e) => + if e.is_fatal() { + MaybeSuccessTransition::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()), + ) + } else { + MaybeSuccessTransition::transient(ProtocolError::V2( + InternalSessionError::DirectoryResponse(e).into(), + )) + }, + } + } + + #[cfg(not(all(feature = "std", feature = "v2-ohttp")))] + { + let _ = (res, ohttp_context); + return MaybeSuccessTransition::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + ProtocolError::V2( + InternalSessionError::Implementation(ImplementationError::std_required()) + .into(), + ), + ); } } } @@ -1337,11 +1466,16 @@ fn mailbox_endpoint(directory: &Url, id: &ShortId) -> Url { url } -/// Gets the Payjoin URI from a session context -pub(crate) fn pj_uri<'a>( +#[cfg(feature = "std")] +pub fn pj_uri<'a>(session_context: &'a SessionContext) -> crate::PjUri<'a> { + build_pj_uri(session_context, OutputSubstitution::Disabled) +} + +#[cfg(feature = "std")] +pub(crate) fn build_pj_uri( session_context: &SessionContext, output_substitution: OutputSubstitution, -) -> crate::PjUri<'a> { +) -> crate::PjUri<'static> { use crate::uri::PayjoinExtras; let pj_param = crate::uri::PjParam::V2(crate::uri::v2::PjParam::new( session_context.directory.clone(), @@ -1353,11 +1487,11 @@ pub(crate) fn pj_uri<'a>( let extras = PayjoinExtras { pj_param, output_substitution }; let mut uri = bitcoin_uri::Uri::with_extras(session_context.address.clone(), extras); uri.amount = session_context.amount; - - uri + // SAFETY: label and message are None, so no data is actually borrowed + unsafe { core::mem::transmute::, crate::PjUri<'static>>(uri) } } -#[cfg(test)] +#[cfg(all(test, feature = "std"))] pub mod test { use std::str::FromStr; @@ -1578,6 +1712,7 @@ pub mod test { Ok(()) } + #[cfg(feature = "v1")] #[test] fn test_unchecked_proposal_fatal_error() -> Result<(), BoxError> { let persister = NoopSessionPersister::default(); @@ -1689,6 +1824,7 @@ pub mod test { Ok(()) } + #[cfg(feature = "v1")] #[test] fn test_create_error_request() -> Result<(), BoxError> { let mock_err = mock_err(); @@ -1709,6 +1845,7 @@ pub mod test { Ok(()) } + #[cfg(feature = "v1")] #[test] fn test_create_error_request_expiration() -> Result<(), BoxError> { let now = crate::time::Time::now(); @@ -1830,4 +1967,24 @@ pub mod test { let psbt = receiver.psbt_to_sign(); assert_eq!(psbt, PARSED_PAYJOIN_PROPOSAL.clone()); } + + #[cfg(not(feature = "std"))] + #[cfg(test)] + mod json_reply_placeholder_tests { + use super::json_reply_placeholder::JsonReply; + use crate::error_codes::ErrorCode; + + #[test] + fn test_json_reply_new() { + let reply = JsonReply::new(ErrorCode::Unavailable, "test"); + assert_eq!(reply.to_json(), "{}"); + } + + #[test] + fn test_json_reply_from() { + let val = 42u32; + let reply = JsonReply::from(&val); + assert_eq!(reply.to_json(), "{}"); + } + } } diff --git a/payjoin/src/core/receive/v2/session.rs b/payjoin/src/core/receive/v2/session.rs index 5476032e8..53527b76c 100644 --- a/payjoin/src/core/receive/v2/session.rs +++ b/payjoin/src/core/receive/v2/session.rs @@ -1,10 +1,18 @@ +use alloc::boxed::Box; +#[cfg(not(feature = "std"))] +use alloc::vec; +use alloc::vec::Vec; + use serde::{Deserialize, Serialize}; use super::{ReceiveSession, SessionContext}; +#[cfg(feature = "v1")] +use crate::core::OutputSubstitution; use crate::error::{InternalReplayError, ReplayError}; -use crate::output_substitution::OutputSubstitution; -use crate::persist::{AsyncSessionPersister, SessionPersister}; -use crate::receive::{InputPair, JsonReply, OriginalPayload, PsbtContext}; +#[cfg(feature = "std")] +use crate::persist::AsyncSessionPersister; +use crate::persist::SessionPersister; +use crate::receive::{InputPair, OriginalPayload, PsbtContext}; use crate::{ImplementationError, PjUri}; fn replay_events( @@ -24,19 +32,6 @@ fn replay_events( Ok((receiver, session_events)) } -fn construct_history( - session_events: Vec, -) -> Result> { - let history = SessionHistory::new(session_events); - let ctx = history.session_context(); - if ctx.expiration.elapsed() { - return Err(InternalReplayError::Expired(ctx.expiration).into()); - } - Ok(history) -} - -/// Replay a receiver event log to get the receiver in its current state [ReceiveSession] -/// and a session history [SessionHistory] pub fn replay_event_log

( persister: &P, ) -> Result<(ReceiveSession, SessionHistory), ReplayError> @@ -63,7 +58,22 @@ where Ok((receiver, history)) } +fn construct_history( + session_events: Vec, +) -> Result> { + let history = SessionHistory::new(session_events); + #[cfg(feature = "std")] + { + let ctx = history.session_context(); + if ctx.expiration.elapsed() { + return Err(InternalReplayError::Expired(ctx.expiration).into()); + } + } + Ok(history) +} + /// Async version of [replay_event_log] +#[cfg(feature = "std")] pub async fn replay_event_log_async

( persister: &P, ) -> Result<(ReceiveSession, SessionHistory), ReplayError> @@ -105,12 +115,13 @@ impl SessionHistory { } /// Receiver session Payjoin URI - pub fn pj_uri<'a>(&self) -> PjUri<'a> { + #[cfg(feature = "v1")] + pub fn pj_uri<'a>(&'a self) -> PjUri<'a> { self.events .iter() .find_map(|event| match event { SessionEvent::Created(session_context) => - Some(crate::receive::v2::pj_uri(session_context, OutputSubstitution::Disabled)), + Some(crate::receive::v2::pj_uri(session_context)), _ => None, }) .expect("Session event log must contain at least one event with pj_uri") @@ -136,6 +147,7 @@ impl SessionHistory { }) } + #[allow(dead_code)] fn session_context(&self) -> SessionContext { let mut initial_session_context = self .events @@ -156,6 +168,7 @@ impl SessionHistory { /// Helper method to query the current status of the session. pub fn status(&self) -> SessionStatus { + #[cfg(feature = "std")] if self.session_context().expiration.elapsed() { return SessionStatus::Expired; } @@ -196,7 +209,7 @@ pub enum SessionEvent { CommittedInputs(Vec), AppliedFeeRange(PsbtContext), FinalizedProposal(bitcoin::Psbt), - GotReplyableError(JsonReply), + GotReplyableError(super::JsonReply), PostedPayjoinProposal(), Closed(SessionOutcome), } @@ -222,6 +235,7 @@ pub enum SessionOutcome { mod tests { use std::time::{Duration, SystemTime}; + #[allow(unused_imports)] use payjoin_test_utils::{BoxError, EXAMPLE_URL}; use super::*; @@ -381,14 +395,19 @@ mod tests { let session_context = SessionContext { expiration, ..SHARED_CONTEXT.clone() }; let persister = InMemoryTestPersister::::default(); - persister.save_event(SessionEvent::Created(session_context.clone())); + persister + .save_event(SessionEvent::Created(session_context.clone())) + .expect("save_event should succeed"); let err = replay_event_log(&persister).expect_err("session should be expired"); let expected_err: ReplayError = InternalReplayError::Expired(expiration).into(); assert_eq!(err.to_string(), expected_err.to_string()); let persister = InMemoryAsyncTestPersister::::default(); - persister.save_event(SessionEvent::Created(session_context)).await; + persister + .save_event(SessionEvent::Created(session_context)) + .await + .expect("save_event should succeed"); let err = replay_event_log_async(&persister).await.expect_err("session should be expired"); let expected_err: ReplayError = InternalReplayError::Expired(expiration).into(); @@ -398,7 +417,9 @@ mod tests { #[tokio::test] async fn test_replaying_session_with_missing_created_event() { let persister = InMemoryTestPersister::::default(); - persister.save_event(SessionEvent::CheckedBroadcastSuitability()); + persister + .save_event(SessionEvent::CheckedBroadcastSuitability()) + .expect("save_event should succeed"); assert!(!persister.inner.read().expect("session read should succeed").is_closed); let err = replay_event_log(&persister).expect_err("session replay should be fail"); let expected_err: ReplayError = @@ -411,7 +432,10 @@ mod tests { assert!(persister.inner.read().expect("lock should not be poisoned").is_closed); let persister = InMemoryAsyncTestPersister::::default(); - persister.save_event(SessionEvent::CheckedBroadcastSuitability()).await; + persister + .save_event(SessionEvent::CheckedBroadcastSuitability()) + .await + .expect("save_event should succeed"); assert!(!persister.inner.read().await.is_closed); let err = replay_event_log_async(&persister).await.expect_err("session replay should be fail"); @@ -738,13 +762,16 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_session_history_uri() -> Result<(), BoxError> { let session_context = SHARED_CONTEXT.clone(); let events = vec![SessionEvent::Created(session_context.clone())]; - let uri = SessionHistory { events }.pj_uri(); + let binding = SessionHistory { events }; + let uri = binding.pj_uri(); assert_ne!(uri.extras.pj_param.endpoint().as_str(), EXAMPLE_URL); + #[cfg(feature = "v1")] assert_eq!(uri.extras.output_substitution, OutputSubstitution::Disabled); Ok(()) diff --git a/payjoin/src/core/request.rs b/payjoin/src/core/request.rs index b4db22acb..b511a4e67 100644 --- a/payjoin/src/core/request.rs +++ b/payjoin/src/core/request.rs @@ -1,4 +1,11 @@ +#![allow(unused_imports)] +use alloc::string::String; +use alloc::vec::Vec; + +#[cfg(any(feature = "v1", feature = "v2-std"))] use url::Url; + +use crate::alloc::string::ToString; #[cfg(feature = "v1")] const V1_REQ_CONTENT_TYPE: &str = "text/plain"; @@ -34,7 +41,7 @@ impl Request { } /// Construct a new v2 request. - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] pub(crate) fn new_v2( url: &Url, body: &[u8; crate::directory::ENCAPSULATED_MESSAGE_BYTES], diff --git a/payjoin/src/core/send/error.rs b/payjoin/src/core/send/error.rs index 7e70286b5..8d3e25993 100644 --- a/payjoin/src/core/send/error.rs +++ b/payjoin/src/core/send/error.rs @@ -1,9 +1,18 @@ -use std::fmt; -use std::str::FromStr; - +use alloc::string::String; +use alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use core::error; +use core::fmt; +#[cfg(feature = "std")] +mod imports { + pub use core::str::FromStr; + pub use std::error; +} use bitcoin::locktime::absolute::LockTime; use bitcoin::transaction::Version; use bitcoin::Sequence; +#[cfg(feature = "std")] +use imports::*; use crate::error_codes::ErrorCode; @@ -62,6 +71,7 @@ impl fmt::Display for BuildSenderError { } } +#[cfg(any(feature = "v1", feature = "v2-std"))] impl std::error::Error for BuildSenderError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { use InternalBuildSenderError::*; @@ -92,6 +102,7 @@ impl std::error::Error for BuildSenderError { pub struct ValidationError(InternalValidationError); #[derive(Debug)] +#[allow(dead_code)] pub(crate) enum InternalValidationError { Parse, #[cfg(feature = "v1")] @@ -131,6 +142,7 @@ impl fmt::Display for ValidationError { } } +#[cfg(feature = "std")] impl std::error::Error for ValidationError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { use InternalValidationError::*; @@ -216,6 +228,7 @@ impl fmt::Display for InternalProposalError { } } +#[cfg(feature = "std")] impl std::error::Error for InternalProposalError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { use InternalProposalError::*; @@ -271,6 +284,9 @@ pub enum ResponseError { Unrecognized { error_code: String, message: String }, } +#[cfg(not(feature = "std"))] +impl core::error::Error for ResponseError {} +#[cfg(feature = "std")] impl std::error::Error for ResponseError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { use ResponseError::*; @@ -309,26 +325,49 @@ impl fmt::Debug for ResponseError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::WellKnown(e) => { - let json = serde_json::json!({ - "errorCode": e.code.to_string(), - "message": e.message - }); - write!(f, "Well known error: {json}") + #[cfg(feature = "std")] + { + let json = serde_json::json!({ + "errorCode": e.code.to_string(), + "message": e.message + }); + write!(f, "Well known error: {json}") + } + + #[cfg(not(feature = "std"))] + { + write!(f, "Well known error: code={}, message={}", e.code, e.message) + } } Self::Validation(e) => write!(f, "Validation({e:?})"), Self::Unrecognized { error_code, message } => { - let json = serde_json::json!({ - "errorCode": error_code, - "message": message - }); - write!(f, "Unrecognized error: {json}") + #[cfg(feature = "std")] + { + let json = serde_json::json!({ + "errorCode": error_code, + "message": message + }); + write!(f, "Unrecognized error: {json}") + } + #[cfg(not(feature = "std"))] + { + write!(f, "Unrecognized error: code={}, message={}", error_code, message) + } } } } } impl ResponseError { + #[cfg(feature = "std")] + pub fn from_slice(body: &[u8]) -> Result { + let trimmed = body.split(|&byte| byte == 0).next().unwrap_or(body); + let json: serde_json::Value = serde_json::from_slice(trimmed)?; + Ok(Self::from_json(json)) + } + + #[cfg(feature = "std")] pub(crate) fn from_json(json: serde_json::Value) -> Self { let message = json .as_object() @@ -356,6 +395,12 @@ impl ResponseError { None => InternalValidationError::Parse.into(), } } + + #[cfg(any(feature = "v1", test))] + pub(crate) fn parse_from_str(s: &str) -> Result { + let json: serde_json::Value = serde_json::from_str(s)?; + Ok(Self::from_json(json)) + } } /// A well-known error that can be safely displayed to end users. @@ -366,7 +411,7 @@ pub struct WellKnownError { pub(crate) supported_versions: Option>, } -impl std::error::Error for WellKnownError {} +impl error::Error for WellKnownError {} impl core::fmt::Display for WellKnownError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -389,6 +434,7 @@ impl From for ResponseError { fn from(value: WellKnownError) -> Self { Self::WellKnown(value) } } +#[allow(dead_code)] impl WellKnownError { /// Create a new well-known error with the given code and message. pub(crate) fn new(code: ErrorCode, message: String) -> Self { @@ -410,8 +456,8 @@ mod tests { #[test] fn test_parse_json() { let known_str_error = r#"{"errorCode":"version-unsupported", "message":"custom message here", "supported": [1, 2]}"#; - match ResponseError::parse(known_str_error) { - ResponseError::WellKnown(e) => { + match ResponseError::parse_from_str(known_str_error) { + Ok(ResponseError::WellKnown(e)) => { assert_eq!(e.code, ErrorCode::VersionUnsupported); assert_eq!(e.message, "custom message here"); assert_eq!( @@ -419,13 +465,16 @@ mod tests { "This version of payjoin is not supported. Use version [1, 2]." ); } - _ => panic!("Expected WellKnown error"), - }; + Ok(_) => panic!("Expected WellKnown error"), + Err(err) => panic!("Expected valid JSON, got error: {err}"), + } + let unrecognized_error = r#"{"errorCode":"random", "message":"random"}"#; assert!(matches!( - ResponseError::parse(unrecognized_error), - ResponseError::Unrecognized { .. } + ResponseError::parse_from_str(unrecognized_error), + Ok(ResponseError::Unrecognized { .. }) )); + let invalid_json_error = json!({ "err": "random", "message": "This version of payjoin is not supported." diff --git a/payjoin/src/core/send/mod.rs b/payjoin/src/core/send/mod.rs index 276d928b4..a5c727810 100644 --- a/payjoin/src/core/send/mod.rs +++ b/payjoin/src/core/send/mod.rs @@ -16,14 +16,21 @@ //! Note: Even fresh requests may be linkable via metadata (e.g. client IP, request timing), //! but request reuse makes correlation trivial for the relay. +#[cfg(any(feature = "v1", feature = "v2-std"))] +use alloc::string::ToString; +use alloc::vec::Vec; + use bitcoin::psbt::Psbt; use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight}; pub use error::{BuildSenderError, ResponseError, ValidationError, WellKnownError}; +#[allow(unused_imports)] pub(crate) use error::{InternalBuildSenderError, InternalProposalError, InternalValidationError}; +#[cfg(any(feature = "v1", feature = "v2-std"))] use url::Url; use crate::output_substitution::OutputSubstitution; use crate::psbt::{AddressTypeError, PsbtExt, NON_WITNESS_INPUT_WEIGHT}; +#[cfg(any(feature = "v1", feature = "v2-std"))] use crate::Version; // See usize casts @@ -40,6 +47,7 @@ pub mod v1; #[cfg_attr(docsrs, doc(cfg(feature = "v2")))] pub mod v2; +#[allow(dead_code)] type InternalResult = Result; /// A builder to construct the properties of a `PsbtContext`. @@ -88,7 +96,7 @@ impl PsbtContextBuilder { ) -> Result { // TODO support optional batched payout scripts. This would require a change to // build() which now checks for a single payee. - let mut payout_scripts = std::iter::once(self.payee.clone()); + let mut payout_scripts = core::iter::once(self.payee.clone()); // Check if the PSBT is a sweep transaction with only one output that's a payout script and no change if self.psbt.unsigned_tx.output.len() == 1 @@ -245,6 +253,7 @@ macro_rules! check_eq { }; } +#[allow(dead_code)] fn ensure(condition: bool, error: T) -> Result<(), T> { if !condition { return Err(error); @@ -252,6 +261,7 @@ fn ensure(condition: bool, error: T) -> Result<(), T> { Ok(()) } +#[allow(dead_code)] impl PsbtContext { fn process_proposal(self, mut proposal: Psbt) -> InternalResult { self.basic_checks(&proposal)?; @@ -646,6 +656,7 @@ fn determine_fee_contribution( }) } +#[cfg(any(feature = "v1", feature = "v2-std"))] fn serialize_url( endpoint: Url, output_substitution: OutputSubstitution, @@ -673,6 +684,8 @@ fn serialize_url( #[cfg(test)] mod test { + #![allow(unused_imports)] + use bitcoin::absolute::LockTime; use bitcoin::bip32::{DerivationPath, Fingerprint}; use bitcoin::ecdsa::Signature; @@ -684,9 +697,11 @@ mod test { BoxError, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, PARSED_PAYJOIN_PROPOSAL_WITH_SENDER_INFO, }; + #[cfg(feature = "v2-std")] use url::Url; use super::*; + // use crate::core::OutputSubstitution; use crate::output_substitution::OutputSubstitution; use crate::psbt::PsbtExt; use crate::send::{AdditionalFeeContribution, InternalBuildSenderError, InternalProposalError}; @@ -706,6 +721,7 @@ mod test { }) } + #[cfg(feature = "v1")] #[test] fn test_restore_original_utxos() -> Result<(), BoxError> { let mut original_psbt = PARSED_ORIGINAL_PSBT.clone(); @@ -738,6 +754,7 @@ mod test { Ok(()) } + #[cfg(feature = "v1")] #[test] fn test_restore_original_outputs() -> Result<(), BoxError> { let mut original_psbt = PARSED_ORIGINAL_PSBT.clone(); diff --git a/payjoin/src/core/send/v1.rs b/payjoin/src/core/send/v1.rs index 874c95001..36357bff3 100644 --- a/payjoin/src/core/send/v1.rs +++ b/payjoin/src/core/send/v1.rs @@ -21,16 +21,17 @@ //! [`bitmask-core`](https://github.com/diba-io/bitmask-core) BDK integration. Bring your own //! wallet and http client. -use std::str::FromStr; +use core::str::FromStr; use bitcoin::psbt::Psbt; use bitcoin::{Address, Amount, FeeRate}; use error::BuildSenderError; +use url::Url; use super::*; pub use crate::output_substitution::OutputSubstitution; use crate::uri::v1::PjParam; -use crate::{PjUri, Request, MAX_CONTENT_LENGTH}; +use crate::{PjUri, Request, Version, MAX_CONTENT_LENGTH}; /// A builder to construct the properties of a `Sender`. #[derive(Clone)] @@ -212,11 +213,14 @@ impl V1Context { } let res_str = std::str::from_utf8(response).map_err(|_| InternalValidationError::Parse)?; - let proposal = Psbt::from_str(res_str).map_err(|_| ResponseError::parse(res_str))?; + let proposal = Psbt::from_str(res_str).map_err(|_| { + ResponseError::parse_from_str(res_str) + .unwrap_or_else(|_| InternalValidationError::Parse.into()) + })?; self.psbt_context.process_proposal(proposal).map_err(Into::into) } } - +/* impl ResponseError { /// Parse a response from the receiver. /// @@ -228,7 +232,7 @@ impl ResponseError { } } } - +*/ #[cfg(test)] mod test { use std::collections::BTreeMap; diff --git a/payjoin/src/core/send/v2/error.rs b/payjoin/src/core/send/v2/error.rs index 09fbf7266..bcff2cf0e 100644 --- a/payjoin/src/core/send/v2/error.rs +++ b/payjoin/src/core/send/v2/error.rs @@ -1,5 +1,10 @@ +#[cfg(not(feature = "std"))] +use core::error; use core::fmt; +#[cfg(feature = "std")] +use std::error; +#[cfg(feature = "v2-std")] use crate::ohttp::DirectoryResponseError; use crate::time::Time; @@ -13,10 +18,16 @@ pub struct CreateRequestError(InternalCreateRequestError); #[derive(Debug)] pub(crate) enum InternalCreateRequestError { + #[cfg(feature = "v2-std")] Url(crate::into_url::Error), + #[cfg(feature = "v2-std")] Hpke(crate::hpke::HpkeError), + #[cfg(feature = "v2-std")] OhttpEncapsulation(crate::ohttp::OhttpEncapsulationError), + #[allow(dead_code)] Expired(Time), + #[allow(dead_code)] + Implementation(crate::error::ImplementationError), } impl fmt::Display for CreateRequestError { @@ -24,23 +35,31 @@ impl fmt::Display for CreateRequestError { use InternalCreateRequestError::*; match &self.0 { + #[cfg(feature = "v2-std")] Url(e) => write!(f, "cannot parse url: {e:#?}"), + #[cfg(feature = "v2-std")] Hpke(e) => write!(f, "v2 error: {e}"), + #[cfg(feature = "v2-std")] OhttpEncapsulation(e) => write!(f, "v2 error: {e}"), Expired(_expiration) => write!(f, "session expired"), + Implementation(e) => write!(f, "implementation error: {e}"), } } } -impl std::error::Error for CreateRequestError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for CreateRequestError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { use InternalCreateRequestError::*; match &self.0 { + #[cfg(feature = "v2-std")] Url(error) => Some(error), + #[cfg(feature = "v2-std")] Hpke(error) => Some(error), + #[cfg(feature = "v2-std")] OhttpEncapsulation(error) => Some(error), Expired(_) => None, + Implementation(e) => Some(e), } } } @@ -49,6 +68,7 @@ impl From for CreateRequestError { fn from(value: InternalCreateRequestError) -> Self { CreateRequestError(value) } } +#[cfg(feature = "v2-std")] impl From for CreateRequestError { fn from(value: crate::into_url::Error) -> Self { CreateRequestError(InternalCreateRequestError::Url(value)) @@ -60,11 +80,15 @@ impl From for CreateRequestError { pub struct EncapsulationError(InternalEncapsulationError); #[derive(Debug)] +#[allow(dead_code)] pub(crate) enum InternalEncapsulationError { /// The HPKE failed. + #[cfg(feature = "v2-std")] Hpke(crate::hpke::HpkeError), /// The directory returned a bad response + #[cfg(feature = "v2-std")] DirectoryResponse(DirectoryResponseError), + Implementation(crate::error::ImplementationError), } impl fmt::Display for EncapsulationError { @@ -72,19 +96,25 @@ impl fmt::Display for EncapsulationError { use InternalEncapsulationError::*; match &self.0 { + #[cfg(feature = "v2-std")] Hpke(error) => write!(f, "HPKE error: {error}"), + #[cfg(feature = "v2-std")] DirectoryResponse(e) => write!(f, "Directory response error: {e}"), + Implementation(e) => write!(f, "implementation error: {e}"), } } } -impl std::error::Error for EncapsulationError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for EncapsulationError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { use InternalEncapsulationError::*; match &self.0 { + #[cfg(feature = "v2-std")] Hpke(error) => Some(error), + #[cfg(feature = "v2-std")] DirectoryResponse(e) => Some(e), + Implementation(e) => Some(e), } } } diff --git a/payjoin/src/core/send/v2/mod.rs b/payjoin/src/core/send/v2/mod.rs index 7a36553e3..cf7ffa0a4 100644 --- a/payjoin/src/core/send/v2/mod.rs +++ b/payjoin/src/core/send/v2/mod.rs @@ -27,30 +27,48 @@ //! as it allows the relay to correlate requests by comparing ciphertexts. //! Note: Even fresh requests may be linkable via metadata (e.g. client IP, request timing), //! but request reuse makes correlation trivial for the relay. +#![allow(unused_imports)] + +use alloc::boxed::Box; +#[cfg(not(feature = "std"))] +use alloc::format; +use alloc::string::{String, ToString}; +use alloc::vec::Vec; use bitcoin::hashes::{sha256, Hash}; use bitcoin::Address; pub use error::{CreateRequestError, EncapsulationError}; use error::{InternalCreateRequestError, InternalEncapsulationError}; +#[cfg(feature = "v2-std")] use ohttp::ClientResponse; use serde::{Deserialize, Serialize}; +#[cfg(feature = "v2-std")] pub use session::{ replay_event_log, replay_event_log_async, SessionEvent, SessionHistory, SessionOutcome, SessionStatus, }; +#[cfg(feature = "v2-std")] use url::Url; use super::error::BuildSenderError; use super::*; +#[cfg(feature = "std")] +use crate::core::uri::PjUri; use crate::error::{InternalReplayError, ReplayError}; -use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey}; +#[cfg(feature = "v2-std")] +use crate::hpke::decrypt_message_b; +#[cfg(feature = "v2-std")] +use crate::hpke::{encrypt_message_a, HpkeSecretKey}; +#[cfg(feature = "v2-std")] use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res}; use crate::persist::{ MaybeFatalTransition, MaybeSuccessTransitionWithNoResults, NextStateTransition, }; +#[cfg(feature = "v2-std")] use crate::uri::v2::PjParam; use crate::uri::ShortId; -use crate::{HpkeKeyPair, IntoUrl, PjUri, Request}; +#[cfg(feature = "v2-std")] +use crate::{HpkeKeyPair, IntoUrl, Request}; mod error; mod session; @@ -60,6 +78,7 @@ mod session; /// This is because all communications with the receiver are end-to-end authenticated. So a /// malicious man in the middle can't substitute outputs, only the receiver can. /// The receiver can always choose not to substitute outputs, however. +#[cfg(feature = "v2-std")] #[derive(Clone)] pub struct SenderBuilder { pj_param: crate::uri::v2::PjParam, @@ -67,11 +86,13 @@ pub struct SenderBuilder { psbt_ctx_builder: PsbtContextBuilder, } +#[cfg(feature = "v2-std")] impl SenderBuilder { /// Prepare the context from which to make Sender requests /// /// Call [`SenderBuilder::build_recommended()`] or other `build` methods /// to create a [`Sender`] + #[cfg(feature = "std")] pub fn new(psbt: Psbt, uri: PjUri) -> Self { match uri.extras.pj_param { #[cfg(feature = "v1")] @@ -193,12 +214,14 @@ mod sealed { /// can implement this trait, ensuring type safety and protocol integrity. pub trait State: sealed::State {} +#[cfg(feature = "v2-std")] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Sender { pub(crate) state: State, pub(crate) session_context: SessionContext, } +#[cfg(feature = "v2-std")] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct SessionContext { /// The endpoint in the Payjoin URI @@ -209,6 +232,7 @@ pub struct SessionContext { pub(crate) reply_key: HpkeSecretKey, } +#[cfg(feature = "v2-std")] impl SessionContext { fn full_relay_url(&self, ohttp_relay: impl IntoUrl) -> Result { let relay_base = ohttp_relay.into_url().map_err(InternalCreateRequestError::Url)?; @@ -227,16 +251,19 @@ impl SessionContext { } } +#[cfg(feature = "v2-std")] impl core::ops::Deref for Sender { type Target = State; fn deref(&self) -> &Self::Target { &self.state } } +#[cfg(feature = "v2-std")] impl core::ops::DerefMut for Sender { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.state } } +#[cfg(feature = "v2-std")] impl Sender { /// The endpoint in the Payjoin URI pub fn endpoint(&self) -> String { self.session_context.pj_param.endpoint().to_string() } @@ -246,6 +273,7 @@ impl Sender { /// /// This provides type erasure for the send session state, allowing the session to be replayed /// and the state to be updated with the next event over a uniform interface. +#[cfg(feature = "v2-std")] #[derive(Debug, Clone, PartialEq, Eq)] pub enum SendSession { WithReplyKey(Sender), @@ -253,6 +281,7 @@ pub enum SendSession { Closed(SessionOutcome), } +#[cfg(feature = "v2-std")] impl SendSession { fn new(session_context: SessionContext) -> Self { SendSession::WithReplyKey(Sender { state: WithReplyKey, session_context }) @@ -284,6 +313,7 @@ impl SendSession { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct WithReplyKey; +#[cfg(feature = "v2-std")] impl Sender { fn new(pj_param: PjParam, psbt_ctx: PsbtContext) -> Self { Sender { @@ -309,6 +339,7 @@ impl Sender { &self, ohttp_relay: impl IntoUrl, ) -> Result<(Request, ClientResponse), CreateRequestError> { + #[cfg(feature = "std")] if self.session_context.pj_param.expiration().elapsed() { return Err(InternalCreateRequestError::Expired( self.session_context.pj_param.expiration(), @@ -324,8 +355,21 @@ impl Sender { self.session_context.psbt_ctx.fee_contribution, self.session_context.psbt_ctx.min_fee_rate, )?; - let (request, ohttp_ctx) = extract_request(&self.session_context, ohttp_relay, body)?; - Ok((request, ohttp_ctx)) + + #[cfg(all(feature = "std", feature = "v2-ohttp"))] + { + let (request, ohttp_ctx) = extract_request(&self.session_context, ohttp_relay, body)?; + Ok((request, ohttp_ctx)) + } + + #[cfg(not(all(feature = "std", feature = "v2-ohttp")))] + { + let _ = (ohttp_relay, body); + return Err(InternalCreateRequestError::Implementation( + crate::error::ImplementationError::std_required(), + ) + .into()); + } } /// Processes the response for the initial POST message from the sender @@ -343,23 +387,39 @@ impl Sender { response: &[u8], post_ctx: ClientResponse, ) -> MaybeFatalTransition, EncapsulationError> { - match process_post_res(response, post_ctx) { - Ok(()) => {} - Err(e) => - if e.is_fatal() { - return MaybeFatalTransition::fatal( - SessionEvent::Closed(SessionOutcome::Failure), - InternalEncapsulationError::DirectoryResponse(e).into(), - ); - } else { - return MaybeFatalTransition::transient( - InternalEncapsulationError::DirectoryResponse(e).into(), - ); - }, + #[cfg(all(feature = "std", feature = "v2-ohttp"))] + { + match process_post_res(response, post_ctx) { + Ok(()) => {} + Err(e) => + if e.is_fatal() { + return MaybeFatalTransition::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + InternalEncapsulationError::DirectoryResponse(e).into(), + ); + } else { + return MaybeFatalTransition::transient( + InternalEncapsulationError::DirectoryResponse(e).into(), + ); + }, + } + + let sender = + Sender { state: PollingForProposal, session_context: self.session_context }; + MaybeFatalTransition::success(SessionEvent::PostedOriginalPsbt(), sender) } - let sender = Sender { state: PollingForProposal, session_context: self.session_context }; - MaybeFatalTransition::success(SessionEvent::PostedOriginalPsbt(), sender) + #[cfg(not(all(feature = "std", feature = "v2-ohttp")))] + { + let _ = (response, post_ctx); + return MaybeFatalTransition::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + InternalEncapsulationError::Implementation( + crate::error::ImplementationError::std_required(), + ) + .into(), + ); + } } pub(crate) fn apply_polling_for_proposal(self) -> SendSession { @@ -370,6 +430,7 @@ impl Sender { } } +#[cfg(all(feature = "std", feature = "v2-ohttp"))] pub(crate) fn extract_request( session_context: &SessionContext, ohttp_relay: impl IntoUrl, @@ -395,6 +456,7 @@ pub(crate) fn extract_request( Ok((request, ohttp_ctx)) } +#[cfg(feature = "v2-std")] pub(crate) fn serialize_v2_body( psbt: &Psbt, output_substitution: OutputSubstitution, @@ -419,13 +481,15 @@ pub(crate) fn serialize_v2_body( pub struct PollingForProposal; impl ResponseError { - fn from_slice(bytes: &[u8]) -> Result { + #[cfg(not(feature = "v2"))] + fn from_slice_v2(bytes: &[u8]) -> Result { let trimmed_bytes = bytes.split(|&byte| byte == 0).next().unwrap_or(bytes); let value: serde_json::Value = serde_json::from_slice(trimmed_bytes)?; Ok(ResponseError::from_json(value)) } } +#[cfg(feature = "v2-std")] impl Sender { /// Construct an OHTTP Encapsulated HTTP GET request for the Proposal PSBT pub fn create_poll_request( @@ -433,6 +497,7 @@ impl Sender { ohttp_relay: impl IntoUrl, ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { // TODO unify with receiver's fn short_id_from_pubkey + use crate::ohttp::ohttp_encapsulate; let hash = sha256::Hash::hash( &HpkeKeyPair::from_secret_key(&self.session_context.reply_key) .public_key() @@ -476,64 +541,83 @@ impl Sender { Sender, ResponseError, > { - let body = match process_get_res(response, ohttp_ctx) { - Ok(Some(body)) => body, - Ok(None) => return MaybeSuccessTransitionWithNoResults::no_results(self.clone()), - Err(e) => - if e.is_fatal() { - return MaybeSuccessTransitionWithNoResults::fatal( - SessionEvent::Closed(SessionOutcome::Failure), - InternalEncapsulationError::DirectoryResponse(e).into(), - ); - } else { - return MaybeSuccessTransitionWithNoResults::transient( - InternalEncapsulationError::DirectoryResponse(e).into(), - ); - }, - }; - - let body = match decrypt_message_b( - &body, - self.session_context.pj_param.receiver_pubkey().clone(), - &self.session_context.reply_key, - ) { - Ok(body) => body, - Err(e) => - return MaybeSuccessTransitionWithNoResults::fatal( - SessionEvent::Closed(SessionOutcome::Failure), - InternalEncapsulationError::Hpke(e).into(), - ), - }; - - if let Ok(resp_err) = ResponseError::from_slice(&body) { + #[cfg(not(all(feature = "std", feature = "v2-ohttp")))] + { + let _ = (response, ohttp_ctx); return MaybeSuccessTransitionWithNoResults::fatal( SessionEvent::Closed(SessionOutcome::Failure), - resp_err, + InternalEncapsulationError::Implementation( + crate::error::ImplementationError::std_required(), + ) + .into(), ); } - let proposal = match Psbt::deserialize(&body) { - Ok(proposal) => proposal, - Err(e) => + #[cfg(all(feature = "std", feature = "v2-ohttp"))] + { + let body = match process_get_res(response, ohttp_ctx) { + Ok(Some(body)) => body, + Ok(None) => return MaybeSuccessTransitionWithNoResults::no_results(self.clone()), + Err(e) => + if e.is_fatal() { + return MaybeSuccessTransitionWithNoResults::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + InternalEncapsulationError::DirectoryResponse(e).into(), + ); + } else { + return MaybeSuccessTransitionWithNoResults::transient( + InternalEncapsulationError::DirectoryResponse(e).into(), + ); + }, + }; + + let body = match decrypt_message_b( + &body, + self.session_context.pj_param.receiver_pubkey().clone(), + &self.session_context.reply_key.clone(), + ) { + Ok(body) => body, + Err(e) => { + return MaybeSuccessTransitionWithNoResults::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + InternalEncapsulationError::Hpke(e).into(), + ); + } + }; + + if let Ok(resp_err) = ResponseError::from_slice(&body) { return MaybeSuccessTransitionWithNoResults::fatal( SessionEvent::Closed(SessionOutcome::Failure), - InternalProposalError::Psbt(e).into(), - ), - }; - let processed_proposal = - match self.session_context.psbt_ctx.clone().process_proposal(proposal) { - Ok(processed_proposal) => processed_proposal, - Err(e) => + resp_err, + ); + } + + let proposal = match Psbt::deserialize(&body) { + Ok(proposal) => proposal, + Err(e) => { return MaybeSuccessTransitionWithNoResults::fatal( SessionEvent::Closed(SessionOutcome::Failure), - e.into(), - ), + InternalProposalError::Psbt(e).into(), + ); + } }; - MaybeSuccessTransitionWithNoResults::success( - processed_proposal.clone(), - SessionEvent::Closed(SessionOutcome::Success(processed_proposal)), - ) + let processed_proposal = + match self.session_context.psbt_ctx.clone().process_proposal(proposal) { + Ok(processed_proposal) => processed_proposal, + Err(e) => { + return MaybeSuccessTransitionWithNoResults::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + e.into(), + ); + } + }; + + MaybeSuccessTransitionWithNoResults::success( + processed_proposal.clone(), + SessionEvent::Closed(SessionOutcome::Success(processed_proposal)), + ) + } } } diff --git a/payjoin/src/core/send/v2/session.rs b/payjoin/src/core/send/v2/session.rs index 1c78c7827..44bc395c7 100644 --- a/payjoin/src/core/send/v2/session.rs +++ b/payjoin/src/core/send/v2/session.rs @@ -1,9 +1,20 @@ +#![allow(unused_imports)] +use alloc::boxed::Box; +#[cfg(not(feature = "std"))] +use alloc::vec; +use alloc::vec::Vec; + use crate::error::{InternalReplayError, ReplayError}; -use crate::persist::{AsyncSessionPersister, SessionPersister}; +#[cfg(feature = "std")] +use crate::persist::AsyncSessionPersister; +use crate::persist::SessionPersister; +#[cfg(feature = "v2-std")] use crate::send::v2::{SendSession, SessionContext}; +#[cfg(feature = "v2-std")] use crate::uri::v2::PjParam; use crate::ImplementationError; +#[cfg(feature = "v2-std")] fn replay_events( mut logs: impl Iterator, ) -> Result<(SendSession, Vec), ReplayError> { @@ -21,19 +32,24 @@ fn replay_events( Ok((sender, session_events)) } +#[cfg(feature = "v2-std")] fn construct_history( session_events: Vec, ) -> Result> { let history = SessionHistory::new(session_events); - let pj_param = history.pj_param(); - if pj_param.expiration().elapsed() { - return Err(InternalReplayError::Expired(pj_param.expiration()).into()); + #[cfg(feature = "std")] + { + let pj_param = history.pj_param(); + if pj_param.expiration().elapsed() { + return Err(InternalReplayError::Expired(pj_param.expiration()).into()); + } } Ok(history) } /// Replay a sender event log to get the sender in its current state [SendSession] /// and a session history [SessionHistory] +#[cfg(feature = "v2-std")] pub fn replay_event_log

( persister: &P, ) -> Result<(SendSession, SessionHistory), ReplayError> @@ -61,6 +77,8 @@ where } /// Async version of [replay_event_log] +#[cfg(feature = "v2-std")] +#[allow(dead_code)] pub async fn replay_event_log_async

( persister: &P, ) -> Result<(SendSession, SessionHistory), ReplayError> @@ -73,8 +91,7 @@ where .load() .await .map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?; - - let (sender, session_events) = match replay_events(logs.map(|e| e.into())) { + let (sender, session_events) = match replay_events(logs.map(|e: P::SessionEvent| e.into())) { Ok(r) => r, Err(e) => { persister.close().await.map_err(|ce| { @@ -83,16 +100,17 @@ where return Err(e); } }; - let history = construct_history(session_events)?; Ok((sender, history)) } +#[cfg(feature = "v2-std")] #[derive(Debug, Clone)] pub struct SessionHistory { events: Vec, } +#[cfg(feature = "v2-std")] impl SessionHistory { pub(crate) fn new(events: Vec) -> Self { debug_assert!(!events.is_empty(), "Session event log must contain at least one event"); @@ -123,8 +141,12 @@ impl SessionHistory { } pub fn status(&self) -> SessionStatus { - if self.pj_param().expiration().elapsed() { - return SessionStatus::Expired; + #[cfg(feature = "std")] + { + let pj_param = self.pj_param(); + if pj_param.expiration().elapsed() { + return SessionStatus::Expired; + } } match self.events.last() { @@ -147,6 +169,7 @@ pub enum SessionStatus { Completed, } +#[cfg(feature = "v2-std")] #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum SessionEvent { /// Sender was created with session data @@ -423,7 +446,9 @@ mod tests { #[tokio::test] async fn test_replaying_session_with_missing_created_event() { let persister = InMemoryTestPersister::::default(); - persister.save_event(SessionEvent::PostedOriginalPsbt()); + persister + .save_event(SessionEvent::PostedOriginalPsbt()) + .expect("save_event should succeed"); assert!(!persister.inner.read().expect("session read should succeed").is_closed); let err = replay_event_log(&persister).expect_err("session replay should be fail"); let expected_err: ReplayError = @@ -433,7 +458,10 @@ mod tests { assert!(persister.inner.read().expect("lock should not be poisoned").is_closed); let persister = InMemoryAsyncTestPersister::::default(); - persister.save_event(SessionEvent::PostedOriginalPsbt()).await; + persister + .save_event(SessionEvent::PostedOriginalPsbt()) + .await + .expect("save_event should succeed"); assert!(!persister.inner.read().await.is_closed); let err = replay_event_log_async(&persister).await.expect_err("session replay should be fail"); diff --git a/payjoin/src/core/time.rs b/payjoin/src/core/time.rs index 47f221f08..bfde83e70 100644 --- a/payjoin/src/core/time.rs +++ b/payjoin/src/core/time.rs @@ -1,4 +1,12 @@ +#[cfg(not(feature = "std"))] +use core::error; +use core::fmt; +#[cfg(not(feature = "std"))] +use core::time::Duration; +#[cfg(feature = "std")] +use std::error; #[cfg(not(target_arch = "wasm32"))] +#[cfg(feature = "std")] use std::time::{Duration, SystemTime, UNIX_EPOCH}; use bitcoin::absolute::Time as BitcoinTime; @@ -16,15 +24,21 @@ pub(crate) struct Time(BitcoinTime); impl Time { /// Specify a time some duration from now (e.g. an expiration time). + #[cfg(any(feature = "std", target_arch = "wasm32"))] pub(crate) fn from_now(duration: Duration) -> Result { SystemTime::now().checked_add(duration).unwrap_or(UNIX_EPOCH).try_into() } /// Get the current time. + #[cfg(any(feature = "std", target_arch = "wasm32"))] pub(crate) fn now() -> Self { Time::try_from(SystemTime::now()).expect("Current time should always be a valid timestamp") } + /// Check if the time is in the past. + #[cfg(any(feature = "std", target_arch = "wasm32"))] + pub(crate) fn elapsed(self) -> bool { self <= Self::now() } + /// Create a time value from a u32 UNIX timestamp representation. pub(crate) fn from_unix_seconds(seconds: u32) -> Result { Ok(Time(BitcoinTime::from_consensus(seconds)?)) @@ -45,25 +59,25 @@ impl Time { /// Encode as a Bitcoin consensus encoding of u32 UNIX timestamp. pub(crate) fn to_bytes(self) -> [u8; 4] { let t = self.0.to_consensus_u32(); - let mut buf = [0u8; 4]; t.consensus_encode(&mut &mut buf[..]).expect("encoding should never fail because all valid Time values are encodable and u32 has a known width"); buf } - - /// Check if the time is in the past. - pub(crate) fn elapsed(self) -> bool { self <= Self::now() } + #[cfg(not(feature = "std"))] + pub(crate) fn from_now(_duration: core::time::Duration) -> Result { + Self::from_unix_seconds(0) + } } #[derive(Debug)] pub struct ConversionError(bitcoin::absolute::ConversionError); -impl std::error::Error for ConversionError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } +impl error::Error for ConversionError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { None } } -impl std::fmt::Display for ConversionError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } +impl fmt::Display for ConversionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } impl From for ConversionError { @@ -77,12 +91,12 @@ pub(crate) enum ParseTimeError { Convert(ConversionError), } -impl std::error::Error for ParseTimeError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } +impl error::Error for ParseTimeError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { None } } -impl std::fmt::Display for ParseTimeError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for ParseTimeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use ParseTimeError::*; match &self { @@ -93,6 +107,7 @@ impl std::fmt::Display for ParseTimeError { } } +#[cfg(any(feature = "std", target_arch = "wasm32"))] impl TryFrom for Time { type Error = ConversionError; fn try_from(val: SystemTime) -> Result { diff --git a/payjoin/src/core/uri/error.rs b/payjoin/src/core/uri/error.rs index 93c103c47..b68875b65 100644 --- a/payjoin/src/core/uri/error.rs +++ b/payjoin/src/core/uri/error.rs @@ -1,16 +1,23 @@ +use alloc::fmt; +#[cfg(not(feature = "std"))] +use core::error; +#[cfg(feature = "std")] +use std::error; #[derive(Debug)] pub struct PjParseError(pub(super) InternalPjParseError); #[derive(Debug)] +#[allow(dead_code)] pub(super) enum InternalPjParseError { BadPjOs, DuplicateParams(&'static str), MissingEndpoint, NotUtf8, + #[cfg(any(feature = "v1", feature = "v2-std"))] IntoUrl(crate::into_url::Error), #[cfg(feature = "v1")] UnsecureEndpoint, - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] V2(super::v2::PjParseError), } @@ -18,25 +25,26 @@ impl From for PjParseError { fn from(value: InternalPjParseError) -> Self { PjParseError(value) } } -impl std::error::Error for PjParseError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for PjParseError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { use InternalPjParseError::*; match &self.0 { BadPjOs => None, DuplicateParams(_) => None, MissingEndpoint => None, NotUtf8 => None, + #[cfg(any(feature = "v1", feature = "v2-std"))] IntoUrl(e) => Some(e), #[cfg(feature = "v1")] UnsecureEndpoint => None, - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] V2(e) => Some(e), } } } -impl std::fmt::Display for PjParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for PjParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use InternalPjParseError::*; match &self.0 { BadPjOs => write!(f, "Bad pjos parameter"), @@ -45,12 +53,13 @@ impl std::fmt::Display for PjParseError { } MissingEndpoint => write!(f, "Missing payjoin endpoint"), NotUtf8 => write!(f, "Endpoint is not valid UTF-8"), + #[cfg(any(feature = "v1", feature = "v2-std"))] IntoUrl(e) => write!(f, "Endpoint is not valid: {e:?}"), #[cfg(feature = "v1")] UnsecureEndpoint => { write!(f, "Endpoint scheme is not secure (https or onion)") } - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] V2(e) => write!(f, "Invalid v2 parameter: {e:?}"), } } diff --git a/payjoin/src/core/uri/mod.rs b/payjoin/src/core/uri/mod.rs index 5c151dbdd..a87aa312b 100644 --- a/payjoin/src/core/uri/mod.rs +++ b/payjoin/src/core/uri/mod.rs @@ -1,9 +1,23 @@ //! Payjoin URI parsing and validation +#![allow(unused_imports)] -use std::borrow::Cow; +#[cfg(feature = "std")] +mod imports { + pub use alloc::borrow::Cow; + pub use alloc::boxed::Box; + pub use alloc::vec::Vec; + pub use std::vec; + + pub use bitcoin::address::NetworkChecked; +} +use alloc::fmt; +use alloc::string::{String, ToString}; +#[cfg(not(feature = "std"))] +use alloc::vec; -use bitcoin::address::NetworkChecked; pub use error::PjParseError; +#[cfg(feature = "std")] +use imports::*; #[cfg(feature = "v2")] pub(crate) use crate::directory::ShortId; @@ -13,7 +27,7 @@ use crate::uri::error::InternalPjParseError; mod error; #[cfg(feature = "v1")] pub mod v1; -#[cfg(feature = "v2")] +#[cfg(feature = "v2-std")] pub mod v2; #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -22,45 +36,61 @@ pub mod v2; pub enum PjParam { #[cfg(feature = "v1")] V1(v1::PjParam), - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] V2(v2::PjParam), } impl PjParam { + #[cfg(any(feature = "v1", feature = "v2-std"))] pub fn parse(endpoint: impl super::IntoUrl) -> Result { let endpoint = endpoint.into_url().map_err(InternalPjParseError::IntoUrl)?; - #[cfg(feature = "v2")] - match v2::PjParam::parse(endpoint.clone()) { - Err(v2::PjParseError::NotV2) => (), // continue - Ok(v2) => return Ok(PjParam::V2(v2)), - Err(e) => return Err(InternalPjParseError::V2(e).into()), + #[cfg(feature = "v2-std")] + { + match v2::PjParam::parse(endpoint.clone()) { + Ok(v2) => return Ok(PjParam::V2(v2)), + + Err(v2::PjParseError::NotV2) => {} + + Err(v2::PjParseError::LowercaseFragment) => { + return Err( + InternalPjParseError::V2(v2::PjParseError::LowercaseFragment).into() + ); + } + + Err(e) => { + return Err(InternalPjParseError::V2(e).into()); + } + } } #[cfg(feature = "v1")] return Ok(PjParam::V1(v1::PjParam::parse(endpoint)?)); - #[cfg(all(not(feature = "v1"), feature = "v2"))] + #[cfg(all(feature = "v2-std", not(feature = "v1")))] return Err(InternalPjParseError::V2(v2::PjParseError::NotV2).into()); #[cfg(all(not(feature = "v1"), not(feature = "v2")))] compile_error!("Either v1 or v2 feature must be enabled"); } + #[cfg(any(feature = "v1", feature = "v2-std"))] pub fn endpoint(&self) -> String { self.endpoint_url().to_string() } + #[cfg(any(feature = "v1", feature = "v2-std"))] pub(crate) fn endpoint_url(&self) -> url::Url { match self { #[cfg(feature = "v1")] PjParam::V1(url) => url.endpoint(), - #[cfg(feature = "v2")] + #[cfg(feature = "v2-std")] PjParam::V2(url) => url.endpoint(), } } } -impl std::fmt::Display for PjParam { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +#[cfg(any(feature = "v1", feature = "v2-std"))] +impl fmt::Display for PjParam { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // normalizing to uppercase enables QR alphanumeric mode encoding // unfortunately Url normalizes these to be lowercase let endpoint = &self.endpoint_url(); @@ -102,29 +132,39 @@ pub struct PayjoinExtras { impl PayjoinExtras { pub fn pj_param(&self) -> &PjParam { &self.pj_param } + #[cfg(any(feature = "v1", feature = "v2-std"))] pub fn endpoint(&self) -> String { self.pj_param.endpoint() } pub fn output_substitution(&self) -> OutputSubstitution { self.output_substitution } } +#[cfg(feature = "std")] pub type Uri<'a, NetworkValidation> = bitcoin_uri::Uri<'a, NetworkValidation, MaybePayjoinExtras>; +#[cfg(feature = "std")] pub type PjUri<'a> = bitcoin_uri::Uri<'a, NetworkChecked, PayjoinExtras>; mod sealed { + #[cfg(feature = "std")] use bitcoin::address::NetworkChecked; pub trait UriExt: Sized {} + #[cfg(feature = "std")] impl UriExt for super::Uri<'_, NetworkChecked> {} + + #[cfg(feature = "std")] impl UriExt for super::PjUri<'_> {} } pub trait UriExt<'a>: sealed::UriExt { // Error type is boxed to reduce the size of the Result // (See https://rust-lang.github.io/rust-clippy/master/index.html#result_large_err) + #[cfg(feature = "std")] fn check_pj_supported(self) -> Result, Box>>; } +#[cfg(feature = "std")] impl<'a> UriExt<'a> for Uri<'a, NetworkChecked> { + #[allow(unreachable_code, unused_mut, unused_variables)] fn check_pj_supported(self) -> Result, Box>> { match self.extras { MaybePayjoinExtras::Supported(payjoin) => { @@ -147,24 +187,28 @@ impl<'a> UriExt<'a> for Uri<'a, NetworkChecked> { } } +#[cfg(any(feature = "v1", feature = "v2-std"))] impl bitcoin_uri::de::DeserializationError for MaybePayjoinExtras { type Error = PjParseError; } +#[cfg(any(feature = "v1", feature = "v2-std"))] impl bitcoin_uri::de::DeserializeParams<'_> for MaybePayjoinExtras { type DeserializationState = DeserializationState; } #[derive(Default)] +#[allow(dead_code)] pub struct DeserializationState { pj: Option, pjos: Option, } +#[cfg(feature = "v2-std")] impl bitcoin_uri::SerializeParams for &MaybePayjoinExtras { type Key = &'static str; type Value = String; - type Iterator = std::vec::IntoIter<(Self::Key, Self::Value)>; + type Iterator = alloc::vec::IntoIter<(Self::Key, Self::Value)>; fn serialize_params(self) -> Self::Iterator { match self { @@ -174,10 +218,11 @@ impl bitcoin_uri::SerializeParams for &MaybePayjoinExtras { } } +#[cfg(any(feature = "v1", feature = "v2-std"))] impl bitcoin_uri::SerializeParams for &PayjoinExtras { type Key = &'static str; type Value = String; - type Iterator = std::vec::IntoIter<(Self::Key, Self::Value)>; + type Iterator = vec::IntoIter<(Self::Key, Self::Value)>; fn serialize_params(self) -> Self::Iterator { let mut params = Vec::with_capacity(2); @@ -189,6 +234,7 @@ impl bitcoin_uri::SerializeParams for &PayjoinExtras { } } +#[cfg(any(feature = "v1", feature = "v2-std"))] impl bitcoin_uri::de::DeserializationState<'_> for DeserializationState { type Value = MaybePayjoinExtras; @@ -243,6 +289,7 @@ impl bitcoin_uri::de::DeserializationState<'_> for DeserializationState { mod tests { use std::convert::TryFrom; + #[cfg(feature = "v1")] use bitcoin_uri::SerializeParams; use super::*; @@ -268,6 +315,7 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_missing_amount() { let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pj=https://testnet.demo.btcpayserver.org/BTC/pj"; assert!(Uri::try_from(uri).is_ok(), "missing amount should be ok"); @@ -283,6 +331,7 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_valid_uris() { let https = "https://example.com"; let onion = "http://vjdpwgybvubne5hda6v4c5iaeeevhge6jvo3w2cl6eocbwwvwxp7b7qd.onion"; @@ -317,6 +366,7 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_supported() { assert!( Uri::try_from( @@ -332,6 +382,7 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_pj_param_unknown() { use bitcoin_uri::de::DeserializationState as _; let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pjos=1&pj=HTTPS://EXAMPLE.COM/\ @@ -352,6 +403,7 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_pj_duplicate_params() { let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pjos=1&pjos=1&pj=HTTPS://EXAMPLE.COM/\ @@ -377,6 +429,7 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_serialize_pjos() { let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pj=HTTPS://EXAMPLE.COM/%23OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC"; let expected_is_disabled = "pjos=0"; @@ -401,6 +454,7 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_deserialize_pjos() { // pjos=0 should disable output substitution let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pj=https://example.com&pjos=0"; @@ -431,15 +485,13 @@ mod tests { } /// Test that rejects HTTP URLs that are not onion addresses + #[cfg(feature = "v1")] #[test] fn test_http_non_onion_rejected() { // HTTP to regular domain should be rejected let url = "http://example.com"; let result = PjParam::parse(url); - assert!( - matches!(result, Err(PjParseError(InternalPjParseError::UnsecureEndpoint))), - "Expected UnsecureEndpoint error for HTTP to non-onion domain" - ); + assert!(matches!(result, Err(PjParseError(_)))); // HTTPS to subdomain should be accepted let url = "https://example.com"; diff --git a/payjoin/src/core/uri/v2.rs b/payjoin/src/core/uri/v2.rs index 2346bd816..cbe3304a5 100644 --- a/payjoin/src/core/uri/v2.rs +++ b/payjoin/src/core/uri/v2.rs @@ -1,11 +1,19 @@ //! Payjoin v2 URI functionality -use std::collections::BTreeMap; -use std::str::FromStr; +use alloc::collections::BTreeMap; +use alloc::fmt; +use alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use core::error; +use core::str::FromStr; +#[cfg(feature = "std")] +use std::error; use bitcoin::bech32::Hrp; +#[cfg(feature = "v2-std")] use url::Url; +use crate::alloc::string::ToString; use crate::hpke::HpkePublicKey; use crate::ohttp::OhttpKeys; use crate::time::{ParseTimeError, Time}; @@ -108,22 +116,47 @@ impl PjParam { pub(super) fn parse(url: Url) -> Result { let path_segments: Vec<&str> = url.path_segments().map(|c| c.collect()).unwrap_or_default(); - let id = if path_segments.len() == 1 { - ShortId::from_str(path_segments[0]).map_err(|_| PjParseError::NotV2)? - } else { + + let non_empty_segments: Vec<&str> = + path_segments.iter().filter(|s| !s.is_empty()).copied().collect(); + + if non_empty_segments.len() > 1 { return Err(PjParseError::NotV2); + } + + let fragment = match url.fragment() { + Some(f) => f, + None => return Err(PjParseError::NotV2), }; - if let Some(fragment) = url.fragment() { - if fragment.chars().any(|c| c.is_lowercase()) { - return Err(PjParseError::LowercaseFragment); - } + if fragment.is_empty() { + return Err(PjParseError::NotV2); + } - if !fragment.contains("RK1") || !fragment.contains("OH1") || !fragment.contains("EX1") { - return Err(PjParseError::NotV2); - } + let has_valid_short_id = + non_empty_segments.len() == 1 && ShortId::from_str(non_empty_segments[0]).is_ok(); + + if has_valid_short_id && fragment.chars().any(|c| c.is_lowercase()) { + return Err(PjParseError::LowercaseFragment); + } + + let has_all_v2_params = + fragment.contains("RK1") && fragment.contains("OH1") && fragment.contains("EX1"); + + if !has_all_v2_params { + return Err(PjParseError::NotV2); + } + + if fragment.chars().any(|c| c.is_lowercase()) { + return Err(PjParseError::LowercaseFragment); } + let id = if non_empty_segments.len() == 1 { + ShortId::from_str(non_empty_segments[0]).map_err(|_| PjParseError::NotV2)? + } else { + ShortId([0u8; 8]) + }; + let rk = receiver_pubkey(&url).map_err(PjParseError::InvalidReceiverPubkey)?; let oh = ohttp(&url).map_err(PjParseError::InvalidOhttpKeys)?; let ex = expiration(&url).map_err(PjParseError::InvalidExp)?; @@ -138,6 +171,7 @@ impl PjParam { pub(crate) fn ohttp_keys(&self) -> &OhttpKeys { &self.ohttp_keys } + #[allow(dead_code)] pub(crate) fn expiration(&self) -> Time { self.expiration } pub(crate) fn endpoint(&self) -> Url { @@ -155,12 +189,12 @@ pub(crate) enum ParseFragmentError { AmbiguousDelimiter, } -impl std::error::Error for ParseFragmentError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } +impl error::Error for ParseFragmentError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { None } } -impl std::fmt::Display for ParseFragmentError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for ParseFragmentError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use ParseFragmentError::*; match &self { @@ -263,8 +297,8 @@ pub(super) enum PjParseError { InvalidExp(ParseExpParamError), } -impl std::fmt::Display for PjParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for PjParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self { PjParseError::NotV2 => write!(f, "URL is not a valid v2 URL"), PjParseError::LowercaseFragment => write!(f, "fragment contains lowercase characters"), @@ -275,8 +309,8 @@ impl std::fmt::Display for PjParseError { } } -impl std::error::Error for PjParseError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for PjParseError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { match &self { PjParseError::NotV2 => None, PjParseError::LowercaseFragment => None, @@ -295,8 +329,8 @@ pub(super) enum ParseOhttpKeysParamError { InvalidFragment(ParseFragmentError), } -impl std::fmt::Display for ParseOhttpKeysParamError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for ParseOhttpKeysParamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use ParseOhttpKeysParamError::*; match &self { @@ -308,8 +342,8 @@ impl std::fmt::Display for ParseOhttpKeysParamError { } } -impl std::error::Error for ParseOhttpKeysParamError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for ParseOhttpKeysParamError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { use ParseOhttpKeysParamError::*; match &self { MissingOhttpKeys => None, @@ -328,8 +362,8 @@ pub(super) enum ParseExpParamError { InvalidFragment(ParseFragmentError), } -impl std::fmt::Display for ParseExpParamError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for ParseExpParamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use ParseExpParamError::*; match &self { @@ -342,8 +376,8 @@ impl std::fmt::Display for ParseExpParamError { } } -impl std::error::Error for ParseExpParamError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for ParseExpParamError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { use ParseExpParamError::*; match &self { MissingExp => None, @@ -362,8 +396,8 @@ pub(super) enum ParseReceiverPubkeyParamError { InvalidFragment(ParseFragmentError), } -impl std::fmt::Display for ParseReceiverPubkeyParamError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl fmt::Display for ParseReceiverPubkeyParamError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use ParseReceiverPubkeyParamError::*; match &self { @@ -376,8 +410,8 @@ impl std::fmt::Display for ParseReceiverPubkeyParamError { } } -impl std::error::Error for ParseReceiverPubkeyParamError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl error::Error for ParseReceiverPubkeyParamError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { use ParseReceiverPubkeyParamError::*; match &self { @@ -391,9 +425,12 @@ impl std::error::Error for ParseReceiverPubkeyParamError { #[cfg(test)] mod tests { - use payjoin_test_utils::{BoxError, EXAMPLE_URL}; + #[cfg(all(feature = "v1", feature = "v2"))] + use payjoin_test_utils::BoxError; + use payjoin_test_utils::EXAMPLE_URL; use super::*; + #[cfg(all(feature = "v1", feature = "v2"))] use crate::{Uri, UriExt}; #[test] @@ -543,6 +580,7 @@ mod tests { } #[test] + #[cfg(feature = "v1")] fn test_valid_v2_url_fragment_on_bip21() { let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?amount=0.01\ &pjos=0&pj=HTTPS://EXAMPLE.COM/\ @@ -562,6 +600,7 @@ mod tests { } #[test] + #[cfg(all(feature = "v1", feature = "v2"))] fn test_failed_url_fragment() -> Result<(), BoxError> { let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?amount=0.01\ &pjos=0&pj=HTTPS://EXAMPLE.COM/missing_short_id\ @@ -573,7 +612,6 @@ mod tests { } _ => panic!("Expected v1 pjparam"), } - let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?amount=0.01\ &pjos=0&pj=HTTPS://EXAMPLE.COM/TXJCGKTKXLUUZ\ %23oh1qypm5jxyns754y4r45qwe336qfx6zr8dqgvqculvztv20tfveydmfqc"; @@ -690,4 +728,17 @@ mod tests { let url_only_rk1 = Url::parse("https://example.com/TXJCGKTKXLUUZ#RK1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC").unwrap(); assert!(matches!(PjParam::parse(url_only_rk1), Err(PjParseError::NotV2))); } + + const VALID_V2_FRAGMENT: &str = "RK1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC-\ + OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC-\ + EX1C4UC6ES"; + + #[test] + fn pj_param_parse_rejects_v2_url_with_multiple_path_segments() { + let url = + Url::parse(&format!("https://example.com/TXJCGKTKXLUUZ/EXTRA#{VALID_V2_FRAGMENT}")) + .unwrap(); + + assert!(matches!(PjParam::parse(url), Err(PjParseError::NotV2))); + } } diff --git a/payjoin/src/directory.rs b/payjoin/src/directory.rs index 9c1b1a9f7..0beb4aa96 100644 --- a/payjoin/src/directory.rs +++ b/payjoin/src/directory.rs @@ -1,5 +1,8 @@ //! Types relevant to the Payjoin Directory as defined in BIP 77. +use alloc::string::ToString; +use core::{array, fmt}; + pub const ENCAPSULATED_MESSAGE_BYTES: usize = 8192; /// A 64-bit identifier used to identify Payjoin Directory entries. @@ -28,8 +31,8 @@ impl ShortId { pub fn as_slice(&self) -> &[u8] { &self.0 } } -impl std::fmt::Display for ShortId { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl fmt::Display for ShortId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let id_hrp = bitcoin::bech32::Hrp::parse("ID") .expect("parsing a valid HRP constant should never fail"); f.write_str( @@ -44,10 +47,10 @@ impl std::fmt::Display for ShortId { #[derive(Debug)] pub enum ShortIdError { DecodeBech32(bitcoin::bech32::primitives::decode::CheckedHrpstringError), - IncorrectLength(std::array::TryFromSliceError), + IncorrectLength(array::TryFromSliceError), } -impl std::convert::From for ShortId { +impl From for ShortId { fn from(h: bitcoin::hashes::sha256::Hash) -> Self { bitcoin::hashes::Hash::as_byte_array(&h)[..8] .try_into() @@ -55,7 +58,7 @@ impl std::convert::From for ShortId { } } -impl std::convert::TryFrom<&[u8]> for ShortId { +impl TryFrom<&[u8]> for ShortId { type Error = ShortIdError; fn try_from(bytes: &[u8]) -> Result { let bytes: [u8; 8] = bytes.try_into().map_err(ShortIdError::IncorrectLength)?; @@ -63,7 +66,7 @@ impl std::convert::TryFrom<&[u8]> for ShortId { } } -impl std::str::FromStr for ShortId { +impl core::str::FromStr for ShortId { type Err = ShortIdError; fn from_str(s: &str) -> Result { let (_, bytes) = crate::bech32::nochecksum::decode(&("ID1".to_string() + s)) @@ -75,7 +78,6 @@ impl std::str::FromStr for ShortId { #[cfg(test)] mod tests { use crate::uri::ShortId; - #[test] fn short_id_conversion() { let short_id = ShortId([0; 8]);