Skip to content

Commit 49e432d

Browse files
authored
mcp elicitation (#1094)
1 parent 071f89f commit 49e432d

File tree

3 files changed

+92
-4
lines changed

3 files changed

+92
-4
lines changed

src/strands/tools/mcp/mcp_client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import anyio
2222
from mcp import ClientSession, ListToolsResult
23+
from mcp.client.session import ElicitationFnT
2324
from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents
2425
from mcp.types import CallToolResult as MCPCallToolResult
2526
from mcp.types import EmbeddedResource as MCPEmbeddedResource
@@ -98,19 +99,22 @@ def __init__(
9899
startup_timeout: int = 30,
99100
tool_filters: ToolFilters | None = None,
100101
prefix: str | None = None,
101-
):
102+
elicitation_callback: Optional[ElicitationFnT] = None,
103+
) -> None:
102104
"""Initialize a new MCP Server connection.
103105
104106
Args:
105-
transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple
106-
startup_timeout: Timeout after which MCP server initialization should be cancelled
107+
transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple.
108+
startup_timeout: Timeout after which MCP server initialization should be cancelled.
107109
Defaults to 30.
108110
tool_filters: Optional filters to apply to tools.
109111
prefix: Optional prefix for tool names.
112+
elicitation_callback: Optional callback function to handle elicitation requests from the MCP server.
110113
"""
111114
self._startup_timeout = startup_timeout
112115
self._tool_filters = tool_filters
113116
self._prefix = prefix
117+
self._elicitation_callback = elicitation_callback
114118

115119
mcp_instrumentation()
116120
self._session_id = uuid.uuid4()
@@ -563,7 +567,10 @@ async def _async_background_thread(self) -> None:
563567
async with self._transport_callable() as (read_stream, write_stream, *_):
564568
self._log_debug_with_thread("transport connection established")
565569
async with ClientSession(
566-
read_stream, write_stream, message_handler=self._handle_error_message
570+
read_stream,
571+
write_stream,
572+
message_handler=self._handle_error_message,
573+
elicitation_callback=self._elicitation_callback,
567574
) as session:
568575
self._log_debug_with_thread("initializing MCP session")
569576
await session.initialize()
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""MCP server for testing elicitation.
2+
3+
- Docs: https://modelcontextprotocol.io/specification/draft/client/elicitation
4+
"""
5+
6+
from mcp.server import FastMCP
7+
from mcp.types import ElicitRequest, ElicitRequestParams, ElicitResult
8+
9+
10+
def server() -> None:
11+
"""Simulate approval through MCP elicitation."""
12+
server_ = FastMCP()
13+
14+
@server_.tool(description="Tool to request approval")
15+
async def approval_tool() -> str:
16+
"""Simulated approval tool.
17+
18+
Returns:
19+
The elicitation result from the user.
20+
"""
21+
request = ElicitRequest(
22+
params=ElicitRequestParams(
23+
message="Do you approve",
24+
requestedSchema={
25+
"type": "object",
26+
"properties": {
27+
"message": {"type": "string", "description": "request message"},
28+
},
29+
"required": ["message"],
30+
},
31+
),
32+
)
33+
result = await server_.get_context().session.send_request(request, ElicitResult)
34+
35+
return result.model_dump_json()
36+
37+
server_.run(transport="stdio")
38+
39+
40+
if __name__ == "__main__":
41+
server()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import json
2+
3+
import pytest
4+
from mcp import StdioServerParameters, stdio_client
5+
from mcp.types import ElicitResult
6+
7+
from strands import Agent
8+
from strands.tools.mcp import MCPClient
9+
10+
11+
@pytest.fixture
12+
def callback():
13+
async def callback_(_, params):
14+
return ElicitResult(action="accept", content={"message": params.message})
15+
16+
return callback_
17+
18+
19+
@pytest.fixture
20+
def client(callback):
21+
return MCPClient(
22+
lambda: stdio_client(
23+
StdioServerParameters(command="python", args=["tests_integ/mcp/elicitation_server.py"]),
24+
),
25+
elicitation_callback=callback,
26+
)
27+
28+
29+
def test_mcp_elicitation(client):
30+
with client:
31+
tools = client.list_tools_sync()
32+
agent = Agent(tools=tools)
33+
34+
agent("Can you get approval")
35+
36+
tool_result = agent.messages[-2]
37+
38+
tru_result = json.loads(tool_result["content"][0]["toolResult"]["content"][0]["text"])
39+
exp_result = {"meta": None, "action": "accept", "content": {"message": "Do you approve"}}
40+
assert tru_result == exp_result

0 commit comments

Comments
 (0)