Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6c00bbe
feat(multiagent): Add stream async
Oct 2, 2025
b09b539
Merge branch 'main' into multiagent-streaming
Oct 2, 2025
08141a0
fix(graph): improve parallel node calling
Oct 2, 2025
d4f5571
fix: Fix double execution
Oct 2, 2025
fc0a272
fix: improve graph timeout
Oct 3, 2025
ca59221
Merge branch 'main' into multiagent-streaming
Oct 3, 2025
60f16b9
fix: Add integ tests
Oct 3, 2025
a307f37
refactor(multiagent): improve streaming event handling and documentation
Oct 10, 2025
24502fc
fix(multiagent): remove no-op asyncio.gather in parallel execution
Oct 10, 2025
dd5445a
refactor: Fix streaming timeout logic
Oct 13, 2025
050c369
refactor: rename result to multiagent_result
Oct 13, 2025
defb5e5
refactor: simplify timeout logic
Oct 13, 2025
0b49c15
refactor: exception handling in graphs
Oct 14, 2025
6b64254
refactor: use alist in tests
Oct 14, 2025
d035654
Merge branch 'main' into multiagent-streaming
Oct 14, 2025
f018ea0
feat(multiagent): add type details to result events
Oct 14, 2025
3df5ee3
refactor: include node result in node complete event
Oct 14, 2025
19c93cc
refactor: change node complete to node stop
Oct 14, 2025
cd583ad
fix: fix failing integ tests
Oct 14, 2025
d97e5f4
refactor: address pr comments
Oct 17, 2025
45a1ee1
refactor: update multiagent types to use type key and update handoff …
Oct 17, 2025
01cb874
refactor: address comments
Oct 17, 2025
68d0b96
refactor: simplify integ tests
Oct 17, 2025
fb670ba
refactor: revert agent result changes
Oct 17, 2025
7f34e2d
refactor: update handoff event to use ids
Oct 17, 2025
7bede48
fix: remove comment
Oct 17, 2025
ff5bec8
chore: Merge main
Oct 17, 2025
012ef4a
Merge branch 'main' into multiagent-streaming
Oct 29, 2025
82fdd4f
test: improve test coverage
Oct 29, 2025
053f483
Merge branch 'main' into multiagent-streaming
mkmeral Oct 30, 2025
dbf1bf8
fix: add none check to fix mypy errors
Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,9 @@ async def stream_async(
yield as_dict

result = AgentResult(*event["stop"])
callback_handler(result=result)
yield AgentResultEvent(result=result).as_dict()
result_event = AgentResultEvent(result=result)
callback_handler(**result_event.as_dict())
yield result_event.as_dict()

self._end_agent_trace_span(response=result)

Expand Down
28 changes: 27 additions & 1 deletion src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Union
from typing import Any, AsyncIterator, Union

from ..agent import AgentResult
from ..types.content import ContentBlock
Expand Down Expand Up @@ -98,6 +98,32 @@ async def invoke_async(
"""
raise NotImplementedError("invoke_async not implemented")

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during multi-agent execution.

This default implementation provides backward compatibility by executing
invoke_async and yielding a single result event. Subclasses can override
this method to provide true streaming capabilities.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Additional keyword arguments passed to underlying agents.

Yields:
Dictionary events containing multi-agent execution information including:
- Multi-agent coordination events (node start/complete, handoffs)
- Forwarded single-agent events with node context
- Final result event
"""
# Default implementation for backward compatibility
# Execute invoke_async and yield the result as a single event
result = await self.invoke_async(task, invocation_state, **kwargs)
yield {"result": result}

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
Expand Down
323 changes: 242 additions & 81 deletions src/strands/multiagent/graph.py

Large diffs are not rendered by default.

206 changes: 164 additions & 42 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple
from typing import Any, AsyncIterator, Callable, Tuple, cast

from opentelemetry import trace as trace_api

from ..agent import Agent, AgentResult
from ..agent import Agent
from ..agent.state import AgentState
from ..telemetry import get_tracer
from ..tools.decorator import tool
from ..types._events import (
MultiAgentHandoffEvent,
MultiAgentNodeStartEvent,
MultiAgentNodeStopEvent,
MultiAgentNodeStreamEvent,
MultiAgentResultEvent,
)
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
Expand Down Expand Up @@ -266,12 +273,43 @@ async def invoke_async(
) -> SwarmResult:
"""Invoke the swarm asynchronously.

This method uses stream_async internally and consumes all events until completion,
following the same pattern as the Agent class.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues - a new empty dict
is created if None is provided.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
events = self.stream_async(task, invocation_state, **kwargs)
final_event = None
async for event in events:
final_event = event

if final_event is None or "result" not in final_event:
raise ValueError("Swarm streaming completed without producing a result event")

return cast(SwarmResult, final_event["result"])

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during swarm execution.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.

Yields:
Dictionary events during swarm execution, such as:
- multi_agent_node_start: When a node begins execution
- multi_agent_node_stream: Forwarded agent events with node context
- multi_agent_handoff: When control is handed off between agents
- multi_agent_node_stop: When a node stops execution
- result: Final swarm result
"""
if invocation_state is None:
invocation_state = {}
Expand All @@ -282,7 +320,7 @@ async def invoke_async(
if self.entry_point:
initial_node = self.nodes[str(self.entry_point.name)]
else:
initial_node = next(iter(self.nodes.values())) # First SwarmNode
initial_node = next(iter(self.nodes.values()))

self.state = SwarmState(
current_node=initial_node,
Expand All @@ -303,15 +341,65 @@ async def invoke_async(
self.execution_timeout,
)

await self._execute_swarm(invocation_state)
async for event in self._execute_swarm(invocation_state):
yield event.as_dict()

# Set execution time before building result
self.state.execution_time = round((time.time() - start_time) * 1000)

# Yield final result (consistent with Agent's AgentResultEvent format)
result = self._build_result()
yield MultiAgentResultEvent(result=result).as_dict()

except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED
raise
finally:
# Set execution time even on failure
self.state.execution_time = round((time.time() - start_time) * 1000)
raise

return self._build_result()
async def _stream_with_timeout(
self, async_generator: AsyncIterator[Any], timeout: float | None, timeout_message: str
) -> AsyncIterator[Any]:
"""Wrap an async generator with timeout for total execution time.

Tracks elapsed time from start and enforces timeout across all events.
Each event wait uses remaining time from the total timeout budget.

Args:
async_generator: The generator to wrap
timeout: Total timeout in seconds for entire stream, or None for no timeout
timeout_message: Message to include in timeout exception

Yields:
Events from the wrapped generator as they arrive

Raises:
Exception: If total execution time exceeds timeout
"""
if timeout is None:
# No timeout - just pass through
async for event in async_generator:
yield event
else:
# Track start time for total timeout
start_time = asyncio.get_event_loop().time()

while True:
# Calculate remaining time from total timeout budget
elapsed = asyncio.get_event_loop().time() - start_time
remaining = timeout - elapsed

if remaining <= 0:
raise Exception(timeout_message)

try:
event = await asyncio.wait_for(async_generator.__anext__(), timeout=remaining)
yield event
except StopAsyncIteration:
break
except asyncio.TimeoutError:
raise Exception(timeout_message) from None

def _setup_swarm(self, nodes: list[Agent]) -> None:
"""Initialize swarm configuration."""
Expand Down Expand Up @@ -533,14 +621,14 @@ def _build_node_input(self, target_node: SwarmNode) -> str:

return context_text

async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
"""Shared execution logic used by execute_async."""
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute swarm and yield TypedEvent objects."""
try:
# Main execution loop
while True:
if self.state.completion_status != Status.EXECUTING:
reason = f"Completion status is: {self.state.completion_status}"
logger.debug("reason=<%s> | stopping execution", reason)
logger.debug("reason=<%s> | stopping streaming execution", reason)
break

should_continue, reason = self.state.should_continue(
Expand Down Expand Up @@ -568,34 +656,44 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
len(self.state.node_history) + 1,
)

# Store the current node before execution to detect handoffs
previous_node = current_node

# Execute node with timeout protection
# TODO: Implement cancellation token to stop _execute_node from continuing
try:
await asyncio.wait_for(
# Execute with timeout wrapper for async generator streaming
node_stream = self._stream_with_timeout(
self._execute_node(current_node, self.state.task, invocation_state),
timeout=self.node_timeout,
self.node_timeout,
f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s",
)
async for event in node_stream:
yield event

self.state.node_history.append(current_node)

logger.debug("node=<%s> | node execution completed", current_node.node_id)

# Check if the current node is still the same after execution
# If it is, then no handoff occurred and we consider the swarm complete
if self.state.current_node == current_node:
# Check if handoff occurred during execution
if self.state.current_node != previous_node:
# Emit handoff event
handoff_event = MultiAgentHandoffEvent(
from_node=previous_node.node_id,
to_node=self.state.current_node.node_id,
message=self.state.handoff_message or "Agent handoff occurred",
)
yield handoff_event
logger.debug(
"from_node=<%s>, to_node=<%s> | handoff detected",
previous_node.node_id,
self.state.current_node.node_id,
)
else:
# No handoff occurred, mark swarm as complete
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
self.state.completion_status = Status.COMPLETED
break

except asyncio.TimeoutError:
logger.exception(
"node=<%s>, timeout=<%s>s | node execution timed out after timeout",
current_node.node_id,
self.node_timeout,
)
self.state.completion_status = Status.FAILED
break

except Exception:
logger.exception("node=<%s> | node execution failed", current_node.node_id)
self.state.completion_status = Status.FAILED
Expand All @@ -615,11 +713,15 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:

async def _execute_node(
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
) -> AgentResult:
"""Execute swarm node."""
) -> AsyncIterator[Any]:
"""Execute swarm node and yield TypedEvent objects."""
start_time = time.time()
node_name = node.node_id

# Emit node start event
start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent")
yield start_event

try:
# Prepare context for node
context_text = self._build_node_input(node)
Expand All @@ -632,21 +734,29 @@ async def _execute_node(
# Include additional ContentBlocks in node input
node_input = node_input + task

# Execute node
result = None
# Execute node with streaming
node.reset_executor_state()
result = await node.executor.invoke_async(node_input, invocation_state=invocation_state)

# Stream agent events with node context and capture final result
result = None
async for event in node.executor.stream_async(node_input, invocation_state=invocation_state):
# Forward agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node_name, event)
yield wrapped_event
# Capture the final result event
if "result" in event:
result = event["result"]

# Use the captured result from streaming to avoid double execution
if result is None:
raise ValueError(f"Node '{node_name}' did not produce a result event")

execution_time = round((time.time() - start_time) * 1000)

# Create NodeResult
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
metrics = Metrics(latencyMs=execution_time)
if hasattr(result, "metrics") and result.metrics:
if hasattr(result.metrics, "accumulated_usage"):
usage = result.metrics.accumulated_usage
if hasattr(result.metrics, "accumulated_metrics"):
metrics = result.metrics.accumulated_metrics
# Create NodeResult with extracted metrics
result_metrics = getattr(result, "metrics", None)
usage = getattr(result_metrics, "accumulated_usage", Usage(inputTokens=0, outputTokens=0, totalTokens=0))
metrics = getattr(result_metrics, "accumulated_metrics", Metrics(latencyMs=execution_time))

node_result = NodeResult(
result=result,
Expand All @@ -663,15 +773,20 @@ async def _execute_node(
# Accumulate metrics
self._accumulate_metrics(node_result)

return result
# Emit node stop event with full NodeResult
complete_event = MultiAgentNodeStopEvent(
node_id=node_name,
node_result=node_result,
)
yield complete_event

except Exception as e:
execution_time = round((time.time() - start_time) * 1000)
logger.exception("node=<%s> | node execution failed", node_name)

# Create a NodeResult for the failed node
node_result = NodeResult(
result=e, # Store exception as result
result=e,
execution_time=execution_time,
status=Status.FAILED,
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
Expand All @@ -682,6 +797,13 @@ async def _execute_node(
# Store result in state
self.state.results[node_name] = node_result

# Emit node stop event even for failures
complete_event = MultiAgentNodeStopEvent(
node_id=node_name,
node_result=node_result,
)
yield complete_event

raise

def _accumulate_metrics(self, node_result: NodeResult) -> None:
Expand Down
Loading