Skip to content

Commit 738f2c5

Browse files
authored
Merge pull request #53 from sacha-development-stuff/codex/fix-merge-conflicts-in-oauth2.py-and-test_auth.py
Fix OAuth2 merge conflicts in auth flow
2 parents f049650 + b0674ab commit 738f2c5

File tree

2 files changed

+37
-117
lines changed

2 files changed

+37
-117
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 30 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
handle_token_response_scopes,
3636
)
3737
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
38+
from mcp.types import LATEST_PROTOCOL_VERSION
3839
from mcp.shared.auth import (
3940
OAuthClientInformationFull,
4041
OAuthClientMetadata,
@@ -341,35 +342,11 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
341342
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
342343
return False
343344
else:
344-
#<<<<<<< main
345-
# Priority 3: Omit scope parameter
346-
self.context.client_metadata.scope = None
347-
348-
# Discovery and registration helpers provided by BaseOAuthProvider
349-
#=======
350345
# Other error - fail immediately
351346
raise OAuthFlowError(
352347
f"Protected Resource Metadata request failed: {response.status_code}"
353348
) # pragma: no cover
354349

355-
async def _register_client(self) -> httpx.Request | None:
356-
"""Build registration request or skip if already registered."""
357-
if self.context.client_info:
358-
return None
359-
360-
if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint:
361-
registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover
362-
else:
363-
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
364-
registration_url = urljoin(auth_base_url, "/register")
365-
366-
registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
367-
368-
return httpx.Request(
369-
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
370-
)
371-
#>>>>>>> main
372-
373350
async def _perform_authorization(self) -> httpx.Request:
374351
"""Perform the authorization flow."""
375352
auth_code, code_verifier = await self._perform_authorization_code_grant()
@@ -473,21 +450,10 @@ async def _exchange_token_authorization_code(
473450

474451
async def _handle_token_response(self, response: httpx.Response) -> None:
475452
"""Handle token exchange response."""
476-
#<<<<<<< main
477-
if response.status_code != 200: # pragma: no cover
478-
body = response.content or await response.aread()
479-
body = body.decode("utf-8")
480-
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")
481-
482-
try:
483-
content = response.content or await response.aread()
484-
token_response = OAuthToken.model_validate_json(content)
485-
#=======
486453
if response.status_code != 200:
487454
body = await response.aread() # pragma: no cover
488455
body_text = body.decode("utf-8") # pragma: no cover
489456
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
490-
#>>>>>>> main
491457

492458
# Parse and validate response with scope validation
493459
token_response = await handle_token_response_scopes(response)
@@ -557,14 +523,6 @@ def _add_auth_header(self, request: httpx.Request) -> None:
557523
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
558524
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
559525

560-
#<<<<<<< main
561-
#=======
562-
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
563-
content = await response.aread()
564-
metadata = OAuthMetadata.model_validate_json(content)
565-
self.context.oauth_metadata = metadata
566-
567-
#>>>>>>> main
568526
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
569527
"""HTTPX auth flow integration."""
570528
async with self.context.lock:
@@ -593,6 +551,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
593551
try:
594552
# OAuth flow must be inline due to generator constraints
595553
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
554+
www_auth_scope = extract_scope_from_www_auth(response)
555+
556+
# Reset discovery context before attempting new discovery sequence
557+
self.context.protected_resource_metadata = None
558+
self.context.auth_server_url = None
559+
self.context.oauth_metadata = None
560+
self._metadata = None
596561

597562
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
598563
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
@@ -601,84 +566,58 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
601566

602567
for url in prm_discovery_urls: # pragma: no branch
603568
discovery_request = create_oauth_metadata_request(url)
569+
discovery_response = yield discovery_request
604570

605-
discovery_response = yield discovery_request # sending request
606-
607-
#<<<<<<< main
608-
# Step 3: Discover OAuth metadata (with fallback for legacy servers)
609-
discovery_urls = self._get_discovery_urls(self.context.auth_server_url or self.context.server_url)
610-
for url in discovery_urls:
611-
oauth_metadata_request = self._create_oauth_metadata_request(url)
612-
oauth_metadata_response = yield oauth_metadata_request
613-
614-
if oauth_metadata_response.status_code == 200:
615-
try:
616-
await self._handle_oauth_metadata_response(oauth_metadata_response)
617-
self.context.oauth_metadata = self._metadata
618-
break
619-
except ValidationError: # pragma: no cover
620-
continue
621-
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
622-
break # Non-4XX error, stop trying
623-
624-
# Step 4: Register client if needed
625-
registration_request = self._create_registration_request(self._metadata)
626-
if registration_request:
627-
registration_response = yield registration_request
628-
await self._handle_registration_response(registration_response)
629-
self.context.client_info = self._client_info
630-
#=======
631571
prm = await handle_protected_resource_response(discovery_response)
632572
if prm:
633573
self.context.protected_resource_metadata = prm
634-
635-
# todo: try all authorization_servers to find the OASM
636-
assert (
637-
len(prm.authorization_servers) > 0
638-
) # this is always true as authorization_servers has a min length of 1
639-
640-
self.context.auth_server_url = str(prm.authorization_servers[0])
574+
if prm.authorization_servers: # pragma: no branch
575+
self.context.auth_server_url = str(prm.authorization_servers[0])
641576
break
642-
else:
643-
logger.debug(f"Protected resource metadata discovery failed: {url}")
644577

578+
logger.debug(f"Protected resource metadata discovery failed: {url}")
579+
580+
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
645581
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
646582
self.context.auth_server_url, self.context.server_url
647583
)
648584

649-
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
650-
for url in asm_discovery_urls: # pragma: no cover
585+
authorization_metadata: OAuthMetadata | None = None
586+
for url in asm_discovery_urls: # pragma: no branch
651587
oauth_metadata_request = create_oauth_metadata_request(url)
652588
oauth_metadata_response = yield oauth_metadata_request
653589

654590
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
655591
if not ok:
656592
break
657-
if ok and asm:
658-
self.context.oauth_metadata = asm
593+
if asm:
594+
authorization_metadata = asm
659595
break
660-
else:
661-
logger.debug(f"OAuth metadata discovery failed: {url}")
596+
597+
logger.debug(f"OAuth metadata discovery failed: {url}")
598+
599+
if authorization_metadata:
600+
self.context.oauth_metadata = authorization_metadata
601+
self._metadata = authorization_metadata
662602

663603
# Step 3: Apply scope selection strategy
664604
self.context.client_metadata.scope = get_client_metadata_scopes(
665-
www_auth_resource_metadata_url,
605+
www_auth_scope,
666606
self.context.protected_resource_metadata,
667607
self.context.oauth_metadata,
668608
)
669609

670610
# Step 4: Register client if needed
671-
registration_request = create_client_registration_request(
672-
self.context.oauth_metadata,
673-
self.context.client_metadata,
674-
self.context.get_authorization_base_url(self.context.server_url),
675-
)
676611
if not self.context.client_info:
612+
registration_request = create_client_registration_request(
613+
self.context.oauth_metadata,
614+
self.context.client_metadata,
615+
self.context.get_authorization_base_url(self.context.server_url),
616+
)
677617
registration_response = yield registration_request
678618
client_information = await handle_registration_response(registration_response)
679619
self.context.client_info = client_information
680620
await self.context.storage.set_client_info(client_information)
681-
#>>>>>>> main
682621

683622
# Step 5: Perform authorization and complete token exchange
684623
token_response = yield await self._perform_authorization()

tests/client/test_auth.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,12 @@
1313
from inline_snapshot import Is, snapshot
1414
from pydantic import AnyHttpUrl, AnyUrl
1515

16-
#<<<<<<< main
1716
from mcp.client.auth import (
1817
ClientCredentialsProvider,
1918
OAuthClientProvider,
2019
PKCEParameters,
2120
TokenExchangeProvider,
2221
)
23-
from mcp.shared.auth import (
24-
OAuthClientInformationFull,
25-
OAuthClientMetadata,
26-
OAuthMetadata,
27-
OAuthToken,
28-
ProtectedResourceMetadata,
29-
)
30-
#=======
31-
from mcp.client.auth import OAuthClientProvider, PKCEParameters
3222
from mcp.client.auth.utils import (
3323
build_oauth_authorization_server_metadata_discovery_urls,
3424
build_protected_resource_metadata_discovery_urls,
@@ -39,8 +29,13 @@
3929
get_client_metadata_scopes,
4030
handle_registration_response,
4131
)
42-
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata
43-
#>>>>>>> main
32+
from mcp.shared.auth import (
33+
OAuthClientInformationFull,
34+
OAuthClientMetadata,
35+
OAuthMetadata,
36+
OAuthToken,
37+
ProtectedResourceMetadata,
38+
)
4439

4540

4641
class MockTokenStorage:
@@ -556,23 +551,9 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
556551
return_value=("test_auth_code", "test_code_verifier")
557552
)
558553

559-
#<<<<<<< main
560-
# Next request should fall back to legacy behavior: register then obtain token
561-
registration_request = await auth_flow.asend(oauth_metadata_response_3)
562-
assert str(registration_request.url) == "https://api.example.com/register"
563-
assert registration_request.method == "POST"
564-
565-
registration_response = httpx.Response(
566-
200,
567-
content=b'{"client_id":"c","redirect_uris":["http://localhost:3030/callback"]}',
568-
request=registration_request,
569-
)
570-
token_request = await auth_flow.asend(registration_response)
571-
#=======
572554
# All path-based URLs failed, flow continues with default endpoints
573555
# Next request should be token exchange using MCP server base URL (fallback when OAuth metadata not found)
574556
token_request = await auth_flow.asend(oauth_metadata_response_3)
575-
#>>>>>>> main
576557
assert str(token_request.url) == "https://api.example.com/token"
577558
assert token_request.method == "POST"
578559

0 commit comments

Comments
 (0)