@@ -156,6 +156,7 @@ class SwarmState:
156156 # Total metrics across all agents
157157 accumulated_metrics : Metrics = field (default_factory = lambda : Metrics (latencyMs = 0 ))
158158 execution_time : int = 0 # Total execution time in milliseconds
159+ handoff_node : SwarmNode | None = None # The agent to execute next
159160 handoff_message : str | None = None # Message passed during agent handoff
160161
161162 def should_continue (
@@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No
537538 # Execute handoff
538539 swarm_ref ._handle_handoff (target_node , message , context )
539540
540- return {"status" : "success" , "content" : [{"text" : f"Handed off to { agent_name } : { message } " }]}
541+ return {"status" : "success" , "content" : [{"text" : f"Handing off to { agent_name } : { message } " }]}
541542 except Exception as e :
542543 return {"status" : "error" , "content" : [{"text" : f"Error in handoff: { str (e )} " }]}
543544
@@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
553554 )
554555 return
555556
556- # Update swarm state
557- previous_agent = cast (SwarmNode , self .state .current_node )
558- self .state .current_node = target_node
557+ current_node = cast (SwarmNode , self .state .current_node )
559558
560- # Store handoff message for the target agent
559+ self . state . handoff_node = target_node
561560 self .state .handoff_message = message
562561
563562 # Store handoff context as shared context
564563 if context :
565564 for key , value in context .items ():
566- self .shared_context .add_context (previous_agent , key , value )
565+ self .shared_context .add_context (current_node , key , value )
567566
568567 logger .debug (
569- "from_node=<%s>, to_node=<%s> | handed off from agent to agent" ,
570- previous_agent .node_id ,
568+ "from_node=<%s>, to_node=<%s> | handing off from agent to agent" ,
569+ current_node .node_id ,
571570 target_node .node_id ,
572571 )
573572
@@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
667666 logger .debug ("reason=<%s> | stopping execution" , reason )
668667 break
669668
670- # Get current node
671669 current_node = self .state .current_node
672670 if not current_node or current_node .node_id not in self .nodes :
673671 logger .error ("node=<%s> | node not found" , current_node .node_id if current_node else "None" )
@@ -680,14 +678,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680678 len (self .state .node_history ) + 1 ,
681679 )
682680
683- # Store the current node before execution to detect handoffs
684- previous_node = current_node
685-
686- # Execute node with timeout protection
687681 # TODO: Implement cancellation token to stop _execute_node from continuing
688682 try :
689- # Execute with timeout wrapper for async generator streaming
690683 self .hooks .invoke_callbacks (BeforeNodeCallEvent (self , current_node .node_id , invocation_state ))
684+
691685 node_stream = self ._stream_with_timeout (
692686 self ._execute_node (current_node , self .state .task , invocation_state ),
693687 self .node_timeout ,
@@ -697,28 +691,31 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
697691 yield event
698692
699693 self .state .node_history .append (current_node )
700-
701- # After self.state add current node, swarm state finish updating, we persist here
702694 self .hooks .invoke_callbacks (AfterNodeCallEvent (self , current_node .node_id , invocation_state ))
703695
704696 logger .debug ("node=<%s> | node execution completed" , current_node .node_id )
705697
706- # Check if handoff occurred during execution
707- if self .state .current_node is not None and self .state .current_node != previous_node :
708- # Emit handoff event (single node transition in Swarm)
698+ # Check if handoff requested during execution
699+ if self .state .handoff_node :
700+ previous_node = current_node
701+ current_node = self .state .handoff_node
702+
703+ self .state .handoff_node = None
704+ self .state .current_node = current_node
705+
709706 handoff_event = MultiAgentHandoffEvent (
710707 from_node_ids = [previous_node .node_id ],
711- to_node_ids = [self . state . current_node .node_id ],
708+ to_node_ids = [current_node .node_id ],
712709 message = self .state .handoff_message or "Agent handoff occurred" ,
713710 )
714711 yield handoff_event
715712 logger .debug (
716713 "from_node=<%s>, to_node=<%s> | handoff detected" ,
717714 previous_node .node_id ,
718- self . state . current_node .node_id ,
715+ current_node .node_id ,
719716 )
717+
720718 else :
721- # No handoff occurred, mark swarm as complete
722719 logger .debug ("node=<%s> | no handoff occurred, marking swarm as complete" , current_node .node_id )
723720 self .state .completion_status = Status .COMPLETED
724721 break
0 commit comments