diff --git a/src/vellum/utils/json_encoder.py b/src/vellum/utils/json_encoder.py index b8363525d7..197dc0d44f 100644 --- a/src/vellum/utils/json_encoder.py +++ b/src/vellum/utils/json_encoder.py @@ -1,3 +1,4 @@ +import base64 from dataclasses import asdict, is_dataclass from datetime import datetime import enum @@ -48,6 +49,12 @@ def default(self, obj: Any) -> Any: if hasattr(obj, "__vellum_encode__") and callable(getattr(obj, "__vellum_encode__")): return obj.__vellum_encode__() + if isinstance(obj, (bytes, bytearray)): + try: + return obj.decode("utf-8") + except UnicodeDecodeError: + return base64.b64encode(obj).decode("ascii") + if isinstance(obj, UUID): return str(obj) diff --git a/src/vellum/workflows/nodes/bases/tests/test_base_node.py b/src/vellum/workflows/nodes/bases/tests/test_base_node.py index a6996baa3a..b5bfdce500 100644 --- a/src/vellum/workflows/nodes/bases/tests/test_base_node.py +++ b/src/vellum/workflows/nodes/bases/tests/test_base_node.py @@ -6,6 +6,7 @@ from vellum.client.types.string_vellum_value_request import StringVellumValueRequest from vellum.workflows.constants import undefined from vellum.workflows.descriptors.tests.test_utils import FixtureState +from vellum.workflows.events.types import default_serializer from vellum.workflows.inputs.base import BaseInputs from vellum.workflows.nodes import FinalOutputNode from vellum.workflows.nodes.bases.base import BaseNode @@ -15,6 +16,7 @@ from vellum.workflows.references.node import NodeReference from vellum.workflows.references.output import OutputReference from vellum.workflows.state.base import BaseState, StateMeta +from vellum.workflows.workflows.base import BaseWorkflow def test_base_node__node_resolution__unset_pydantic_fields(): @@ -379,3 +381,37 @@ class Ports(MyNode.Ports): # Potentially in the future, we support inheriting ports from multiple parents. # For now, we take only the declared ports, so that not all nodes have the default port. assert ports == ["bar"] + + +def test_base_node__bytes_output_converts_to_string(): + """Test that returning bytes in node outputs automatically converts to string.""" + + # GIVEN a node that returns bytes + class BytesOutputNode(BaseNode): + class Outputs(BaseNode.Outputs): + result: str + + def run(self) -> "BytesOutputNode.Outputs": + b = b"hello" + return self.Outputs(result=b) # type: ignore[arg-type] + + class OutputNode(FinalOutputNode): + class Outputs(FinalOutputNode.Outputs): + result = BytesOutputNode.Outputs.result + + class BytesWorkflow(BaseWorkflow): + graph = BytesOutputNode >> OutputNode + + class Outputs(BaseWorkflow.Outputs): + result = OutputNode.Outputs.result + + workflow = BytesWorkflow() + + # WHEN we run the workflow + result = workflow.run() + + # THEN the execution is fulfilled successfully + assert result.name == "workflow.execution.fulfilled" + + # AND the bytes are converted to a UTF-8 string when serialized + assert default_serializer(result.outputs.result) == "hello" diff --git a/src/vellum/workflows/runner/runner.py b/src/vellum/workflows/runner/runner.py index ed3a7dae0f..fa5dc647e0 100644 --- a/src/vellum/workflows/runner/runner.py +++ b/src/vellum/workflows/runner/runner.py @@ -49,7 +49,14 @@ NodeExecutionRejectedBody, NodeExecutionStreamingBody, ) -from vellum.workflows.events.types import BaseEvent, NodeParentContext, ParentContext, SpanLink, WorkflowParentContext +from vellum.workflows.events.types import ( + BaseEvent, + NodeParentContext, + ParentContext, + SpanLink, + WorkflowParentContext, + default_serializer, +) from vellum.workflows.events.workflow import ( WorkflowEventStream, WorkflowExecutionFulfilledBody, @@ -599,6 +606,20 @@ def initiate_node_streaming_output( parent=execution.parent_context, ) + for descriptor, output_value in outputs: + if output_value is undefined: + continue + try: + default_serializer(output_value) + except (TypeError, ValueError) as exc: + raise NodeException( + message=( + f"Node {node.__class__.__name__} produced output '{descriptor.name}' " + f"that could not be serialized to JSON: {exc}" + ), + code=WorkflowErrorCode.INVALID_OUTPUTS, + ) from exc + node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id) with execution_context(parent_context=updated_parent_context, trace_id=execution.trace_id):