From 2c007b0be2f80fe09d1429417e6f73d3a38b1915 Mon Sep 17 00:00:00 2001 From: David Vargas Date: Mon, 17 Nov 2025 11:41:39 -0500 Subject: [PATCH] WIP - Investigating all changes needed for reasonable cancellation --- src/vellum/workflows/events/node.py | 4 ++ src/vellum/workflows/events/workflow.py | 4 ++ .../core/inline_subworkflow_node/node.py | 7 ++- src/vellum/workflows/runner/runner.py | 57 ++++++------------- .../tests/test_workflow.py | 6 +- .../workflow.py | 11 +++- 6 files changed, 44 insertions(+), 45 deletions(-) diff --git a/src/vellum/workflows/events/node.py b/src/vellum/workflows/events/node.py index 0a3084d021..b55da3bc3c 100644 --- a/src/vellum/workflows/events/node.py +++ b/src/vellum/workflows/events/node.py @@ -24,6 +24,10 @@ class _BaseNodeExecutionBody(UniversalBaseModel): def serialize_node_definition(self, node_definition: Type, _info: Any) -> Dict[str, Any]: return serialize_type_encoder_with_id(node_definition) + @property + def definition_name(self) -> str: + return self.node_definition.__name__ + class _BaseNodeEvent(BaseEvent): body: _BaseNodeExecutionBody diff --git a/src/vellum/workflows/events/workflow.py b/src/vellum/workflows/events/workflow.py index d430b0f092..5c8d9b7e81 100644 --- a/src/vellum/workflows/events/workflow.py +++ b/src/vellum/workflows/events/workflow.py @@ -59,6 +59,10 @@ class _BaseWorkflowExecutionBody(UniversalBaseModel): def serialize_workflow_definition(self, workflow_definition: Type, _info: Any) -> Dict[str, Any]: return serialize_type_encoder_with_id(workflow_definition) + @property + def definition_name(self) -> str: + return self.workflow_definition.__name__ + class _BaseWorkflowEvent(BaseEvent): body: _BaseWorkflowExecutionBody diff --git a/src/vellum/workflows/nodes/core/inline_subworkflow_node/node.py b/src/vellum/workflows/nodes/core/inline_subworkflow_node/node.py index 9d7c6038c2..a5e162ef06 100644 --- a/src/vellum/workflows/nodes/core/inline_subworkflow_node/node.py +++ b/src/vellum/workflows/nodes/core/inline_subworkflow_node/node.py @@ -75,11 +75,11 @@ def run(self) -> Iterator[BaseOutput]: self._child_cancel_signal = ThreadingEvent() with execution_context(parent_context=get_parent_context()): - subworkflow = self.subworkflow( + self._subworkflow_instance = self.subworkflow( parent_state=self.state, context=WorkflowContext.create_from(self._context), ) - subworkflow_stream = subworkflow.stream( + subworkflow_stream = self._subworkflow_instance.stream( inputs=self._compile_subworkflow_inputs(), event_filter=all_workflow_event_filter, node_output_mocks=self._context._get_all_node_output_mocks(), @@ -91,6 +91,7 @@ def run(self) -> Iterator[BaseOutput]: fulfilled_output_names: Set[str] = set() for event in subworkflow_stream: + print("subevent", event.name, event.body.definition_name) self._context._emit_subworkflow_event(event) if exception: continue @@ -136,8 +137,10 @@ def __cancel__(self, message: str) -> None: """ Propagate cancellation to the nested workflow by setting its cancel signal. """ + print("cancelling subworkflow node") if hasattr(self, "_child_cancel_signal"): self._child_cancel_signal.set() + self._subworkflow_instance.join() def _compile_subworkflow_inputs(self) -> InputsType: if self.subworkflow is None: diff --git a/src/vellum/workflows/runner/runner.py b/src/vellum/workflows/runner/runner.py index ed3a7dae0f..39d70afd32 100644 --- a/src/vellum/workflows/runner/runner.py +++ b/src/vellum/workflows/runner/runner.py @@ -263,6 +263,7 @@ def __init__( self._timeout = timeout self._execution_context = init_execution_context or get_execution_context() self._trigger = trigger + self._is_cancelling = False setattr( self._initial_state, @@ -664,6 +665,7 @@ def _handle_run_node_exception( ) -> NodeExecutionRejectedEvent: logger.info(f"{prefix}: {exception}") captured_stacktrace = traceback.format_exc() + print("handle run node exception", self.workflow.__class__.__name__, node.__class__.__name__) return NodeExecutionRejectedEvent( trace_id=execution.trace_id, @@ -870,7 +872,7 @@ def _handle_work_item_event(self, event: WorkflowEvent) -> Optional[NodeExecutio return None - def _emit_node_cancellation_events( + def _cancel_active_nodes( self, error_message: str, ) -> None: @@ -880,32 +882,9 @@ def _emit_node_cancellation_events( Args: error_message: The error message to include in the cancellation events """ - parent_context = WorkflowParentContext( - span_id=self._initial_state.meta.span_id, - workflow_definition=self.workflow.__class__, - parent=self._execution_context.parent_context, - ) - captured_stacktrace = "".join(traceback.format_stack()) - active_span_ids = list(self._active_nodes_by_execution_id.keys()) - for span_id in active_span_ids: - active_node = self._active_nodes_by_execution_id.pop(span_id, None) - if active_node is not None: - active_node.node.__cancel__(error_message) - - rejection_event = NodeExecutionRejectedEvent( - trace_id=self._execution_context.trace_id, - span_id=span_id, - body=NodeExecutionRejectedBody( - node_definition=active_node.node.__class__, - error=WorkflowError( - code=WorkflowErrorCode.NODE_CANCELLED, - message=error_message, - ), - stacktrace=captured_stacktrace, - ), - parent=parent_context, - ) - self._workflow_event_outer_queue.put(rejection_event) + self._is_cancelling = True + for active_node in self._active_nodes_by_execution_id.values(): + active_node.node.__cancel__(error_message) def _initiate_workflow_event(self) -> WorkflowExecutionInitiatedEvent: links: Optional[List[SpanLink]] = None @@ -1059,18 +1038,21 @@ def _stream(self) -> None: break event = self._workflow_event_inner_queue.get() + print("outer flow event", self.workflow.__class__.__name__, event.name, event.body.definition_name) self._workflow_event_outer_queue.put(event) with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id): rejection_event = self._handle_work_item_event(event) - if rejection_event: + if rejection_event and not self._is_cancelling: failed_node_name = rejection_event.body.node_definition.__name__ - self._emit_node_cancellation_events( + self._cancel_active_nodes( error_message=f"Node execution cancelled due to {failed_node_name} failure", ) - break + self._workflow_event_inner_queue.put( + self._reject_workflow_event(rejection_event.error, rejection_event.body.stacktrace) + ) # Handle any remaining events try: @@ -1079,9 +1061,6 @@ def _stream(self) -> None: with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id): rejection_event = self._handle_work_item_event(event) - - if rejection_event: - break except Empty: pass @@ -1100,10 +1079,7 @@ def _stream(self) -> None: ) return - if rejection_event: - self._workflow_event_outer_queue.put( - self._reject_workflow_event(rejection_event.error, rejection_event.body.stacktrace) - ) + if self._is_cancelling: return fulfilled_outputs = self.workflow.Outputs() @@ -1139,11 +1115,14 @@ def _run_cancel_thread(self, kill_switch: ThreadingEvent) -> None: while not kill_switch.wait(timeout=0.1): if self._cancel_signal.is_set(): - self._emit_node_cancellation_events( + self._cancel_active_nodes( error_message="Workflow run cancelled", ) captured_stacktrace = "".join(traceback.format_stack()) + print( + "cancel thread rejection event put", self.workflow.__class__.__name__, "workflow.execution.rejected" + ) self._workflow_event_outer_queue.put( self._reject_workflow_event( WorkflowError( @@ -1162,7 +1141,7 @@ def _run_timeout_thread(self, kill_switch: ThreadingEvent) -> None: if kill_switch.wait(timeout=self._timeout): return - self._emit_node_cancellation_events( + self._cancel_active_nodes( error_message=f"Workflow execution exceeded timeout of {self._timeout} seconds", ) diff --git a/tests/workflows/parallel_inline_subworkflow_cancellation/tests/test_workflow.py b/tests/workflows/parallel_inline_subworkflow_cancellation/tests/test_workflow.py index fe885f663d..d4fe662564 100644 --- a/tests/workflows/parallel_inline_subworkflow_cancellation/tests/test_workflow.py +++ b/tests/workflows/parallel_inline_subworkflow_cancellation/tests/test_workflow.py @@ -12,7 +12,6 @@ ) -@pytest.mark.xfail(reason="Substantial changes are needed in Workflow Rejection to get this test to pass") def test_parallel_inline_subworkflow_cancellation__streaming(): """ Tests that when one parallel node fails, the other parallel inline subworkflow node @@ -24,10 +23,13 @@ def test_parallel_inline_subworkflow_cancellation__streaming(): stream = workflow.stream(event_filter=all_workflow_event_filter) events = list(stream) + workflow.join() rejection_events = [e for e in events if e.name == "node.execution.rejected"] - assert len(rejection_events) >= 3, f"Expected at least 3 rejection events, got {len(rejection_events)}" + assert ( + len(rejection_events) == 3 + ), f"Expected 3 rejection events, got {[e.node_definition.__name__ for e in rejection_events]}" fast_failing_rejection = next((e for e in rejection_events if e.node_definition == FastFailingNode), None) assert fast_failing_rejection is not None, "Expected FastFailingNode rejection event" diff --git a/tests/workflows/parallel_inline_subworkflow_cancellation/workflow.py b/tests/workflows/parallel_inline_subworkflow_cancellation/workflow.py index 71b34cc9c5..45bdbbbda6 100644 --- a/tests/workflows/parallel_inline_subworkflow_cancellation/workflow.py +++ b/tests/workflows/parallel_inline_subworkflow_cancellation/workflow.py @@ -1,3 +1,4 @@ +import threading import time from vellum.workflows import BaseWorkflow @@ -14,7 +15,7 @@ class Outputs(BaseNode.Outputs): value: str def run(self) -> Outputs: - time.sleep(0.01) + time.sleep(0.1) raise NodeException(code=WorkflowErrorCode.USER_DEFINED_ERROR, message="Fast node failed") @@ -25,9 +26,15 @@ class Outputs(BaseNode.Outputs): value: str def run(self) -> Outputs: - time.sleep(0.5) + self._cancelled = threading.Event() + if self._cancelled.wait(timeout=0.5): + raise NodeException(code=WorkflowErrorCode.NODE_CANCELLED, message="Slow node cancelled") + return self.Outputs(value="slow complete") + def __cancel__(self, message: str) -> None: + self._cancelled.set() + class SlowSubworkflow(BaseWorkflow): """Subworkflow containing a slow node."""