diff --git a/docs/source/user-guide/worker-resolver.md b/docs/source/user-guide/worker-resolver.md index a17e5915..7b119d73 100644 --- a/docs/source/user-guide/worker-resolver.md +++ b/docs/source/user-guide/worker-resolver.md @@ -70,4 +70,4 @@ It's up to you to decide how the URLs should be resolved. One important implemen A good example can be found in [benchmarks/cdk/bin/worker.rs](https://github.com/datafusion-contrib/datafusion-distributed/blob/main/benchmarks/cdk/bin/worker.rs), -where a cluster of AWS EC2 machines is discovered identified by tags with the AWS Rust SDK. \ No newline at end of file +where a cluster of AWS EC2 machines is discovered identified by tags with the AWS Rust SDK. diff --git a/docs/source/user-guide/worker.md b/docs/source/user-guide/worker.md index 91fa3bf9..fa3726bc 100644 --- a/docs/source/user-guide/worker.md +++ b/docs/source/user-guide/worker.md @@ -95,3 +95,146 @@ async fn main() { ``` The `into_worker_server()` method builds a `WorkerServiceServer` ready to be added as a Tonic service. + +## Worker Versioning + +Workers expose a `GetWorkerInfo` gRPC endpoint that reports metadata about the running worker, +including a user-defined version string. This is useful during rolling deployments, when +workers running different code versions coexist in the cluster, the coordinator can route queries +only to workers running compatible code. + +### Setting a version + +Use the `Worker::with_version()` builder method to tag a worker with a version string. + +The version string is free-form — it can be a semver tag, a git SHA, a build number, or any +identifier that makes sense for your deployment workflow. Workers that don't call `with_version()` +report an empty string. + +```rust +let worker = Worker::default() + .with_version("2.0.0"); +``` + +One way to avoid forgetting to bump the version on each deploy, derive it from an environment variable +set by your CI/CD pipeline: + +```rust +let worker = Worker::default() + .with_version(std::env::var("COMMIT_HASH").unwrap_or_default()); +``` + +### Querying a worker's version + +From the coordinator, use `DefaultChannelResolver` to get a cached channel +and `create_worker_client` to build a client, then call `get_worker_info`: + +```rust +use datafusion_distributed::{DefaultChannelResolver, GetWorkerInfoRequest, create_worker_client}; + +let channel_resolver = DefaultChannelResolver::default(); +let channel = channel_resolver.get_channel(&worker_url).await?; +let mut client = create_worker_client(channel); + +let response = client.get_worker_info(GetWorkerInfoRequest {}).await?; +println!("version: {}", response.into_inner().version_number); +``` + +### Zero-downtime rolling deployments + +During a rolling deployment, workers transition from version A to version B over time. To avoid +routing queries to workers running incompatible code, you can filter workers by version before +the planner sees them. + +The recommended pattern is: + +1. **Background polling loop**: Periodically query each worker's version and maintain a filtered + list of compatible URLs. +2. **Version-aware WorkerResolver**: Implement `WorkerResolver::get_urls()` to return only the + compatible URLs from the filtered list. + +```rust +use std::sync::{Arc, RwLock}; +use std::time::Duration; +use url::Url; +use datafusion::common::DataFusionError; +use datafusion_distributed::{ + DefaultChannelResolver, GetWorkerInfoRequest, WorkerResolver, create_worker_client, +}; + +struct VersionAwareWorkerResolver { + compatible_urls: Arc>>, +} + +impl VersionAwareWorkerResolver { + /// Starts a background task that periodically polls all known worker URLs + /// and filters them by the expected version. + fn start_version_filtering( + known_urls: Vec, + expected_version: String, + ) -> Self { + let compatible_urls = Arc::new(RwLock::new(vec![])); + let urls_handle = compatible_urls.clone(); + + tokio::spawn(async move { + let channel_resolver = DefaultChannelResolver::default(); + loop { + let mut filtered = vec![]; + for url in &known_urls { + if let Ok(channel) = channel_resolver.get_channel(url).await { + let mut client = create_worker_client(channel); + if let Ok(resp) = client.get_worker_info(GetWorkerInfoRequest {}).await { + if resp.into_inner().version_number == expected_version { + filtered.push(url.clone()); + } + } + } + } + *urls_handle.write().unwrap() = filtered; + tokio::time::sleep(Duration::from_secs(5)).await; + } + }); + + Self { compatible_urls } + } +} + +impl WorkerResolver for VersionAwareWorkerResolver { + fn get_urls(&self) -> Result, DataFusionError> { + Ok(self.compatible_urls.read().unwrap().clone()) + } +} +``` + +With the resolver in place, wire it into the session and tag each worker with a version. + +```rust +use datafusion::execution::SessionStateBuilder; +use datafusion_distributed::{DistributedExt, DistributedPhysicalOptimizerRule, Worker}; + +let worker_version = std::env::var("COMMIT_HASH").unwrap_or_default(); + +// `known_urls` comes from your service discovery. +let resolver = VersionAwareWorkerResolver::start_version_filtering( + known_urls, + worker_version.clone(), +); + +let state = SessionStateBuilder::new() + .with_default_features() + .with_distributed_worker_resolver(resolver) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .build(); + +let ctx = SessionContext::from(state); + +let worker = Worker::default().with_version(worker_version); + +Server::builder() + .add_service(worker.into_worker_server()) + .serve(addr) + .await?; +``` + +The coordinator's resolver continuously polls all known URLs in the background. +Only workers that respond with the correct version will appear in `get_urls()`. diff --git a/examples/localhost_run.rs b/examples/localhost_run.rs index eae9b6cd..073fb732 100644 --- a/examples/localhost_run.rs +++ b/examples/localhost_run.rs @@ -4,7 +4,8 @@ use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_distributed::{ - DistributedExt, DistributedPhysicalOptimizerRule, WorkerResolver, display_plan_ascii, + DefaultChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, GetWorkerInfoRequest, + WorkerResolver, create_worker_client, display_plan_ascii, }; use futures::TryStreamExt; use std::error::Error; @@ -23,19 +24,65 @@ struct Args { #[structopt(long = "cluster-ports", use_delimiter = true)] cluster_ports: Vec, + /// Only use workers reporting this version (via GetWorkerInfo). + /// When omitted, all workers in --cluster-ports are used. + #[structopt(long)] + version: Option, + /// Whether the distributed plan should be rendered instead of executing the query. #[structopt(long)] show_distributed_plan: bool, } +/// Returns `true` if the worker at `url` reports `expected_version` via +/// `GetWorkerInfo`. Returns `false` if the worker is unreachable, returns +/// an error, or reports a different version. +async fn worker_has_version( + channel_resolver: &DefaultChannelResolver, + url: &Url, + expected_version: &str, +) -> bool { + let Ok(channel) = channel_resolver.get_channel(url).await else { + return false; + }; + let mut client = create_worker_client(channel); + let Ok(response) = client.get_worker_info(GetWorkerInfoRequest {}).await else { + return false; + }; + response.into_inner().version_number == expected_version +} + #[tokio::main] async fn main() -> Result<(), Box> { let args = Args::from_args(); - let localhost_resolver = LocalhostWorkerResolver { - ports: args.cluster_ports, + let ports = if let Some(target_version) = &args.version { + // Filter workers by version before building the session. + let channel_resolver = DefaultChannelResolver::default(); + let mut compatible = Vec::new(); + for &port in &args.cluster_ports { + let url = Url::parse(&format!("http://localhost:{port}"))?; + if worker_has_version(&channel_resolver, &url, target_version).await { + compatible.push(port); + } else { + println!("Excluding worker on port {port} (version mismatch)"); + } + } + if compatible.is_empty() { + return Err(format!("No workers matched version '{target_version}'").into()); + } + println!( + "Using {}/{} workers matching version '{target_version}'\n", + compatible.len(), + args.cluster_ports.len(), + ); + compatible + } else { + args.cluster_ports }; + let localhost_resolver = LocalhostWorkerResolver { ports }; + let state = SessionStateBuilder::new() .with_default_features() .with_distributed_worker_resolver(localhost_resolver) diff --git a/examples/localhost_worker.rs b/examples/localhost_worker.rs index 2ed0d6d3..5fd8482f 100644 --- a/examples/localhost_worker.rs +++ b/examples/localhost_worker.rs @@ -9,15 +9,27 @@ use tonic::transport::Server; struct Args { #[structopt(default_value = "8080")] port: u16, + + /// Optional version string reported via GetWorkerInfo (e.g. "2.0", a git SHA). + #[structopt(long)] + version: Option, } #[tokio::main] async fn main() -> Result<(), Box> { let args = Args::from_args(); + let mut worker = Worker::default(); + if let Some(version) = &args.version { + worker = worker.with_version(version); + } + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port); + println!("Worker listening on {addr}"); + Server::builder() - .add_service(Worker::default().into_worker_server()) - .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port)) + .add_service(worker.into_worker_server()) + .serve(addr) .await?; Ok(()) diff --git a/examples/worker_versioning.md b/examples/worker_versioning.md new file mode 100644 index 00000000..2fbf24f8 --- /dev/null +++ b/examples/worker_versioning.md @@ -0,0 +1,173 @@ +# Worker Versioning + +Workers expose a `GetWorkerInfo` gRPC endpoint that reports metadata including a +user-defined version string. This is useful during **rolling deployments** when +workers running different code versions coexist in the cluster and you need the +coordinator to route queries only to workers running compatible code. + +## Quick start + +### 1. Start versioned workers + +Open separate terminals and start workers with different version strings: + +```bash +# Terminal 1 — v1.0 worker +cargo run --example localhost_worker -- 8080 --version 1.0 + +# Terminal 2 — v2.0 worker +cargo run --example localhost_worker -- 8081 --version 2.0 + +# Terminal 3 — v2.0 worker +cargo run --example localhost_worker -- 8082 --version 2.0 +``` + +The `--version` flag is optional. Workers started without it report an empty +version string. + +### 2. Check what versions are running + +```bash +cargo run --example worker_versioning -- --ports 8080,8081,8082 +``` + +``` +Worker at :8080 -> version 1.0 +Worker at :8081 -> version 2.0 +Worker at :8082 -> version 2.0 +``` + +### 3. Run a query against only v2.0 workers + +```bash +cargo run --example localhost_run -- \ + "SELECT city, count(*) FROM weather GROUP BY city" \ + --cluster-ports 8080,8081,8082 \ + --version 2.0 +``` + +``` +Excluding worker on port 8080 (version mismatch) +Using 2/3 workers matching version '2.0' + ++--------+----------+ +| city | count(*) | +... +``` + +The `--version` flag queries each worker's version via `GetWorkerInfo` and +excludes any worker that doesn't match before the query is planned. + +Without `--version`, all workers in `--cluster-ports` are used regardless of +what version they report. + +## How it works + +### Setting a version on a worker + +Use `Worker::with_version()` when building the worker: + +```rust +let worker = Worker::default().with_version("2.0"); + +Server::builder() + .add_service(worker.into_worker_server()) + .serve(addr) + .await?; +``` + +The version string is free-form — it can be a semver tag, a git SHA, or a build +number. + +### Querying a worker's version + +Use `DefaultChannelResolver` to obtain a cached gRPC channel and +`create_worker_client` to build a client: + +```rust +let channel_resolver = DefaultChannelResolver::default(); +let channel = channel_resolver.get_channel(&worker_url).await?; +let mut client = create_worker_client(channel); + +let response = client.get_worker_info(GetWorkerInfoRequest {}).await?; +println!("version: {}", response.into_inner().version_number); +``` + +### Checking version compatibility + +The `worker_has_version` helper wraps the above into a single boolean check. +Returns `false` if the worker is unreachable, returns an error, or reports a +different version: + +```rust +async fn worker_has_version( + channel_resolver: &DefaultChannelResolver, + url: &Url, + expected_version: &str, +) -> bool { + let Ok(channel) = channel_resolver.get_channel(url).await else { + return false; + }; + let mut client = create_worker_client(channel); + let Ok(response) = client.get_worker_info(GetWorkerInfoRequest {}).await else { + return false; + }; + response.into_inner().version_number == expected_version +} +``` + +## Production pattern: background polling + +In the examples above, version filtering happens once at startup. In production, +workers come and go during rolling deployments, so you need to poll continuously. + +`WorkerResolver::get_urls()` is **synchronous** — it's called during planning +and cannot do async I/O. The recommended pattern is a background task that +periodically polls each worker's version and writes a filtered URL list into +shared state. The resolver reads from that shared state. + +```rust +// In a background task, periodically filter worker URLs by version: +let channel_resolver = DefaultChannelResolver::default(); +let compatible_urls: Arc>> = /* shared with WorkerResolver */; + +tokio::spawn(async move { + loop { + let mut filtered = vec![]; + for url in &all_known_urls { + if worker_has_version(&channel_resolver, url, "2.0").await { + filtered.push(url.clone()); + } + } + *compatible_urls.write().unwrap() = filtered; + tokio::time::sleep(Duration::from_secs(5)).await; + } +}); +``` + +The `WorkerResolver` then just reads from the shared list: + +```rust +struct VersionAwareWorkerResolver { + compatible_urls: Arc>>, +} + +impl WorkerResolver for VersionAwareWorkerResolver { + fn get_urls(&self) -> Result, DataFusionError> { + Ok(self.compatible_urls.read().unwrap().clone()) + } +} +``` + +Wire it into a session: + +```rust +let state = SessionStateBuilder::new() + .with_default_features() + .with_distributed_worker_resolver(resolver) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .build(); +``` + +The planner calls `get_urls()` during planning and will only see workers that +passed the version check on the last polling cycle. diff --git a/src/lib.rs b/src/lib.rs index e1771d3b..ef56cc34 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,11 +40,12 @@ pub use stage::{ DistributedTaskContext, ExecutionTask, Stage, display_plan_ascii, display_plan_graphviz, explain_analyze, }; -pub use worker::generated::worker::TaskKey; -pub use worker::generated::worker::worker_service_client::WorkerServiceClient; pub use worker::{ DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, TaskData, Worker, WorkerQueryContext, WorkerSessionBuilder, + generated::worker::{ + GetWorkerInfoRequest, GetWorkerInfoResponse, worker_service_client::WorkerServiceClient, + }, }; pub use observability::{ diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs index e659236b..2c80ef85 100644 --- a/src/worker/generated/worker.rs +++ b/src/worker/generated/worker.rs @@ -29,6 +29,13 @@ pub struct SetPlanRequest { #[prost(bytes = "vec", tag = "3")] pub plan_proto: ::prost::alloc::vec::Vec, } +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct GetWorkerInfoRequest {} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct GetWorkerInfoResponse { + #[prost(string, tag = "1")] + pub version_number: ::prost::alloc::string::String, +} #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ExecuteTaskRequest { /// The unique identifier of the task that is going to get executed. @@ -400,7 +407,8 @@ pub mod worker_service_client { self } /// Establishes a bidirectional message stream between a coordinator and a worker, over which messages - /// will be exchanged at any time during a query's lifetime. + /// will be exchanged at any time during a query's lifetime. It's expected to be one coordinator channel + /// per task. pub async fn coordinator_channel( &mut self, request: impl tonic::IntoStreamingRequest, @@ -439,6 +447,21 @@ pub mod worker_service_client { .insert(GrpcMethod::new("worker.WorkerService", "ExecuteTask")); self.inner.server_streaming(req, path, codec).await } + pub async fn get_worker_info( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static("/worker.WorkerService/GetWorkerInfo"); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("worker.WorkerService", "GetWorkerInfo")); + self.inner.unary(req, path, codec).await + } } } /// Generated server implementations. @@ -460,7 +483,8 @@ pub mod worker_service_server { > + std::marker::Send + 'static; /// Establishes a bidirectional message stream between a coordinator and a worker, over which messages - /// will be exchanged at any time during a query's lifetime. + /// will be exchanged at any time during a query's lifetime. It's expected to be one coordinator channel + /// per task. async fn coordinator_channel( &self, request: tonic::Request>, @@ -475,6 +499,10 @@ pub mod worker_service_server { &self, request: tonic::Request, ) -> std::result::Result, tonic::Status>; + async fn get_worker_info( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; } #[derive(Debug)] pub struct WorkerServiceServer { @@ -639,6 +667,47 @@ pub mod worker_service_server { }; Box::pin(fut) } + "/worker.WorkerService/GetWorkerInfo" => { + #[allow(non_camel_case_types)] + struct GetWorkerInfoSvc(pub Arc); + impl tonic::server::UnaryService + for GetWorkerInfoSvc + { + type Response = super::GetWorkerInfoResponse; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_worker_info(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = GetWorkerInfoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } _ => Box::pin(async move { let mut response = http::Response::new(tonic::body::Body::default()); let headers = response.headers_mut(); diff --git a/src/worker/worker.proto b/src/worker/worker.proto index bf42fe52..4fca1105 100644 --- a/src/worker/worker.proto +++ b/src/worker/worker.proto @@ -8,6 +8,7 @@ service WorkerService { rpc CoordinatorChannel(stream CoordinatorToWorkerMsg) returns (stream WorkerToCoordinatorMsg); // Executes the requested partition range of a subplan previously sent by the coordinator channel. rpc ExecuteTask(ExecuteTaskRequest) returns (stream FlightData); + rpc GetWorkerInfo(GetWorkerInfoRequest) returns (GetWorkerInfoResponse); } message CoordinatorToWorkerMsg { @@ -31,6 +32,12 @@ message SetPlanRequest { bytes plan_proto = 3; } +message GetWorkerInfoRequest {} + +message GetWorkerInfoResponse { + string version_number = 1; +} + message ExecuteTaskRequest { // The unique identifier of the task that is going to get executed. TaskKey task_key = 1; diff --git a/src/worker/worker_service.rs b/src/worker/worker_service.rs index 8a841f88..ef57ae7c 100644 --- a/src/worker/worker_service.rs +++ b/src/worker/worker_service.rs @@ -21,6 +21,8 @@ use std::time::Duration; use tonic::codegen::BoxStream; use tonic::{Request, Response, Status, Streaming}; +use super::generated::worker::{GetWorkerInfoRequest, GetWorkerInfoResponse}; + #[allow(clippy::type_complexity)] #[derive(Clone, Default)] pub(super) struct WorkerHooks { @@ -40,6 +42,7 @@ pub struct Worker { pub(super) session_builder: Arc, pub(super) hooks: WorkerHooks, pub(super) max_message_size: Option, + pub(super) version: String, } impl Default for Worker { @@ -53,6 +56,7 @@ impl Default for Worker { session_builder: Arc::new(DefaultSessionBuilder), hooks: WorkerHooks::default(), max_message_size: Some(usize::MAX), + version: String::default(), } } } @@ -151,6 +155,11 @@ impl Worker { )) } + pub fn with_version(mut self, version: impl Into) -> Self { + self.version = version.into(); + self + } + /// Returns the number of cached task entries currently held by this worker. #[cfg(any(test, feature = "integration"))] pub async fn tasks_running(&self) -> usize { @@ -196,4 +205,13 @@ impl WorkerService for Worker { ) -> Result, Status> { self.impl_execute_task(request).await } + + async fn get_worker_info( + &self, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(GetWorkerInfoResponse { + version_number: self.version.clone(), + })) + } }