Skip to content

Commit 0a30b70

Browse files
authored
Merge branch 'main' into fix-ci-highest-resolution-testing
2 parents 98e86fd + 91ccdb3 commit 0a30b70

File tree

7 files changed

+716
-293
lines changed

7 files changed

+716
-293
lines changed

src/mcp/client/auth/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
Implements authorization code flow with PKCE and automatic token refresh.
55
"""
66

7+
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
78
from mcp.client.auth.oauth2 import (
89
OAuthClientProvider,
9-
OAuthFlowError,
10-
OAuthRegistrationError,
11-
OAuthTokenError,
1210
PKCEParameters,
1311
TokenStorage,
1412
)

src/mcp/client/auth/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class OAuthFlowError(Exception):
2+
"""Base exception for OAuth flow errors."""
3+
4+
5+
class OAuthTokenError(OAuthFlowError):
6+
"""Raised when token operations fail."""
7+
8+
9+
class OAuthRegistrationError(OAuthFlowError):
10+
"""Raised when client registration fails."""

src/mcp/client/auth/oauth2.py

Lines changed: 101 additions & 229 deletions
Large diffs are not rendered by default.

src/mcp/client/auth/utils.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import logging
2+
import re
3+
from urllib.parse import urljoin, urlparse
4+
5+
from httpx import Request, Response
6+
from pydantic import ValidationError
7+
8+
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
9+
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
10+
from mcp.shared.auth import (
11+
OAuthClientInformationFull,
12+
OAuthClientMetadata,
13+
OAuthMetadata,
14+
OAuthToken,
15+
ProtectedResourceMetadata,
16+
)
17+
from mcp.types import LATEST_PROTOCOL_VERSION
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
23+
"""
24+
Extract field from WWW-Authenticate header.
25+
26+
Returns:
27+
Field value if found in WWW-Authenticate header, None otherwise
28+
"""
29+
www_auth_header = response.headers.get("WWW-Authenticate")
30+
if not www_auth_header:
31+
return None
32+
33+
# Pattern matches: field_name="value" or field_name=value (unquoted)
34+
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
35+
match = re.search(pattern, www_auth_header)
36+
37+
if match:
38+
# Return quoted value if present, otherwise unquoted value
39+
return match.group(1) or match.group(2)
40+
41+
return None
42+
43+
44+
def extract_scope_from_www_auth(response: Response) -> str | None:
45+
"""
46+
Extract scope parameter from WWW-Authenticate header as per RFC6750.
47+
48+
Returns:
49+
Scope string if found in WWW-Authenticate header, None otherwise
50+
"""
51+
return extract_field_from_www_auth(response, "scope")
52+
53+
54+
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
55+
"""
56+
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
57+
58+
Returns:
59+
Resource metadata URL if found in WWW-Authenticate header, None otherwise
60+
"""
61+
if not response or response.status_code != 401:
62+
return None # pragma: no cover
63+
64+
return extract_field_from_www_auth(response, "resource_metadata")
65+
66+
67+
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
68+
"""
69+
Build ordered list of URLs to try for protected resource metadata discovery.
70+
71+
Per SEP-985, the client MUST:
72+
1. Try resource_metadata from WWW-Authenticate header (if present)
73+
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
74+
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
75+
76+
Args:
77+
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
78+
server_url: server url
79+
80+
Returns:
81+
Ordered list of URLs to try for discovery
82+
"""
83+
urls: list[str] = []
84+
85+
# Priority 1: WWW-Authenticate header with resource_metadata parameter
86+
if www_auth_url:
87+
urls.append(www_auth_url)
88+
89+
# Priority 2-3: Well-known URIs (RFC 9728)
90+
parsed = urlparse(server_url)
91+
base_url = f"{parsed.scheme}://{parsed.netloc}"
92+
93+
# Priority 2: Path-based well-known URI (if server has a path component)
94+
if parsed.path and parsed.path != "/":
95+
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
96+
urls.append(path_based_url)
97+
98+
# Priority 3: Root-based well-known URI
99+
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
100+
urls.append(root_based_url)
101+
102+
return urls
103+
104+
105+
def get_client_metadata_scopes(
106+
www_authenticate_scope: str | None,
107+
protected_resource_metadata: ProtectedResourceMetadata | None,
108+
authorization_server_metadata: OAuthMetadata | None = None,
109+
) -> str | None:
110+
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
111+
# Per MCP spec, scope selection priority order:
112+
# 1. Use scope from WWW-Authenticate header (if provided)
113+
# 2. Use all scopes from PRM scopes_supported (if available)
114+
# 3. Omit scope parameter if neither is available
115+
116+
if www_authenticate_scope is not None:
117+
# Priority 1: WWW-Authenticate header scope
118+
return www_authenticate_scope
119+
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
120+
# Priority 2: PRM scopes_supported
121+
return " ".join(protected_resource_metadata.scopes_supported)
122+
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
123+
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
124+
else:
125+
# Priority 3: Omit scope parameter
126+
return None
127+
128+
129+
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
130+
"""
131+
Generate ordered list of (url, type) tuples for discovery attempts.
132+
133+
Args:
134+
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
135+
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
136+
"""
137+
138+
if not auth_server_url:
139+
# Legacy path using the 2025-03-26 spec:
140+
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
141+
parsed = urlparse(server_url)
142+
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
143+
144+
urls: list[str] = []
145+
parsed = urlparse(auth_server_url)
146+
base_url = f"{parsed.scheme}://{parsed.netloc}"
147+
148+
# RFC 8414: Path-aware OAuth discovery
149+
if parsed.path and parsed.path != "/":
150+
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
151+
urls.append(urljoin(base_url, oauth_path))
152+
153+
# RFC 8414 section 5: Path-aware OIDC discovery
154+
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
155+
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
156+
urls.append(urljoin(base_url, oidc_path))
157+
158+
# https://openid.net/specs/openid-connect-discovery-1_0.html
159+
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
160+
urls.append(urljoin(base_url, oidc_path))
161+
return urls
162+
163+
# OAuth root
164+
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
165+
166+
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
167+
# https://openid.net/specs/openid-connect-discovery-1_0.html
168+
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
169+
170+
return urls
171+
172+
173+
async def handle_protected_resource_response(
174+
response: Response,
175+
) -> ProtectedResourceMetadata | None:
176+
"""
177+
Handle protected resource metadata discovery response.
178+
179+
Per SEP-985, supports fallback when discovery fails at one URL.
180+
181+
Returns:
182+
True if metadata was successfully discovered, False if we should try next URL
183+
"""
184+
if response.status_code == 200:
185+
try:
186+
content = await response.aread()
187+
metadata = ProtectedResourceMetadata.model_validate_json(content)
188+
return metadata
189+
190+
except ValidationError: # pragma: no cover
191+
# Invalid metadata - try next URL
192+
return None
193+
else:
194+
# Not found - try next URL in fallback chain
195+
return None
196+
197+
198+
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
199+
if response.status_code == 200:
200+
try:
201+
content = await response.aread()
202+
asm = OAuthMetadata.model_validate_json(content)
203+
return True, asm
204+
except ValidationError: # pragma: no cover
205+
return True, None
206+
elif response.status_code < 400 or response.status_code >= 500:
207+
return False, None # Non-4XX error, stop trying
208+
return True, None
209+
210+
211+
def create_oauth_metadata_request(url: str) -> Request:
212+
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
213+
214+
215+
def create_client_registration_request(
216+
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
217+
) -> Request:
218+
"""Build registration request or skip if already registered."""
219+
220+
if auth_server_metadata and auth_server_metadata.registration_endpoint:
221+
registration_url = str(auth_server_metadata.registration_endpoint)
222+
else:
223+
registration_url = urljoin(auth_base_url, "/register")
224+
225+
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
226+
227+
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
228+
229+
230+
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
231+
"""Handle registration response."""
232+
if response.status_code not in (200, 201):
233+
await response.aread()
234+
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
235+
236+
try:
237+
content = await response.aread()
238+
client_info = OAuthClientInformationFull.model_validate_json(content)
239+
return client_info
240+
# self.context.client_info = client_info
241+
# await self.context.storage.set_client_info(client_info)
242+
except ValidationError as e: # pragma: no cover
243+
raise OAuthRegistrationError(f"Invalid registration response: {e}")
244+
245+
246+
async def handle_token_response_scopes(
247+
response: Response,
248+
) -> OAuthToken:
249+
"""Parse and validate token response with optional scope validation.
250+
251+
Parses token response JSON. Callers should check response.status_code before calling.
252+
253+
Args:
254+
response: HTTP response from token endpoint (status already checked by caller)
255+
256+
Returns:
257+
Validated OAuthToken model
258+
259+
Raises:
260+
OAuthTokenError: If response JSON is invalid
261+
"""
262+
try:
263+
content = await response.aread()
264+
token_response = OAuthToken.model_validate_json(content)
265+
return token_response
266+
except ValidationError as e: # pragma: no cover
267+
raise OAuthTokenError(f"Invalid token response: {e}")

src/mcp/shared/auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class OAuthMetadata(BaseModel):
130130
introspection_endpoint_auth_methods_supported: list[str] | None = None
131131
introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None
132132
code_challenge_methods_supported: list[str] | None = None
133+
client_id_metadata_document_supported: bool | None = None
133134

134135

135136
class ProtectedResourceMetadata(BaseModel):

src/mcp/shared/auth_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707)."""
1+
"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636)."""
22

3+
import time
34
from urllib.parse import urlparse, urlsplit, urlunsplit
45

56
from pydantic import AnyUrl, HttpUrl
@@ -67,3 +68,18 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) ->
6768
configured_path += "/"
6869

6970
return requested_path.startswith(configured_path)
71+
72+
73+
def calculate_token_expiry(expires_in: int | str | None) -> float | None:
74+
"""Calculate token expiry timestamp from expires_in seconds.
75+
76+
Args:
77+
expires_in: Seconds until token expiration (may be string from some servers)
78+
79+
Returns:
80+
Unix timestamp when token expires, or None if no expiry specified
81+
"""
82+
if expires_in is None:
83+
return None # pragma: no cover
84+
# Defensive: handle servers that return expires_in as string
85+
return time.time() + int(expires_in)

0 commit comments

Comments
 (0)