Skip to content

Commit 1e58e59

Browse files
committed
swarm - switch to handoff node only after current node stops
1 parent 1df45be commit 1e58e59

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

src/strands/multiagent/swarm.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ class SwarmState:
156156
# Total metrics across all agents
157157
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
158158
execution_time: int = 0 # Total execution time in milliseconds
159+
handoff_node: SwarmNode | None = None # The agent to execute next
159160
handoff_message: str | None = None # Message passed during agent handoff
160161

161162
def should_continue(
@@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No
537538
# Execute handoff
538539
swarm_ref._handle_handoff(target_node, message, context)
539540

540-
return {"status": "success", "content": [{"text": f"Handed off to {agent_name}: {message}"}]}
541+
return {"status": "success", "content": [{"text": f"Handing off to {agent_name}: {message}"}]}
541542
except Exception as e:
542543
return {"status": "error", "content": [{"text": f"Error in handoff: {str(e)}"}]}
543544

@@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
553554
)
554555
return
555556

556-
# Update swarm state
557-
previous_agent = cast(SwarmNode, self.state.current_node)
558-
self.state.current_node = target_node
557+
current_node = cast(SwarmNode, self.state.current_node)
559558

560-
# Store handoff message for the target agent
559+
self.state.handoff_node = target_node
561560
self.state.handoff_message = message
562561

563562
# Store handoff context as shared context
564563
if context:
565564
for key, value in context.items():
566-
self.shared_context.add_context(previous_agent, key, value)
565+
self.shared_context.add_context(current_node, key, value)
567566

568567
logger.debug(
569-
"from_node=<%s>, to_node=<%s> | handed off from agent to agent",
570-
previous_agent.node_id,
568+
"from_node=<%s>, to_node=<%s> | handing off from agent to agent",
569+
current_node.node_id,
571570
target_node.node_id,
572571
)
573572

@@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
667666
logger.debug("reason=<%s> | stopping execution", reason)
668667
break
669668

670-
# Get current node
671669
current_node = self.state.current_node
672670
if not current_node or current_node.node_id not in self.nodes:
673671
logger.error("node=<%s> | node not found", current_node.node_id if current_node else "None")
@@ -680,14 +678,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680678
len(self.state.node_history) + 1,
681679
)
682680

683-
# Store the current node before execution to detect handoffs
684-
previous_node = current_node
685-
686-
# Execute node with timeout protection
687681
# TODO: Implement cancellation token to stop _execute_node from continuing
688682
try:
689-
# Execute with timeout wrapper for async generator streaming
690683
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state))
684+
691685
node_stream = self._stream_with_timeout(
692686
self._execute_node(current_node, self.state.task, invocation_state),
693687
self.node_timeout,
@@ -697,28 +691,31 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
697691
yield event
698692

699693
self.state.node_history.append(current_node)
700-
701-
# After self.state add current node, swarm state finish updating, we persist here
702694
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state))
703695

704696
logger.debug("node=<%s> | node execution completed", current_node.node_id)
705697

706-
# Check if handoff occurred during execution
707-
if self.state.current_node is not None and self.state.current_node != previous_node:
708-
# Emit handoff event (single node transition in Swarm)
698+
# Check if handoff requested during execution
699+
if self.state.handoff_node:
700+
previous_node = current_node
701+
current_node = self.state.handoff_node
702+
703+
self.state.handoff_node = None
704+
self.state.current_node = current_node
705+
709706
handoff_event = MultiAgentHandoffEvent(
710707
from_node_ids=[previous_node.node_id],
711-
to_node_ids=[self.state.current_node.node_id],
708+
to_node_ids=[current_node.node_id],
712709
message=self.state.handoff_message or "Agent handoff occurred",
713710
)
714711
yield handoff_event
715712
logger.debug(
716713
"from_node=<%s>, to_node=<%s> | handoff detected",
717714
previous_node.node_id,
718-
self.state.current_node.node_id,
715+
current_node.node_id,
719716
)
717+
720718
else:
721-
# No handoff occurred, mark swarm as complete
722719
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
723720
self.state.completion_status = Status.COMPLETED
724721
break

src/strands/session/session_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ..experimental.hooks.multiagent.events import (
88
AfterMultiAgentInvocationEvent,
9-
AfterNodeCallEvent,
9+
BeforeNodeCallEvent,
1010
MultiAgentInitializedEvent,
1111
)
1212
from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
@@ -44,7 +44,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
4444
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))
4545

4646
registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source))
47-
registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source))
47+
registry.add_callback(BeforeNodeCallEvent, lambda event: self.sync_multi_agent(event.source))
4848
registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source))
4949

5050
@abstractmethod

0 commit comments

Comments
 (0)