Skip to content

Commit 1b46b12

Browse files
committed
share interrupt state
1 parent 1df45be commit 1b46b12

File tree

13 files changed

+220
-181
lines changed

13 files changed

+220
-181
lines changed

src/strands/agent/agent.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
HookRegistry,
4747
MessageAddedEvent,
4848
)
49+
from ..interrupt import _InterruptState
4950
from ..models.bedrock import BedrockModel
5051
from ..models.model import Model
5152
from ..session.session_manager import SessionManager
@@ -60,15 +61,13 @@
6061
from ..types.agent import AgentInput
6162
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
6263
from ..types.exceptions import ContextWindowOverflowException
63-
from ..types.interrupt import InterruptResponseContent
6464
from ..types.tools import ToolResult, ToolUse
6565
from ..types.traces import AttributeValue
6666
from .agent_result import AgentResult
6767
from .conversation_manager import (
6868
ConversationManager,
6969
SlidingWindowConversationManager,
7070
)
71-
from .interrupt import InterruptState
7271
from .state import AgentState
7372

7473
logger = logging.getLogger(__name__)
@@ -353,7 +352,7 @@ def __init__(
353352

354353
self.hooks = HookRegistry()
355354

356-
self._interrupt_state = InterruptState()
355+
self._interrupt_state = _InterruptState()
357356

358357
# Initialize session management functionality
359358
self._session_manager = session_manager
@@ -641,7 +640,7 @@ async def stream_async(
641640
yield event["data"]
642641
```
643642
"""
644-
self._resume_interrupt(prompt)
643+
self._interrupt_state.resume(prompt)
645644

646645
merged_state = {}
647646
if kwargs:
@@ -684,38 +683,6 @@ async def stream_async(
684683
self._end_agent_trace_span(error=e)
685684
raise
686685

687-
def _resume_interrupt(self, prompt: AgentInput) -> None:
688-
"""Configure the interrupt state if resuming from an interrupt event.
689-
690-
Args:
691-
prompt: User responses if resuming from interrupt.
692-
693-
Raises:
694-
TypeError: If in interrupt state but user did not provide responses.
695-
"""
696-
if not self._interrupt_state.activated:
697-
return
698-
699-
if not isinstance(prompt, list):
700-
raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's")
701-
702-
invalid_types = [
703-
content_type for content in prompt for content_type in content if content_type != "interruptResponse"
704-
]
705-
if invalid_types:
706-
raise TypeError(
707-
f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's"
708-
)
709-
710-
for content in cast(list[InterruptResponseContent], prompt):
711-
interrupt_id = content["interruptResponse"]["interruptId"]
712-
interrupt_response = content["interruptResponse"]["response"]
713-
714-
if interrupt_id not in self._interrupt_state.interrupts:
715-
raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found")
716-
717-
self._interrupt_state.interrupts[interrupt_id].response = interrupt_response
718-
719686
async def _run_loop(
720687
self,
721688
messages: Messages,

src/strands/agent/interrupt.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

src/strands/interrupt.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Human-in-the-loop interrupt system for agent workflows."""
22

3-
from dataclasses import asdict, dataclass
4-
from typing import Any
3+
from dataclasses import asdict, dataclass, field
4+
from typing import TYPE_CHECKING, Any, cast
5+
6+
if TYPE_CHECKING:
7+
from .types.agent import AgentInput
8+
from .types.interrupt import InterruptResponseContent
59

610

711
@dataclass
@@ -31,3 +35,89 @@ class InterruptException(Exception):
3135
def __init__(self, interrupt: Interrupt) -> None:
3236
"""Set the interrupt."""
3337
self.interrupt = interrupt
38+
39+
40+
@dataclass
41+
class _InterruptState:
42+
"""Track the state of interrupt events raised by the user.
43+
44+
Note, interrupt state is cleared after resuming.
45+
46+
Attributes:
47+
interrupts: Interrupts raised by the user.
48+
context: Additional context associated with an interrupt event.
49+
activated: True if agent is in an interrupt state, False otherwise.
50+
"""
51+
52+
interrupts: dict[str, Interrupt] = field(default_factory=dict)
53+
context: dict[str, Any] = field(default_factory=dict)
54+
activated: bool = False
55+
56+
def activate(self, context: dict[str, Any] | None = None) -> None:
57+
"""Activate the interrupt state.
58+
59+
Args:
60+
context: Context associated with the interrupt event.
61+
"""
62+
self.context = context or {}
63+
self.activated = True
64+
65+
def deactivate(self) -> None:
66+
"""Deacitvate the interrupt state.
67+
68+
Interrupts and context are cleared.
69+
"""
70+
self.interrupts = {}
71+
self.context = {}
72+
self.activated = False
73+
74+
def resume(self, prompt: "AgentInput") -> None:
75+
"""Configure the interrupt state if resuming from an interrupt event.
76+
77+
Args:
78+
prompt: User responses if resuming from interrupt.
79+
80+
Raises:
81+
TypeError: If in interrupt state but user did not provide responses.
82+
"""
83+
if not self.activated:
84+
return
85+
86+
if not isinstance(prompt, list):
87+
raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's")
88+
89+
invalid_types = [
90+
content_type for content in prompt for content_type in content if content_type != "interruptResponse"
91+
]
92+
if invalid_types:
93+
raise TypeError(
94+
f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's"
95+
)
96+
97+
contents = cast(list["InterruptResponseContent"], prompt)
98+
for content in contents:
99+
interrupt_id = content["interruptResponse"]["interruptId"]
100+
interrupt_response = content["interruptResponse"]["response"]
101+
102+
if interrupt_id not in self.interrupts:
103+
raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found")
104+
105+
self.interrupts[interrupt_id].response = interrupt_response
106+
107+
def to_dict(self) -> dict[str, Any]:
108+
"""Serialize to dict for session management."""
109+
return asdict(self)
110+
111+
@classmethod
112+
def from_dict(cls, data: dict[str, Any]) -> "_InterruptState":
113+
"""Initiailize interrupt state from serialized interrupt state.
114+
115+
Interrupt state can be serialized with the `to_dict` method.
116+
"""
117+
return cls(
118+
interrupts={
119+
interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items()
120+
},
121+
context=data["context"],
122+
activated=data["activated"],
123+
)

src/strands/types/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from enum import Enum
88
from typing import TYPE_CHECKING, Any, Optional
99

10-
from ..agent.interrupt import InterruptState
10+
from ..interrupt import _InterruptState
1111
from .content import Message
1212

1313
if TYPE_CHECKING:
@@ -148,7 +148,7 @@ def to_dict(self) -> dict[str, Any]:
148148
def initialize_internal_state(self, agent: "Agent") -> None:
149149
"""Initialize internal state of agent."""
150150
if "interrupt_state" in self._internal_state:
151-
agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"])
151+
agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"])
152152

153153

154154
@dataclass

tests/strands/agent/test_interrupt.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

tests/strands/event_loop/test_event_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66

77
import strands
88
import strands.telemetry
9-
from strands.agent.interrupt import InterruptState
109
from strands.hooks import (
1110
AfterModelCallEvent,
1211
BeforeModelCallEvent,
1312
BeforeToolCallEvent,
1413
HookRegistry,
1514
MessageAddedEvent,
1615
)
17-
from strands.interrupt import Interrupt
16+
from strands.interrupt import Interrupt, _InterruptState
1817
from strands.telemetry.metrics import EventLoopMetrics
1918
from strands.tools.executors import SequentialToolExecutor
2019
from strands.tools.registry import ToolRegistry
@@ -143,7 +142,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis
143142
mock.event_loop_metrics = EventLoopMetrics()
144143
mock.hooks = hook_registry
145144
mock.tool_executor = tool_executor
146-
mock._interrupt_state = InterruptState()
145+
mock._interrupt_state = _InterruptState()
147146

148147
return mock
149148

tests/strands/hooks/test_registry.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import pytest
44

5-
from strands.agent.interrupt import InterruptState
65
from strands.hooks import BeforeToolCallEvent, HookRegistry
7-
from strands.interrupt import Interrupt
6+
from strands.interrupt import Interrupt, _InterruptState
87

98

109
@pytest.fixture
@@ -15,7 +14,7 @@ def registry():
1514
@pytest.fixture
1615
def agent():
1716
instance = unittest.mock.Mock()
18-
instance._interrupt_state = InterruptState()
17+
instance._interrupt_state = _InterruptState()
1918
return instance
2019

2120

tests/strands/session/test_repository_session_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from strands.agent.agent import Agent
88
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
99
from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager
10-
from strands.agent.interrupt import InterruptState
10+
from strands.interrupt import _InterruptState
1111
from strands.session.repository_session_manager import RepositorySessionManager
1212
from strands.types.content import ContentBlock
1313
from strands.types.exceptions import SessionException
@@ -131,7 +131,7 @@ def test_initialize_restores_existing_agent(session_manager, agent):
131131
assert len(agent.messages) == 1
132132
assert agent.messages[0]["role"] == "user"
133133
assert agent.messages[0]["content"][0]["text"] == "Hello"
134-
assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False)
134+
assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False)
135135

136136

137137
def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager):

0 commit comments

Comments
 (0)