Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vsock-relay"
version = "1.1.0"
version = "2.0.0"
description = "Relays TCP connections from IPv4/IPv6 to vsock."
license = "MPL-2.0"
homepage = "https://github.com/brave-experiments/vsock-relay"
Expand All @@ -19,3 +19,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }

[features]
mock-vsock = []

[profile.release]
panic = "abort"
74 changes: 59 additions & 15 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::Semaphore,
task::JoinSet,
};
use tracing::{debug, error, info, metadata::LevelFilter};
use tracing_subscriber::EnvFilter;
Expand Down Expand Up @@ -95,22 +96,28 @@ mod agnostic {
}

/// Relays TCP connections from IPV4/IPV6 to VSOCK.
#[derive(Parser, Debug)]
#[derive(Clone, Parser, Debug)]
#[command(version)]
struct Cli {
/// Buffer size to use when reading/writing data between peers
#[arg(long, default_value_t = 8192)]
buffer_size: usize,

/// IPV4/IPV6 address/port to listen on
#[arg(short = 's', long, default_value = "0.0.0.0:8443")]
source_address: String,

/// VSOCK address/port to connect to
#[arg(short = 'l', long, default_value = DEFAULT_DEST_ADDR)]
destination_address: String,

/// Maximum amount of allowed concurrent connections
/// IPV4/IPV6 addresses/ports to listen on (comma separated)
#[arg(
short = 's',
long,
default_value = "0.0.0.0:8443",
value_delimiter = ','
)]
source_addresses: Vec<String>,

/// VSOCK addresses/ports to connect to (comma separated)
#[arg(short = 'd', long, default_value = DEFAULT_DEST_ADDR, value_delimiter=',')]
destination_addresses: Vec<String>,

/// Maximum amount of allowed concurrent connections.
/// The limit will be enforced separately for each source/destination pair.
#[arg(short = 'c', long, default_value_t = 1250)]
max_concurrent_connections: usize,

Expand Down Expand Up @@ -224,13 +231,17 @@ async fn host_ip_provider_server(port: u16) -> Result<()> {
}
}

async fn listen_and_serve(args: &Cli) -> Result<()> {
let host_listener = TcpListener::bind(&args.source_address)
async fn listen_and_serve(
source_address: String,
destination_address: String,
args: Cli,
) -> Result<()> {
let host_listener = TcpListener::bind(&source_address)
.await
.context("failed to start source listener")?;
let conn_count_semaphore = Arc::new(Semaphore::new(args.max_concurrent_connections));
let destination_address = parse_enclave_addr(&args.destination_address)?;
info!("Listening on tcp {}...", args.source_address);
let destination_address = parse_enclave_addr(&destination_address)?;
info!("Listening on tcp {}...", source_address);

// Use semaphore to limit active connection count
while let Ok(semaphore_permit) = conn_count_semaphore.clone().acquire_owned().await {
Expand Down Expand Up @@ -262,6 +273,16 @@ async fn listen_and_serve(args: &Cli) -> Result<()> {
async fn main() -> Result<()> {
let args = Cli::parse();

assert!(
!args.source_addresses.is_empty(),
"at least one source/destination pair must be defined"
);
assert_eq!(
args.source_addresses.len(),
args.destination_addresses.len(),
"amount of source and destination addresses must match"
);

tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::builder()
Expand All @@ -279,5 +300,28 @@ async fn main() -> Result<()> {
});
}

listen_and_serve(&args).await
let mut join_set = JoinSet::new();
for (source_address, destination_address) in args
.source_addresses
.clone()
.into_iter()
.zip(args.destination_addresses.clone().into_iter())
{
let args = args.clone();
join_set.spawn(async move {
listen_and_serve(
source_address.to_string(),
destination_address.to_string(),
args,
)
.await
});
}
let result = join_set
.join_next()
.await
.expect("join set should not be empty")
.expect("async task should not result in join error");
join_set.shutdown().await;
result
}