diff --git a/Cargo.lock b/Cargo.lock index 348a9493..6346e70e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1731,13 +1731,14 @@ version = "0.5.5" dependencies = [ "anyhow", "clap", - "dstack-gateway-rpc", + "hex", "hex_fmt", - "ra-rpc", "regex", "reqwest", "serde", + "serde-human-bytes", "serde_json", + "sha2 0.10.9", "tokio", "tracing", "tracing-subscriber", diff --git a/ct_monitor/Cargo.toml b/ct_monitor/Cargo.toml index 3828526e..82c1dccd 100644 --- a/ct_monitor/Cargo.toml +++ b/ct_monitor/Cargo.toml @@ -11,16 +11,16 @@ license.workspace = true [dependencies] anyhow.workspace = true -clap = { workspace = true, features = ["derive"] } +clap = { workspace = true, features = ["derive", "env"] } +hex = { workspace = true, features = ["alloc", "std"] } hex_fmt.workspace = true regex.workspace = true reqwest = { workspace = true, default-features = false, features = ["json", "rustls-tls", "charset", "hickory-dns"] } serde = { workspace = true, features = ["derive"] } +serde-human-bytes.workspace = true serde_json.workspace = true +sha2.workspace = true tokio = { workspace = true, features = ["full"] } tracing.workspace = true tracing-subscriber.workspace = true x509-parser.workspace = true - -dstack-gateway-rpc.workspace = true -ra-rpc = { workspace = true, default-features = false, features = ["client"] } diff --git a/ct_monitor/src/main.rs b/ct_monitor/src/main.rs index 6a168cce..bfa0565c 100644 --- a/ct_monitor/src/main.rs +++ b/ct_monitor/src/main.rs @@ -4,22 +4,95 @@ use anyhow::{bail, Context, Result}; use clap::Parser; -use dstack_gateway_rpc::gateway_client::GatewayClient; -use ra_rpc::client::RaClient; use regex::Regex; use serde::{Deserialize, Serialize}; +use serde_human_bytes as hex_bytes; +use sha2::{Digest, Sha512}; use std::collections::BTreeSet; use std::time::Duration; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use x509_parser::prelude::*; const BASE_URL: &str = "https://crt.sh"; +/// Quoted public key with TDX quote +#[derive(Debug, Deserialize)] +struct QuotedPublicKey { + /// Hex-encoded public key + public_key: String, + /// JSON-encoded GetQuoteResponse + quote: String, +} + +/// GetQuoteResponse from guest-agent +#[derive(Debug, Deserialize)] +struct GetQuoteResponse { + /// TDX quote (hex-encoded in JSON) + #[serde(with = "hex_bytes")] + quote: Vec, + /// JSON-encoded event log + event_log: String, + /// VM configuration + vm_config: String, +} + +/// Request for dstack-verifier +#[derive(Debug, Serialize)] +struct VerificationRequest { + quote: String, + event_log: String, + vm_config: String, + pccs_url: Option, +} + +/// Response from dstack-verifier +#[derive(Debug, Deserialize)] +struct VerificationResponse { + is_valid: bool, + details: VerificationDetails, + reason: Option, +} + +#[derive(Debug, Deserialize)] +struct VerificationDetails { + #[allow(dead_code)] + quote_verified: bool, + #[allow(dead_code)] + event_log_verified: bool, + #[allow(dead_code)] + os_image_hash_verified: bool, + report_data: Option, + app_info: Option, +} + +/// App info from verification response +#[derive(Debug, Deserialize)] +struct AppInfo { + #[serde(with = "hex_bytes")] + app_id: Vec, + #[serde(with = "hex_bytes")] + compose_hash: Vec, + #[serde(with = "hex_bytes")] + os_image_hash: Vec, +} + +#[derive(Debug, Deserialize)] +struct AcmeInfoResponse { + #[allow(dead_code)] + account_uri: String, + #[allow(dead_code)] + hist_keys: Vec, + quoted_hist_keys: Vec, +} + struct Monitor { gateway_uri: String, - domain: String, + verifier_url: String, + pccs_url: Option, + base_domain: String, known_keys: BTreeSet>, last_checked: Option, + client: reqwest::Client, } #[derive(Debug, Serialize, Deserialize)] @@ -37,24 +110,194 @@ struct CTLog { } impl Monitor { - fn new(gateway_uri: String, domain: String) -> Result { - validate_domain(&domain)?; + /// Create a new monitor + /// `gateway` format: `base_domain[:port]`, e.g., `example.com` or `example.com:8443` + fn new(gateway: String, verifier_url: String, pccs_url: Option) -> Result { + let (base_domain, gateway_uri) = Self::parse_gateway(&gateway)?; + validate_domain(&base_domain)?; Ok(Self { gateway_uri, - domain, + verifier_url, + pccs_url, + base_domain, known_keys: BTreeSet::new(), last_checked: None, + client: reqwest::Client::new(), }) } + /// Parse gateway input into base_domain and gateway URI + /// Input: `base_domain[:port]`, e.g., `example.com` or `example.com:8443` + /// Output: (base_domain, gateway_uri) + fn parse_gateway(gateway: &str) -> Result<(String, String)> { + let (base_domain, port) = match gateway.rsplit_once(':') { + Some((domain, port_str)) => { + // Validate port is a number + let _: u16 = port_str.parse().context("invalid port number")?; + (domain.to_string(), Some(port_str.to_string())) + } + None => (gateway.to_string(), None), + }; + + let gateway_uri = match port { + Some(p) => format!("https://gateway.{}:{}", base_domain, p), + None => format!("https://gateway.{}", base_domain), + }; + + Ok((base_domain, gateway_uri)) + } + + /// Compute expected report_data for a public key using zt-cert content type + fn compute_expected_report_data(public_key: &[u8]) -> [u8; 64] { + // Format: sha512("zt-cert:" + public_key) + let mut hasher = Sha512::new(); + hasher.update(b"zt-cert:"); + hasher.update(public_key); + hasher.finalize().into() + } + + /// Verify a quoted public key using the verifier service + /// Returns (public_key, app_info) + async fn verify_quoted_key(&self, quoted_key: &QuotedPublicKey) -> Result<(Vec, AppInfo)> { + let public_key = + hex::decode("ed_key.public_key).context("invalid hex in public_key")?; + + if quoted_key.quote.is_empty() { + bail!("empty quote for public key"); + } + + // Parse the GetQuoteResponse from the quote field + let quote_response: GetQuoteResponse = + serde_json::from_str("ed_key.quote).context("failed to parse quote response")?; + + // Build verification request + let verify_request = VerificationRequest { + quote: hex::encode("e_response.quote), + event_log: quote_response.event_log, + vm_config: quote_response.vm_config, + pccs_url: self.pccs_url.clone(), + }; + + // Call verifier + let verify_url = format!("{}/verify", self.verifier_url.trim_end_matches('/')); + let response = self + .client + .post(&verify_url) + .json(&verify_request) + .send() + .await + .context("failed to call verifier")?; + + if !response.status().is_success() { + bail!("verifier returned HTTP {}", response.status().as_u16()); + } + + let verify_response: VerificationResponse = response + .json() + .await + .context("failed to parse verifier response")?; + + if !verify_response.is_valid { + bail!( + "quote verification failed: {}", + verify_response.reason.unwrap_or_default() + ); + } + + // Verify report_data matches expected value + let expected_report_data = Self::compute_expected_report_data(&public_key); + let expected_hex = hex::encode(expected_report_data); + + let actual_report_data = verify_response + .details + .report_data + .context("verifier did not return report_data")?; + + if actual_report_data != expected_hex { + bail!( + "report_data mismatch: expected {}, got {}", + expected_hex, + actual_report_data + ); + } + + let app_info = verify_response + .details + .app_info + .context("verifier did not return app_info")?; + + Ok((public_key, app_info)) + } + async fn refresh_known_keys(&mut self) -> Result<()> { - info!("fetching known public keys from {}", self.gateway_uri); - // TODO: Use RA-TLS - let tls_no_check = true; - let rpc = GatewayClient::new(RaClient::new(self.gateway_uri.clone(), tls_no_check)?); - let info = rpc.acme_info().await?; - self.known_keys = info.hist_keys.into_iter().collect(); - info!("got {} known public keys", self.known_keys.len()); + let acme_info_url = format!( + "{}/.dstack/acme-info", + self.gateway_uri.trim_end_matches('/') + ); + info!("fetching known public keys from {}", acme_info_url); + + let response = self + .client + .get(&acme_info_url) + .send() + .await + .context("failed to fetch acme-info")?; + + if !response.status().is_success() { + bail!( + "failed to fetch acme-info: HTTP {}", + response.status().as_u16() + ); + } + + let info: AcmeInfoResponse = response + .json() + .await + .context("failed to parse acme-info response")?; + + info!( + "got {} quoted public keys, verifying...", + info.quoted_hist_keys.len() + ); + + let mut verified_keys = BTreeSet::new(); + for (i, quoted_key) in info.quoted_hist_keys.iter().enumerate() { + match self.verify_quoted_key(quoted_key).await { + Ok((public_key, app_info)) => { + info!( + "✅ verified public key {}: {}", + i, + hex_fmt::HexFmt(&public_key) + ); + info!(" app_id: {}", hex_fmt::HexFmt(&app_info.app_id)); + info!( + " compose_hash: {}", + hex_fmt::HexFmt(&app_info.compose_hash) + ); + info!( + " os_image_hash: {}", + hex_fmt::HexFmt(&app_info.os_image_hash) + ); + verified_keys.insert(public_key); + } + Err(e) => { + warn!( + "⚠️ failed to verify public key {}: {}", + i, + hex_fmt::HexFmt("ed_key.public_key) + ); + warn!(" error: {:#}", e); + // Continue with other keys, but don't add this one + } + } + } + + if verified_keys.is_empty() && !info.quoted_hist_keys.is_empty() { + bail!("no public keys could be verified"); + } + + self.known_keys = verified_keys; + info!("verified {} public keys", self.known_keys.len()); for key in self.known_keys.iter() { debug!(" {}", hex_fmt::HexFmt(key)); } @@ -64,7 +307,7 @@ impl Monitor { async fn get_logs(&self, count: u32) -> Result> { let url = format!( "{}/?q={}&output=json&limit={}", - BASE_URL, self.domain, count + BASE_URL, self.base_domain, count ); let response = reqwest::get(&url).await?; Ok(response.json().await?) @@ -125,7 +368,7 @@ impl Monitor { } async fn run(&mut self) { - info!("monitoring {}...", self.domain); + info!("monitoring {}...", self.base_domain); loop { if let Err(err) = self.refresh_known_keys().await { error!("error refreshing known keys: {}", err); @@ -151,12 +394,18 @@ fn validate_domain(domain: &str) -> Result<()> { #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { - /// The gateway URI - #[arg(short, long)] - gateway_uri: String, - /// Domain name to monitor - #[arg(short, long)] - domain: String, + /// Gateway address in format: base_domain[:port] + /// e.g., "example.com" or "example.com:8443" + #[arg(short, long, env = "GATEWAY")] + gateway: String, + + /// The dstack-verifier URL + #[arg(short, long, env = "VERIFIER_URL")] + verifier_url: String, + + /// PCCS URL for TDX collateral fetching (optional) + #[arg(long, env = "PCCS_URL")] + pccs_url: Option, } #[tokio::main] @@ -167,7 +416,7 @@ async fn main() -> anyhow::Result<()> { fmt().with_env_filter(filter).init(); } let args = Args::parse(); - let mut monitor = Monitor::new(args.gateway_uri, args.domain)?; + let mut monitor = Monitor::new(args.gateway, args.verifier_url, args.pccs_url)?; monitor.run().await; Ok(()) }