Skip to content

Commit c091def

Browse files
NgoHarrisonHarrison Ngo
andauthored
[APO-2122] Serialize workflow trigger id for scenarios (#3069)
* [APO-2122] Serialize workflow trigger id for scenarios * udpate tests --------- Co-authored-by: Harrison Ngo <[email protected]>
1 parent d168b4b commit c091def

File tree

8 files changed

+155
-3
lines changed

8 files changed

+155
-3
lines changed

ee/vellum_ee/workflows/display/workflows/base_workflow_display.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,11 @@ def serialize_module(
12111211
for i, inputs_obj in enumerate(dataset_attr):
12121212
if isinstance(inputs_obj, DatasetRow):
12131213
serialized_inputs = json.loads(json.dumps(inputs_obj.inputs, cls=VellumJsonEncoder))
1214-
dataset.append({"label": inputs_obj.label, "inputs": serialized_inputs})
1214+
row_data = {"label": inputs_obj.label, "inputs": serialized_inputs}
1215+
trigger_class = inputs_obj.workflow_trigger
1216+
if trigger_class is not None:
1217+
row_data["workflow_trigger_id"] = str(trigger_class.__id__)
1218+
dataset.append(row_data)
12151219
elif isinstance(inputs_obj, BaseInputs):
12161220
serialized_inputs = json.loads(json.dumps(inputs_obj, cls=VellumJsonEncoder))
12171221
dataset.append({"label": f"Scenario {i + 1}", "inputs": serialized_inputs})

ee/vellum_ee/workflows/tests/test_serialize_module.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import shutil
44
import sys
55
import tempfile
6+
import uuid
7+
from uuid import UUID
8+
9+
from pytest_mock import MockerFixture
610

711
from vellum.workflows.exceptions import WorkflowInitializationException
812
from vellum_ee.workflows.display.workflows.base_workflow_display import BaseWorkflowDisplay
@@ -20,6 +24,18 @@ def temp_module_path():
2024
shutil.rmtree(temp_dir)
2125

2226

27+
@pytest.fixture
28+
def metadata_trigger_factory(mocker: MockerFixture):
29+
def _factory(metadata_trigger_id: UUID) -> UUID:
30+
mocker.patch(
31+
"vellum.workflows.triggers.base._get_trigger_id_from_metadata",
32+
return_value=metadata_trigger_id,
33+
)
34+
return metadata_trigger_id
35+
36+
return _factory
37+
38+
2339
def test_serialize_module_with_dataset():
2440
"""Test that serialize_module correctly serializes dataset from sandbox modules."""
2541
module_path = "tests.workflows.basic_inputs_and_outputs"
@@ -50,6 +66,25 @@ def test_serialize_module_with_actual_dataset():
5066
assert result.dataset[1]["inputs"]["message"] == "DatasetRow Test"
5167

5268

69+
def test_serialize_module_with_actual_dataset_with_trigger(metadata_trigger_factory):
70+
"""Test that serialize_module correctly serializes dataset with trigger"""
71+
module_path = "tests.workflows.test_dataset_with_trigger_serialization"
72+
73+
metadata_trigger_id = uuid.uuid4()
74+
metadata_trigger_factory(metadata_trigger_id)
75+
76+
result = BaseWorkflowDisplay.serialize_module(module_path)
77+
78+
assert hasattr(result, "dataset")
79+
80+
assert result.dataset is not None
81+
assert isinstance(result.dataset, list)
82+
assert len(result.dataset) == 1
83+
84+
assert result.dataset[0]["label"] == "Scenario 1"
85+
assert result.dataset[0]["workflow_trigger_id"] == str(metadata_trigger_id)
86+
87+
5388
def test_serialize_module_happy_path():
5489
"""Test that serialize_module works with a valid module path."""
5590
module_path = "tests.workflows.trivial"

src/vellum/workflows/inputs/dataset_row.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any, Dict, Union
1+
from typing import Any, Dict, Optional, Type, Union
22

33
from pydantic import Field, field_serializer
44

55
from vellum.client.core.pydantic_utilities import UniversalBaseModel
66
from vellum.workflows.inputs.base import BaseInputs
7+
from vellum.workflows.triggers import BaseTrigger
78

89

910
class DatasetRow(UniversalBaseModel):
@@ -13,10 +14,12 @@ class DatasetRow(UniversalBaseModel):
1314
Attributes:
1415
label: String label for the dataset row
1516
inputs: BaseInputs instance or dict containing the input data
17+
workflow_trigger_id: Optional Trigger identifying the workflow trigger class for this scenario
1618
"""
1719

1820
label: str
1921
inputs: Union[BaseInputs, Dict[str, Any]] = Field(default_factory=BaseInputs)
22+
workflow_trigger: Optional[Type[BaseTrigger]] = None
2023

2124
@field_serializer("inputs")
2225
def serialize_inputs(self, inputs: Union[BaseInputs, Dict[str, Any]]) -> Dict[str, Any]:

src/vellum/workflows/sandbox.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Any, Dict, Generic, Optional, Sequence, Union
1+
from typing import Any, Dict, Generic, Optional, Sequence, Type, Union
22

33
import dotenv
44

55
from vellum.workflows.events.workflow import WorkflowEventStream
66
from vellum.workflows.inputs.base import BaseInputs
77
from vellum.workflows.inputs.dataset_row import DatasetRow
88
from vellum.workflows.logging import load_logger
9+
from vellum.workflows.triggers.base import BaseTrigger
910
from vellum.workflows.types.generics import WorkflowType
1011
from vellum.workflows.workflows.event_filters import root_workflow_event_filter
1112

@@ -52,8 +53,10 @@ def run(self, index: int = 0):
5253
selected_inputs = self._inputs[index]
5354

5455
raw_inputs: Union[BaseInputs, Dict[str, Any]]
56+
trigger_class: Optional[Type[BaseTrigger]] = None
5557
if isinstance(selected_inputs, DatasetRow):
5658
raw_inputs = selected_inputs.inputs
59+
trigger_class = selected_inputs.workflow_trigger
5760
else:
5861
raw_inputs = selected_inputs
5962

@@ -64,9 +67,15 @@ def run(self, index: int = 0):
6467
else:
6568
inputs_for_stream = raw_inputs
6669

70+
trigger_instance: Optional[BaseTrigger] = None
71+
if trigger_class is not None:
72+
# Instantiate the trigger with the inputs
73+
trigger_instance = trigger_class(**raw_inputs) if isinstance(raw_inputs, dict) else trigger_class()
74+
6775
events = self._workflow.stream(
6876
inputs=inputs_for_stream,
6977
event_filter=root_workflow_event_filter,
78+
trigger=trigger_instance,
7079
)
7180

7281
self._process_events(events)

src/vellum/workflows/tests/test_sandbox.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import pytest
2+
from datetime import datetime
23
from typing import List
34

45
from vellum.workflows.inputs.base import BaseInputs
56
from vellum.workflows.inputs.dataset_row import DatasetRow
67
from vellum.workflows.nodes.bases.base import BaseNode
78
from vellum.workflows.sandbox import WorkflowSandboxRunner
89
from vellum.workflows.state.base import BaseState
10+
from vellum.workflows.triggers import ScheduleTrigger
911
from vellum.workflows.workflows.base import BaseWorkflow
1012

1113

@@ -100,3 +102,53 @@ class Outputs(BaseWorkflow.Outputs):
100102
"----------------------------------",
101103
"final_output: Hello from dict",
102104
]
105+
106+
107+
def test_sandbox_runner_with_workflow_trigger(mock_logger):
108+
"""
109+
Test that WorkflowSandboxRunner can run with DatasetRow containing workflow_trigger.
110+
"""
111+
112+
# GIVEN we capture the logs to stdout
113+
logs = []
114+
mock_logger.return_value.info.side_effect = lambda msg: logs.append(msg)
115+
116+
class MySchedule(ScheduleTrigger):
117+
class Config(ScheduleTrigger.Config):
118+
cron = "* * * * *"
119+
timezone = "UTC"
120+
121+
class StartNode(BaseNode):
122+
class Outputs(BaseNode.Outputs):
123+
result = MySchedule.current_run_at
124+
125+
class Workflow(BaseWorkflow):
126+
graph = MySchedule >> StartNode
127+
128+
class Outputs(BaseWorkflow.Outputs):
129+
final_output = StartNode.Outputs.result
130+
131+
# AND a dataset with workflow_trigger
132+
dataset = [
133+
DatasetRow(
134+
label="test_row",
135+
inputs={"current_run_at": datetime.min, "next_run_at": datetime.now()},
136+
workflow_trigger=MySchedule,
137+
),
138+
]
139+
140+
# WHEN we run the sandbox with the DatasetRow containing workflow_trigger
141+
runner = WorkflowSandboxRunner(workflow=Workflow(), dataset=dataset)
142+
runner.run()
143+
144+
# THEN the workflow should run successfully
145+
assert logs == [
146+
"Just started Node: StartNode",
147+
"Just finished Node: StartNode",
148+
"Workflow fulfilled!",
149+
"----------------------------------",
150+
"final_output: 0001-01-01 00:00:00",
151+
]
152+
153+
# AND the dataset row should still have the trigger class
154+
assert dataset[0].workflow_trigger == MySchedule

tests/workflows/test_dataset_with_trigger_serialization/__init__.py

Whitespace-only changes.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from datetime import datetime
2+
from typing import List
3+
4+
from vellum.workflows.inputs.dataset_row import DatasetRow
5+
from vellum.workflows.sandbox import WorkflowSandboxRunner
6+
7+
from .workflow import MySchedule, TestDatasetWithTriggerSerializationWorkflow
8+
9+
if __name__ == "__main__":
10+
raise Exception("This file is not meant to be imported")
11+
12+
dataset: List[DatasetRow] = [
13+
DatasetRow(
14+
label="Scenario 1",
15+
inputs={"current_run_at": datetime.min, "next_run_at": datetime.now()},
16+
workflow_trigger=MySchedule,
17+
),
18+
]
19+
20+
runner = WorkflowSandboxRunner(workflow=TestDatasetWithTriggerSerializationWorkflow(), dataset=dataset)
21+
22+
runner.run()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from vellum.workflows import BaseWorkflow
2+
from vellum.workflows.nodes.bases import BaseNode
3+
from vellum.workflows.outputs import BaseOutputs
4+
from vellum.workflows.triggers import ScheduleTrigger
5+
6+
7+
class MySchedule(ScheduleTrigger):
8+
class Config(ScheduleTrigger.Config):
9+
cron = "* * * * *"
10+
timezone = "UTC"
11+
12+
13+
class SimpleNode(BaseNode):
14+
message = MySchedule.current_run_at
15+
16+
class Outputs(BaseOutputs):
17+
result: str
18+
19+
def run(self) -> BaseOutputs:
20+
return self.Outputs(result=f"Current run at: {str(self.message)}")
21+
22+
23+
class TestDatasetWithTriggerSerializationWorkflow(BaseWorkflow):
24+
graph = MySchedule >> SimpleNode
25+
26+
class Outputs(BaseOutputs):
27+
final_result = SimpleNode.Outputs.result

0 commit comments

Comments
 (0)