Skip to content

Commit 0c71708

Browse files
committed
multiagent
1 parent a4de3cc commit 0c71708

File tree

9 files changed

+42
-30
lines changed

9 files changed

+42
-30
lines changed

src/strands/agent/agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
541541
if not self.messages and not prompt:
542542
raise ValueError("No conversation history or prompt provided")
543543

544-
temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)
544+
temp_messages: Messages = self.messages + await self._convert_prompt_to_messages(prompt)
545545

546546
structured_output_span.set_attributes(
547547
{
@@ -657,7 +657,7 @@ async def stream_async(
657657
callback_handler = kwargs.get("callback_handler", self.callback_handler)
658658

659659
# Process input and get message to add (if any)
660-
messages = self._convert_prompt_to_messages(prompt)
660+
messages = await self._convert_prompt_to_messages(prompt)
661661

662662
self.trace_span = self._start_agent_trace_span(messages)
663663

@@ -812,7 +812,7 @@ async def _execute_event_loop_cycle(
812812
if structured_output_context:
813813
structured_output_context.cleanup(self.tool_registry)
814814

815-
def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
815+
async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
816816
if self._interrupt_state.activated:
817817
return []
818818

@@ -827,7 +827,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
827827
tool_use_ids = [
828828
content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content
829829
]
830-
self._append_message(
830+
await self._append_message(
831831
{
832832
"role": "user",
833833
"content": generate_missing_tool_result_content(tool_use_ids),

src/strands/multiagent/graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def __init__(
453453
self._resume_from_session = False
454454
self.id = id
455455

456-
self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))
456+
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
457457

458458
def __call__(
459459
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
@@ -516,7 +516,7 @@ async def stream_async(
516516
if invocation_state is None:
517517
invocation_state = {}
518518

519-
self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state))
519+
await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state))
520520

521521
logger.debug("task=<%s> | starting graph execution", task)
522522

@@ -569,7 +569,7 @@ async def stream_async(
569569
raise
570570
finally:
571571
self.state.execution_time = round((time.time() - start_time) * 1000)
572-
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self))
572+
await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self))
573573
self._resume_from_session = False
574574
self._resume_next_nodes.clear()
575575

@@ -776,7 +776,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
776776

777777
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
778778
"""Execute a single node and yield TypedEvent objects."""
779-
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state))
779+
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
780780

781781
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
782782
if self.reset_on_revisit and node in self.state.completed_nodes:
@@ -920,7 +920,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
920920
raise
921921

922922
finally:
923-
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))
923+
await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state))
924924

925925
def _accumulate_metrics(self, node_result: NodeResult) -> None:
926926
"""Accumulate metrics from a node result."""

src/strands/multiagent/swarm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __init__(
273273

274274
self._setup_swarm(nodes)
275275
self._inject_swarm_tools()
276-
self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))
276+
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
277277

278278
def __call__(
279279
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
@@ -336,7 +336,7 @@ async def stream_async(
336336
if invocation_state is None:
337337
invocation_state = {}
338338

339-
self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state))
339+
await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state))
340340

341341
logger.debug("starting swarm execution")
342342

@@ -375,7 +375,7 @@ async def stream_async(
375375
raise
376376
finally:
377377
self.state.execution_time = round((time.time() - self.state.start_time) * 1000)
378-
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self, invocation_state))
378+
await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self, invocation_state))
379379
self._resume_from_session = False
380380

381381
# Yield final result after execution_time is set
@@ -687,7 +687,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
687687
# TODO: Implement cancellation token to stop _execute_node from continuing
688688
try:
689689
# Execute with timeout wrapper for async generator streaming
690-
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state))
690+
await self.hooks.invoke_callbacks_async(
691+
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
692+
)
691693
node_stream = self._stream_with_timeout(
692694
self._execute_node(current_node, self.state.task, invocation_state),
693695
self.node_timeout,
@@ -699,7 +701,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
699701
self.state.node_history.append(current_node)
700702

701703
# After self.state add current node, swarm state finish updating, we persist here
702-
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state))
704+
await self.hooks.invoke_callbacks_async(
705+
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
706+
)
703707

704708
logger.debug("node=<%s> | node execution completed", current_node.node_id)
705709

tests/strands/agent/hooks/test_hook_registry.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,16 @@ def test_get_callbacks_for_after_event(hook_registry, after_event):
113113
assert callbacks[1] == callback1 # Reverse order
114114

115115

116-
def test_invoke_callbacks(hook_registry, normal_event):
116+
@pytest.mark.asyncio
117+
async def test_invoke_callbacks(hook_registry, normal_event):
117118
"""Test that invoke_callbacks calls all registered callbacks for an event."""
118119
callback1 = Mock()
119120
callback2 = Mock()
120121

121122
hook_registry.add_callback(NormalTestEvent, callback1)
122123
hook_registry.add_callback(NormalTestEvent, callback2)
123124

124-
hook_registry.invoke_callbacks(normal_event)
125+
await hook_registry.invoke_callbacks_async(normal_event)
125126

126127
callback1.assert_called_once_with(normal_event)
127128
callback2.assert_called_once_with(normal_event)
@@ -134,7 +135,8 @@ def test_invoke_callbacks_no_registered_callbacks(hook_registry, normal_event):
134135
# Test passes if no exception is raised
135136

136137

137-
def test_invoke_callbacks_after_event(hook_registry, after_event):
138+
@pytest.mark.asyncio
139+
async def test_invoke_callbacks_after_event(hook_registry, after_event):
138140
"""Test that invoke_callbacks calls callbacks in reverse order for after events."""
139141
call_order: List[str] = []
140142

@@ -147,7 +149,7 @@ def callback2(_event):
147149
hook_registry.add_callback(AfterTestEvent, callback1)
148150
hook_registry.add_callback(AfterTestEvent, callback2)
149151

150-
hook_registry.invoke_callbacks(after_event)
152+
await hook_registry.invoke_callbacks_async(after_event)
151153

152154
assert call_order == ["callback2", "callback1"] # Reverse order
153155

tests/strands/agent/test_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2240,8 +2240,8 @@ def test_agent_backwards_compatibility_single_text_block():
22402240

22412241
# Should extract text for backwards compatibility
22422242
assert agent.system_prompt == text
2243-
2244-
2243+
2244+
22452245
@pytest.mark.parametrize(
22462246
"content, expected",
22472247
[

tests/strands/event_loop/test_event_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import concurrent
22
import unittest.mock
3-
from unittest.mock import ANY, MagicMock, call, patch
3+
from unittest.mock import ANY, AsyncMock, MagicMock, call, patch
44

55
import pytest
66

@@ -750,6 +750,7 @@ async def test_request_state_initialization(alist):
750750
# not setting this to False results in endless recursion
751751
mock_agent._interrupt_state.activated = False
752752
mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock())
753+
mock_agent.hooks.invoke_callbacks_async = AsyncMock()
753754

754755
# Call without providing request_state
755756
stream = strands.event_loop.event_loop.event_loop_cycle(

tests/strands/event_loop/test_event_loop_structured_output.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for structured output integration in the event loop."""
22

3-
from unittest.mock import Mock, patch
3+
from unittest.mock import AsyncMock, Mock, patch
44

55
import pytest
66
from pydantic import BaseModel
@@ -38,10 +38,10 @@ def mock_agent():
3838
agent.tool_registry = ToolRegistry()
3939
agent.event_loop_metrics = EventLoopMetrics()
4040
agent.hooks = Mock()
41-
agent.hooks.invoke_callbacks = Mock()
41+
agent.hooks.invoke_callbacks_async = AsyncMock()
4242
agent.trace_span = None
4343
agent.tool_executor = Mock()
44-
agent._append_message = Mock()
44+
agent._append_message = AsyncMock()
4545

4646
# Set up _interrupt_state properly
4747
agent._interrupt_state = Mock()

tests/strands/experimental/hooks/test_hook_aliases.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import sys
1010
from unittest.mock import Mock
1111

12+
import pytest
13+
1214
from strands.experimental.hooks import (
1315
AfterModelInvocationEvent,
1416
AfterToolInvocationEvent,
@@ -80,7 +82,8 @@ def test_after_model_call_event_type_equality():
8082
assert isinstance(after_model_event, AfterModelCallEvent)
8183

8284

83-
def test_experimental_aliases_in_hook_registry():
85+
@pytest.mark.asyncio
86+
async def test_experimental_aliases_in_hook_registry():
8487
"""Verify that experimental aliases work with hook registry callbacks."""
8588
hook_registry = HookRegistry()
8689
callback_called = False
@@ -103,7 +106,7 @@ def experimental_callback(event: BeforeToolInvocationEvent):
103106
)
104107

105108
# Invoke callbacks - should work since alias points to same type
106-
hook_registry.invoke_callbacks(test_event)
109+
await hook_registry.invoke_callbacks_async(test_event)
107110

108111
assert callback_called
109112
assert received_event is test_event

tests/strands/hooks/test_registry.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def agent():
1919
return instance
2020

2121

22-
def test_hook_registry_invoke_callbacks_interrupt(registry, agent):
22+
@pytest.mark.asyncio
23+
async def test_hook_registry_invoke_callbacks_interrupt(registry, agent):
2324
event = BeforeToolCallEvent(
2425
agent=agent,
2526
selected_tool=None,
@@ -35,7 +36,7 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent):
3536
registry.add_callback(BeforeToolCallEvent, callback2)
3637
registry.add_callback(BeforeToolCallEvent, callback3)
3738

38-
_, tru_interrupts = registry.invoke_callbacks(event)
39+
_, tru_interrupts = await registry.invoke_callbacks_async(event)
3940
exp_interrupts = [
4041
Interrupt(
4142
id="v1:before_tool_call:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee",
@@ -55,7 +56,8 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent):
5556
callback3.assert_called_once_with(event)
5657

5758

58-
def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent):
59+
@pytest.mark.asyncio
60+
async def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent):
5961
event = BeforeToolCallEvent(
6062
agent=agent,
6163
selected_tool=None,
@@ -70,4 +72,4 @@ def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent):
7072
registry.add_callback(BeforeToolCallEvent, callback2)
7173

7274
with pytest.raises(ValueError, match="interrupt_name=<test_name> | interrupt name used more than once"):
73-
registry.invoke_callbacks(event)
75+
await registry.invoke_callbacks_async(event)

0 commit comments

Comments
 (0)