Skip to content

Commit ffe1559

Browse files
author
Guru Hariharaun
committed
feat(multiagent): add structured_output_model support to Graph
- Add structured_output_model parameter to Graph.__call__, invoke_async, and stream_async - Support per-agent structured_output_model (uses agent's default if available, otherwise uses graph-level) - Pass structured_output_model through execution chain to all agent nodes - Support structured output for nested MultiAgentBase nodes (Graph/Swarm) - Add comprehensive tests for structured output functionality This enables structured output from all agent nodes in a graph, allowing for type-safe, validated responses. Each agent can have its own structured output model, or a graph-level model can be used as a fallback. Fixes #538
1 parent 2b0c6e6 commit ffe1559

File tree

2 files changed

+254
-16
lines changed

2 files changed

+254
-16
lines changed

src/strands/multiagent/graph.py

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
import logging
2020
import time
2121
from dataclasses import dataclass, field
22-
from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast
22+
from typing import Any, AsyncIterator, Callable, Optional, Tuple, Type, cast
2323

2424
from opentelemetry import trace as trace_api
25+
from pydantic import BaseModel
2526

2627
from .._async import run_async
2728
from ..agent import Agent
@@ -456,23 +457,34 @@ def __init__(
456457
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
457458

458459
def __call__(
459-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
460+
self,
461+
task: str | list[ContentBlock],
462+
invocation_state: dict[str, Any] | None = None,
463+
structured_output_model: Type[BaseModel] | None = None,
464+
**kwargs: Any,
460465
) -> GraphResult:
461466
"""Invoke the graph synchronously.
462467
463468
Args:
464469
task: The task to execute
465470
invocation_state: Additional state/context passed to underlying agents.
466471
Defaults to None to avoid mutable default argument issues.
472+
structured_output_model: Pydantic model type for structured output.
467473
**kwargs: Keyword arguments allowing backward compatible future changes.
468474
"""
469475
if invocation_state is None:
470476
invocation_state = {}
471477

472-
return run_async(lambda: self.invoke_async(task, invocation_state))
478+
return run_async(
479+
lambda: self.invoke_async(task, invocation_state, structured_output_model=structured_output_model)
480+
)
473481

474482
async def invoke_async(
475-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
483+
self,
484+
task: str | list[ContentBlock],
485+
invocation_state: dict[str, Any] | None = None,
486+
structured_output_model: Type[BaseModel] | None = None,
487+
**kwargs: Any,
476488
) -> GraphResult:
477489
"""Invoke the graph asynchronously.
478490
@@ -483,9 +495,10 @@ async def invoke_async(
483495
task: The task to execute
484496
invocation_state: Additional state/context passed to underlying agents.
485497
Defaults to None to avoid mutable default argument issues.
498+
structured_output_model: Pydantic model type for structured output.
486499
**kwargs: Keyword arguments allowing backward compatible future changes.
487500
"""
488-
events = self.stream_async(task, invocation_state, **kwargs)
501+
events = self.stream_async(task, invocation_state, structured_output_model=structured_output_model, **kwargs)
489502
final_event = None
490503
async for event in events:
491504
final_event = event
@@ -496,14 +509,19 @@ async def invoke_async(
496509
return cast(GraphResult, final_event["result"])
497510

498511
async def stream_async(
499-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
512+
self,
513+
task: str | list[ContentBlock],
514+
invocation_state: dict[str, Any] | None = None,
515+
structured_output_model: Type[BaseModel] | None = None,
516+
**kwargs: Any,
500517
) -> AsyncIterator[dict[str, Any]]:
501518
"""Stream events during graph execution.
502519
503520
Args:
504521
task: The task to execute
505522
invocation_state: Additional state/context passed to underlying agents.
506523
Defaults to None to avoid mutable default argument issues.
524+
structured_output_model: Pydantic model type for structured output.
507525
**kwargs: Keyword arguments allowing backward compatible future changes.
508526
509527
Yields:
@@ -546,7 +564,9 @@ async def stream_async(
546564
self.node_timeout or "None",
547565
)
548566

549-
async for event in self._execute_graph(invocation_state):
567+
async for event in self._execute_graph(
568+
invocation_state, structured_output_model=structured_output_model
569+
):
550570
yield event.as_dict()
551571

552572
# Set final status based on execution results
@@ -585,7 +605,9 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
585605
# Validate Agent-specific constraints for each node
586606
_validate_node_executor(node.executor)
587607

588-
async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
608+
async def _execute_graph(
609+
self, invocation_state: dict[str, Any], structured_output_model: Type[BaseModel] | None = None
610+
) -> AsyncIterator[Any]:
589611
"""Execute graph and yield TypedEvent objects."""
590612
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)
591613

@@ -604,7 +626,9 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
604626
ready_nodes.clear()
605627

606628
# Execute current batch
607-
async for event in self._execute_nodes_parallel(current_batch, invocation_state):
629+
async for event in self._execute_nodes_parallel(
630+
current_batch, invocation_state, structured_output_model=structured_output_model
631+
):
608632
yield event
609633

610634
# Find newly ready nodes after batch execution
@@ -628,7 +652,10 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
628652
ready_nodes.extend(newly_ready)
629653

630654
async def _execute_nodes_parallel(
631-
self, nodes: list["GraphNode"], invocation_state: dict[str, Any]
655+
self,
656+
nodes: list["GraphNode"],
657+
invocation_state: dict[str, Any],
658+
structured_output_model: Type[BaseModel] | None = None,
632659
) -> AsyncIterator[Any]:
633660
"""Execute multiple nodes in parallel and merge their event streams in real-time.
634661
@@ -638,7 +665,14 @@ async def _execute_nodes_parallel(
638665
event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue()
639666

640667
# Start all node streams as independent tasks
641-
tasks = [asyncio.create_task(self._stream_node_to_queue(node, event_queue, invocation_state)) for node in nodes]
668+
tasks = [
669+
asyncio.create_task(
670+
self._stream_node_to_queue(
671+
node, event_queue, invocation_state, structured_output_model=structured_output_model
672+
)
673+
)
674+
for node in nodes
675+
]
642676

643677
try:
644678
# Consume events from the queue as they arrive
@@ -689,14 +723,17 @@ async def _stream_node_to_queue(
689723
node: GraphNode,
690724
event_queue: asyncio.Queue[Any | None | Exception],
691725
invocation_state: dict[str, Any],
726+
structured_output_model: Type[BaseModel] | None = None,
692727
) -> None:
693728
"""Stream events from a node to the shared queue with optional timeout."""
694729
try:
695730
# Apply timeout to the entire streaming process if configured
696731
if self.node_timeout is not None:
697732

698733
async def stream_node() -> None:
699-
async for event in self._execute_node(node, invocation_state):
734+
async for event in self._execute_node(
735+
node, invocation_state, structured_output_model=structured_output_model
736+
):
700737
await event_queue.put(event)
701738

702739
try:
@@ -707,7 +744,9 @@ async def stream_node() -> None:
707744
await event_queue.put(timeout_exc)
708745
else:
709746
# No timeout - stream normally
710-
async for event in self._execute_node(node, invocation_state):
747+
async for event in self._execute_node(
748+
node, invocation_state, structured_output_model=structured_output_model
749+
):
711750
await event_queue.put(event)
712751
except Exception as e:
713752
# Send exception through queue for fail-fast behavior
@@ -774,7 +813,12 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
774813
)
775814
return False
776815

777-
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
816+
async def _execute_node(
817+
self,
818+
node: GraphNode,
819+
invocation_state: dict[str, Any],
820+
structured_output_model: Type[BaseModel] | None = None,
821+
) -> AsyncIterator[Any]:
778822
"""Execute a single node and yield TypedEvent objects."""
779823
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
780824

@@ -802,7 +846,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
802846
if isinstance(node.executor, MultiAgentBase):
803847
# For nested multi-agent systems, stream their events and collect result
804848
multi_agent_result = None
805-
async for event in node.executor.stream_async(node_input, invocation_state):
849+
async for event in node.executor.stream_async(
850+
node_input, invocation_state, structured_output_model=structured_output_model
851+
):
806852
# Forward nested multi-agent events with node context
807853
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
808854
yield wrapped_event
@@ -824,9 +870,20 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
824870
)
825871

826872
elif isinstance(node.executor, Agent):
873+
# For agents, use agent's default structured_output_model if available,
874+
# otherwise use the graph-level one
875+
agent_structured_output_model = structured_output_model
876+
if (
877+
hasattr(node.executor, "_default_structured_output_model")
878+
and node.executor._default_structured_output_model is not None
879+
):
880+
agent_structured_output_model = node.executor._default_structured_output_model
881+
827882
# For agents, stream their events and collect result
828883
agent_response = None
829-
async for event in node.executor.stream_async(node_input, invocation_state=invocation_state):
884+
async for event in node.executor.stream_async(
885+
node_input, invocation_state=invocation_state, structured_output_model=agent_structured_output_model
886+
):
830887
# Forward agent events with node context
831888
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
832889
yield wrapped_event

0 commit comments

Comments
 (0)