Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/user-guide/worker-resolver.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ It's up to you to decide how the URLs should be resolved. One important implemen

A good example can be found
in [benchmarks/cdk/bin/worker.rs](https://github.com/datafusion-contrib/datafusion-distributed/blob/main/benchmarks/cdk/bin/worker.rs),
where a cluster of AWS EC2 machines is discovered identified by tags with the AWS Rust SDK.
where a cluster of AWS EC2 machines is discovered identified by tags with the AWS Rust SDK.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline removed by my IDEs markdown formatter

143 changes: 143 additions & 0 deletions docs/source/user-guide/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,146 @@ async fn main() {
```

The `into_worker_server()` method builds a `WorkerServiceServer` ready to be added as a Tonic service.

## Worker Versioning

Workers expose a `GetWorkerInfo` gRPC endpoint that reports metadata about the running worker,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided this wording was best since it's not unlikely that in the future GetWorkerInfo may send more than just a version string.

including a user-defined version string. This is useful during rolling deployments, when
workers running different code versions coexist in the cluster, the coordinator can route queries
only to workers running compatible code.

### Setting a version

Use the `Worker::with_version()` builder method to tag a worker with a version string.

The version string is free-form — it can be a semver tag, a git SHA, a build number, or any
identifier that makes sense for your deployment workflow. Workers that don't call `with_version()`
report an empty string.

```rust
let worker = Worker::default()
.with_version("2.0.0");
```

One way to avoid forgetting to bump the version on each deploy, derive it from an environment variable
set by your CI/CD pipeline:

```rust
let worker = Worker::default()
.with_version(std::env::var("COMMIT_HASH").unwrap_or_default());
```

### Querying a worker's version

From the coordinator, use `DefaultChannelResolver` to get a cached channel
and `create_worker_client` to build a client, then call `get_worker_info`:

```rust
use datafusion_distributed::{DefaultChannelResolver, GetWorkerInfoRequest, create_worker_client};

let channel_resolver = DefaultChannelResolver::default();
let channel = channel_resolver.get_channel(&worker_url).await?;
let mut client = create_worker_client(channel);

let response = client.get_worker_info(GetWorkerInfoRequest {}).await?;
println!("version: {}", response.into_inner().version_number);
```

### Zero-downtime rolling deployments

During a rolling deployment, workers transition from version A to version B over time. To avoid
routing queries to workers running incompatible code, you can filter workers by version before
the planner sees them.

The recommended pattern is:

1. **Background polling loop**: Periodically query each worker's version and maintain a filtered
list of compatible URLs.
2. **Version-aware WorkerResolver**: Implement `WorkerResolver::get_urls()` to return only the
compatible URLs from the filtered list.

```rust
use std::sync::{Arc, RwLock};
use std::time::Duration;
use url::Url;
use datafusion::common::DataFusionError;
use datafusion_distributed::{
DefaultChannelResolver, GetWorkerInfoRequest, WorkerResolver, create_worker_client,
};

struct VersionAwareWorkerResolver {
compatible_urls: Arc<RwLock<Vec<Url>>>,
}

impl VersionAwareWorkerResolver {
/// Starts a background task that periodically polls all known worker URLs
/// and filters them by the expected version.
fn start_version_filtering(
known_urls: Vec<Url>,
expected_version: String,
) -> Self {
let compatible_urls = Arc::new(RwLock::new(vec![]));
let urls_handle = compatible_urls.clone();

tokio::spawn(async move {
let channel_resolver = DefaultChannelResolver::default();
loop {
let mut filtered = vec![];
for url in &known_urls {
if let Ok(channel) = channel_resolver.get_channel(url).await {
let mut client = create_worker_client(channel);
if let Ok(resp) = client.get_worker_info(GetWorkerInfoRequest {}).await {
if resp.into_inner().version_number == expected_version {
filtered.push(url.clone());
}
}
}
}
*urls_handle.write().unwrap() = filtered;
tokio::time::sleep(Duration::from_secs(5)).await;
}
});

Self { compatible_urls }
}
}

impl WorkerResolver for VersionAwareWorkerResolver {
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> {
Ok(self.compatible_urls.read().unwrap().clone())
}
}
```

With the resolver in place, wire it into the session and tag each worker with a version.

```rust
use datafusion::execution::SessionStateBuilder;
use datafusion_distributed::{DistributedExt, DistributedPhysicalOptimizerRule, Worker};

let worker_version = std::env::var("COMMIT_HASH").unwrap_or_default();

// `known_urls` comes from your service discovery.
let resolver = VersionAwareWorkerResolver::start_version_filtering(
known_urls,
worker_version.clone(),
);

let state = SessionStateBuilder::new()
.with_default_features()
.with_distributed_worker_resolver(resolver)
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
.build();

let ctx = SessionContext::from(state);

let worker = Worker::default().with_version(worker_version);

Server::builder()
.add_service(worker.into_worker_server())
.serve(addr)
.await?;
```

The coordinator's resolver continuously polls all known URLs in the background.
Only workers that respond with the correct version will appear in `get_urls()`.
53 changes: 50 additions & 3 deletions examples/localhost_run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use datafusion::common::DataFusionError;
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::{ParquetReadOptions, SessionContext};
use datafusion_distributed::{
DistributedExt, DistributedPhysicalOptimizerRule, WorkerResolver, display_plan_ascii,
DefaultChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, GetWorkerInfoRequest,
WorkerResolver, create_worker_client, display_plan_ascii,
};
use futures::TryStreamExt;
use std::error::Error;
Expand All @@ -23,19 +24,65 @@ struct Args {
#[structopt(long = "cluster-ports", use_delimiter = true)]
cluster_ports: Vec<u16>,

/// Only use workers reporting this version (via GetWorkerInfo).
/// When omitted, all workers in --cluster-ports are used.
#[structopt(long)]
version: Option<String>,

/// Whether the distributed plan should be rendered instead of executing the query.
#[structopt(long)]
show_distributed_plan: bool,
}

/// Returns `true` if the worker at `url` reports `expected_version` via
/// `GetWorkerInfo`. Returns `false` if the worker is unreachable, returns
/// an error, or reports a different version.
async fn worker_has_version(
channel_resolver: &DefaultChannelResolver,
url: &Url,
expected_version: &str,
) -> bool {
let Ok(channel) = channel_resolver.get_channel(url).await else {
return false;
};
let mut client = create_worker_client(channel);
let Ok(response) = client.get_worker_info(GetWorkerInfoRequest {}).await else {
return false;
};
response.into_inner().version_number == expected_version
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let args = Args::from_args();

let localhost_resolver = LocalhostWorkerResolver {
ports: args.cluster_ports,
let ports = if let Some(target_version) = &args.version {
// Filter workers by version before building the session.
let channel_resolver = DefaultChannelResolver::default();
let mut compatible = Vec::new();
for &port in &args.cluster_ports {
let url = Url::parse(&format!("http://localhost:{port}"))?;
if worker_has_version(&channel_resolver, &url, target_version).await {
compatible.push(port);
} else {
println!("Excluding worker on port {port} (version mismatch)");
}
}
if compatible.is_empty() {
return Err(format!("No workers matched version '{target_version}'").into());
}
println!(
"Using {}/{} workers matching version '{target_version}'\n",
compatible.len(),
args.cluster_ports.len(),
);
compatible
} else {
args.cluster_ports
};

let localhost_resolver = LocalhostWorkerResolver { ports };

let state = SessionStateBuilder::new()
.with_default_features()
.with_distributed_worker_resolver(localhost_resolver)
Expand Down
16 changes: 14 additions & 2 deletions examples/localhost_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,27 @@ use tonic::transport::Server;
struct Args {
#[structopt(default_value = "8080")]
port: u16,

/// Optional version string reported via GetWorkerInfo (e.g. "2.0", a git SHA).
#[structopt(long)]
version: Option<String>,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let args = Args::from_args();

let mut worker = Worker::default();
if let Some(version) = &args.version {
worker = worker.with_version(version);
}

let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port);
println!("Worker listening on {addr}");

Server::builder()
.add_service(Worker::default().into_worker_server())
.serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port))
.add_service(worker.into_worker_server())
.serve(addr)
.await?;

Ok(())
Expand Down
Loading
Loading