Skip to content

Commit 55ef47c

Browse files
committed
fix: add BeforeMultiAgentInvocation events in Graph/Swarm, add more tests
1 parent 80a4e54 commit 55ef47c

File tree

5 files changed

+91
-21
lines changed

5 files changed

+91
-21
lines changed

src/strands/multiagent/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ..experimental.hooks.multiagent import (
3030
AfterMultiAgentInvocationEvent,
3131
AfterNodeCallEvent,
32+
BeforeMultiAgentInvocationEvent,
3233
BeforeNodeCallEvent,
3334
MultiAgentInitializedEvent,
3435
)
@@ -468,6 +469,7 @@ def __call__(
468469
if invocation_state is None:
469470
invocation_state = {}
470471

472+
self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state))
471473
return run_async(lambda: self.invoke_async(task, invocation_state))
472474

473475
async def invoke_async(

src/strands/multiagent/swarm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..experimental.hooks.multiagent import (
2929
AfterMultiAgentInvocationEvent,
3030
AfterNodeCallEvent,
31+
BeforeMultiAgentInvocationEvent,
3132
BeforeNodeCallEvent,
3233
MultiAgentInitializedEvent,
3334
)
@@ -287,7 +288,7 @@ def __call__(
287288
"""
288289
if invocation_state is None:
289290
invocation_state = {}
290-
291+
self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state))
291292
return run_async(lambda: self.invoke_async(task, invocation_state))
292293

293294
async def invoke_async(

tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from strands.experimental.hooks.multiagent.events import (
55
AfterMultiAgentInvocationEvent,
66
AfterNodeCallEvent,
7+
BeforeMultiAgentInvocationEvent,
78
BeforeNodeCallEvent,
89
MultiAgentInitializedEvent,
910
)
@@ -17,6 +18,7 @@
1718
def hook_provider():
1819
return MockMultiAgentHookProvider(
1920
[
21+
BeforeMultiAgentInvocationEvent,
2022
AfterMultiAgentInvocationEvent,
2123
AfterNodeCallEvent,
2224
BeforeNodeCallEvent,
@@ -67,7 +69,7 @@ def test_swarm_complete_hook_lifecycle(swarm, hook_provider):
6769
result = swarm("test task")
6870

6971
length, events = hook_provider.get_events()
70-
assert length == 4
72+
assert length == 5
7173
assert result.status.value == "completed"
7274

7375
events_list = list(events)
@@ -76,24 +78,27 @@ def test_swarm_complete_hook_lifecycle(swarm, hook_provider):
7678
assert isinstance(events_list[0], MultiAgentInitializedEvent)
7779
assert events_list[0].source == swarm
7880

79-
assert isinstance(events_list[1], BeforeNodeCallEvent)
81+
assert isinstance(events_list[1], BeforeMultiAgentInvocationEvent)
8082
assert events_list[1].source == swarm
81-
assert events_list[1].node_id == "agent1"
8283

83-
assert isinstance(events_list[2], AfterNodeCallEvent)
84+
assert isinstance(events_list[2], BeforeNodeCallEvent)
8485
assert events_list[2].source == swarm
8586
assert events_list[2].node_id == "agent1"
8687

87-
assert isinstance(events_list[3], AfterMultiAgentInvocationEvent)
88+
assert isinstance(events_list[3], AfterNodeCallEvent)
8889
assert events_list[3].source == swarm
90+
assert events_list[3].node_id == "agent1"
91+
92+
assert isinstance(events_list[4], AfterMultiAgentInvocationEvent)
93+
assert events_list[4].source == swarm
8994

9095

9196
def test_graph_complete_hook_lifecycle(graph, hook_provider):
9297
"""E2E test verifying complete hook lifecycle for Graph."""
9398
result = graph("test task")
9499

95100
length, events = hook_provider.get_events()
96-
assert length == 6
101+
assert length == 7
97102
assert result.status.value == "completed"
98103

99104
events_list = list(events)
@@ -102,21 +107,24 @@ def test_graph_complete_hook_lifecycle(graph, hook_provider):
102107
assert isinstance(events_list[0], MultiAgentInitializedEvent)
103108
assert events_list[0].source == graph
104109

105-
assert isinstance(events_list[1], BeforeNodeCallEvent)
110+
assert isinstance(events_list[1], BeforeMultiAgentInvocationEvent)
106111
assert events_list[1].source == graph
107-
assert events_list[1].node_id == "agent1"
108112

109-
assert isinstance(events_list[2], AfterNodeCallEvent)
113+
assert isinstance(events_list[2], BeforeNodeCallEvent)
110114
assert events_list[2].source == graph
111115
assert events_list[2].node_id == "agent1"
112116

113-
assert isinstance(events_list[3], BeforeNodeCallEvent)
117+
assert isinstance(events_list[3], AfterNodeCallEvent)
114118
assert events_list[3].source == graph
115-
assert events_list[3].node_id == "agent2"
119+
assert events_list[3].node_id == "agent1"
116120

117-
assert isinstance(events_list[4], AfterNodeCallEvent)
121+
assert isinstance(events_list[4], BeforeNodeCallEvent)
118122
assert events_list[4].source == graph
119123
assert events_list[4].node_id == "agent2"
120124

121-
assert isinstance(events_list[5], AfterMultiAgentInvocationEvent)
125+
assert isinstance(events_list[5], AfterNodeCallEvent)
122126
assert events_list[5].source == graph
127+
assert events_list[5].node_id == "agent2"
128+
129+
assert isinstance(events_list[6], AfterMultiAgentInvocationEvent)
130+
assert events_list[6].source == graph

tests_integ/test_multiagent_graph.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,13 +488,14 @@ async def test_graph_interrupt_and_resume():
488488
graph = builder.build()
489489

490490
# Mock agent2 to fail on first execution
491-
async def failing_invoke(*args, **kwargs):
491+
async def failing_stream_async(*args, **kwargs):
492492
raise Exception("Simulated failure in agent2")
493+
yield # This line is never reached, but makes it an async generator
493494

494-
with patch.object(agent2, "invoke_async", side_effect=failing_invoke):
495-
# First execution - should fail at agent2
495+
with patch.object(agent2, "stream_async", side_effect=failing_stream_async):
496496
try:
497-
await graph.invoke_async("Test task")
497+
await graph.invoke_async("This is a test task, just do it shortly")
498+
raise AssertionError("Expected exception was not raised")
498499
except Exception as e:
499500
assert "Simulated failure in agent2" in str(e)
500501

@@ -530,3 +531,57 @@ async def failing_invoke(*args, **kwargs):
530531

531532
# Clean up
532533
session_manager.delete_session(session_id)
534+
535+
536+
@pytest.mark.asyncio
537+
async def test_self_loop_resume_from_persisted_state(tmp_path):
538+
"""Test resuming self-loop from persisted state where next node is itself."""
539+
540+
session_id = f"self_loop_resume_{uuid4()}"
541+
session_manager = FileSessionManager(session_id=session_id, storage_dir=str(tmp_path))
542+
543+
counter_agent = Agent(
544+
model="us.amazon.nova-pro-v1:0",
545+
system_prompt="You are a counter. Just respond with 'Count: 1', 'Count: 2', Stop at 5.",
546+
)
547+
548+
def should_continue_loop(state):
549+
loop_executions = len([node for node in state.execution_order if node.node_id == "loop_node"])
550+
return loop_executions < 5
551+
552+
builder = GraphBuilder()
553+
builder.add_node(counter_agent, "loop_node")
554+
builder.add_edge("loop_node", "loop_node", condition=should_continue_loop)
555+
builder.set_entry_point("loop_node")
556+
builder.set_session_manager(session_manager)
557+
builder.reset_on_revisit(True)
558+
559+
graph = builder.build()
560+
561+
call_count = 0
562+
original_stream = counter_agent.stream_async
563+
564+
async def failing_after_two(*args, **kwargs):
565+
nonlocal call_count
566+
call_count += 1
567+
if call_count <= 2:
568+
async for event in original_stream(*args, **kwargs):
569+
yield event
570+
else:
571+
raise Exception("Simulated failure after two executions")
572+
573+
with patch.object(counter_agent, "stream_async", side_effect=failing_after_two):
574+
try:
575+
await graph.invoke_async("Count till 5")
576+
except Exception as e:
577+
assert "Simulated failure after two executions" in str(e)
578+
579+
persisted_state = session_manager.read_multi_agent(session_id, graph.id)
580+
assert persisted_state["status"] == "failed"
581+
assert "loop_node" in persisted_state.get("failed_nodes")
582+
assert len(persisted_state.get("execution_order")) == 2
583+
584+
result = await graph.invoke_async("Continue counting to 5")
585+
assert result.status == Status.COMPLETED
586+
assert len(result.execution_order) == 5
587+
assert all(node.node_id == "loop_node" for node in result.execution_order)

tests_integ/test_multiagent_swarm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,18 @@ async def test_swarm_interrupt_and_resume(researcher_agent, analyst_agent, write
337337
# Create swarm with session manager
338338
swarm = Swarm([researcher_agent, analyst_agent, writer_agent], session_manager=session_manager)
339339

340-
# Mock analyst_agent to fail
340+
# Mock analyst_agent's _invoke method to fail
341341
async def failing_invoke(*args, **kwargs):
342342
raise Exception("Simulated failure in analyst")
343+
yield # This line is never reached, but makes it an async generator
343344

344-
with patch.object(analyst_agent, "invoke_async", side_effect=failing_invoke):
345+
with patch.object(analyst_agent, "stream_async", side_effect=failing_invoke):
345346
# First execution - should fail at analyst
346347
result = await swarm.invoke_async("Research AI trends and create a brief report")
347-
assert result.status == Status.FAILED
348+
try:
349+
assert result.status == Status.FAILED
350+
except Exception as e:
351+
assert "Simulated failure in analyst" in str(e)
348352

349353
# Verify partial execution was persisted
350354
persisted_state = session_manager.read_multi_agent(session_id, swarm.id)

0 commit comments

Comments
 (0)