Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 182 additions & 61 deletions crates/key-server/src/aggregator/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ const REFRESH_INTERVAL_SECS: u64 = 30;

/// Default SDK version requirement.
fn default_ts_sdk_version_requirement() -> VersionReq {
VersionReq::parse(">=0.10.0").expect("Failed to parse default SDK version requirement")
VersionReq::parse(">=0.10.0").expect("Failed to parse default TS SDK version requirement")
}

/// Default Rust SDK version requirement.
fn default_rust_sdk_version_requirement() -> VersionReq {
VersionReq::parse(">=0.0.0").expect("Failed to parse default Rust SDK version requirement")
}

/// Default key server version requirement.
Expand Down Expand Up @@ -90,10 +95,14 @@ struct AggregatorOptions {

key_server_object_id: Address,

/// The minimum version of the SDK that is required to use this aggregator.
/// The minimum version of the TS SDK that is required to use this aggregator.
#[serde(default = "default_ts_sdk_version_requirement")]
ts_sdk_version_requirement: VersionReq,

/// The minimum version of the Rust SDK that is required to use this aggregator.
#[serde(default = "default_rust_sdk_version_requirement")]
rust_sdk_version_requirement: VersionReq,

/// The minimum version of the key server that is required by this aggregator.
#[serde(default = "default_key_server_version_requirement")]
key_server_version_requirement: VersionReq,
Expand Down Expand Up @@ -133,31 +142,51 @@ struct AppState {
options: AggregatorOptions,
}

impl AppState {
/// Validate SDK version against requirement based on SDK type.
fn validate_sdk_version(
&self,
version: &str,
sdk_type: Option<&HeaderValue>,
) -> Result<(), InternalError> {
let version = Version::parse(version).map_err(|_| InvalidSDKVersion)?;
let sdk_type = ClientSdkType::from_header(sdk_type.and_then(|v| v.to_str().ok()));

match sdk_type {
ClientSdkType::TypeScript => {
if !self.options.ts_sdk_version_requirement.matches(&version) {
return Err(DeprecatedSDKVersion);
}
fn validate_client_sdk_version(
version: &str,
sdk_type: ClientSdkType,
ts_requirement: &VersionReq,
rust_requirement: &VersionReq,
) -> Result<(), InternalError> {
let version = Version::parse(version).map_err(|_| InvalidSDKVersion)?;
match sdk_type {
ClientSdkType::TypeScript => {
if !ts_requirement.matches(&version) {
return Err(DeprecatedSDKVersion);
}
_ => {
// TODO: Add support for other SDK types.
return Err(InvalidSDKType);
}
ClientSdkType::Rust => {
if !rust_requirement.matches(&version) {
return Err(DeprecatedSDKVersion);
}
}

Ok(())
ClientSdkType::Aggregator | ClientSdkType::Other => {
return Err(InvalidSDKType);
}
}

Ok(())
}

fn validate_client_sdk_headers<'a>(
headers: &'a HeaderMap,
ts_requirement: &VersionReq,
rust_requirement: &VersionReq,
) -> Result<(ClientSdkType, &'a str), InternalError> {
let sdk_type_header = headers.get(HEADER_CLIENT_SDK_TYPE).ok_or(InvalidSDKType)?;
let sdk_type =
ClientSdkType::from_header(Some(sdk_type_header.to_str().map_err(|_| InvalidSDKType)?))?;
let version_str = headers
.get(HEADER_CLIENT_SDK_VERSION)
.ok_or_else(|| MissingRequiredHeader(HEADER_CLIENT_SDK_VERSION.to_string()))?
.to_str()
.map_err(|_| InvalidSDKVersion)?;

validate_client_sdk_version(version_str, sdk_type, ts_requirement, rust_requirement)?;

Ok((sdk_type, version_str))
}

#[tokio::main]
async fn main() -> Result<()> {
let _guard = mysten_service::logging::init();
Expand Down Expand Up @@ -276,55 +305,34 @@ async fn handle_fetch_key(
.and_then(|v| v.to_str().ok())
.unwrap_or_default();

// Extract headers and validate version.
let version = headers.get(HEADER_CLIENT_SDK_VERSION);
let sdk_type = headers.get(HEADER_CLIENT_SDK_TYPE);

let version_str = version
.ok_or_else(|| {
let err = MissingRequiredHeader(HEADER_CLIENT_SDK_VERSION.to_string());
debug!("Missing SDK version header (req_id: {})", req_id);
state.aggregator_metrics.observe_error(err.as_str());
ErrorResponse::from(err)
})
.and_then(|v| {
v.to_str().map_err(|_| {
debug!(
"Invalid SDK version header format (req_id: {}), header: {:?}",
req_id, v
);
state
.aggregator_metrics
.observe_error(InvalidSDKVersion.as_str());
ErrorResponse::from(InvalidSDKVersion)
})
})?;

// Validate and track SDK version.
state
.validate_sdk_version(version_str, sdk_type)
.map_err(|e| {
debug!(
"Invalid SDK version: {:?}, sdk_version: {:?}, sdk_type: {:?} (req_id: {})",
e, version, sdk_type, req_id
);
state.aggregator_metrics.observe_error(e.as_str());
ErrorResponse::from(e)
})?;
let (sdk_type, version_str) = validate_client_sdk_headers(
&headers,
&state.options.ts_sdk_version_requirement,
&state.options.rust_sdk_version_requirement,
)
.map_err(|e| {
debug!(
"Invalid SDK headers: {:?}, sdk_version: {:?}, sdk_type: {:?} (req_id: {})",
e,
headers.get(HEADER_CLIENT_SDK_VERSION),
headers.get(HEADER_CLIENT_SDK_TYPE),
req_id
);
state.aggregator_metrics.observe_error(e.as_str());
ErrorResponse::from(e)
})?;

// Track client SDK version by type
let sdk_type_enum = ClientSdkType::from_header(sdk_type.and_then(|v| v.to_str().ok()));
let sdk_type_str = sdk_type_enum.to_string();
state
.aggregator_metrics
.client_sdk_version
.with_label_values::<&str>(&[&sdk_type_str, version_str])
.with_label_values(&[sdk_type.as_str(), version_str])
.inc();

// Log incoming request with structured data
info!(
"Aggregator request - req_id: {}, SDK version: {}, SDK type: {:?}, user: {:?}",
req_id, version_str, sdk_type_enum, request.certificate.user
req_id, version_str, sdk_type, request.certificate.user
);

// Parse the PTB and build the set of expected full ids every honest committee member should
Expand Down Expand Up @@ -885,6 +893,7 @@ mod tests {
node_url: None,
key_server_object_id: Address::from([0u8; 32]),
ts_sdk_version_requirement: VersionReq::parse(">=0.9.0").unwrap(),
rust_sdk_version_requirement: VersionReq::parse(">=0.0.0").unwrap(),
key_server_version_requirement: VersionReq::parse(">=0.5.14").unwrap(),
key_server_timeout_secs: 8,
api_credentials,
Expand All @@ -905,6 +914,112 @@ mod tests {
}
}

#[test]
fn test_client_sdk_version_validation_matrix() {
let ts_requirement = VersionReq::parse(">=1.2.3").unwrap();
let rust_requirement = VersionReq::parse(">=2.0.0").unwrap();

assert_eq!(
validate_client_sdk_version(
"1.2.3",
ClientSdkType::TypeScript,
&ts_requirement,
&rust_requirement,
),
Ok(())
);
assert_eq!(
validate_client_sdk_version(
"1.2.2",
ClientSdkType::TypeScript,
&ts_requirement,
&rust_requirement,
),
Err(DeprecatedSDKVersion)
);
assert_eq!(
validate_client_sdk_version(
"2.0.0",
ClientSdkType::Rust,
&ts_requirement,
&rust_requirement,
),
Ok(())
);
assert_eq!(
validate_client_sdk_version(
"1.9.9",
ClientSdkType::Rust,
&ts_requirement,
&rust_requirement,
),
Err(DeprecatedSDKVersion)
);
assert_eq!(
validate_client_sdk_version(
"0.6.5",
ClientSdkType::Aggregator,
&ts_requirement,
&rust_requirement,
),
Err(InvalidSDKType)
);
assert_eq!(
validate_client_sdk_version(
"0.6.5",
ClientSdkType::Other,
&ts_requirement,
&rust_requirement,
),
Err(InvalidSDKType)
);
assert_eq!(
validate_client_sdk_version(
"not-semver",
ClientSdkType::TypeScript,
&ts_requirement,
&rust_requirement,
),
Err(InvalidSDKVersion)
);
}

#[test]
fn test_client_sdk_header_error_matrix() {
let ts_requirement = VersionReq::parse(">=1.2.3").unwrap();
let rust_requirement = VersionReq::parse(">=2.0.0").unwrap();

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "aggregator".parse().unwrap());
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.6.5".parse().unwrap());
assert_eq!(
validate_client_sdk_headers(&headers, &ts_requirement, &rust_requirement),
Err(InvalidSDKType)
);

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "python".parse().unwrap());
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.6.5".parse().unwrap());
assert_eq!(
validate_client_sdk_headers(&headers, &ts_requirement, &rust_requirement),
Err(InvalidSDKType)
);

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.6.5".parse().unwrap());
assert_eq!(
validate_client_sdk_headers(&headers, &ts_requirement, &rust_requirement),
Err(InvalidSDKType)
);

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
assert_eq!(
validate_client_sdk_headers(&headers, &ts_requirement, &rust_requirement),
Err(MissingRequiredHeader(HEADER_CLIENT_SDK_VERSION.to_string()))
);
}

#[tokio::test]
async fn test_version_validations() {
// Test 1: Aggregator rejects client SDK if version is too old
Expand All @@ -927,6 +1042,7 @@ mod tests {
let (request, _, _) = create_test_fetch_key_request(&mut thread_rng());

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.3.0".parse().unwrap()); // Too old
let result = handle_fetch_key(State(state), headers, Json(request)).await;

Expand Down Expand Up @@ -958,6 +1074,7 @@ mod tests {
let (request, _, _) = create_test_fetch_key_request(&mut thread_rng());

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.9.6".parse().unwrap());
let result = handle_fetch_key(State(state), headers, Json(request)).await;

Expand Down Expand Up @@ -1006,6 +1123,7 @@ mod tests {
let state = create_test_app_state(&[server3], 1, vec![partial_pk]);

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.9.6".parse().unwrap());
let result = handle_fetch_key(State(state), headers, Json(request)).await;
let response = result.unwrap().into_response();
Expand Down Expand Up @@ -1069,6 +1187,7 @@ mod tests {

// Call handle_fetch_key and check majority error.
let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.9.6".parse().unwrap());
let result = handle_fetch_key(State(state), headers, Json(request)).await;
match result {
Expand Down Expand Up @@ -1153,6 +1272,7 @@ mod tests {
);

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.9.6".parse().unwrap());
let result = handle_fetch_key(State(state), headers, Json(request)).await;
let response =
Expand Down Expand Up @@ -1221,6 +1341,7 @@ mod tests {
);

let mut headers = HeaderMap::new();
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.9.6".parse().unwrap());
let result = handle_fetch_key(State(state), headers, Json(request)).await;
assert_eq!(result.err().unwrap().error, "Failure");
Expand Down
21 changes: 14 additions & 7 deletions crates/key-server/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use axum::response::Response;
use serde::{Deserialize, Serialize};
use sui_types::base_types::ObjectID;

use crate::errors::InternalError;

/// Network configuration.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum Network {
Expand Down Expand Up @@ -50,6 +52,9 @@ pub const SDK_TYPE_AGGREGATOR: &str = "aggregator";
/// SDK type value for TypeScript clients.
pub const SDK_TYPE_TYPESCRIPT: &str = "typescript";

/// SDK type value for Rust clients.
pub const SDK_TYPE_RUST: &str = "rust";

/// Get the git version.
/// Based on https://github.com/MystenLabs/walrus/blob/7e282a681e6530ae4073210b33cac915fab439fa/crates/walrus-service/src/common/utils.rs#L69
#[macro_export]
Expand Down Expand Up @@ -80,23 +85,25 @@ macro_rules! git_version {
pub enum ClientSdkType {
Aggregator,
TypeScript,
Rust,
Other,
}

impl ClientSdkType {
pub fn from_header(header_value: Option<&str>) -> Self {
pub fn from_header(header_value: Option<&str>) -> Result<ClientSdkType, InternalError> {
match header_value {
Some(SDK_TYPE_AGGREGATOR) => ClientSdkType::Aggregator,
Some(SDK_TYPE_TYPESCRIPT) => ClientSdkType::TypeScript,
Some(_) => ClientSdkType::Other,
None => ClientSdkType::TypeScript, // Default to TypeScript for backward compatibility
Some(SDK_TYPE_AGGREGATOR) => Ok(ClientSdkType::Aggregator),
Some(SDK_TYPE_TYPESCRIPT) => Ok(ClientSdkType::TypeScript),
Some(SDK_TYPE_RUST) => Ok(ClientSdkType::Rust),
_ => Ok(ClientSdkType::Other),
}
}

pub fn as_str(&self) -> &'static str {
match self {
ClientSdkType::Aggregator => "aggregator",
ClientSdkType::TypeScript => "typescript",
ClientSdkType::Aggregator => SDK_TYPE_AGGREGATOR,
ClientSdkType::TypeScript => SDK_TYPE_TYPESCRIPT,
ClientSdkType::Rust => SDK_TYPE_RUST,
ClientSdkType::Other => "other",
}
}
Expand Down
Loading
Loading