Skip to content

Commit a37d3ed

Browse files
committed
call AgentInitializedEvent hooks synchronously
1 parent 083af6f commit a37d3ed

File tree

4 files changed

+15
-12
lines changed

4 files changed

+15
-12
lines changed

src/strands/agent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def __init__(
365365
for hook in hooks:
366366
self.hooks.add_hook(hook)
367367

368-
run_async(lambda: self.hooks.invoke_callbacks_async(AgentInitializedEvent(agent=self)))
368+
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
369369

370370
@property
371371
def tool(self) -> ToolCaller:

src/strands/hooks/registry.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import inspect
1111
import logging
12-
import warnings
1312
from dataclasses import dataclass
1413
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
1514

@@ -171,6 +170,9 @@ def my_handler(event: StartRequestEvent):
171170
registry.add_callback(StartRequestEvent, my_handler)
172171
```
173172
"""
173+
if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
174+
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")
175+
174176
callbacks = self._registered_callbacks.setdefault(event_type, [])
175177
callbacks.append(callback)
176178

@@ -268,11 +270,7 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte
268270
registry.invoke_callbacks(event)
269271
```
270272
"""
271-
warnings.warn(
272-
"invoke_callbacks is deprecated and replaced by invoke_callbacks_async", DeprecationWarning, stacklevel=2
273-
)
274-
275-
callbacks = self.get_callbacks_for(event)
273+
callbacks = list(self.get_callbacks_for(event))
276274
interrupts: dict[str, Interrupt] = {}
277275

278276
if any(inspect.iscoroutinefunction(callback) for callback in callbacks):

tests/strands/hooks/test_registry.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from strands.agent.interrupt import InterruptState
6-
from strands.hooks import AgentInitializedEvent, BeforeToolCallEvent, HookRegistry
6+
from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry
77
from strands.interrupt import Interrupt
88

99

@@ -19,6 +19,13 @@ def agent():
1919
return instance
2020

2121

22+
def test_hook_registry_add_callback_agent_init_coroutine(registry):
23+
callback = unittest.mock.AsyncMock()
24+
25+
with pytest.raises(ValueError, match=r"AgentInitializedEvent can only be registered with a synchronous callback"):
26+
registry.add_callback(AgentInitializedEvent, callback)
27+
28+
2229
@pytest.mark.asyncio
2330
async def test_hook_registry_invoke_callbacks_async_interrupt(registry, agent):
2431
event = BeforeToolCallEvent(
@@ -77,7 +84,7 @@ async def test_hook_registry_invoke_callbacks_async_interrupt_name_clash(registr
7784

7885
def test_hook_registry_invoke_callbacks_coroutine(registry, agent):
7986
callback = unittest.mock.AsyncMock()
80-
registry.add_callback(AgentInitializedEvent, callback)
87+
registry.add_callback(BeforeInvocationEvent, callback)
8188

8289
with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"):
83-
registry.invoke_callbacks(AgentInitializedEvent(agent=agent))
90+
registry.invoke_callbacks(BeforeInvocationEvent(agent=agent))

tests_integ/hooks/test_events.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def register_hooks(self, registry):
3030
registry.add_callback(AfterToolCallEvent, self.after_tool_call)
3131
registry.add_callback(AfterToolCallEvent, self.after_tool_call_async)
3232
registry.add_callback(AgentInitializedEvent, self.agent_initialized)
33-
registry.add_callback(AgentInitializedEvent, self.agent_initialized_async)
3433
registry.add_callback(BeforeInvocationEvent, self.before_invocation)
3534
registry.add_callback(BeforeInvocationEvent, self.before_invocation_async)
3635
registry.add_callback(BeforeModelCallEvent, self.before_model_call)
@@ -111,7 +110,6 @@ def test_events(agent, callback_names):
111110
tru_callback_names = callback_names
112111
exp_callback_names = [
113112
"agent_initialized",
114-
"agent_initialized_async",
115113
"before_invocation",
116114
"before_invocation_async",
117115
"message_added",

0 commit comments

Comments
 (0)