Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions payjoin-mailroom/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub struct Config {
pub storage_dir: PathBuf,
#[serde(deserialize_with = "deserialize_duration_secs")]
pub timeout: Duration,
#[serde(deserialize_with = "deserialize_optional_duration_secs")]
pub ohttp_keys_max_age: Option<Duration>,
pub v1: Option<V1Config>,
#[cfg(feature = "telemetry")]
pub telemetry: Option<TelemetryConfig>,
Expand Down Expand Up @@ -85,6 +87,7 @@ impl Default for Config {
listener: "[::]:8080".parse().expect("valid default listener address"),
storage_dir: PathBuf::from("./data"),
timeout: Duration::from_secs(30),
ohttp_keys_max_age: None, //Some(Duration::from_secs(30)),
v1: None,
#[cfg(feature = "telemetry")]
telemetry: None,
Expand All @@ -104,17 +107,33 @@ where
Ok(Duration::from_secs(secs))
}

fn deserialize_optional_duration_secs<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: serde::Deserializer<'de>,
{
let secs: Option<u64> = Option::deserialize(deserializer)?;
match secs {
None => Ok(None),
Some(0) => Err(<D::Error as serde::de::Error>::custom(
"ohttp_keys_max_age must be greater than 0 seconds when set",
)),
Some(s) => Ok(Some(Duration::from_secs(s))),
}
}

impl Config {
pub fn new(
listener: ListenerAddress,
storage_dir: PathBuf,
timeout: Duration,
ohttp_keys_max_age: Option<Duration>,
v1: Option<V1Config>,
) -> Self {
Self {
listener,
storage_dir,
timeout,
ohttp_keys_max_age,
v1,
#[cfg(feature = "telemetry")]
telemetry: None,
Expand Down
216 changes: 200 additions & 16 deletions payjoin-mailroom/src/directory.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use std::path::PathBuf;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};

use anyhow::Result;
use axum::body::{Body, Bytes};
use axum::http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE};
use axum::http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL, CONTENT_TYPE};
use axum::http::{Method, Request, Response, StatusCode, Uri};
use http_body_util::BodyExt;
use payjoin::directory::{ShortId, ShortIdError, ENCAPSULATED_MESSAGE_BYTES};
use tokio::sync::RwLock;
use tracing::{debug, error, trace, warn};

use crate::db::{Db, Error as DbError, SendableError};
Expand All @@ -28,6 +31,83 @@ const V1_VERSION_UNSUPPORTED_RES_JSON: &str =

pub type BoxError = Box<dyn std::error::Error + Send + Sync>;

// Two-slot OHTTP key set supporting rotation overlap.
//
// Key IDs alternate between 0 and 1. Both slots are always populated.
// The current key is served to new clients; both slots are accepted
// for decapsulation so that clients with a cached previous key still
// work during the grace window after a switch.
#[derive(Debug)]
pub(crate) struct KeySlot {
pub(crate) server: ohttp::Server,
}

#[derive(Debug)]
struct ActiveKey {
key_id: u8,
valid_until: Instant,
}

#[derive(Debug)]
pub struct KeyRotatingServer {
keys: [Box<RwLock<KeySlot>>; 2],
current: RwLock<ActiveKey>,
}

impl KeyRotatingServer {
pub(crate) fn new(
slot0: KeySlot,
slot1: KeySlot,
current_key_id: u8,
valid_until: Instant,
) -> Self {
assert!(current_key_id <= 1, "key_id must be 0 or 1");
Self {
keys: [Box::new(RwLock::new(slot0)), Box::new(RwLock::new(slot1))],
current: RwLock::new(ActiveKey { key_id: current_key_id, valid_until }),
}
}

pub async fn current_key_id(&self) -> u8 { self.current.read().await.key_id }

pub async fn valid_until(&self) -> Instant { self.current.read().await.valid_until }

// Look up the server matching the key_id in an OHTTP message and
// decapsulate. The first byte of an OHTTP encapsulated request is the
// key identifier (RFC 9458 Section 4.3).
pub async fn decapsulate(
&self,
ohttp_body: &[u8],
) -> std::result::Result<(Vec<u8>, ohttp::ServerResponse), ohttp::Error> {
let key_id = ohttp_body.first().copied().ok_or(ohttp::Error::Truncated)?;
match self.keys.get(key_id as usize) {
Some(slot) => slot.read().await.server.decapsulate(ohttp_body),
None => Err(ohttp::Error::KeyId),
}
}

// Encode the current key's config for serving to clients.
pub async fn encode_current(&self) -> std::result::Result<Vec<u8>, ohttp::Error> {
let id = self.current_key_id().await;
self.keys[id as usize].read().await.server.config().encode()
}

// Flip which key is advertised to new clients and stamp the new expiry.
// Anchored to Instant::now() at the moment of the actual switch so that
// the next rotation cycle is measured from when the key became active,
pub async fn switch(&self, interval: Duration) {
let mut current = self.current.write().await;
current.key_id = 1 - current.key_id;
current.valid_until = Instant::now() + interval;
}

// Replace a slot with fresh key material.
pub async fn overwrite(&self, key_id: u8, server: ohttp::Server) {
assert!(key_id <= 1, "key_id must be 0 or 1");
*self.keys[key_id as usize].write().await = KeySlot { server };
}
}

/// Opaque blocklist of Bitcoin addresses stored as script pubkeys.
///
/// Addresses are converted to `ScriptBuf` at parse time so that
Expand Down Expand Up @@ -91,7 +171,8 @@ fn parse_address_lines(text: &str) -> std::collections::HashSet<bitcoin::ScriptB
#[derive(Clone)]
pub struct Service<D: Db> {
db: D,
ohttp: ohttp::Server,
ohttp: Arc<KeyRotatingServer>,
ohttp_keys_max_age: Option<Duration>,
sentinel_tag: SentinelTag,
v1: Option<V1>,
}
Expand All @@ -117,10 +198,18 @@ where
}

impl<D: Db> Service<D> {
pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, v1: Option<V1>) -> Self {
Self { db, ohttp, sentinel_tag, v1 }
pub fn new(
db: D,
ohttp: Arc<KeyRotatingServer>,
ohttp_keys_max_age: Option<Duration>,
sentinel_tag: SentinelTag,
v1: Option<V1>,
) -> Self {
Self { db, ohttp, ohttp_keys_max_age, sentinel_tag, v1 }
}

pub fn ohttp_key_set(&self) -> &Arc<KeyRotatingServer> { &self.ohttp }

async fn serve_request<B>(&self, req: Request<B>) -> Result<Response<Body>>
where
B: axum::body::HttpBody<Data = Bytes> + Send + 'static,
Expand Down Expand Up @@ -200,10 +289,10 @@ impl<D: Db> Service<D> {
.map_err(|e| HandlerError::BadRequest(anyhow::anyhow!(e.into())))?
.to_bytes();

// Decapsulate OHTTP request
let (bhttp_req, res_ctx) = self
.ohttp
.decapsulate(&ohttp_body)
.await
.map_err(|e| HandlerError::OhttpKeyRejection(e.into()))?;
let mut cursor = std::io::Cursor::new(bhttp_req);
let req = bhttp::Message::read_bhttp(&mut cursor)
Expand Down Expand Up @@ -380,11 +469,31 @@ impl<D: Db> Service<D> {
async fn get_ohttp_keys(&self) -> Result<Response<Body>, HandlerError> {
let ohttp_keys = self
.ohttp
.config()
.encode()
.encode_current()
.await
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
let mut res = Response::new(full(ohttp_keys));
res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/ohttp-keys"));
if let Some(max_age) = self.ohttp_keys_max_age {
// Subtract ROTATION_GRACE / 3 so clients refresh their cached key
// slightly before the rotation boundary, staying well within the
// grace window where the old key is still accepted.
let remaining = self
.ohttp
.valid_until()
.await
.saturating_duration_since(Instant::now())
.min(max_age)
.saturating_add(ROTATION_GRACE / 3);
res.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_str(&format!(
"public, s-maxage={}, immutable",
remaining.as_secs()
))
.expect("valid header value"),
);
}
Ok(res)
}

Expand Down Expand Up @@ -412,6 +521,66 @@ impl<D: Db> Service<D> {
}
}

// Grace period after a switch during which the old key is still
// accepted for decapsulation.
const ROTATION_GRACE: Duration = Duration::from_secs(30);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const ROTATION_GRACE: Duration = Duration::from_secs(30);
const ROTATION_GRACE: Duration = Duration::from_secs(15);

this delay should be reduced somewhat, it should only really account for end to end latency, and even then the expected delays not worst case so thinking about it more 30 seconds seems a bit excessive


// Background task that rotates OHTTP keys on a fixed interval.
//
// 1. Sleep until the current key is about to expire (valid_until - ROTATION_GRACE/2).
// 2. Switch to the standby slot; stamp valid_until = now + interval.
// 3. Sleep until the old key's grace window has elapsed.
// 4. Overwrite the old slot with fresh key material for the next cycle.
pub fn spawn_key_rotation(keyset: Arc<KeyRotatingServer>, keys_dir: PathBuf, interval: Duration) {
tokio::spawn(async move {
loop {
// Sleep until just before the current key expires.
let valid_until = keyset.valid_until().await;
tracing::info!("Sleeping until {:?}", valid_until);
//let switch_at = valid_until.checked_sub(ROTATION_GRACE / 2).unwrap_or(valid_until);
tokio::time::sleep_until(valid_until.into()).await;

// Capture old key id before switching, then switch.
let old_key_id = keyset.current_key_id().await;
let new_key_id = 1 - old_key_id;

tracing::info!(
"---------------------------------------------------------------------------"
);

// Touch the new active key file *after* overwriting the old slot so
// its mtime is newest on disk. On restart,
// and derives valid_until from its age.
let active_path = keys_dir.join(format!("{new_key_id}.ikm"));
let times = std::fs::FileTimes::new().set_modified(std::time::SystemTime::now());
match std::fs::File::open(&active_path).and_then(|f| f.set_times(times)) {
Ok(()) => {}
Err(e) => tracing::warn!("Failed to change mtime {}: {e}", active_path.display()),
}

// `switch` stamps valid_until = Instant::now() + interval, anchored
// to the actual moment the new key goes live.
keyset.switch(interval).await;

tracing::info!("Switched OHTTP serving: From key_id {old_key_id} -> TO {new_key_id}");

// Wait until the old key's grace window has fully elapsed before
// overwriting it, so in-flight clients using the old key still succeed.
tokio::time::sleep(ROTATION_GRACE).await;

let config = crate::key_config::gen_ohttp_server_config_with_id(old_key_id)
.expect("OHTTP key generation must not fail");
let _ = tokio::fs::remove_file(keys_dir.join(format!("{old_key_id}.ikm"))).await;
crate::key_config::persist_key_config(&config, &keys_dir)
.await
.expect("OHTTP key persistence must not fail");

keyset.overwrite(old_key_id, config.into_server()).await;
tracing::info!("Overwrote OHTTP key_id {old_key_id} with fresh material");
}
});
}

fn handle_peek<E: SendableError>(
result: Result<Arc<Vec<u8>>, DbError<E>>,
timeout_response: Response<Body>,
Expand Down Expand Up @@ -485,8 +654,8 @@ impl HandlerError {
}
HandlerError::OhttpKeyRejection(e) => {
const OHTTP_KEY_REJECTION_RES_JSON: &str = r#"{"type":"https://iana.org/assignments/http-problem-types#ohttp-key", "title": "key identifier unknown"}"#;
warn!("Bad request: Key configuration rejected: {}", e);
*res.status_mut() = StatusCode::BAD_REQUEST;
warn!("Key configuration rejected: {}", e);
*res.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
res.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("application/problem+json"));
*res.body_mut() = full(OHTTP_KEY_REJECTION_RES_JSON);
Expand Down Expand Up @@ -592,9 +761,17 @@ mod tests {
async fn test_service(v1: Option<V1>) -> Service<FilesDb> {
let dir = tempfile::tempdir().expect("tempdir");
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
let ohttp: ohttp::Server =
crate::key_config::gen_ohttp_server_config().expect("ohttp config").into();
Service::new(db, ohttp, SentinelTag::new([0u8; 32]), v1)
let c0 = crate::key_config::gen_ohttp_server_config_with_id(0).expect("ohttp config");
let c1 = crate::key_config::gen_ohttp_server_config_with_id(1).expect("ohttp config");
// valid_until = now + a generous test interval so nothing rotates during tests
let valid_until = Instant::now() + Duration::from_secs(3600);
let keyset = Arc::new(KeyRotatingServer::new(
KeySlot { server: c0.into_server() },
KeySlot { server: c1.into_server() },
0,
valid_until,
));
Service::new(db, keyset, None, SentinelTag::new([0u8; 32]), v1)
}

/// A valid ShortId encoded as bech32 for use in URL paths.
Expand Down Expand Up @@ -826,9 +1003,16 @@ mod tests {
let dir = tempfile::tempdir().expect("tempdir");
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
let db = MetricsDb::new(db, metrics);
let ohttp: ohttp::Server =
crate::key_config::gen_ohttp_server_config().expect("ohttp config").into();
let svc = Service::new(db, ohttp, SentinelTag::new([0u8; 32]), None);
let c0 = crate::key_config::gen_ohttp_server_config_with_id(0).expect("ohttp config");
let c1 = crate::key_config::gen_ohttp_server_config_with_id(1).expect("ohttp config");
let valid_until = Instant::now() + Duration::from_secs(3600);
let keyset = Arc::new(KeyRotatingServer::new(
KeySlot { server: c0.into_server() },
KeySlot { server: c1.into_server() },
0,
valid_until,
));
let svc = Service::new(db, keyset, None, SentinelTag::new([0u8; 32]), None);

let id = valid_short_id_path();
let res = svc
Expand All @@ -849,7 +1033,7 @@ mod tests {
use opentelemetry::KeyValue;
use opentelemetry_sdk::metrics::data::{AggregatedMetrics, MetricData};

// This checks that counter value is 1 as post_mailbox was called once
// This checks that counter value is 1 as post_mailbox was called once
// Also confirms the v2 label is recorded
match db_metric.data() {
AggregatedMetrics::U64(MetricData::Sum(sum)) => {
Expand Down
Loading
Loading