Skip to content

Commit 2b0c6e6

Browse files
authored
async hooks (#1119)
1 parent c250fc0 commit 2b0c6e6

File tree

15 files changed

+419
-71
lines changed

15 files changed

+419
-71
lines changed

src/strands/agent/agent.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,21 @@ async def acall() -> ToolResult:
171171
self._agent._interrupt_state.deactivate()
172172
raise RuntimeError("cannot raise interrupt in direct tool call")
173173

174-
return tool_results[0]
174+
tool_result = tool_results[0]
175175

176-
tool_result = run_async(acall)
176+
if record_direct_tool_call is not None:
177+
should_record_direct_tool_call = record_direct_tool_call
178+
else:
179+
should_record_direct_tool_call = self._agent.record_direct_tool_call
177180

178-
if record_direct_tool_call is not None:
179-
should_record_direct_tool_call = record_direct_tool_call
180-
else:
181-
should_record_direct_tool_call = self._agent.record_direct_tool_call
181+
if should_record_direct_tool_call:
182+
# Create a record of this tool execution in the message history
183+
await self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
182184

183-
if should_record_direct_tool_call:
184-
# Create a record of this tool execution in the message history
185-
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
185+
return tool_result
186186

187-
# Apply window management
187+
tool_result = run_async(acall)
188188
self._agent.conversation_manager.apply_management(self._agent)
189-
190189
return tool_result
191190

192191
return caller
@@ -534,15 +533,15 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
534533
category=DeprecationWarning,
535534
stacklevel=2,
536535
)
537-
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
536+
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
538537
with self.tracer.tracer.start_as_current_span(
539538
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
540539
) as structured_output_span:
541540
try:
542541
if not self.messages and not prompt:
543542
raise ValueError("No conversation history or prompt provided")
544543

545-
temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)
544+
temp_messages: Messages = self.messages + await self._convert_prompt_to_messages(prompt)
546545

547546
structured_output_span.set_attributes(
548547
{
@@ -575,7 +574,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
575574
return event["output"]
576575

577576
finally:
578-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
577+
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))
579578

580579
def cleanup(self) -> None:
581580
"""Clean up resources used by the agent.
@@ -658,7 +657,7 @@ async def stream_async(
658657
callback_handler = kwargs.get("callback_handler", self.callback_handler)
659658

660659
# Process input and get message to add (if any)
661-
messages = self._convert_prompt_to_messages(prompt)
660+
messages = await self._convert_prompt_to_messages(prompt)
662661

663662
self.trace_span = self._start_agent_trace_span(messages)
664663

@@ -732,13 +731,13 @@ async def _run_loop(
732731
Yields:
733732
Events from the event loop cycle.
734733
"""
735-
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
734+
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
736735

737736
try:
738737
yield InitEventLoopEvent()
739738

740739
for message in messages:
741-
self._append_message(message)
740+
await self._append_message(message)
742741

743742
structured_output_context = StructuredOutputContext(
744743
structured_output_model or self._default_structured_output_model
@@ -764,7 +763,7 @@ async def _run_loop(
764763

765764
finally:
766765
self.conversation_manager.apply_management(self)
767-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
766+
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))
768767

769768
async def _execute_event_loop_cycle(
770769
self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None
@@ -813,7 +812,7 @@ async def _execute_event_loop_cycle(
813812
if structured_output_context:
814813
structured_output_context.cleanup(self.tool_registry)
815814

816-
def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
815+
async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
817816
if self._interrupt_state.activated:
818817
return []
819818

@@ -828,7 +827,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
828827
tool_use_ids = [
829828
content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content
830829
]
831-
self._append_message(
830+
await self._append_message(
832831
{
833832
"role": "user",
834833
"content": generate_missing_tool_result_content(tool_use_ids),
@@ -859,7 +858,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
859858
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
860859
return messages
861860

862-
def _record_tool_execution(
861+
async def _record_tool_execution(
863862
self,
864863
tool: ToolUse,
865864
tool_result: ToolResult,
@@ -919,10 +918,10 @@ def _record_tool_execution(
919918
}
920919

921920
# Add to message history
922-
self._append_message(user_msg)
923-
self._append_message(tool_use_msg)
924-
self._append_message(tool_result_msg)
925-
self._append_message(assistant_msg)
921+
await self._append_message(user_msg)
922+
await self._append_message(tool_use_msg)
923+
await self._append_message(tool_result_msg)
924+
await self._append_message(assistant_msg)
926925

927926
def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
928927
"""Starts a trace span for the agent.
@@ -1008,10 +1007,10 @@ def _initialize_system_prompt(
10081007
else:
10091008
return None, None
10101009

1011-
def _append_message(self, message: Message) -> None:
1010+
async def _append_message(self, message: Message) -> None:
10121011
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
10131012
self.messages.append(message)
1014-
self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))
1013+
await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message))
10151014

10161015
def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]:
10171016
"""Redact user content preserving toolResult blocks.

src/strands/event_loop/event_loop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ async def event_loop_cycle(
227227
)
228228
structured_output_context.set_forced_mode()
229229
logger.debug("Forcing structured output tool")
230-
agent._append_message(
230+
await agent._append_message(
231231
{"role": "user", "content": [{"text": "You must format the previous response as structured output."}]}
232232
)
233233

@@ -322,7 +322,7 @@ async def _handle_model_execution(
322322
model_id=model_id,
323323
)
324324
with trace_api.use_span(model_invoke_span):
325-
agent.hooks.invoke_callbacks(
325+
await agent.hooks.invoke_callbacks_async(
326326
BeforeModelCallEvent(
327327
agent=agent,
328328
)
@@ -347,7 +347,7 @@ async def _handle_model_execution(
347347
stop_reason, message, usage, metrics = event["stop"]
348348
invocation_state.setdefault("request_state", {})
349349

350-
agent.hooks.invoke_callbacks(
350+
await agent.hooks.invoke_callbacks_async(
351351
AfterModelCallEvent(
352352
agent=agent,
353353
stop_response=AfterModelCallEvent.ModelStopResponse(
@@ -368,7 +368,7 @@ async def _handle_model_execution(
368368
if model_invoke_span:
369369
tracer.end_span_with_error(model_invoke_span, str(e), e)
370370

371-
agent.hooks.invoke_callbacks(
371+
await agent.hooks.invoke_callbacks_async(
372372
AfterModelCallEvent(
373373
agent=agent,
374374
exception=e,
@@ -402,7 +402,7 @@ async def _handle_model_execution(
402402

403403
# Add the response message to the conversation
404404
agent.messages.append(message)
405-
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
405+
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=message))
406406

407407
# Update metrics
408408
agent.event_loop_metrics.update_usage(usage)
@@ -507,7 +507,7 @@ async def _handle_tool_execution(
507507
}
508508

509509
agent.messages.append(tool_result_message)
510-
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
510+
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=tool_result_message))
511511

512512
yield ToolResultMessageEvent(message=tool_result_message)
513513

src/strands/hooks/registry.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
via hook provider objects.
88
"""
99

10+
import inspect
1011
import logging
1112
from dataclasses import dataclass
12-
from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar
13+
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
1314

1415
from ..interrupt import Interrupt, InterruptException
1516

@@ -122,10 +123,15 @@ class HookCallback(Protocol, Generic[TEvent]):
122123
```python
123124
def my_callback(event: StartRequestEvent) -> None:
124125
print(f"Request started for agent: {event.agent.name}")
126+
127+
# Or
128+
129+
async def my_callback(event: StartRequestEvent) -> None:
130+
# await an async operation
125131
```
126132
"""
127133

128-
def __call__(self, event: TEvent) -> None:
134+
def __call__(self, event: TEvent) -> None | Awaitable[None]:
129135
"""Handle a hook event.
130136
131137
Args:
@@ -164,6 +170,10 @@ def my_handler(event: StartRequestEvent):
164170
registry.add_callback(StartRequestEvent, my_handler)
165171
```
166172
"""
173+
# Related issue: https://github.com/strands-agents/sdk-python/issues/330
174+
if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
175+
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")
176+
167177
callbacks = self._registered_callbacks.setdefault(event_type, [])
168178
callbacks.append(callback)
169179

@@ -189,6 +199,52 @@ def register_hooks(self, registry: HookRegistry):
189199
"""
190200
hook.register_hooks(self)
191201

202+
async def invoke_callbacks_async(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]:
203+
"""Invoke all registered callbacks for the given event.
204+
205+
This method finds all callbacks registered for the event's type and
206+
invokes them in the appropriate order. For events with should_reverse_callbacks=True,
207+
callbacks are invoked in reverse registration order. Any exceptions raised by callback
208+
functions will propagate to the caller.
209+
210+
Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows.
211+
212+
Args:
213+
event: The event to dispatch to registered callbacks.
214+
215+
Returns:
216+
The event dispatched to registered callbacks and any interrupts raised by the user.
217+
218+
Raises:
219+
ValueError: If interrupt name is used more than once.
220+
221+
Example:
222+
```python
223+
event = StartRequestEvent(agent=my_agent)
224+
await registry.invoke_callbacks_async(event)
225+
```
226+
"""
227+
interrupts: dict[str, Interrupt] = {}
228+
229+
for callback in self.get_callbacks_for(event):
230+
try:
231+
if inspect.iscoroutinefunction(callback):
232+
await callback(event)
233+
else:
234+
callback(event)
235+
236+
except InterruptException as exception:
237+
interrupt = exception.interrupt
238+
if interrupt.name in interrupts:
239+
message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once"
240+
logger.error(message)
241+
raise ValueError(message) from exception
242+
243+
# Each callback is allowed to raise their own interrupt.
244+
interrupts[interrupt.name] = interrupt
245+
246+
return event, list(interrupts.values())
247+
192248
def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]:
193249
"""Invoke all registered callbacks for the given event.
194250
@@ -206,6 +262,7 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte
206262
The event dispatched to registered callbacks and any interrupts raised by the user.
207263
208264
Raises:
265+
RuntimeError: If at least one callback is async.
209266
ValueError: If interrupt name is used more than once.
210267
211268
Example:
@@ -214,9 +271,13 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte
214271
registry.invoke_callbacks(event)
215272
```
216273
"""
274+
callbacks = list(self.get_callbacks_for(event))
217275
interrupts: dict[str, Interrupt] = {}
218276

219-
for callback in self.get_callbacks_for(event):
277+
if any(inspect.iscoroutinefunction(callback) for callback in callbacks):
278+
raise RuntimeError(f"event=<{event}> | use invoke_callbacks_async to invoke async callback")
279+
280+
for callback in callbacks:
220281
try:
221282
callback(event)
222283
except InterruptException as exception:

src/strands/multiagent/graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def __init__(
453453
self._resume_from_session = False
454454
self.id = id
455455

456-
self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))
456+
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
457457

458458
def __call__(
459459
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
@@ -516,7 +516,7 @@ async def stream_async(
516516
if invocation_state is None:
517517
invocation_state = {}
518518

519-
self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state))
519+
await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state))
520520

521521
logger.debug("task=<%s> | starting graph execution", task)
522522

@@ -569,7 +569,7 @@ async def stream_async(
569569
raise
570570
finally:
571571
self.state.execution_time = round((time.time() - start_time) * 1000)
572-
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self))
572+
await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self))
573573
self._resume_from_session = False
574574
self._resume_next_nodes.clear()
575575

@@ -776,7 +776,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
776776

777777
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
778778
"""Execute a single node and yield TypedEvent objects."""
779-
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state))
779+
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
780780

781781
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
782782
if self.reset_on_revisit and node in self.state.completed_nodes:
@@ -920,7 +920,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
920920
raise
921921

922922
finally:
923-
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))
923+
await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state))
924924

925925
def _accumulate_metrics(self, node_result: NodeResult) -> None:
926926
"""Accumulate metrics from a node result."""

0 commit comments

Comments
 (0)