1919import logging
2020import time
2121from dataclasses import dataclass , field
22- from typing import Any , AsyncIterator , Callable , Optional , Tuple , cast
22+ from typing import Any , AsyncIterator , Callable , Optional , Tuple , Type , cast
2323
2424from opentelemetry import trace as trace_api
25+ from pydantic import BaseModel
2526
2627from .._async import run_async
2728from ..agent import Agent
@@ -456,23 +457,34 @@ def __init__(
456457 run_async (lambda : self .hooks .invoke_callbacks_async (MultiAgentInitializedEvent (self )))
457458
458459 def __call__ (
459- self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
460+ self ,
461+ task : str | list [ContentBlock ],
462+ invocation_state : dict [str , Any ] | None = None ,
463+ structured_output_model : Type [BaseModel ] | None = None ,
464+ ** kwargs : Any ,
460465 ) -> GraphResult :
461466 """Invoke the graph synchronously.
462467
463468 Args:
464469 task: The task to execute
465470 invocation_state: Additional state/context passed to underlying agents.
466471 Defaults to None to avoid mutable default argument issues.
472+ structured_output_model: Pydantic model type for structured output.
467473 **kwargs: Keyword arguments allowing backward compatible future changes.
468474 """
469475 if invocation_state is None :
470476 invocation_state = {}
471477
472- return run_async (lambda : self .invoke_async (task , invocation_state ))
478+ return run_async (
479+ lambda : self .invoke_async (task , invocation_state , structured_output_model = structured_output_model )
480+ )
473481
474482 async def invoke_async (
475- self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
483+ self ,
484+ task : str | list [ContentBlock ],
485+ invocation_state : dict [str , Any ] | None = None ,
486+ structured_output_model : Type [BaseModel ] | None = None ,
487+ ** kwargs : Any ,
476488 ) -> GraphResult :
477489 """Invoke the graph asynchronously.
478490
@@ -483,9 +495,10 @@ async def invoke_async(
483495 task: The task to execute
484496 invocation_state: Additional state/context passed to underlying agents.
485497 Defaults to None to avoid mutable default argument issues.
498+ structured_output_model: Pydantic model type for structured output.
486499 **kwargs: Keyword arguments allowing backward compatible future changes.
487500 """
488- events = self .stream_async (task , invocation_state , ** kwargs )
501+ events = self .stream_async (task , invocation_state , structured_output_model = structured_output_model , ** kwargs )
489502 final_event = None
490503 async for event in events :
491504 final_event = event
@@ -496,14 +509,19 @@ async def invoke_async(
496509 return cast (GraphResult , final_event ["result" ])
497510
498511 async def stream_async (
499- self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
512+ self ,
513+ task : str | list [ContentBlock ],
514+ invocation_state : dict [str , Any ] | None = None ,
515+ structured_output_model : Type [BaseModel ] | None = None ,
516+ ** kwargs : Any ,
500517 ) -> AsyncIterator [dict [str , Any ]]:
501518 """Stream events during graph execution.
502519
503520 Args:
504521 task: The task to execute
505522 invocation_state: Additional state/context passed to underlying agents.
506523 Defaults to None to avoid mutable default argument issues.
524+ structured_output_model: Pydantic model type for structured output.
507525 **kwargs: Keyword arguments allowing backward compatible future changes.
508526
509527 Yields:
@@ -546,7 +564,9 @@ async def stream_async(
546564 self .node_timeout or "None" ,
547565 )
548566
549- async for event in self ._execute_graph (invocation_state ):
567+ async for event in self ._execute_graph (
568+ invocation_state , structured_output_model = structured_output_model
569+ ):
550570 yield event .as_dict ()
551571
552572 # Set final status based on execution results
@@ -585,7 +605,9 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
585605 # Validate Agent-specific constraints for each node
586606 _validate_node_executor (node .executor )
587607
588- async def _execute_graph (self , invocation_state : dict [str , Any ]) -> AsyncIterator [Any ]:
608+ async def _execute_graph (
609+ self , invocation_state : dict [str , Any ], structured_output_model : Type [BaseModel ] | None = None
610+ ) -> AsyncIterator [Any ]:
589611 """Execute graph and yield TypedEvent objects."""
590612 ready_nodes = self ._resume_next_nodes if self ._resume_from_session else list (self .entry_points )
591613
@@ -604,7 +626,9 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
604626 ready_nodes .clear ()
605627
606628 # Execute current batch
607- async for event in self ._execute_nodes_parallel (current_batch , invocation_state ):
629+ async for event in self ._execute_nodes_parallel (
630+ current_batch , invocation_state , structured_output_model = structured_output_model
631+ ):
608632 yield event
609633
610634 # Find newly ready nodes after batch execution
@@ -628,7 +652,10 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
628652 ready_nodes .extend (newly_ready )
629653
630654 async def _execute_nodes_parallel (
631- self , nodes : list ["GraphNode" ], invocation_state : dict [str , Any ]
655+ self ,
656+ nodes : list ["GraphNode" ],
657+ invocation_state : dict [str , Any ],
658+ structured_output_model : Type [BaseModel ] | None = None ,
632659 ) -> AsyncIterator [Any ]:
633660 """Execute multiple nodes in parallel and merge their event streams in real-time.
634661
@@ -638,7 +665,14 @@ async def _execute_nodes_parallel(
638665 event_queue : asyncio .Queue [Any | None | Exception ] = asyncio .Queue ()
639666
640667 # Start all node streams as independent tasks
641- tasks = [asyncio .create_task (self ._stream_node_to_queue (node , event_queue , invocation_state )) for node in nodes ]
668+ tasks = [
669+ asyncio .create_task (
670+ self ._stream_node_to_queue (
671+ node , event_queue , invocation_state , structured_output_model = structured_output_model
672+ )
673+ )
674+ for node in nodes
675+ ]
642676
643677 try :
644678 # Consume events from the queue as they arrive
@@ -689,14 +723,17 @@ async def _stream_node_to_queue(
689723 node : GraphNode ,
690724 event_queue : asyncio .Queue [Any | None | Exception ],
691725 invocation_state : dict [str , Any ],
726+ structured_output_model : Type [BaseModel ] | None = None ,
692727 ) -> None :
693728 """Stream events from a node to the shared queue with optional timeout."""
694729 try :
695730 # Apply timeout to the entire streaming process if configured
696731 if self .node_timeout is not None :
697732
698733 async def stream_node () -> None :
699- async for event in self ._execute_node (node , invocation_state ):
734+ async for event in self ._execute_node (
735+ node , invocation_state , structured_output_model = structured_output_model
736+ ):
700737 await event_queue .put (event )
701738
702739 try :
@@ -707,7 +744,9 @@ async def stream_node() -> None:
707744 await event_queue .put (timeout_exc )
708745 else :
709746 # No timeout - stream normally
710- async for event in self ._execute_node (node , invocation_state ):
747+ async for event in self ._execute_node (
748+ node , invocation_state , structured_output_model = structured_output_model
749+ ):
711750 await event_queue .put (event )
712751 except Exception as e :
713752 # Send exception through queue for fail-fast behavior
@@ -774,7 +813,12 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
774813 )
775814 return False
776815
777- async def _execute_node (self , node : GraphNode , invocation_state : dict [str , Any ]) -> AsyncIterator [Any ]:
816+ async def _execute_node (
817+ self ,
818+ node : GraphNode ,
819+ invocation_state : dict [str , Any ],
820+ structured_output_model : Type [BaseModel ] | None = None ,
821+ ) -> AsyncIterator [Any ]:
778822 """Execute a single node and yield TypedEvent objects."""
779823 await self .hooks .invoke_callbacks_async (BeforeNodeCallEvent (self , node .node_id , invocation_state ))
780824
@@ -802,7 +846,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
802846 if isinstance (node .executor , MultiAgentBase ):
803847 # For nested multi-agent systems, stream their events and collect result
804848 multi_agent_result = None
805- async for event in node .executor .stream_async (node_input , invocation_state ):
849+ async for event in node .executor .stream_async (
850+ node_input , invocation_state , structured_output_model = structured_output_model
851+ ):
806852 # Forward nested multi-agent events with node context
807853 wrapped_event = MultiAgentNodeStreamEvent (node .node_id , event )
808854 yield wrapped_event
@@ -824,9 +870,20 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
824870 )
825871
826872 elif isinstance (node .executor , Agent ):
873+ # For agents, use agent's default structured_output_model if available,
874+ # otherwise use the graph-level one
875+ agent_structured_output_model = structured_output_model
876+ if (
877+ hasattr (node .executor , "_default_structured_output_model" )
878+ and node .executor ._default_structured_output_model is not None
879+ ):
880+ agent_structured_output_model = node .executor ._default_structured_output_model
881+
827882 # For agents, stream their events and collect result
828883 agent_response = None
829- async for event in node .executor .stream_async (node_input , invocation_state = invocation_state ):
884+ async for event in node .executor .stream_async (
885+ node_input , invocation_state = invocation_state , structured_output_model = agent_structured_output_model
886+ ):
830887 # Forward agent events with node context
831888 wrapped_event = MultiAgentNodeStreamEvent (node .node_id , event )
832889 yield wrapped_event
0 commit comments