|
20 | 20 |
|
21 | 21 | import anyio |
22 | 22 | from mcp import ClientSession, ListToolsResult |
| 23 | +from mcp.client.session import ElicitationFnT |
23 | 24 | from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents |
24 | 25 | from mcp.types import CallToolResult as MCPCallToolResult |
25 | 26 | from mcp.types import EmbeddedResource as MCPEmbeddedResource |
@@ -98,19 +99,22 @@ def __init__( |
98 | 99 | startup_timeout: int = 30, |
99 | 100 | tool_filters: ToolFilters | None = None, |
100 | 101 | prefix: str | None = None, |
101 | | - ): |
| 102 | + elicitation_callback: Optional[ElicitationFnT] = None, |
| 103 | + ) -> None: |
102 | 104 | """Initialize a new MCP Server connection. |
103 | 105 |
|
104 | 106 | Args: |
105 | | - transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple |
106 | | - startup_timeout: Timeout after which MCP server initialization should be cancelled |
| 107 | + transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple. |
| 108 | + startup_timeout: Timeout after which MCP server initialization should be cancelled. |
107 | 109 | Defaults to 30. |
108 | 110 | tool_filters: Optional filters to apply to tools. |
109 | 111 | prefix: Optional prefix for tool names. |
| 112 | + elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. |
110 | 113 | """ |
111 | 114 | self._startup_timeout = startup_timeout |
112 | 115 | self._tool_filters = tool_filters |
113 | 116 | self._prefix = prefix |
| 117 | + self._elicitation_callback = elicitation_callback |
114 | 118 |
|
115 | 119 | mcp_instrumentation() |
116 | 120 | self._session_id = uuid.uuid4() |
@@ -563,7 +567,10 @@ async def _async_background_thread(self) -> None: |
563 | 567 | async with self._transport_callable() as (read_stream, write_stream, *_): |
564 | 568 | self._log_debug_with_thread("transport connection established") |
565 | 569 | async with ClientSession( |
566 | | - read_stream, write_stream, message_handler=self._handle_error_message |
| 570 | + read_stream, |
| 571 | + write_stream, |
| 572 | + message_handler=self._handle_error_message, |
| 573 | + elicitation_callback=self._elicitation_callback, |
567 | 574 | ) as session: |
568 | 575 | self._log_debug_with_thread("initializing MCP session") |
569 | 576 | await session.initialize() |
|
0 commit comments