Skip to content

Commit 8c88772

Browse files
committed
feat(agent-interface): introduce AgentBase abstract class as the interface for agent classes to implement
1 parent db671ba commit 8c88772

File tree

5 files changed

+193
-8
lines changed

5 files changed

+193
-8
lines changed

src/strands/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from . import agent, models, telemetry, types
44
from .agent.agent import Agent
5+
from .agent.base import AgentBase
56
from .tools.decorator import tool
67
from .types.tools import ToolContext
78

89
__all__ = [
910
"Agent",
11+
"AgentBase",
1012
"agent",
1113
"models",
1214
"tool",

src/strands/agent/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
33
It includes:
44
5-
- Agent: The main interface for interacting with AI models and tools
5+
- AgentBase: Abstract interface for all agent types
6+
- Agent: The main implementation for interacting with AI models and tools
67
- ConversationManager: Classes for managing conversation history and context windows
78
"""
89

910
from .agent import Agent
1011
from .agent_result import AgentResult
12+
from .base import AgentBase
1113
from .conversation_manager import (
1214
ConversationManager,
1315
NullConversationManager,
@@ -17,6 +19,7 @@
1719

1820
__all__ = [
1921
"Agent",
22+
"AgentBase",
2023
"AgentResult",
2124
"ConversationManager",
2225
"NullConversationManager",

src/strands/agent/agent.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from ..types.tools import ToolResult, ToolUse
6464
from ..types.traces import AttributeValue
6565
from .agent_result import AgentResult
66+
from .base import AgentBase
6667
from .conversation_manager import (
6768
ConversationManager,
6869
SlidingWindowConversationManager,
@@ -88,8 +89,8 @@ class _DefaultCallbackHandlerSentinel:
8889
_DEFAULT_AGENT_ID = "default"
8990

9091

91-
class Agent:
92-
"""Core Agent interface.
92+
class Agent(AgentBase):
93+
"""Core Agent implementation.
9394
9495
An agent orchestrates the following workflow:
9596
@@ -289,8 +290,8 @@ def __init__(
289290
self.messages = messages if messages is not None else []
290291
self.system_prompt = system_prompt
291292
self._default_structured_output_model = structured_output_model
292-
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
293-
self.name = name or _DEFAULT_AGENT_NAME
293+
self._agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
294+
self._name = name or _DEFAULT_AGENT_NAME
294295
self.description = description
295296

296297
# If not provided, create a new PrintingCallbackHandler instance
@@ -338,13 +339,13 @@ def __init__(
338339
# Initialize agent state management
339340
if state is not None:
340341
if isinstance(state, dict):
341-
self.state = AgentState(state)
342+
self._state = AgentState(state)
342343
elif isinstance(state, AgentState):
343-
self.state = state
344+
self._state = state
344345
else:
345346
raise ValueError("state must be an AgentState object or a dict")
346347
else:
347-
self.state = AgentState()
348+
self._state = AgentState()
348349

349350
self.tool_caller = Agent.ToolCaller(self)
350351

@@ -389,6 +390,60 @@ def tool_names(self) -> list[str]:
389390
all_tools = self.tool_registry.get_all_tools_config()
390391
return list(all_tools.keys())
391392

393+
@property
394+
def agent_id(self) -> str:
395+
"""Unique identifier for the agent.
396+
397+
Returns:
398+
Unique string identifier for this agent instance.
399+
"""
400+
return self._agent_id
401+
402+
@agent_id.setter
403+
def agent_id(self, value: str) -> None:
404+
"""Set the agent identifier.
405+
406+
Args:
407+
value: New agent identifier.
408+
"""
409+
self._agent_id = value
410+
411+
@property
412+
def name(self) -> str:
413+
"""Human-readable name of the agent.
414+
415+
Returns:
416+
Display name for the agent.
417+
"""
418+
return self._name
419+
420+
@name.setter
421+
def name(self, value: str) -> None:
422+
"""Set the agent name.
423+
424+
Args:
425+
value: New agent name.
426+
"""
427+
self._name = value
428+
429+
@property
430+
def state(self) -> AgentState:
431+
"""Current state of the agent.
432+
433+
Returns:
434+
AgentState object containing stateful information.
435+
"""
436+
return self._state
437+
438+
@state.setter
439+
def state(self, value: AgentState) -> None:
440+
"""Set the agent state.
441+
442+
Args:
443+
value: New agent state.
444+
"""
445+
self._state = value
446+
392447
def __call__(
393448
self,
394449
prompt: AgentInput = None,

src/strands/agent/base.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Agent Interface.
2+
3+
Defines the minimal interface that all agent types must implement.
4+
"""
5+
6+
from abc import ABC, abstractmethod
7+
from typing import Any, AsyncIterator, Type
8+
9+
from pydantic import BaseModel
10+
11+
from ..types.agent import AgentInput
12+
from .agent_result import AgentResult
13+
from .state import AgentState
14+
15+
16+
class AgentBase(ABC):
17+
"""Abstract interface for all agent types in Strands.
18+
19+
This interface defines the minimal contract that all agent implementations
20+
must satisfy.
21+
"""
22+
23+
@property
24+
@abstractmethod
25+
def agent_id(self) -> str:
26+
"""Unique identifier for the agent.
27+
28+
Returns:
29+
Unique string identifier for this agent instance.
30+
"""
31+
pass
32+
33+
@property
34+
@abstractmethod
35+
def name(self) -> str:
36+
"""Human-readable name of the agent.
37+
38+
Returns:
39+
Display name for the agent.
40+
"""
41+
pass
42+
43+
@property
44+
@abstractmethod
45+
def state(self) -> AgentState:
46+
"""Current state of the agent.
47+
48+
Returns:
49+
AgentState object containing stateful information.
50+
"""
51+
pass
52+
53+
@abstractmethod
54+
async def invoke_async(
55+
self,
56+
prompt: AgentInput = None,
57+
*,
58+
invocation_state: dict[str, Any] | None = None,
59+
structured_output_model: Type[BaseModel] | None = None,
60+
**kwargs: Any,
61+
) -> AgentResult:
62+
"""Asynchronously invoke the agent with the given prompt.
63+
64+
Args:
65+
prompt: Input to the agent.
66+
invocation_state: Optional state to pass to the agent invocation.
67+
structured_output_model: Optional Pydantic model for structured output.
68+
**kwargs: Additional provider-specific arguments.
69+
70+
Returns:
71+
AgentResult containing the agent's response.
72+
"""
73+
pass
74+
75+
@abstractmethod
76+
def __call__(
77+
self,
78+
prompt: AgentInput = None,
79+
*,
80+
invocation_state: dict[str, Any] | None = None,
81+
structured_output_model: Type[BaseModel] | None = None,
82+
**kwargs: Any,
83+
) -> AgentResult:
84+
"""Synchronously invoke the agent with the given prompt.
85+
86+
Args:
87+
prompt: Input to the agent.
88+
invocation_state: Optional state to pass to the agent invocation.
89+
structured_output_model: Optional Pydantic model for structured output.
90+
**kwargs: Additional provider-specific arguments.
91+
92+
Returns:
93+
AgentResult containing the agent's response.
94+
"""
95+
pass
96+
97+
@abstractmethod
98+
def stream_async(
99+
self,
100+
prompt: AgentInput = None,
101+
*,
102+
invocation_state: dict[str, Any] | None = None,
103+
structured_output_model: Type[BaseModel] | None = None,
104+
**kwargs: Any,
105+
) -> AsyncIterator[Any]:
106+
"""Stream agent execution asynchronously.
107+
108+
Args:
109+
prompt: Input to the agent.
110+
invocation_state: Optional state to pass to the agent invocation.
111+
structured_output_model: Optional Pydantic model for structured output.
112+
**kwargs: Additional provider-specific arguments.
113+
114+
Yields:
115+
Events representing the streaming execution.
116+
"""
117+
pass

tests/strands/multiagent/test_graph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
1717
"""Create a mock Agent with specified properties."""
1818
agent = Mock(spec=Agent)
1919
agent.name = name
20+
agent.agent_id = agent_id or f"{name}_id"
2021
agent.id = agent_id or f"{name}_id"
2122
agent._session_manager = None
2223
agent.hooks = HookRegistry()
24+
agent.state = AgentState()
2325

2426
if metrics is None:
2527
metrics = Mock(
@@ -280,12 +282,14 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span)
280282
"""Test graph execution error handling and failure propagation."""
281283
failing_agent = Mock(spec=Agent)
282284
failing_agent.name = "failing_agent"
285+
failing_agent.agent_id = "fail_node"
283286
failing_agent.id = "fail_node"
284287
failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure"))
285288

286289
# Add required attributes for validation
287290
failing_agent._session_manager = None
288291
failing_agent.hooks = HookRegistry()
292+
failing_agent.state = AgentState()
289293

290294
async def mock_invoke_failure(*args, **kwargs):
291295
raise Exception("Simulated failure")
@@ -1524,9 +1528,11 @@ async def test_graph_streaming_with_failures(mock_strands_tracer, mock_use_span)
15241528
# Create a failing agent
15251529
failing_agent = Mock(spec=Agent)
15261530
failing_agent.name = "failing_agent"
1531+
failing_agent.agent_id = "fail_node"
15271532
failing_agent.id = "fail_node"
15281533
failing_agent._session_manager = None
15291534
failing_agent.hooks = HookRegistry()
1535+
failing_agent.state = AgentState()
15301536

15311537
async def failing_stream(*args, **kwargs):
15321538
yield {"agent_start": True}
@@ -1697,9 +1703,11 @@ async def test_graph_parallel_with_failures(mock_strands_tracer, mock_use_span):
16971703
# Create a failing agent
16981704
failing_agent = Mock(spec=Agent)
16991705
failing_agent.name = "failing_agent"
1706+
failing_agent.agent_id = "fail_node"
17001707
failing_agent.id = "fail_node"
17011708
failing_agent._session_manager = None
17021709
failing_agent.hooks = HookRegistry()
1710+
failing_agent.state = AgentState()
17031711

17041712
async def mock_invoke_failure(*args, **kwargs):
17051713
await asyncio.sleep(0.05) # Small delay

0 commit comments

Comments
 (0)