Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions ct_monitor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ license.workspace = true

[dependencies]
anyhow.workspace = true
clap = { workspace = true, features = ["derive"] }
clap = { workspace = true, features = ["derive", "env"] }
hex = { workspace = true, features = ["alloc", "std"] }
hex_fmt.workspace = true
regex.workspace = true
reqwest = { workspace = true, default-features = false, features = ["json", "rustls-tls", "charset", "hickory-dns"] }
serde = { workspace = true, features = ["derive"] }
serde-human-bytes.workspace = true
serde_json.workspace = true
sha2.workspace = true
tokio = { workspace = true, features = ["full"] }
tracing.workspace = true
tracing-subscriber.workspace = true
x509-parser.workspace = true

dstack-gateway-rpc.workspace = true
ra-rpc = { workspace = true, default-features = false, features = ["client"] }
295 changes: 272 additions & 23 deletions ct_monitor/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,95 @@

use anyhow::{bail, Context, Result};
use clap::Parser;
use dstack_gateway_rpc::gateway_client::GatewayClient;
use ra_rpc::client::RaClient;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_human_bytes as hex_bytes;
use sha2::{Digest, Sha512};
use std::collections::BTreeSet;
use std::time::Duration;
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};
use x509_parser::prelude::*;

const BASE_URL: &str = "https://crt.sh";

/// Quoted public key with TDX quote
#[derive(Debug, Deserialize)]
struct QuotedPublicKey {
/// Hex-encoded public key
public_key: String,
/// JSON-encoded GetQuoteResponse
quote: String,
}

/// GetQuoteResponse from guest-agent
#[derive(Debug, Deserialize)]
struct GetQuoteResponse {
/// TDX quote (hex-encoded in JSON)
#[serde(with = "hex_bytes")]
quote: Vec<u8>,
/// JSON-encoded event log
event_log: String,
/// VM configuration
vm_config: String,
}

/// Request for dstack-verifier
#[derive(Debug, Serialize)]
struct VerificationRequest {
quote: String,
event_log: String,
vm_config: String,
pccs_url: Option<String>,
}

/// Response from dstack-verifier
#[derive(Debug, Deserialize)]
struct VerificationResponse {
is_valid: bool,
details: VerificationDetails,
reason: Option<String>,
}

#[derive(Debug, Deserialize)]
struct VerificationDetails {
#[allow(dead_code)]
quote_verified: bool,
#[allow(dead_code)]
event_log_verified: bool,
#[allow(dead_code)]
os_image_hash_verified: bool,
report_data: Option<String>,
app_info: Option<AppInfo>,
}

/// App info from verification response
#[derive(Debug, Deserialize)]
struct AppInfo {
#[serde(with = "hex_bytes")]
app_id: Vec<u8>,
#[serde(with = "hex_bytes")]
compose_hash: Vec<u8>,
#[serde(with = "hex_bytes")]
os_image_hash: Vec<u8>,
}

#[derive(Debug, Deserialize)]
struct AcmeInfoResponse {
#[allow(dead_code)]
account_uri: String,
#[allow(dead_code)]
hist_keys: Vec<String>,
quoted_hist_keys: Vec<QuotedPublicKey>,
}

struct Monitor {
gateway_uri: String,
domain: String,
verifier_url: String,
pccs_url: Option<String>,
base_domain: String,
known_keys: BTreeSet<Vec<u8>>,
last_checked: Option<u64>,
client: reqwest::Client,
}

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -37,24 +110,194 @@ struct CTLog {
}

impl Monitor {
fn new(gateway_uri: String, domain: String) -> Result<Self> {
validate_domain(&domain)?;
/// Create a new monitor
/// `gateway` format: `base_domain[:port]`, e.g., `example.com` or `example.com:8443`
fn new(gateway: String, verifier_url: String, pccs_url: Option<String>) -> Result<Self> {
let (base_domain, gateway_uri) = Self::parse_gateway(&gateway)?;
validate_domain(&base_domain)?;
Ok(Self {
gateway_uri,
domain,
verifier_url,
pccs_url,
base_domain,
known_keys: BTreeSet::new(),
last_checked: None,
client: reqwest::Client::new(),
})
}

/// Parse gateway input into base_domain and gateway URI
/// Input: `base_domain[:port]`, e.g., `example.com` or `example.com:8443`
/// Output: (base_domain, gateway_uri)
fn parse_gateway(gateway: &str) -> Result<(String, String)> {
let (base_domain, port) = match gateway.rsplit_once(':') {
Some((domain, port_str)) => {
// Validate port is a number
let _: u16 = port_str.parse().context("invalid port number")?;
(domain.to_string(), Some(port_str.to_string()))
}
None => (gateway.to_string(), None),
};

let gateway_uri = match port {
Some(p) => format!("https://gateway.{}:{}", base_domain, p),
None => format!("https://gateway.{}", base_domain),
};

Ok((base_domain, gateway_uri))
}

/// Compute expected report_data for a public key using zt-cert content type
fn compute_expected_report_data(public_key: &[u8]) -> [u8; 64] {
// Format: sha512("zt-cert:" + public_key)
let mut hasher = Sha512::new();
hasher.update(b"zt-cert:");
hasher.update(public_key);
hasher.finalize().into()
}

/// Verify a quoted public key using the verifier service
/// Returns (public_key, app_info)
async fn verify_quoted_key(&self, quoted_key: &QuotedPublicKey) -> Result<(Vec<u8>, AppInfo)> {
let public_key =
hex::decode(&quoted_key.public_key).context("invalid hex in public_key")?;

if quoted_key.quote.is_empty() {
bail!("empty quote for public key");
}

// Parse the GetQuoteResponse from the quote field
let quote_response: GetQuoteResponse =
serde_json::from_str(&quoted_key.quote).context("failed to parse quote response")?;

// Build verification request
let verify_request = VerificationRequest {
quote: hex::encode(&quote_response.quote),
event_log: quote_response.event_log,
vm_config: quote_response.vm_config,
pccs_url: self.pccs_url.clone(),
};

// Call verifier
let verify_url = format!("{}/verify", self.verifier_url.trim_end_matches('/'));
let response = self
.client
.post(&verify_url)
.json(&verify_request)
.send()
.await
.context("failed to call verifier")?;

if !response.status().is_success() {
bail!("verifier returned HTTP {}", response.status().as_u16());
}

let verify_response: VerificationResponse = response
.json()
.await
.context("failed to parse verifier response")?;

if !verify_response.is_valid {
bail!(
"quote verification failed: {}",
verify_response.reason.unwrap_or_default()
);
}

// Verify report_data matches expected value
let expected_report_data = Self::compute_expected_report_data(&public_key);
let expected_hex = hex::encode(expected_report_data);

let actual_report_data = verify_response
.details
.report_data
.context("verifier did not return report_data")?;

if actual_report_data != expected_hex {
bail!(
"report_data mismatch: expected {}, got {}",
expected_hex,
actual_report_data
);
}

let app_info = verify_response
.details
.app_info
.context("verifier did not return app_info")?;

Ok((public_key, app_info))
}

async fn refresh_known_keys(&mut self) -> Result<()> {
info!("fetching known public keys from {}", self.gateway_uri);
// TODO: Use RA-TLS
let tls_no_check = true;
let rpc = GatewayClient::new(RaClient::new(self.gateway_uri.clone(), tls_no_check)?);
let info = rpc.acme_info().await?;
self.known_keys = info.hist_keys.into_iter().collect();
info!("got {} known public keys", self.known_keys.len());
let acme_info_url = format!(
"{}/.dstack/acme-info",
self.gateway_uri.trim_end_matches('/')
);
info!("fetching known public keys from {}", acme_info_url);

let response = self
.client
.get(&acme_info_url)
.send()
.await
.context("failed to fetch acme-info")?;

if !response.status().is_success() {
bail!(
"failed to fetch acme-info: HTTP {}",
response.status().as_u16()
);
}

let info: AcmeInfoResponse = response
.json()
.await
.context("failed to parse acme-info response")?;

info!(
"got {} quoted public keys, verifying...",
info.quoted_hist_keys.len()
);

let mut verified_keys = BTreeSet::new();
for (i, quoted_key) in info.quoted_hist_keys.iter().enumerate() {
match self.verify_quoted_key(quoted_key).await {
Ok((public_key, app_info)) => {
info!(
"✅ verified public key {}: {}",
i,
hex_fmt::HexFmt(&public_key)
);
info!(" app_id: {}", hex_fmt::HexFmt(&app_info.app_id));
info!(
" compose_hash: {}",
hex_fmt::HexFmt(&app_info.compose_hash)
);
info!(
" os_image_hash: {}",
hex_fmt::HexFmt(&app_info.os_image_hash)
);
verified_keys.insert(public_key);
}
Err(e) => {
warn!(
"⚠️ failed to verify public key {}: {}",
i,
hex_fmt::HexFmt(&quoted_key.public_key)
);
warn!(" error: {:#}", e);
// Continue with other keys, but don't add this one
}
}
}

if verified_keys.is_empty() && !info.quoted_hist_keys.is_empty() {
bail!("no public keys could be verified");
}

self.known_keys = verified_keys;
info!("verified {} public keys", self.known_keys.len());
for key in self.known_keys.iter() {
debug!(" {}", hex_fmt::HexFmt(key));
}
Expand All @@ -64,7 +307,7 @@ impl Monitor {
async fn get_logs(&self, count: u32) -> Result<Vec<CTLog>> {
let url = format!(
"{}/?q={}&output=json&limit={}",
BASE_URL, self.domain, count
BASE_URL, self.base_domain, count
);
let response = reqwest::get(&url).await?;
Ok(response.json().await?)
Expand Down Expand Up @@ -125,7 +368,7 @@ impl Monitor {
}

async fn run(&mut self) {
info!("monitoring {}...", self.domain);
info!("monitoring {}...", self.base_domain);
loop {
if let Err(err) = self.refresh_known_keys().await {
error!("error refreshing known keys: {}", err);
Expand All @@ -151,12 +394,18 @@ fn validate_domain(domain: &str) -> Result<()> {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The gateway URI
#[arg(short, long)]
gateway_uri: String,
/// Domain name to monitor
#[arg(short, long)]
domain: String,
/// Gateway address in format: base_domain[:port]
/// e.g., "example.com" or "example.com:8443"
#[arg(short, long, env = "GATEWAY")]
gateway: String,

/// The dstack-verifier URL
#[arg(short, long, env = "VERIFIER_URL")]
verifier_url: String,

/// PCCS URL for TDX collateral fetching (optional)
#[arg(long, env = "PCCS_URL")]
pccs_url: Option<String>,
}

#[tokio::main]
Expand All @@ -167,7 +416,7 @@ async fn main() -> anyhow::Result<()> {
fmt().with_env_filter(filter).init();
}
let args = Args::parse();
let mut monitor = Monitor::new(args.gateway_uri, args.domain)?;
let mut monitor = Monitor::new(args.gateway, args.verifier_url, args.pccs_url)?;
monitor.run().await;
Ok(())
}