Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/vellum/workflows/events/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class _BaseNodeExecutionBody(UniversalBaseModel):
def serialize_node_definition(self, node_definition: Type, _info: Any) -> Dict[str, Any]:
return serialize_type_encoder_with_id(node_definition)

@property
def definition_name(self) -> str:
return self.node_definition.__name__


class _BaseNodeEvent(BaseEvent):
body: _BaseNodeExecutionBody
Expand Down
4 changes: 4 additions & 0 deletions src/vellum/workflows/events/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class _BaseWorkflowExecutionBody(UniversalBaseModel):
def serialize_workflow_definition(self, workflow_definition: Type, _info: Any) -> Dict[str, Any]:
return serialize_type_encoder_with_id(workflow_definition)

@property
def definition_name(self) -> str:
return self.workflow_definition.__name__


class _BaseWorkflowEvent(BaseEvent):
body: _BaseWorkflowExecutionBody
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def run(self) -> Iterator[BaseOutput]:
self._child_cancel_signal = ThreadingEvent()

with execution_context(parent_context=get_parent_context()):
subworkflow = self.subworkflow(
self._subworkflow_instance = self.subworkflow(
parent_state=self.state,
context=WorkflowContext.create_from(self._context),
)
subworkflow_stream = subworkflow.stream(
subworkflow_stream = self._subworkflow_instance.stream(
inputs=self._compile_subworkflow_inputs(),
event_filter=all_workflow_event_filter,
node_output_mocks=self._context._get_all_node_output_mocks(),
Expand All @@ -91,6 +91,7 @@ def run(self) -> Iterator[BaseOutput]:
fulfilled_output_names: Set[str] = set()

for event in subworkflow_stream:
print("subevent", event.name, event.body.definition_name)
self._context._emit_subworkflow_event(event)
if exception:
continue
Expand Down Expand Up @@ -136,8 +137,10 @@ def __cancel__(self, message: str) -> None:
"""
Propagate cancellation to the nested workflow by setting its cancel signal.
"""
print("cancelling subworkflow node")
if hasattr(self, "_child_cancel_signal"):
self._child_cancel_signal.set()
self._subworkflow_instance.join()

def _compile_subworkflow_inputs(self) -> InputsType:
if self.subworkflow is None:
Expand Down
57 changes: 18 additions & 39 deletions src/vellum/workflows/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def __init__(
self._timeout = timeout
self._execution_context = init_execution_context or get_execution_context()
self._trigger = trigger
self._is_cancelling = False

setattr(
self._initial_state,
Expand Down Expand Up @@ -664,6 +665,7 @@ def _handle_run_node_exception(
) -> NodeExecutionRejectedEvent:
logger.info(f"{prefix}: {exception}")
captured_stacktrace = traceback.format_exc()
print("handle run node exception", self.workflow.__class__.__name__, node.__class__.__name__)

return NodeExecutionRejectedEvent(
trace_id=execution.trace_id,
Expand Down Expand Up @@ -870,7 +872,7 @@ def _handle_work_item_event(self, event: WorkflowEvent) -> Optional[NodeExecutio

return None

def _emit_node_cancellation_events(
def _cancel_active_nodes(
self,
error_message: str,
) -> None:
Expand All @@ -880,32 +882,9 @@ def _emit_node_cancellation_events(
Args:
error_message: The error message to include in the cancellation events
"""
parent_context = WorkflowParentContext(
span_id=self._initial_state.meta.span_id,
workflow_definition=self.workflow.__class__,
parent=self._execution_context.parent_context,
)
captured_stacktrace = "".join(traceback.format_stack())
active_span_ids = list(self._active_nodes_by_execution_id.keys())
for span_id in active_span_ids:
active_node = self._active_nodes_by_execution_id.pop(span_id, None)
if active_node is not None:
active_node.node.__cancel__(error_message)

rejection_event = NodeExecutionRejectedEvent(
trace_id=self._execution_context.trace_id,
span_id=span_id,
body=NodeExecutionRejectedBody(
node_definition=active_node.node.__class__,
error=WorkflowError(
code=WorkflowErrorCode.NODE_CANCELLED,
message=error_message,
),
stacktrace=captured_stacktrace,
),
parent=parent_context,
)
self._workflow_event_outer_queue.put(rejection_event)
self._is_cancelling = True
for active_node in self._active_nodes_by_execution_id.values():
active_node.node.__cancel__(error_message)

def _initiate_workflow_event(self) -> WorkflowExecutionInitiatedEvent:
links: Optional[List[SpanLink]] = None
Expand Down Expand Up @@ -1059,18 +1038,21 @@ def _stream(self) -> None:
break

event = self._workflow_event_inner_queue.get()
print("outer flow event", self.workflow.__class__.__name__, event.name, event.body.definition_name)

self._workflow_event_outer_queue.put(event)

with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id):
rejection_event = self._handle_work_item_event(event)

if rejection_event:
if rejection_event and not self._is_cancelling:
failed_node_name = rejection_event.body.node_definition.__name__
self._emit_node_cancellation_events(
self._cancel_active_nodes(
error_message=f"Node execution cancelled due to {failed_node_name} failure",
)
break
self._workflow_event_inner_queue.put(
self._reject_workflow_event(rejection_event.error, rejection_event.body.stacktrace)
)

# Handle any remaining events
try:
Expand All @@ -1079,9 +1061,6 @@ def _stream(self) -> None:

with execution_context(parent_context=current_parent, trace_id=self._execution_context.trace_id):
rejection_event = self._handle_work_item_event(event)

if rejection_event:
break
except Empty:
pass

Expand All @@ -1100,10 +1079,7 @@ def _stream(self) -> None:
)
return

if rejection_event:
self._workflow_event_outer_queue.put(
self._reject_workflow_event(rejection_event.error, rejection_event.body.stacktrace)
)
if self._is_cancelling:
return

fulfilled_outputs = self.workflow.Outputs()
Expand Down Expand Up @@ -1139,11 +1115,14 @@ def _run_cancel_thread(self, kill_switch: ThreadingEvent) -> None:

while not kill_switch.wait(timeout=0.1):
if self._cancel_signal.is_set():
self._emit_node_cancellation_events(
self._cancel_active_nodes(
error_message="Workflow run cancelled",
)

captured_stacktrace = "".join(traceback.format_stack())
print(
"cancel thread rejection event put", self.workflow.__class__.__name__, "workflow.execution.rejected"
)
self._workflow_event_outer_queue.put(
self._reject_workflow_event(
WorkflowError(
Expand All @@ -1162,7 +1141,7 @@ def _run_timeout_thread(self, kill_switch: ThreadingEvent) -> None:
if kill_switch.wait(timeout=self._timeout):
return

self._emit_node_cancellation_events(
self._cancel_active_nodes(
error_message=f"Workflow execution exceeded timeout of {self._timeout} seconds",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
)


@pytest.mark.xfail(reason="Substantial changes are needed in Workflow Rejection to get this test to pass")
def test_parallel_inline_subworkflow_cancellation__streaming():
"""
Tests that when one parallel node fails, the other parallel inline subworkflow node
Expand All @@ -24,10 +23,13 @@ def test_parallel_inline_subworkflow_cancellation__streaming():

stream = workflow.stream(event_filter=all_workflow_event_filter)
events = list(stream)
workflow.join()

rejection_events = [e for e in events if e.name == "node.execution.rejected"]

assert len(rejection_events) >= 3, f"Expected at least 3 rejection events, got {len(rejection_events)}"
assert (
len(rejection_events) == 3
), f"Expected 3 rejection events, got {[e.node_definition.__name__ for e in rejection_events]}"

fast_failing_rejection = next((e for e in rejection_events if e.node_definition == FastFailingNode), None)
assert fast_failing_rejection is not None, "Expected FastFailingNode rejection event"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
import time

from vellum.workflows import BaseWorkflow
Expand All @@ -14,7 +15,7 @@ class Outputs(BaseNode.Outputs):
value: str

def run(self) -> Outputs:
time.sleep(0.01)
time.sleep(0.1)
raise NodeException(code=WorkflowErrorCode.USER_DEFINED_ERROR, message="Fast node failed")


Expand All @@ -25,9 +26,15 @@ class Outputs(BaseNode.Outputs):
value: str

def run(self) -> Outputs:
time.sleep(0.5)
self._cancelled = threading.Event()
if self._cancelled.wait(timeout=0.5):
raise NodeException(code=WorkflowErrorCode.NODE_CANCELLED, message="Slow node cancelled")

return self.Outputs(value="slow complete")

def __cancel__(self, message: str) -> None:
self._cancelled.set()


class SlowSubworkflow(BaseWorkflow):
"""Subworkflow containing a slow node."""
Expand Down
Loading