Skip to content

Commit 1e6226d

Browse files
committed
tests
1 parent 0c71708 commit 1e6226d

File tree

7 files changed

+284
-8
lines changed

7 files changed

+284
-8
lines changed

src/strands/agent/agent.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,13 @@ async def acall() -> ToolResult:
173173

174174
tool_result = tool_results[0]
175175

176-
should_record_direct_tool_call = (
177-
record_direct_tool_call
178-
if record_direct_tool_call is not None
179-
else self._agent.record_direct_tool_call
180-
)
176+
if record_direct_tool_call is not None:
177+
should_record_direct_tool_call = record_direct_tool_call
178+
else:
179+
should_record_direct_tool_call = self._agent.record_direct_tool_call
180+
181181
if should_record_direct_tool_call:
182+
# Create a record of this tool execution in the message history
182183
await self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
183184

184185
return tool_result

src/strands/hooks/registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import inspect
1111
import logging
12+
import warnings
1213
from dataclasses import dataclass
1314
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
1415

@@ -267,6 +268,10 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte
267268
registry.invoke_callbacks(event)
268269
```
269270
"""
271+
warnings.warn(
272+
"invoke_callbacks is deprecated and replaced by invoke_callbacks_async", DeprecationWarning, stacklevel=2
273+
)
274+
270275
callbacks = self.get_callbacks_for(event)
271276
interrupts: dict[str, Interrupt] = {}
272277

tests/strands/hooks/test_registry.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from strands.agent.interrupt import InterruptState
6-
from strands.hooks import BeforeToolCallEvent, HookRegistry
6+
from strands.hooks import AgentInitializedEvent, BeforeToolCallEvent, HookRegistry
77
from strands.interrupt import Interrupt
88

99

@@ -20,7 +20,7 @@ def agent():
2020

2121

2222
@pytest.mark.asyncio
23-
async def test_hook_registry_invoke_callbacks_interrupt(registry, agent):
23+
async def test_hook_registry_invoke_callbacks_async_interrupt(registry, agent):
2424
event = BeforeToolCallEvent(
2525
agent=agent,
2626
selected_tool=None,
@@ -57,7 +57,7 @@ async def test_hook_registry_invoke_callbacks_interrupt(registry, agent):
5757

5858

5959
@pytest.mark.asyncio
60-
async def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent):
60+
async def test_hook_registry_invoke_callbacks_async_interrupt_name_clash(registry, agent):
6161
event = BeforeToolCallEvent(
6262
agent=agent,
6363
selected_tool=None,
@@ -73,3 +73,11 @@ async def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, age
7373

7474
with pytest.raises(ValueError, match="interrupt_name=<test_name> | interrupt name used more than once"):
7575
await registry.invoke_callbacks_async(event)
76+
77+
78+
def test_hook_registry_invoke_callbacks_coroutine(registry, agent):
79+
callback = unittest.mock.AsyncMock()
80+
registry.add_callback(AgentInitializedEvent, callback)
81+
82+
with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"):
83+
registry.invoke_callbacks(AgentInitializedEvent(agent=agent))

tests_integ/hooks/__init__.py

Whitespace-only changes.

tests_integ/hooks/multiagent/__init__.py

Whitespace-only changes.
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import pytest
2+
3+
from strands import Agent
4+
from strands.experimental.hooks.multiagent import (
5+
AfterMultiAgentInvocationEvent,
6+
AfterNodeCallEvent,
7+
BeforeMultiAgentInvocationEvent,
8+
BeforeNodeCallEvent,
9+
MultiAgentInitializedEvent,
10+
)
11+
from strands.hooks import HookProvider
12+
from strands.multiagent import GraphBuilder, Swarm
13+
14+
15+
@pytest.fixture
16+
def callback_names():
17+
return []
18+
19+
20+
@pytest.fixture
21+
def hook_provider(callback_names):
22+
class TestHook(HookProvider):
23+
def register_hooks(self, registry):
24+
registry.add_callback(AfterMultiAgentInvocationEvent, self.after_multi_agent_invocation)
25+
registry.add_callback(AfterMultiAgentInvocationEvent, self.after_multi_agent_invocation_async)
26+
registry.add_callback(AfterNodeCallEvent, self.after_node_call)
27+
registry.add_callback(AfterNodeCallEvent, self.after_node_call_async)
28+
registry.add_callback(BeforeMultiAgentInvocationEvent, self.before_multi_agent_invocation)
29+
registry.add_callback(BeforeMultiAgentInvocationEvent, self.before_multi_agent_invocation_async)
30+
registry.add_callback(BeforeNodeCallEvent, self.before_node_call)
31+
registry.add_callback(BeforeNodeCallEvent, self.before_node_call_async)
32+
registry.add_callback(MultiAgentInitializedEvent, self.multi_agent_initialized_event)
33+
registry.add_callback(MultiAgentInitializedEvent, self.multi_agent_initialized_event_async)
34+
35+
def after_multi_agent_invocation(self, _event):
36+
callback_names.append("after_multi_agent_invocation")
37+
38+
async def after_multi_agent_invocation_async(self, _event):
39+
callback_names.append("after_multi_agent_invocation_async")
40+
41+
def after_node_call(self, _event):
42+
callback_names.append("after_node_call")
43+
44+
async def after_node_call_async(self, _event):
45+
callback_names.append("after_node_call_async")
46+
47+
def before_multi_agent_invocation(self, _event):
48+
callback_names.append("before_multi_agent_invocation")
49+
50+
async def before_multi_agent_invocation_async(self, _event):
51+
callback_names.append("before_multi_agent_invocation_async")
52+
53+
def before_node_call(self, _event):
54+
callback_names.append("before_node_call")
55+
56+
async def before_node_call_async(self, _event):
57+
callback_names.append("before_node_call_async")
58+
59+
def multi_agent_initialized_event(self, _event):
60+
callback_names.append("multi_agent_initialized_event")
61+
62+
async def multi_agent_initialized_event_async(self, _event):
63+
callback_names.append("multi_agent_initialized_event_async")
64+
65+
return TestHook()
66+
67+
68+
@pytest.fixture
69+
def agent():
70+
return Agent()
71+
72+
73+
@pytest.fixture
74+
def graph(agent, hook_provider):
75+
builder = GraphBuilder()
76+
builder.add_node(agent, "agent")
77+
builder.set_entry_point("agent")
78+
builder.set_hook_providers([hook_provider])
79+
return builder.build()
80+
81+
82+
@pytest.fixture
83+
def swarm(agent, hook_provider):
84+
return Swarm([agent], hooks=[hook_provider])
85+
86+
87+
def test_graph_events(graph, callback_names):
88+
graph("Hello")
89+
90+
tru_callback_names = callback_names
91+
exp_callback_names = [
92+
"multi_agent_initialized_event",
93+
"multi_agent_initialized_event_async",
94+
"before_multi_agent_invocation",
95+
"before_multi_agent_invocation_async",
96+
"before_node_call",
97+
"before_node_call_async",
98+
"after_node_call_async",
99+
"after_node_call",
100+
"after_multi_agent_invocation_async",
101+
"after_multi_agent_invocation",
102+
]
103+
assert tru_callback_names == exp_callback_names
104+
105+
106+
def test_swarm_events(swarm, callback_names):
107+
swarm("Hello")
108+
109+
tru_callback_names = callback_names
110+
exp_callback_names = [
111+
"multi_agent_initialized_event",
112+
"multi_agent_initialized_event_async",
113+
"before_multi_agent_invocation",
114+
"before_multi_agent_invocation_async",
115+
"before_node_call",
116+
"before_node_call_async",
117+
"after_node_call_async",
118+
"after_node_call",
119+
"after_multi_agent_invocation_async",
120+
"after_multi_agent_invocation",
121+
]
122+
assert tru_callback_names == exp_callback_names

tests_integ/hooks/test_events.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import pytest
2+
3+
from strands import Agent, tool
4+
from strands.hooks import (
5+
AfterInvocationEvent,
6+
AfterModelCallEvent,
7+
AfterToolCallEvent,
8+
AgentInitializedEvent,
9+
BeforeInvocationEvent,
10+
BeforeModelCallEvent,
11+
BeforeToolCallEvent,
12+
HookProvider,
13+
MessageAddedEvent,
14+
)
15+
16+
17+
@pytest.fixture
18+
def callback_names():
19+
return []
20+
21+
22+
@pytest.fixture
23+
def hook_provider(callback_names):
24+
class TestHook(HookProvider):
25+
def register_hooks(self, registry):
26+
registry.add_callback(AfterInvocationEvent, self.after_invocation)
27+
registry.add_callback(AfterInvocationEvent, self.after_invocation_async)
28+
registry.add_callback(AfterModelCallEvent, self.after_model_call)
29+
registry.add_callback(AfterModelCallEvent, self.after_model_call_async)
30+
registry.add_callback(AfterToolCallEvent, self.after_tool_call)
31+
registry.add_callback(AfterToolCallEvent, self.after_tool_call_async)
32+
registry.add_callback(AgentInitializedEvent, self.agent_initialized)
33+
registry.add_callback(AgentInitializedEvent, self.agent_initialized_async)
34+
registry.add_callback(BeforeInvocationEvent, self.before_invocation)
35+
registry.add_callback(BeforeInvocationEvent, self.before_invocation_async)
36+
registry.add_callback(BeforeModelCallEvent, self.before_model_call)
37+
registry.add_callback(BeforeModelCallEvent, self.before_model_call_async)
38+
registry.add_callback(BeforeToolCallEvent, self.before_tool_call)
39+
registry.add_callback(BeforeToolCallEvent, self.before_tool_call_async)
40+
registry.add_callback(MessageAddedEvent, self.message_added)
41+
registry.add_callback(MessageAddedEvent, self.message_added_async)
42+
43+
def after_invocation(self, _event):
44+
callback_names.append("after_invocation")
45+
46+
async def after_invocation_async(self, _event):
47+
callback_names.append("after_invocation_async")
48+
49+
def after_model_call(self, _event):
50+
callback_names.append("after_model_call")
51+
52+
async def after_model_call_async(self, _event):
53+
callback_names.append("after_model_call_async")
54+
55+
def after_tool_call(self, _event):
56+
callback_names.append("after_tool_call")
57+
58+
async def after_tool_call_async(self, _event):
59+
callback_names.append("after_tool_call_async")
60+
61+
def agent_initialized(self, _event):
62+
callback_names.append("agent_initialized")
63+
64+
async def agent_initialized_async(self, _event):
65+
callback_names.append("agent_initialized_async")
66+
67+
def before_invocation(self, _event):
68+
callback_names.append("before_invocation")
69+
70+
async def before_invocation_async(self, _event):
71+
callback_names.append("before_invocation_async")
72+
73+
def before_model_call(self, _event):
74+
callback_names.append("before_model_call")
75+
76+
async def before_model_call_async(self, _event):
77+
callback_names.append("before_model_call_async")
78+
79+
def before_tool_call(self, _event):
80+
callback_names.append("before_tool_call")
81+
82+
async def before_tool_call_async(self, _event):
83+
callback_names.append("before_tool_call_async")
84+
85+
def message_added(self, _event):
86+
callback_names.append("message_added")
87+
88+
async def message_added_async(self, _event):
89+
callback_names.append("message_added_async")
90+
91+
return TestHook()
92+
93+
94+
@pytest.fixture
95+
def time_tool():
96+
@tool(name="time_tool")
97+
def tool_() -> str:
98+
return "12:00"
99+
100+
return tool_
101+
102+
103+
@pytest.fixture
104+
def agent(hook_provider, time_tool):
105+
return Agent(hooks=[hook_provider], tools=[time_tool])
106+
107+
108+
def test_events(agent, callback_names):
109+
agent("What time is it?")
110+
111+
tru_callback_names = callback_names
112+
exp_callback_names = [
113+
"agent_initialized",
114+
"agent_initialized_async",
115+
"before_invocation",
116+
"before_invocation_async",
117+
"message_added",
118+
"message_added_async",
119+
"before_model_call",
120+
"before_model_call_async",
121+
"after_model_call_async",
122+
"after_model_call",
123+
"message_added",
124+
"message_added_async",
125+
"before_tool_call",
126+
"before_tool_call_async",
127+
"after_tool_call_async",
128+
"after_tool_call",
129+
"message_added",
130+
"message_added_async",
131+
"before_model_call",
132+
"before_model_call_async",
133+
"after_model_call_async",
134+
"after_model_call",
135+
"message_added",
136+
"message_added_async",
137+
"after_invocation_async",
138+
"after_invocation",
139+
]
140+
assert tru_callback_names == exp_callback_names

0 commit comments

Comments
 (0)