diff --git a/Cargo.lock b/Cargo.lock index 1c8eb0729..8eab6842c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4158,6 +4158,7 @@ dependencies = [ "cfg-if", "js-sys", "wasm-bindgen", + "wasm-bindgen-futures", "web-sys", ] @@ -4172,6 +4173,15 @@ dependencies = [ "worker 0.7.5", ] +[[package]] +name = "workflow-example" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", + "worker 0.7.5", +] + [[package]] name = "writeable" version = "0.6.2" diff --git a/Cargo.toml b/Cargo.toml index dfb53de9a..5c33b7736 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,7 +74,7 @@ web-sys = { version = "0.3.90", features = [ "WritableStream", "WritableStreamDefaultWriter", ] } -worker = { version = "0.7.5", path = "worker", features = ["queue", "d1", "axum", "timezone"] } +worker = { version = "0.7.5", path = "worker", features = ["queue", "d1", "axum", "timezone", "workflow"] } worker-codegen = { path = "worker-codegen", version = "0.2.0" } worker-macros = { version = "0.7.5", path = "worker-macros", features = ["queue"] } worker-sys = { version = "0.7.5", path = "worker-sys", features = ["d1", "queue"] } diff --git a/examples/workflow/Cargo.toml b/examples/workflow/Cargo.toml new file mode 100644 index 000000000..41e4881b3 --- /dev/null +++ b/examples/workflow/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "workflow-example" +version = "0.1.0" +edition = "2021" + +[package.metadata.release] +release = false + +[lib] +crate-type = ["cdylib"] + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json = "1" +worker = { path = "../../worker", features = ["workflow"] } diff --git a/examples/workflow/src/lib.rs b/examples/workflow/src/lib.rs new file mode 100644 index 000000000..6a840e4dc --- /dev/null +++ b/examples/workflow/src/lib.rs @@ -0,0 +1,178 @@ +use serde::{Deserialize, Serialize}; +use worker::wasm_bindgen::JsValue; +use worker::*; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MyParams { + pub email: String, + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MyOutput { + pub message: String, + pub steps_completed: u32, +} + +#[workflow] +pub struct MyWorkflow { + #[allow(dead_code)] + env: Env, +} + +impl WorkflowEntrypoint for MyWorkflow { + fn new(_ctx: Context, env: Env) -> Self { + Self { env } + } + + async fn run(&self, event: WorkflowEvent, step: WorkflowStep) -> Result { + console_log!("Workflow started with instance ID: {}", event.instance_id); + + let params: MyParams = serde_wasm_bindgen::from_value(event.payload)?; + + let email_for_validation = params.email.clone(); + step.do_with_config( + "validate-params", + StepConfig { + retries: Some(RetryConfig { + limit: 3, + delay: "1 second".into(), + backoff: None, + }), + timeout: None, + }, + move |_ctx| { + let email = email_for_validation.clone(); + async move { + if !email.contains('@') { + return Err(NonRetryableError::new("invalid email address").into()); + } + Ok(serde_json::json!({ "valid": true })) + } + }, + ) + .await?; + + let name_for_step1 = params.name.clone(); + let step1_result = step + .do_("initial-processing", move |_ctx| { + let name = name_for_step1.clone(); + async move { + console_log!("Processing for user: {}", name); + Ok(serde_json::json!({ + "processed": true, + "user": name + })) + } + }) + .await?; + + console_log!("Step 1 completed: {:?}", step1_result); + + console_log!("Step 2: Sleeping for 10 seconds..."); + step.sleep("wait-for-processing", "10 seconds").await?; + + let email_for_step3 = params.email.clone(); + let notification_result = step + .do_with_config( + "send-notification", + StepConfig { + retries: Some(RetryConfig { + limit: 3, + delay: "5 seconds".into(), + backoff: Some(Backoff::Exponential), + }), + timeout: Some("1 minute".into()), + }, + move |_ctx| { + let email = email_for_step3.clone(); + async move { + console_log!("Sending notification to: {}", email); + if js_sys::Math::random() < 0.5 { + return Err("notification service temporarily unavailable".into()); + } + Ok(serde_json::json!({ + "notification_sent": true, + "email": email + })) + } + }, + ) + .await?; + + console_log!("Step 3 completed: {:?}", notification_result); + + let output = MyOutput { + message: format!("Workflow completed for {}", params.name), + steps_completed: 3, + }; + + Ok(serialize_as_object(&output)?) + } +} + +#[event(fetch)] +async fn fetch(mut req: Request, env: Env, _ctx: Context) -> Result { + let url = req.url()?; + let path = url.path(); + let workflow = env.workflow("MY_WORKFLOW")?; + + match (req.method(), path) { + (Method::Post, "/workflow") => { + let params: MyParams = req.json().await?; + + let instance = workflow + .create(Some(CreateOptions { + params: Some(params), + ..Default::default() + })) + .await?; + + Response::from_json(&serde_json::json!({ + "id": instance.id(), + "message": "Workflow created" + })) + } + + (Method::Get, path) if path.starts_with("/workflow/") => { + let id = path.trim_start_matches("/workflow/"); + let instance = workflow.get(id).await?; + let status = instance.status().await?; + + Response::from_json(&serde_json::json!({ + "id": instance.id(), + "status": format!("{:?}", status.status), + "error": status.error, + "output": status.output + })) + } + + (Method::Post, path) if path.starts_with("/workflow/") && path.ends_with("/pause") => { + let id = path + .trim_start_matches("/workflow/") + .trim_end_matches("/pause"); + let instance = workflow.get(id).await?; + instance.pause().await?; + + Response::from_json(&serde_json::json!({ + "id": instance.id(), + "message": "Workflow paused" + })) + } + + (Method::Post, path) if path.starts_with("/workflow/") && path.ends_with("/resume") => { + let id = path + .trim_start_matches("/workflow/") + .trim_end_matches("/resume"); + let instance = workflow.get(id).await?; + instance.resume().await?; + + Response::from_json(&serde_json::json!({ + "id": instance.id(), + "message": "Workflow resumed" + })) + } + + _ => Response::error("Not Found", 404), + } +} diff --git a/examples/workflow/wrangler.toml b/examples/workflow/wrangler.toml new file mode 100644 index 000000000..b20e6ef7c --- /dev/null +++ b/examples/workflow/wrangler.toml @@ -0,0 +1,13 @@ +name = "workflow-example" +main = "build/worker/shim.mjs" +compatibility_date = "2024-10-22" + +[build] +# For development: use local worker-build binary +# For production: command = "cargo install -q worker-build && worker-build --release" +command = "RUSTFLAGS='--cfg=web_sys_unstable_apis' ../../target/release/worker-build --release" + +[[workflows]] +name = "my-workflow" +binding = "MY_WORKFLOW" +class_name = "MyWorkflow" diff --git a/package-lock.json b/package-lock.json index df1c9579f..bd318c8ff 100644 --- a/package-lock.json +++ b/package-lock.json @@ -2714,6 +2714,16 @@ "node": ">=18.0.0" } }, + "node_modules/miniflare/node_modules/undici": { + "version": "7.24.4", + "resolved": "https://registry.npmjs.org/undici/-/undici-7.24.4.tgz", + "integrity": "sha512-BM/JzwwaRXxrLdElV2Uo6cTLEjhSb3WXboncJamZ15NgUURmvlXvxa6xkwIOILIjPNo9i8ku136ZvWV0Uly8+w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=20.18.1" + } + }, "node_modules/mri": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/mri/-/mri-1.2.0.tgz", @@ -3434,16 +3444,6 @@ "node": ">=14.17" } }, - "node_modules/undici": { - "version": "7.24.4", - "resolved": "https://registry.npmjs.org/undici/-/undici-7.24.4.tgz", - "integrity": "sha512-BM/JzwwaRXxrLdElV2Uo6cTLEjhSb3WXboncJamZ15NgUURmvlXvxa6xkwIOILIjPNo9i8ku136ZvWV0Uly8+w==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=20.18.1" - } - }, "node_modules/undici-types": { "version": "7.12.0", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.12.0.tgz", diff --git a/test/src/lib.rs b/test/src/lib.rs index 4f28ca705..d32e50a08 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -37,6 +37,7 @@ mod socket; mod sql_counter; mod sql_iterator; mod user; +mod workflow; mod ws; #[derive(Deserialize, Serialize)] diff --git a/test/src/router.rs b/test/src/router.rs index cdd0af85f..569e1fab9 100644 --- a/test/src/router.rs +++ b/test/src/router.rs @@ -1,7 +1,7 @@ use crate::{ alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable, fetch, form, js_snippets, kv, put_raw, queue, r2, rate_limit, request, secret_store, service, socket, - sql_counter, sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE, + sql_counter, sql_iterator, user, workflow, ws, SomeSharedData, GLOBAL_STATE, }; #[cfg(feature = "http")] use std::convert::TryInto; @@ -17,6 +17,8 @@ use axum::{ routing::{delete, get, head, options, patch, post, put}, Extension, }; +#[cfg(feature = "http")] +use worker::send::SendFuture; // Transform the argument into the correct form for the router. // For axum::Router: @@ -59,16 +61,16 @@ macro_rules! format_route ( #[cfg(feature = "http")] macro_rules! handler ( ($name:path) => { - |Extension(env): Extension, Extension(data): Extension, req: axum::extract::Request| async { + |Extension(env): Extension, Extension(data): Extension, req: axum::extract::Request| SendFuture::new(async { let resp = $name(req.try_into().expect("convert request"), env, data).await.expect("handler result"); Into::>::into(resp) - } + }) }; ($name:path, sync) => { - |Extension(env): Extension, Extension(data): Extension, req: axum::extract::Request| async { + |Extension(env): Extension, Extension(data): Extension, req: axum::extract::Request| SendFuture::new(async { let resp = $name(req.try_into().expect("convert request"), env, data).expect("handler result"); Into::>::into(resp) - } + }) }; ); #[cfg(not(feature = "http"))] @@ -239,6 +241,18 @@ macro_rules! add_routes ( add_route!($obj, get, format_route!("/rate-limit/key/{}", "key"), rate_limit::handle_rate_limit_with_key); add_route!($obj, get, "/rate-limit/bulk-test", rate_limit::handle_rate_limit_bulk_test); add_route!($obj, get, "/rate-limit/reset", rate_limit::handle_rate_limit_reset); + add_route!($obj, post, "/workflow/create", workflow::handle_workflow_create); + add_route!($obj, post, "/workflow/create-invalid", workflow::handle_workflow_create_invalid); + add_route!($obj, get, format_route!("/workflow/status/{}", "id"), workflow::handle_workflow_status); + add_route!($obj, post, "/workflow/event/create", workflow::handle_event_workflow_create); + add_route!($obj, post, format_route!("/workflow/event/send/{}", "id"), workflow::handle_event_workflow_send); + add_route!($obj, get, format_route!("/workflow/event/status/{}", "id"), workflow::handle_event_workflow_status); + add_route!($obj, post, "/workflow/lifecycle/create", workflow::handle_lifecycle_workflow_create); + add_route!($obj, get, format_route!("/workflow/lifecycle/status/{}", "id"), workflow::handle_lifecycle_workflow_status); + add_route!($obj, post, format_route!("/workflow/lifecycle/pause/{}", "id"), workflow::handle_lifecycle_workflow_pause); + add_route!($obj, post, format_route!("/workflow/lifecycle/resume/{}", "id"), workflow::handle_lifecycle_workflow_resume); + add_route!($obj, post, format_route!("/workflow/lifecycle/terminate/{}", "id"), workflow::handle_lifecycle_workflow_terminate); + add_route!($obj, post, format_route!("/workflow/lifecycle/restart/{}", "id"), workflow::handle_lifecycle_workflow_restart); }); #[cfg(feature = "http")] diff --git a/test/src/workflow.rs b/test/src/workflow.rs new file mode 100644 index 000000000..ba849b438 --- /dev/null +++ b/test/src/workflow.rs @@ -0,0 +1,265 @@ +use serde::{Deserialize, Serialize}; +use worker::wasm_bindgen::JsValue; +use worker::*; + +fn last_path_segment(req: &Request) -> Result { + let url = req.url()?; + url.path_segments() + .and_then(|mut s| s.next_back().map(String::from)) + .ok_or_else(|| Error::RustError("missing path segment".into())) +} + +async fn get_workflow_instance( + req: &Request, + env: &Env, + binding: &str, +) -> Result { + let id = last_path_segment(req)?; + let workflow = env.workflow(binding)?; + workflow.get(&id).await +} + +async fn create_workflow_no_params(env: &Env, binding: &str) -> Result { + let workflow = env.workflow(binding)?; + let instance = workflow.create(None::>).await?; + Response::from_json(&serde_json::json!({ "id": instance.id() })) +} + +fn status_response(status: InstanceStatus) -> Result { + Response::from_json(&serde_json::json!({ + "status": format!("{:?}", status.status), + "output": status.output, + "error": status.error, + })) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TestParams { + pub value: String, +} + +#[workflow] +pub struct TestWorkflow { + #[allow(dead_code)] + env: Env, +} + +impl WorkflowEntrypoint for TestWorkflow { + fn new(_ctx: Context, env: Env) -> Self { + Self { env } + } + + async fn run(&self, event: WorkflowEvent, step: WorkflowStep) -> Result { + let params: TestParams = serde_wasm_bindgen::from_value(event.payload)?; + + let value_for_validation = params.value.clone(); + step.do_with_config( + "validate", + StepConfig { + retries: Some(RetryConfig { + limit: 2, + delay: "1 second".into(), + backoff: None, + }), + timeout: None, + }, + move |_ctx| { + let value = value_for_validation.clone(); + async move { + if value.is_empty() { + return Err(NonRetryableError::new("value must not be empty").into()); + } + Ok(serde_json::json!({ "valid": true })) + } + }, + ) + .await?; + + let result: serde_json::Value = step + .do_("process", move |_ctx| { + let params = params.clone(); + async move { Ok(serde_json::json!({ "processed": params.value })) } + }) + .await?; + + Ok(serialize_as_object(&result)?) + } +} + +async fn create_workflow_with_value(env: &Env, value: &str) -> Result { + let workflow = env.workflow("TEST_WORKFLOW")?; + let params = TestParams { + value: value.to_string(), + }; + let instance = workflow + .create(Some(CreateOptions { + params: Some(params), + ..Default::default() + })) + .await?; + + Response::from_json(&serde_json::json!({ "id": instance.id() })) +} + +pub async fn handle_workflow_create( + _req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + create_workflow_with_value(&env, "hello").await +} + +pub async fn handle_workflow_create_invalid( + _req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + create_workflow_with_value(&env, "").await +} + +pub async fn handle_workflow_status( + req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + let instance = get_workflow_instance(&req, &env, "TEST_WORKFLOW").await?; + let status = instance.status().await?; + status_response(status) +} + +#[workflow] +pub struct EventWorkflow { + #[allow(dead_code)] + env: Env, +} + +impl WorkflowEntrypoint for EventWorkflow { + fn new(_ctx: Context, env: Env) -> Self { + Self { env } + } + + async fn run(&self, _event: WorkflowEvent, step: WorkflowStep) -> Result { + let event = step + .wait_for_event::( + "wait-for-approval", + WaitForEventOptions { + type_: "approval".to_string(), + timeout: Some("30 seconds".into()), + }, + ) + .await?; + + Ok(serialize_as_object(&serde_json::json!({ + "payload": event.payload, + "type": event.type_, + }))?) + } +} + +pub async fn handle_event_workflow_create( + _req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + create_workflow_no_params(&env, "EVENT_WORKFLOW").await +} + +pub async fn handle_event_workflow_send( + mut req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + let instance = get_workflow_instance(&req, &env, "EVENT_WORKFLOW").await?; + let payload: serde_json::Value = req.json().await?; + instance.send_event("approval", payload).await?; + Response::ok("sent") +} + +pub async fn handle_event_workflow_status( + req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + let instance = get_workflow_instance(&req, &env, "EVENT_WORKFLOW").await?; + let status = instance.status().await?; + status_response(status) +} + +#[workflow] +pub struct LifecycleWorkflow { + #[allow(dead_code)] + env: Env, +} + +impl WorkflowEntrypoint for LifecycleWorkflow { + fn new(_ctx: Context, env: Env) -> Self { + Self { env } + } + + async fn run(&self, _event: WorkflowEvent, step: WorkflowStep) -> Result { + step.sleep( + "long-sleep", + WorkflowSleepDuration::new(60, WorkflowDuration::Seconds), + ) + .await?; + Ok(serialize_as_object(&serde_json::json!({ "done": true }))?) + } +} + +pub async fn handle_lifecycle_workflow_create( + _req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + create_workflow_no_params(&env, "LIFECYCLE_WORKFLOW").await +} + +pub async fn handle_lifecycle_workflow_status( + req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + let instance = get_workflow_instance(&req, &env, "LIFECYCLE_WORKFLOW").await?; + let status = instance.status().await?; + status_response(status) +} + +pub async fn handle_lifecycle_workflow_pause( + req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + let instance = get_workflow_instance(&req, &env, "LIFECYCLE_WORKFLOW").await?; + instance.pause().await?; + Response::ok("paused") +} + +pub async fn handle_lifecycle_workflow_resume( + req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + let instance = get_workflow_instance(&req, &env, "LIFECYCLE_WORKFLOW").await?; + instance.resume().await?; + Response::ok("resumed") +} + +pub async fn handle_lifecycle_workflow_terminate( + req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + let instance = get_workflow_instance(&req, &env, "LIFECYCLE_WORKFLOW").await?; + instance.terminate().await?; + Response::ok("terminated") +} + +pub async fn handle_lifecycle_workflow_restart( + req: Request, + env: Env, + _data: crate::SomeSharedData, +) -> Result { + let instance = get_workflow_instance(&req, &env, "LIFECYCLE_WORKFLOW").await?; + instance.restart().await?; + Response::ok("restarted") +} diff --git a/test/tests/mf.ts b/test/tests/mf.ts index 3c881b187..bd65b3418 100644 --- a/test/tests/mf.ts +++ b/test/tests/mf.ts @@ -36,6 +36,7 @@ const mf_instance = new Miniflare({ kvPersist: false, r2Persist: false, cachePersist: false, + workflowsPersist: false, workers: [ { scriptPath: "./build/index.js", @@ -112,6 +113,25 @@ const mf_instance = new Miniflare({ scriptName: "mini-analytics-engine" // mock out analytics engine binding to the "mini-analytics-engine" worker } }, + // Workflow binding requires a separate worker via scriptName in + // the Miniflare JS API (wrangler dev handles this automatically). + workflows: { + TEST_WORKFLOW: { + name: "test-workflow", + className: "TestWorkflow", + scriptName: "workflow-worker", + }, + EVENT_WORKFLOW: { + name: "event-workflow", + className: "EventWorkflow", + scriptName: "workflow-worker", + }, + LIFECYCLE_WORKFLOW: { + name: "lifecycle-workflow", + className: "LifecycleWorkflow", + scriptName: "workflow-worker", + }, + }, ratelimits: { TEST_RATE_LIMITER: { simple: { @@ -121,6 +141,17 @@ const mf_instance = new Miniflare({ } } }, + { + // Dedicated worker for TestWorkflow; uses the generated JS class wrapper. + name: "workflow-worker", + scriptPath: "./build/worker/shim.mjs", + modules: true, + modulesRules: [ + { type: "ESModule", include: ["**/*.js"], fallthrough: true }, + { type: "CompiledWasm", include: ["**/*.wasm"], fallthrough: true }, + ], + compatibilityDate: "2025-07-24", + }, { name: "mini-analytics-engine", modules: true, diff --git a/test/tests/workflow.spec.ts b/test/tests/workflow.spec.ts new file mode 100644 index 000000000..b471490e0 --- /dev/null +++ b/test/tests/workflow.spec.ts @@ -0,0 +1,192 @@ +import { describe, test, expect } from "vitest"; +import { mf, mfUrl } from "./mf"; + +describe("workflow", () => { + test("create and poll status until completion", async () => { + const createResp = await mf.dispatchFetch(`${mfUrl}workflow/create`, { + method: "POST", + }); + expect(createResp.status).toBe(200); + const { id } = (await createResp.json()) as { id: string }; + expect(id).toBeDefined(); + expect(typeof id).toBe("string"); + + let status: string | undefined; + let output: unknown; + for (let i = 0; i < 30; i++) { + const statusResp = await mf.dispatchFetch( + `${mfUrl}workflow/status/${id}` + ); + expect(statusResp.status).toBe(200); + const body = (await statusResp.json()) as { + status: string; + output: unknown; + error: unknown; + }; + status = body.status; + output = body.output; + if (status === "Complete" || status === "Errored") { + break; + } + await new Promise((resolve) => setTimeout(resolve, 500)); + } + + expect(status).toBe("Complete"); + expect(output).toEqual({ processed: "hello" }); + }); + + test("non-retryable error stops workflow immediately", async () => { + const createResp = await mf.dispatchFetch( + `${mfUrl}workflow/create-invalid`, + { method: "POST" } + ); + expect(createResp.status).toBe(200); + const { id } = (await createResp.json()) as { id: string }; + expect(id).toBeDefined(); + + let status: string | undefined; + for (let i = 0; i < 30; i++) { + const statusResp = await mf.dispatchFetch( + `${mfUrl}workflow/status/${id}` + ); + expect(statusResp.status).toBe(200); + const body = (await statusResp.json()) as { + status: string; + output: unknown; + }; + status = body.status; + if (status === "Complete" || status === "Errored") { + break; + } + await new Promise((resolve) => setTimeout(resolve, 500)); + } + + expect(status).toBe("Errored"); + }); + + async function lifecycleStatus( + id: string + ): Promise<{ status: string; error: unknown }> { + const resp = await mf.dispatchFetch( + `${mfUrl}workflow/lifecycle/status/${id}` + ); + return (await resp.json()) as { status: string; error: unknown }; + } + + async function pollUntil( + id: string, + predicate: (status: string) => boolean + ): Promise { + for (let i = 0; i < 10; i++) { + const { status } = await lifecycleStatus(id); + if (predicate(status)) return status; + await new Promise((resolve) => setTimeout(resolve, 100)); + } + const { status } = await lifecycleStatus(id); + return status; + } + + async function createLifecycleWorkflow(): Promise { + const resp = await mf.dispatchFetch( + `${mfUrl}workflow/lifecycle/create`, + { method: "POST" } + ); + expect(resp.status).toBe(200); + const { id } = (await resp.json()) as { id: string }; + await pollUntil(id, (s) => s !== "Queued"); + return id; + } + + async function lifecycleAction( + action: string, + id: string + ): Promise { + return mf.dispatchFetch( + `${mfUrl}workflow/lifecycle/${action}/${id}`, + { method: "POST" } + ); + } + + test("pause and resume a running workflow", async () => { + const id = await createLifecycleWorkflow(); + + expect((await lifecycleAction("pause", id)).status).toBe(200); + const paused = await pollUntil(id, (s) => s === "Paused"); + expect(paused).toBe("Paused"); + + expect((await lifecycleAction("resume", id)).status).toBe(200); + const resumed = await pollUntil(id, (s) => s !== "Paused"); + expect(resumed).not.toBe("Paused"); + + await lifecycleAction("terminate", id); + }); + + test("terminate a running workflow", async () => { + const id = await createLifecycleWorkflow(); + + expect((await lifecycleAction("terminate", id)).status).toBe(200); + const status = await pollUntil(id, (s) => s === "Terminated"); + expect(status).toBe("Terminated"); + }); + + test("restart a running workflow", async () => { + const id = await createLifecycleWorkflow(); + + expect((await lifecycleAction("restart", id)).status).toBe(200); + const status = await pollUntil( + id, + (s) => s !== "Queued" + ); + expect(["Running", "Waiting", "Queued"]).toContain(status); + + await lifecycleAction("terminate", id); + }); + + test("wait_for_event receives sent event", async () => { + const createResp = await mf.dispatchFetch( + `${mfUrl}workflow/event/create`, + { method: "POST" } + ); + expect(createResp.status).toBe(200); + const { id } = (await createResp.json()) as { id: string }; + expect(id).toBeDefined(); + + // Give the workflow time to reach wait_for_event + await new Promise((resolve) => setTimeout(resolve, 1000)); + + // Send the event + const sendResp = await mf.dispatchFetch( + `${mfUrl}workflow/event/send/${id}`, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ approved: true, reason: "looks good" }), + } + ); + expect(sendResp.status).toBe(200); + + // Poll until complete + let status: string | undefined; + let output: any; + for (let i = 0; i < 30; i++) { + const statusResp = await mf.dispatchFetch( + `${mfUrl}workflow/event/status/${id}` + ); + expect(statusResp.status).toBe(200); + const body = (await statusResp.json()) as { + status: string; + output: any; + }; + status = body.status; + output = body.output; + if (status === "Complete" || status === "Errored") { + break; + } + await new Promise((resolve) => setTimeout(resolve, 500)); + } + + expect(status).toBe("Complete"); + expect(output.payload).toEqual({ approved: true, reason: "looks good" }); + expect(output.type).toBe("approval"); + }); +}); diff --git a/test/wrangler.toml b/test/wrangler.toml index 16b49d87e..99c0a2602 100644 --- a/test/wrangler.toml +++ b/test/wrangler.toml @@ -86,6 +86,21 @@ class_name = "EchoContainer" image = "./container-echo/Dockerfile" max_instances = 1 +[[workflows]] +name = "test-workflow" +binding = "TEST_WORKFLOW" +class_name = "TestWorkflow" + +[[workflows]] +name = "event-workflow" +binding = "EVENT_WORKFLOW" +class_name = "EventWorkflow" + +[[workflows]] +name = "lifecycle-workflow" +binding = "LIFECYCLE_WORKFLOW" +class_name = "LifecycleWorkflow" + [[ratelimits]] name = "TEST_RATE_LIMITER" namespace_id = "1" diff --git a/worker-build/src/main.rs b/worker-build/src/main.rs index c6b152d77..4353784c5 100644 --- a/worker-build/src/main.rs +++ b/worker-build/src/main.rs @@ -123,7 +123,7 @@ pub fn main() -> Result<()> { fs::write(&shim_path, shim) .with_context(|| format!("Failed to write {}", shim_path.display()))?; - add_export_wrappers(&staging_dir)?; + let has_workflows = add_export_wrappers(&staging_dir)?; update_package_json(&staging_dir)?; @@ -134,7 +134,9 @@ pub fn main() -> Result<()> { remove_unused_files(&staging_dir)?; - create_wrapper_alias(&staging_dir, false)?; + if !has_workflows { + create_wrapper_alias(&staging_dir, false)?; + } } else { main_legacy::process(&staging_dir)?; create_wrapper_alias(&staging_dir, true)?; @@ -161,7 +163,7 @@ fn generate_handlers(out_dir: &Path) -> Result { if let Some(bracket_pos) = rest.find("(") { let func_name = rest[..bracket_pos].trim(); // strip the exported function (we re-wrap all handlers) - if !SYSTEM_FNS.contains(&func_name) { + if !SYSTEM_FNS.contains(&func_name) && !func_name.starts_with("__wf_") { func_names.push(func_name); } } @@ -170,7 +172,7 @@ fn generate_handlers(out_dir: &Path) -> Result { let rest = &rest[as_pos + 4..]; if let Some(brace_pos) = rest.find("}") { let func_name = rest[..brace_pos].trim(); - if !SYSTEM_FNS.contains(&func_name) { + if !SYSTEM_FNS.contains(&func_name) && !func_name.starts_with("__wf_") { func_names.push(func_name); } } @@ -206,17 +208,26 @@ fn generate_handlers(out_dir: &Path) -> Result { static SYSTEM_FNS: &[&str] = &["__wbg_reset_state", "setPanicHook"]; -fn add_export_wrappers(out_dir: &Path) -> Result<()> { +/// Returns true if workflow classes were detected and a wrapper was generated +fn add_export_wrappers(out_dir: &Path) -> Result { let index_path = output_path(out_dir, "index.js"); let content = fs::read_to_string(&index_path) .with_context(|| format!("Failed to read {}", index_path.display()))?; let mut class_names = Vec::new(); + let mut workflow_classes = Vec::new(); for line in content.lines() { if let Some(rest) = line.strip_prefix("export class ") { if let Some(brace_pos) = rest.find("{") { - let class_name = rest[..brace_pos].trim(); - class_names.push(class_name.to_string()); + class_names.push(rest[..brace_pos].trim().to_string()); + } + } else if let Some(rest) = line.strip_prefix("export function __wf_") { + if let Some(paren_pos) = rest.find('(') { + workflow_classes.push(rest[..paren_pos].trim().to_string()); + } + } else if let Some(rest) = line.strip_prefix("export { __wf_") { + if let Some(as_pos) = rest.find(" as ") { + workflow_classes.push(rest[..as_pos].trim().to_string()); } } } @@ -224,13 +235,63 @@ fn add_export_wrappers(out_dir: &Path) -> Result<()> { let shim_path = output_path(out_dir, "shim.js"); let mut output = fs::read_to_string(&shim_path) .with_context(|| format!("Failed to read {}", shim_path.display()))?; - for class_name in class_names { - output.push_str(&format!( - "export const {class_name} = new Proxy(exports.{class_name}, classProxyHooks);\n" - )); + + for class_name in &class_names { + if workflow_classes.contains(class_name) { + output.push_str(&format!( + "export const {class_name} = exports.{class_name};\n" + )); + } else { + output.push_str(&format!( + "export const {class_name} = new Proxy(exports.{class_name}, classProxyHooks);\n" + )); + } } fs::write(&shim_path, output) .with_context(|| format!("Failed to write {}", shim_path.display()))?; + + // Workflows need a JS wrapper that extends WorkflowEntrypoint from cloudflare:workers + let has_workflows = !workflow_classes.is_empty(); + if has_workflows { + generate_workflow_wrapper(out_dir, &workflow_classes)?; + } + + Ok(has_workflows) +} + +fn generate_workflow_wrapper(out_dir: &Path, workflow_classes: &[String]) -> Result<()> { + let mut wrapper = String::from( + r#"import { WorkflowEntrypoint } from "cloudflare:workers"; +import * as wasm from "../index.js"; +export * from "../index.js"; +export { default } from "../index.js"; + +"#, + ); + + for class_name in workflow_classes { + wrapper.push_str(&format!( + r#"export class {class_name} extends WorkflowEntrypoint {{ + constructor(ctx, env) {{ + super(ctx, env); + this.inner = new wasm.{class_name}(ctx, env); + }} + async run(event, step) {{ + return await this.inner.run(event, step); + }} +}} + +"# + )); + } + + let worker_dir = output_path(out_dir, "worker"); + fs::create_dir_all(&worker_dir) + .with_context(|| format!("Failed to create directory {}", worker_dir.display()))?; + let shim_path = output_path(out_dir, "worker/shim.mjs"); + fs::write(&shim_path, wrapper) + .with_context(|| format!("Failed to write {}", shim_path.display()))?; + Ok(()) } @@ -301,9 +362,7 @@ fn wasm_coredump(out_dir: &Path) -> Result<()> { fn create_wrapper_alias(out_dir: &Path, legacy: bool) -> Result<()> { let msg = if !legacy { - "// Use index.js directly, this file provided for backwards compat -// with former shim.mjs only. -" + "// Use index.js directly, this file provided for backwards compat\n// with former shim.mjs only.\n" } else { "" }; @@ -375,6 +434,7 @@ fn bundle(out_dir: &Path, esbuild_path: &Path) -> Result<()> { "--external:./index_bg.wasm", "--external:cloudflare:sockets", "--external:cloudflare:workers", + "--external:cloudflare:workflows", "--format=esm", "--bundle", "./shim.js", diff --git a/worker-build/src/main_legacy.rs b/worker-build/src/main_legacy.rs index 845473421..2992e0989 100644 --- a/worker-build/src/main_legacy.rs +++ b/worker-build/src/main_legacy.rs @@ -169,6 +169,7 @@ fn bundle(out_dir: &Path, esbuild_path: &Path) -> Result<()> { "--external:./index.wasm", "--external:cloudflare:sockets", "--external:cloudflare:workers", + "--external:cloudflare:workflows", "--format=esm", "--bundle", "./shim.js", diff --git a/worker-macros/Cargo.toml b/worker-macros/Cargo.toml index 94fe30b20..e79a3f75c 100644 --- a/worker-macros/Cargo.toml +++ b/worker-macros/Cargo.toml @@ -28,3 +28,4 @@ trybuild.workspace = true [features] queue = [] http = [] +workflow = [] diff --git a/worker-macros/src/lib.rs b/worker-macros/src/lib.rs index c6f2cd376..ed3bb9ff9 100644 --- a/worker-macros/src/lib.rs +++ b/worker-macros/src/lib.rs @@ -1,6 +1,8 @@ mod durable_object; mod event; mod send; +#[cfg(feature = "workflow")] +mod workflow; use proc_macro::TokenStream; @@ -141,3 +143,36 @@ pub fn send(attr: TokenStream, stream: TokenStream) -> TokenStream { pub fn consume(_: TokenStream, _: TokenStream) -> TokenStream { TokenStream::new() } + +/// Integrate the struct with the Workers Runtime as a Workflow Entrypoint. +/// Requires the `WorkflowEntrypoint` trait with the workflow attribute macro on the struct. +/// +/// ## Example +/// +/// ```rust,ignore +/// #[workflow] +/// pub struct MyWorkflow { +/// env: Env, +/// } +/// +/// impl WorkflowEntrypoint for MyWorkflow { +/// fn new(ctx: Context, env: Env) -> Self { +/// Self { env } +/// } +/// +/// async fn run(&self, event: WorkflowEvent, step: WorkflowStep) -> Result { +/// let result = step.do_("my step", || async { +/// Ok(serde_json::json!({"data": "value"})) +/// }).await?; +/// +/// Ok(result) +/// } +/// } +/// ``` +#[cfg(feature = "workflow")] +#[proc_macro_attribute] +pub fn workflow(_attr: TokenStream, item: TokenStream) -> TokenStream { + workflow::expand_macro(item.into()) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/worker-macros/src/workflow.rs b/worker-macros/src/workflow.rs new file mode 100644 index 000000000..657acef09 --- /dev/null +++ b/worker-macros/src/workflow.rs @@ -0,0 +1,81 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Error, ItemImpl, ItemStruct}; + +pub fn expand_macro(tokens: TokenStream) -> syn::Result { + let target = match syn::parse2::(tokens.clone()) { + Ok(s) => s, + Err(e) => { + if syn::parse2::(tokens).is_ok() { + return Err(Error::new( + proc_macro2::Span::call_site(), + "#[workflow] should only be applied to struct definitions, not impl blocks", + )); + } + return Err(e); + } + }; + let target_name = &target.ident; + let marker_fn_name = format_ident!("__wf_{}", target_name); + let marker_js_name = format!("__wf_{target_name}"); + let target_name_str = target_name.to_string(); + + Ok(quote! { + #target + + impl ::worker::HasWorkflowAttribute for #target_name {} + + const _: () = { + use ::worker::wasm_bindgen::prelude::*; + #[allow(unused_imports)] + use ::worker::WorkflowEntrypoint; + + #[allow(non_snake_case)] + #[wasm_bindgen(js_name = #marker_js_name, wasm_bindgen=::worker::wasm_bindgen)] + pub fn #marker_fn_name() -> ::worker::js_sys::JsString { + ::worker::js_sys::JsString::from(#target_name_str) + } + + #[wasm_bindgen(wasm_bindgen=::worker::wasm_bindgen)] + #[::worker::consume] + #target + + #[wasm_bindgen(wasm_bindgen=::worker::wasm_bindgen)] + impl #target_name { + #[wasm_bindgen(constructor, wasm_bindgen=::worker::wasm_bindgen)] + pub fn new( + ctx: ::worker::worker_sys::Context, + env: ::worker::Env + ) -> Self { + ::new( + ::worker::Context::new(ctx), + env + ) + } + + #[wasm_bindgen(js_name = run, wasm_bindgen=::worker::wasm_bindgen)] + pub fn run( + &self, + event: ::worker::wasm_bindgen::JsValue, + step: ::worker::worker_sys::WorkflowStep + ) -> ::worker::js_sys::Promise { + // SAFETY: The Cloudflare Workers runtime manages the Workflow instance + // lifecycle. The runtime guarantees that: + // 1. The instance is created before run() is called + // 2. The instance is not destroyed while any Promise returned by run() is pending + // 3. WASM execution is single-threaded, so no concurrent access is possible + // This is the same lifecycle model used by Durable Objects and WorkerEntrypoint. + let static_self: &'static Self = unsafe { &*(self as *const _) }; + + ::worker::wasm_bindgen_futures::future_to_promise(::std::panic::AssertUnwindSafe(async move { + let event = ::worker::WorkflowEvent::from_js(event) + .map_err(|e| ::worker::wasm_bindgen::JsValue::from_str(&e.to_string()))?; + let step = ::worker::WorkflowStep::from(step); + ::run(static_self, event, step).await + .map_err(::worker::wasm_bindgen::JsValue::from) + })) + } + } + }; + }) +} diff --git a/worker-sys/Cargo.toml b/worker-sys/Cargo.toml index c8d9baad8..84d9dd160 100644 --- a/worker-sys/Cargo.toml +++ b/worker-sys/Cargo.toml @@ -11,8 +11,10 @@ description = "Low-level extern definitions / FFI bindings to the Cloudflare Wor cfg-if.workspace = true js-sys.workspace = true wasm-bindgen.workspace = true +wasm-bindgen-futures.workspace = true web-sys.workspace = true [features] d1 = [] queue = [] +workflow = [] diff --git a/worker-sys/src/types.rs b/worker-sys/src/types.rs index afef5bc11..95158307e 100644 --- a/worker-sys/src/types.rs +++ b/worker-sys/src/types.rs @@ -22,6 +22,8 @@ mod tls_client_auth; mod version; mod websocket_pair; mod websocket_request_response_pair; +#[cfg(feature = "workflow")] +mod workflow; pub use ai::*; pub use analytics_engine::*; @@ -47,3 +49,5 @@ pub use tls_client_auth::*; pub use version::*; pub use websocket_pair::*; pub use websocket_request_response_pair::*; +#[cfg(feature = "workflow")] +pub use workflow::*; diff --git a/worker-sys/src/types/workflow.rs b/worker-sys/src/types/workflow.rs new file mode 100644 index 000000000..ee3ce2d3e --- /dev/null +++ b/worker-sys/src/types/workflow.rs @@ -0,0 +1,95 @@ +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(extends=js_sys::Object)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type WorkflowStep; + + #[wasm_bindgen(method, catch, js_name = "do")] + pub async fn do_( + this: &WorkflowStep, + name: &str, + callback: &js_sys::Function, + ) -> Result; + + #[wasm_bindgen(method, catch, js_name = "do")] + pub async fn do_with_config( + this: &WorkflowStep, + name: &str, + config: JsValue, + callback: &js_sys::Function, + ) -> Result; + + #[wasm_bindgen(method, catch)] + pub async fn sleep( + this: &WorkflowStep, + name: &str, + duration: JsValue, + ) -> Result; + + #[wasm_bindgen(method, catch, js_name = sleepUntil)] + pub async fn sleep_until( + this: &WorkflowStep, + name: &str, + timestamp: JsValue, + ) -> Result; + + #[wasm_bindgen(method, catch, js_name = waitForEvent)] + pub async fn wait_for_event( + this: &WorkflowStep, + name: &str, + options: JsValue, + ) -> Result; + + /// Workflow binding type - may be a Workflow object, WorkflowImpl, or Fetcher (RPC stub). + #[wasm_bindgen(extends=js_sys::Object)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type WorkflowBinding; + + #[wasm_bindgen(method, catch)] + pub async fn get(this: &WorkflowBinding, id: &str) -> Result; + + #[wasm_bindgen(method, catch)] + pub async fn create(this: &WorkflowBinding, options: JsValue) -> Result; + + #[wasm_bindgen(method, catch, js_name = createBatch)] + pub async fn create_batch( + this: &WorkflowBinding, + batch: &js_sys::Array, + ) -> Result; + + /// Workflow instance handle - may be an RPC stub in Miniflare. + #[wasm_bindgen(extends=js_sys::Object)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type WorkflowInstanceSys; + + #[wasm_bindgen(method, catch)] + pub async fn pause(this: &WorkflowInstanceSys) -> Result; + + #[wasm_bindgen(method, catch)] + pub async fn resume(this: &WorkflowInstanceSys) -> Result; + + #[wasm_bindgen(method, catch)] + pub async fn terminate(this: &WorkflowInstanceSys) -> Result; + + #[wasm_bindgen(method, catch)] + pub async fn restart(this: &WorkflowInstanceSys) -> Result; + + #[wasm_bindgen(method, catch)] + pub async fn status(this: &WorkflowInstanceSys) -> Result; + + #[wasm_bindgen(method, catch, js_name = sendEvent)] + pub async fn send_event(this: &WorkflowInstanceSys, event: JsValue) + -> Result; +} + +#[wasm_bindgen(module = "cloudflare:workflows")] +extern "C" { + #[wasm_bindgen(extends = js_sys::Error)] + #[derive(Debug, Clone)] + pub type NonRetryableErrorSys; + + #[wasm_bindgen(constructor, js_class = "NonRetryableError")] + pub fn new(message: &str) -> NonRetryableErrorSys; +} diff --git a/worker/Cargo.toml b/worker/Cargo.toml index bc9aed46d..75f8f0196 100644 --- a/worker/Cargo.toml +++ b/worker/Cargo.toml @@ -47,6 +47,7 @@ axum = { version = "0.8", optional = true, default-features = false } [features] queue = ["worker-macros/queue", "worker-sys/queue"] d1 = ["worker-sys/d1"] +workflow = ["worker-macros/workflow", "worker-sys/workflow"] http = ["worker-macros/http"] axum = ["dep:axum"] timezone = ["dep:chrono-tz"] diff --git a/worker/src/env.rs b/worker/src/env.rs index 780c2585b..396fbb9b7 100644 --- a/worker/src/env.rs +++ b/worker/src/env.rs @@ -8,6 +8,8 @@ use crate::rate_limit::RateLimiter; use crate::Ai; #[cfg(feature = "queue")] use crate::Queue; +#[cfg(feature = "workflow")] +use crate::Workflow; use crate::{durable::ObjectNamespace, Bucket, DynamicDispatcher, Fetcher, Result, SecretStore}; use crate::{error::Error, hyperdrive::Hyperdrive}; @@ -99,6 +101,12 @@ impl Env { self.get_binding(binding) } + #[cfg(feature = "workflow")] + /// Access a Workflow by the binding name configured in your wrangler.toml file. + pub fn workflow(&self, binding: &str) -> Result { + self.get_binding(binding) + } + /// Access an R2 Bucket by the binding name configured in your wrangler.toml file. pub fn bucket(&self, binding: &str) -> Result { self.get_binding(binding) diff --git a/worker/src/error.rs b/worker/src/error.rs index d57cb72ae..eec2ae6e6 100644 --- a/worker/src/error.rs +++ b/worker/src/error.rs @@ -124,7 +124,19 @@ impl std::fmt::Display for Error { #[cfg(feature = "http")] Error::Http(e) => write!(f, "http::Error: {e}"), Error::Infallible => write!(f, "infallible"), - Error::Internal(_) => write!(f, "unrecognized JavaScript object"), + Error::Internal(v) => { + if let Some(e) = v.dyn_ref::() { + let name = String::from(e.name()); + let msg = String::from(e.message()); + if name.is_empty() { + write!(f, "{msg}") + } else { + write!(f, "{name}: {msg}") + } + } else { + write!(f, "unrecognized JavaScript object") + } + } Error::Io(e) => write!(f, "IO Error: {e}"), Error::BindingError(name) => write!(f, "no binding found for `{name}`"), Error::RouteInsertError(e) => write!(f, "failed to insert route: {e}"), @@ -149,24 +161,16 @@ impl std::fmt::Display for Error { impl std::error::Error for Error {} -// Not sure if the changes I've made here are good or bad... impl From for Error { fn from(v: JsValue) -> Self { - match v.as_string().or_else(|| { - v.dyn_ref::().map(|e| { - format!( - "Error: {} - Cause: {}", - e.to_string(), - e.cause() - .as_string() - .or_else(|| { Some(e.to_string().into()) }) - .unwrap_or(String::from("N/A")) - ) - }) - }) { - Some(s) => Self::JsError(s), - None => Self::Internal(v), + if let Some(s) = v.as_string() { + return Self::JsError(s); } + // Preserve JS Error objects (and other non-string JsValues) as + // Internal so they survive roundtrips back to JS unchanged. + // This is important for workflow abort errors whose identity the + // engine checks with `instanceof Error` / `.message`. + Self::Internal(v) } } @@ -178,7 +182,10 @@ impl From for Error { impl From for JsValue { fn from(e: Error) -> Self { - JsValue::from_str(&e.to_string()) + match e { + Error::Internal(v) => v, + _ => JsValue::from_str(&e.to_string()), + } } } diff --git a/worker/src/lib.rs b/worker/src/lib.rs index 174427495..fa3f047cb 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -149,12 +149,16 @@ use std::result::Result as StdResult; #[doc(hidden)] pub use async_trait; pub use js_sys; +pub use serde_json; +pub use serde_wasm_bindgen; pub use url::Url; pub use wasm_bindgen; pub use wasm_bindgen_futures; pub use web_sys; pub use cf::{Cf, CfResponseProperties, TlsClientAuth}; +#[cfg(feature = "workflow")] +pub use worker_macros::workflow; pub use worker_macros::{consume, durable_object, event, send}; #[doc(hidden)] pub use worker_sys; @@ -197,6 +201,8 @@ pub use crate::socket::*; pub use crate::streams::*; pub use crate::version::*; pub use crate::websocket::*; +#[cfg(feature = "workflow")] +pub use crate::workflow::*; mod abort; mod ai; @@ -241,6 +247,8 @@ mod sql; mod streams; mod version; mod websocket; +#[cfg(feature = "workflow")] +mod workflow; /// A `Result` alias defaulting to [`Error`]. pub type Result = StdResult; diff --git a/worker/src/workflow.rs b/worker/src/workflow.rs new file mode 100644 index 000000000..ce698f51c --- /dev/null +++ b/worker/src/workflow.rs @@ -0,0 +1,590 @@ +//! Cloudflare Workflows support for Rust Workers. + +use std::future::Future; +use std::panic::AssertUnwindSafe; +use std::rc::Rc; + +use js_sys::{Object, Reflect}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use wasm_bindgen::prelude::*; +use wasm_bindgen::JsCast; +use wasm_bindgen_futures::future_to_promise; +use worker_sys::types::{ + NonRetryableErrorSys, WorkflowBinding as WorkflowBindingSys, WorkflowInstanceSys, + WorkflowStep as WorkflowStepSys, +}; + +use crate::env::EnvBinding; +use crate::send::SendFuture; +use crate::Result; + +/// Serialize a value to a JS object, ensuring maps are serialized as plain objects. +/// +/// This is useful when returning values from [`WorkflowEntrypoint::run`] that +/// need to be plain JS objects rather than `Map` instances. +pub fn serialize_as_object( + value: &T, +) -> std::result::Result { + value.serialize(&serde_wasm_bindgen::Serializer::new().serialize_maps_as_objects(true)) +} + +fn get_property(target: &JsValue, name: &str) -> Result { + Reflect::get(target, &JsValue::from_str(name)) + .map_err(|e| crate::Error::JsError(format!("failed to get property '{name}': {e:?}"))) +} + +fn get_string_property(target: &JsValue, name: &str) -> Result { + get_property(target, name)? + .as_string() + .ok_or_else(|| crate::Error::JsError(format!("{name} is not a string"))) +} + +fn get_timestamp_property(target: &JsValue, name: &str) -> Result { + let val = get_property(target, name)?; + Ok(crate::Date::from(js_sys::Date::from(val))) +} + +/// A Workflow binding for creating and managing workflow instances. +#[derive(Debug, Clone)] +pub struct Workflow { + inner: WorkflowBindingSys, +} + +// SAFETY: WASM is single-threaded. These types wrap JS objects that are only +// accessed from the main thread. Send/Sync are implemented to satisfy Rust's +// async machinery (e.g., holding references across await points), but actual +// cross-thread access is impossible in the Workers runtime. +unsafe impl Send for Workflow {} +unsafe impl Sync for Workflow {} + +impl Workflow { + /// Get a handle to an existing workflow instance by ID. + pub async fn get(&self, id: &str) -> Result { + let result = SendFuture::new(self.inner.get(id)).await?; + Ok(WorkflowInstance::from_js(result)) + } + + /// Create a new workflow instance. + pub async fn create( + &self, + options: Option>, + ) -> Result { + let js_options = match options { + Some(opts) => serde_wasm_bindgen::to_value(&opts)?, + None => JsValue::UNDEFINED, + }; + let result = SendFuture::new(self.inner.create(js_options)).await?; + Ok(WorkflowInstance::from_js(result)) + } + + /// Create a batch of workflow instances (limited to 100 at a time). + pub async fn create_batch( + &self, + batch: Vec>, + ) -> Result> { + let js_array = js_sys::Array::new(); + for opts in batch { + js_array.push(&serde_wasm_bindgen::to_value(&opts)?); + } + let result = SendFuture::new(self.inner.create_batch(&js_array)).await?; + let result_array: js_sys::Array = result.unchecked_into(); + + let len = result_array.length(); + let mut instances = Vec::with_capacity(len as usize); + for i in 0..len { + instances.push(WorkflowInstance::from_js(result_array.get(i))); + } + Ok(instances) + } +} + +impl EnvBinding for Workflow { + const TYPE_NAME: &'static str = "Workflow"; + + fn get(val: JsValue) -> Result { + let obj = Object::from(val); + let constructor_name = obj.constructor().name(); + if constructor_name == Self::TYPE_NAME || constructor_name == "WorkflowImpl" { + Ok(Self { + inner: obj.unchecked_into(), + }) + } else { + Err(format!( + "Binding cannot be cast to the type {} from {}", + Self::TYPE_NAME, + constructor_name + ) + .into()) + } + } +} + +impl JsCast for Workflow { + fn instanceof(_val: &JsValue) -> bool { + true + } + + fn unchecked_from_js(val: JsValue) -> Self { + Self { + inner: val.unchecked_into(), + } + } + + fn unchecked_from_js_ref(val: &JsValue) -> &Self { + unsafe { &*(val as *const JsValue as *const Self) } + } +} + +impl From for JsValue { + fn from(workflow: Workflow) -> Self { + workflow.inner.into() + } +} + +impl AsRef for Workflow { + fn as_ref(&self) -> &JsValue { + self.inner.as_ref() + } +} + +/// Options for creating a new workflow instance. +#[derive(Debug, Clone, Serialize)] +pub struct CreateOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub retention: Option, +} + +impl Default for CreateOptions { + fn default() -> Self { + Self { + id: None, + params: None, + retention: None, + } + } +} + +/// Retention policy for workflow instances. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct RetentionOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub success_retention: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_retention: Option, +} + +/// A handle to a workflow instance. +#[derive(Debug, Clone)] +pub struct WorkflowInstance { + inner: WorkflowInstanceSys, +} + +// SAFETY: See Workflow for rationale - WASM is single-threaded. +unsafe impl Send for WorkflowInstance {} +unsafe impl Sync for WorkflowInstance {} + +impl WorkflowInstance { + fn from_js(val: JsValue) -> Self { + Self { + inner: val.unchecked_into(), + } + } + + /// The unique ID of this workflow instance. + pub fn id(&self) -> String { + get_string_property(self.inner.as_ref(), "id") + .expect("WorkflowInstance always has an id property") + } + + /// Pause the workflow instance. + pub async fn pause(&self) -> Result<()> { + SendFuture::new(self.inner.pause()).await?; + Ok(()) + } + + /// Resume a paused workflow instance. + pub async fn resume(&self) -> Result<()> { + SendFuture::new(self.inner.resume()).await?; + Ok(()) + } + + /// Terminate the workflow instance. + pub async fn terminate(&self) -> Result<()> { + SendFuture::new(self.inner.terminate()).await?; + Ok(()) + } + + /// Restart the workflow instance. + pub async fn restart(&self) -> Result<()> { + SendFuture::new(self.inner.restart()).await?; + Ok(()) + } + + /// Get the current status of the workflow instance. + pub async fn status(&self) -> Result { + let result = SendFuture::new(self.inner.status()).await?; + Ok(serde_wasm_bindgen::from_value(result)?) + } + + /// Send an event to the workflow instance to trigger `step.wait_for_event()` calls. + pub async fn send_event(&self, type_: &str, payload: T) -> Result<()> { + #[derive(Serialize)] + struct SendEventPayload<'a, P: Serialize> { + #[serde(rename = "type")] + type_: &'a str, + payload: P, + } + let event = serde_wasm_bindgen::to_value(&SendEventPayload { type_, payload })?; + SendFuture::new(self.inner.send_event(event)).await?; + Ok(()) + } +} + +/// The status of a workflow instance. +#[derive(Debug, Clone, Deserialize)] +pub struct InstanceStatus { + pub status: InstanceStatusKind, + #[serde(default)] + pub error: Option, + #[serde(default)] + pub output: Option, +} + +/// The possible status values for a workflow instance. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum InstanceStatusKind { + Queued, + Running, + Paused, + Errored, + Terminated, + Complete, + Waiting, + WaitingForPause, + Unknown, +} + +/// Error information for a failed workflow. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InstanceError { + pub name: String, + pub message: String, +} + +/// Context passed to step callbacks, providing information about the current execution attempt. +#[derive(Debug, Clone, Copy)] +pub struct WorkflowStepContext { + /// The current retry attempt number (starts at 1). + pub attempt: u32, +} + +/// Provides methods for executing durable workflow steps. +#[derive(Debug)] +pub struct WorkflowStep(WorkflowStepSys); + +// SAFETY: See Workflow for rationale - WASM is single-threaded. +unsafe impl Send for WorkflowStep {} +unsafe impl Sync for WorkflowStep {} + +impl WorkflowStep { + fn wrap_callback( + callback: F, + ) -> wasm_bindgen::closure::Closure js_sys::Promise> + where + T: Serialize + 'static, + F: Fn(WorkflowStepContext) -> Fut + 'static, + Fut: Future> + 'static, + { + let callback = Rc::new(AssertUnwindSafe(callback)); + wasm_bindgen::closure::Closure::new(move |ctx: JsValue| -> js_sys::Promise { + let callback = callback.clone(); + let attempt = Reflect::get(&ctx, &JsValue::from_str("attempt")) + .ok() + .and_then(|v| v.as_f64()) + .unwrap_or(1.0) as u32; + future_to_promise(AssertUnwindSafe(async move { + let result = (callback.0)(WorkflowStepContext { attempt }) + .await + .map_err(JsValue::from)?; + serialize_as_object(&result).map_err(|e| JsValue::from_str(&e.to_string())) + })) + }) + } + + /// Execute a named step. The callback's return value is persisted and + /// returned without re-executing on replay. + pub async fn do_(&self, name: &str, callback: F) -> Result + where + T: Serialize + DeserializeOwned + 'static, + F: Fn(WorkflowStepContext) -> Fut + 'static, + Fut: Future> + 'static, + { + let closure = Self::wrap_callback(callback); + let js_fn = closure.as_ref().unchecked_ref::(); + let result = SendFuture::new(self.0.do_(name, js_fn)).await?; + Ok(serde_wasm_bindgen::from_value(result)?) + } + + /// Execute a named step with retry and timeout configuration. + pub async fn do_with_config( + &self, + name: &str, + config: StepConfig, + callback: F, + ) -> Result + where + T: Serialize + DeserializeOwned + 'static, + F: Fn(WorkflowStepContext) -> Fut + 'static, + Fut: Future> + 'static, + { + let config_js = serde_wasm_bindgen::to_value(&config)?; + let closure = Self::wrap_callback(callback); + let js_fn = closure.as_ref().unchecked_ref::(); + let result = SendFuture::new(self.0.do_with_config(name, config_js, js_fn)).await?; + Ok(serde_wasm_bindgen::from_value(result)?) + } + + /// Sleep for a specified duration (e.g., "1 minute", "5 seconds"). + pub async fn sleep( + &self, + name: &str, + duration: impl Into, + ) -> Result<()> { + let duration_js = duration.into().to_js_value(); + SendFuture::new(self.0.sleep(name, duration_js)).await?; + Ok(()) + } + + /// Sleep until a specific timestamp. + pub async fn sleep_until(&self, name: &str, timestamp: impl Into) -> Result<()> { + let date: crate::Date = timestamp.into(); + let ts_ms = date.as_millis() as f64; + SendFuture::new(self.0.sleep_until(name, ts_ms.into())).await?; + Ok(()) + } + + /// Wait for an external event sent via `WorkflowInstance::send_event()`. + pub async fn wait_for_event( + &self, + name: &str, + options: WaitForEventOptions, + ) -> Result> { + let options_js = serde_wasm_bindgen::to_value(&options)?; + let result = SendFuture::new(self.0.wait_for_event(name, options_js)).await?; + WorkflowStepEvent::from_js(result) + } +} + +impl From for WorkflowStep { + fn from(inner: WorkflowStepSys) -> Self { + Self(inner) + } +} + +/// Configuration for a workflow step. +#[derive(Debug, Clone, Default, Serialize)] +pub struct StepConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub retries: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +/// Retry configuration for a workflow step. +#[derive(Debug, Clone, Serialize)] +pub struct RetryConfig { + pub limit: u32, + pub delay: WorkflowSleepDuration, + #[serde(skip_serializing_if = "Option::is_none")] + pub backoff: Option, +} + +/// Backoff strategy for retries. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum Backoff { + Constant, + Linear, + Exponential, +} + +/// Options for waiting for an external event. +#[derive(Debug, Clone, Serialize)] +pub struct WaitForEventOptions { + #[serde(rename = "type")] + pub type_: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +/// An event received from `wait_for_event`. +#[derive(Debug, Clone)] +pub struct WorkflowStepEvent { + pub payload: T, + pub timestamp: crate::Date, + pub type_: String, +} + +impl WorkflowStepEvent { + fn from_js(value: JsValue) -> Result { + Ok(Self { + payload: serde_wasm_bindgen::from_value(get_property(&value, "payload")?)?, + timestamp: get_timestamp_property(&value, "timestamp")?, + type_: get_string_property(&value, "type")?, + }) + } +} + +/// The event passed to a workflow's run method. +#[derive(Debug, Clone)] +pub struct WorkflowEvent { + pub payload: JsValue, + pub timestamp: crate::Date, + pub instance_id: String, +} + +impl WorkflowEvent { + pub fn from_js(value: JsValue) -> Result { + Ok(Self { + payload: get_property(&value, "payload")?, + timestamp: get_timestamp_property(&value, "timestamp")?, + instance_id: get_string_property(&value, "instanceId")?, + }) + } +} + +/// Unit of time for workflow durations. +#[derive(Debug, Clone, Copy)] +pub enum WorkflowDuration { + Seconds, + Minutes, + Hours, + Days, + Weeks, + Months, + Years, +} + +/// A typed duration used throughout the Workflows API for sleep, timeout, +/// retry delay, and retention fields. +/// +/// Corresponds to the `WorkflowSleepDuration` type in the Workers runtime, +/// which accepts either a string like `"5 seconds"` or a number of milliseconds. +#[derive(Debug, Clone)] +enum WorkflowSleepDurationInner { + Text(String), + Millis(f64), +} + +#[derive(Debug, Clone)] +pub struct WorkflowSleepDuration(WorkflowSleepDurationInner); + +impl WorkflowSleepDuration { + /// Create a new duration with the given amount and unit. + pub fn new(amount: u32, unit: WorkflowDuration) -> Self { + let unit_str = match unit { + WorkflowDuration::Seconds => "seconds", + WorkflowDuration::Minutes => "minutes", + WorkflowDuration::Hours => "hours", + WorkflowDuration::Days => "days", + WorkflowDuration::Weeks => "weeks", + WorkflowDuration::Months => "months", + WorkflowDuration::Years => "years", + }; + Self(WorkflowSleepDurationInner::Text(format!( + "{amount} {unit_str}" + ))) + } + + fn to_js_value(&self) -> JsValue { + match &self.0 { + WorkflowSleepDurationInner::Text(s) => JsValue::from_str(s), + WorkflowSleepDurationInner::Millis(ms) => JsValue::from_f64(*ms), + } + } +} + +impl Serialize for WorkflowSleepDuration { + fn serialize( + &self, + serializer: S, + ) -> std::result::Result { + match &self.0 { + WorkflowSleepDurationInner::Text(s) => serializer.serialize_str(s), + WorkflowSleepDurationInner::Millis(ms) => serializer.serialize_f64(*ms), + } + } +} + +impl From<&str> for WorkflowSleepDuration { + fn from(s: &str) -> Self { + Self(WorkflowSleepDurationInner::Text(s.to_string())) + } +} + +impl From for WorkflowSleepDuration { + fn from(s: String) -> Self { + Self(WorkflowSleepDurationInner::Text(s)) + } +} + +impl From for WorkflowSleepDuration { + fn from(d: std::time::Duration) -> Self { + Self(WorkflowSleepDurationInner::Millis(d.as_millis() as f64)) + } +} + +/// Error type for non-retryable workflow errors. +/// +/// This wraps the JavaScript `NonRetryableError` from `cloudflare:workflows`, +/// which the Workflows runtime uses to identify errors that should not be retried. +#[derive(Debug)] +pub struct NonRetryableError { + inner: NonRetryableErrorSys, +} + +impl NonRetryableError { + pub fn new(message: impl Into) -> Self { + Self { + inner: NonRetryableErrorSys::new(&message.into()), + } + } +} + +impl std::fmt::Display for NonRetryableError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.inner.message()) + } +} + +impl std::error::Error for NonRetryableError {} + +impl From for JsValue { + fn from(e: NonRetryableError) -> Self { + e.inner.into() + } +} + +impl From for crate::Error { + fn from(e: NonRetryableError) -> Self { + crate::Error::Internal(e.inner.into()) + } +} + +/// Marker trait implemented by the `#[workflow]` macro. +#[doc(hidden)] +pub trait HasWorkflowAttribute {} + +/// Trait for implementing a Workflow entrypoint. +#[allow(async_fn_in_trait)] +pub trait WorkflowEntrypoint: HasWorkflowAttribute { + fn new(ctx: crate::Context, env: crate::Env) -> Self; + + async fn run(&self, event: WorkflowEvent, step: WorkflowStep) -> Result; +}