77import sys
88import warnings
99from types import TracebackType
10- from typing import Any , ClassVar , Dict , List , Literal , Optional , Type , TypedDict , TypeVar , Union , overload
10+ from typing import Any , ClassVar , Literal , Optional , TypedDict , TypeVar , Union , overload
1111
1212import httpx
13+ import jwt
1314import pydantic
1415from oauthlib .oauth2 import WebApplicationClient
1516from starlette .exceptions import HTTPException
3334P = ParamSpec ("P" )
3435
3536
37+ def _decode_id_token (id_token : str , verify : bool = False ) -> dict :
38+ return jwt .decode (id_token , options = {"verify_signature" : verify })
39+
40+
3641class DiscoveryDocument (TypedDict ):
3742 """Discovery document."""
3843
@@ -95,10 +100,11 @@ class SSOBase:
95100 client_id : str = NotImplemented
96101 client_secret : str = NotImplemented
97102 redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = NotImplemented
98- scope : ClassVar [List [str ]] = []
99- additional_headers : ClassVar [Optional [Dict [str , Any ]]] = None
103+ scope : ClassVar [list [str ]] = []
104+ additional_headers : ClassVar [Optional [dict [str , Any ]]] = None
100105 uses_pkce : bool = False
101106 requires_state : bool = False
107+ use_id_token_for_user_info : ClassVar [bool ] = False
102108
103109 _pkce_challenge_length : int = 96
104110
@@ -109,7 +115,7 @@ def __init__(
109115 redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = None ,
110116 allow_insecure_http : bool = False ,
111117 use_state : bool = False ,
112- scope : Optional [List [str ]] = None ,
118+ scope : Optional [list [str ]] = None ,
113119 ):
114120 """Base class (mixin) for all SSO providers."""
115121 self .client_id : str = client_id
@@ -224,6 +230,18 @@ async def openid_from_response(self, response: dict, session: Optional[httpx.Asy
224230 """
225231 raise NotImplementedError (f"Provider { self .provider } not supported" )
226232
233+ async def openid_from_token (self , id_token : dict , session : Optional [httpx .AsyncClient ] = None ) -> OpenID :
234+ """Converts an ID token from the provider's token endpoint to an OpenID object.
235+
236+ Args:
237+ id_token (dict): The id token data retrieved from the token endpoint.
238+ session: (Optional[httpx.AsyncClient]): The HTTPX AsyncClient session.
239+
240+ Returns:
241+ OpenID: The user information in a standardized format.
242+ """
243+ raise NotImplementedError (f"Provider { self .provider } not supported" )
244+
227245 async def get_discovery_document (self ) -> DiscoveryDocument :
228246 """Retrieves the discovery document containing useful URLs.
229247
@@ -257,14 +275,14 @@ async def get_login_url(
257275 self ,
258276 * ,
259277 redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = None ,
260- params : Optional [Dict [str , Any ]] = None ,
278+ params : Optional [dict [str , Any ]] = None ,
261279 state : Optional [str ] = None ,
262280 ) -> str :
263281 """Generates and returns the prepared login URL.
264282
265283 Args:
266284 redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
267- params (Optional[Dict [str, Any]]): Additional query parameters to add to the login request.
285+ params (Optional[dict [str, Any]]): Additional query parameters to add to the login request.
268286 state (Optional[str]): The state parameter for the OAuth 2.0 authorization request.
269287
270288 Raises:
@@ -304,14 +322,14 @@ async def get_login_redirect(
304322 self ,
305323 * ,
306324 redirect_uri : Optional [str ] = None ,
307- params : Optional [Dict [str , Any ]] = None ,
325+ params : Optional [dict [str , Any ]] = None ,
308326 state : Optional [str ] = None ,
309327 ) -> RedirectResponse :
310328 """Constructs and returns a redirect response to the login page of OAuth SSO provider.
311329
312330 Args:
313331 redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
314- params (Optional[Dict [str, Any]]): Additional query parameters to add to the login request.
332+ params (Optional[dict [str, Any]]): Additional query parameters to add to the login request.
315333 state (Optional[str]): The state parameter for the OAuth 2.0 authorization request.
316334
317335 Returns:
@@ -330,8 +348,8 @@ async def verify_and_process(
330348 self ,
331349 request : Request ,
332350 * ,
333- params : Optional [Dict [str , Any ]] = None ,
334- headers : Optional [Dict [str , Any ]] = None ,
351+ params : Optional [dict [str , Any ]] = None ,
352+ headers : Optional [dict [str , Any ]] = None ,
335353 redirect_uri : Optional [str ] = None ,
336354 convert_response : Literal [True ] = True ,
337355 ) -> Optional [OpenID ]: ...
@@ -341,28 +359,28 @@ async def verify_and_process(
341359 self ,
342360 request : Request ,
343361 * ,
344- params : Optional [Dict [str , Any ]] = None ,
345- headers : Optional [Dict [str , Any ]] = None ,
362+ params : Optional [dict [str , Any ]] = None ,
363+ headers : Optional [dict [str , Any ]] = None ,
346364 redirect_uri : Optional [str ] = None ,
347365 convert_response : Literal [False ],
348- ) -> Optional [Dict [str , Any ]]: ...
366+ ) -> Optional [dict [str , Any ]]: ...
349367
350368 @requires_async_context
351369 async def verify_and_process (
352370 self ,
353371 request : Request ,
354372 * ,
355- params : Optional [Dict [str , Any ]] = None ,
356- headers : Optional [Dict [str , Any ]] = None ,
373+ params : Optional [dict [str , Any ]] = None ,
374+ headers : Optional [dict [str , Any ]] = None ,
357375 redirect_uri : Optional [str ] = None ,
358376 convert_response : Union [Literal [True ], Literal [False ]] = True ,
359- ) -> Union [Optional [OpenID ], Optional [Dict [str , Any ]]]:
377+ ) -> Union [Optional [OpenID ], Optional [dict [str , Any ]]]:
360378 """Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
361379
362380 Args:
363381 request (Request): FastAPI or Starlette request object.
364- params (Optional[Dict [str, Any]]): Additional query parameters to pass to the provider.
365- headers (Optional[Dict [str, Any]]): Additional headers to pass to the provider.
382+ params (Optional[dict [str, Any]]): Additional query parameters to pass to the provider.
383+ headers (Optional[dict [str, Any]]): Additional headers to pass to the provider.
366384 redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
367385 convert_response (bool): If True, userinfo response is converted to OpenID object.
368386
@@ -371,7 +389,7 @@ async def verify_and_process(
371389
372390 Returns:
373391 Optional[OpenID]: User information as OpenID instance (if convert_response == True)
374- Optional[Dict [str, Any]]: The original JSON response from the API.
392+ Optional[dict [str, Any]]: The original JSON response from the API.
375393 """
376394 headers = headers or {}
377395 code = request .query_params .get ("code" )
@@ -433,7 +451,7 @@ async def __aenter__(self) -> "SSOBase":
433451
434452 async def __aexit__ (
435453 self ,
436- _exc_type : Optional [Type [BaseException ]],
454+ _exc_type : Optional [type [BaseException ]],
437455 _exc_val : Optional [BaseException ],
438456 _exc_tb : Optional [TracebackType ],
439457 ) -> None :
@@ -442,14 +460,14 @@ async def __aexit__(
442460
443461 def __exit__ (
444462 self ,
445- _exc_type : Optional [Type [BaseException ]],
463+ _exc_type : Optional [type [BaseException ]],
446464 _exc_val : Optional [BaseException ],
447465 _exc_tb : Optional [TracebackType ],
448466 ) -> None :
449467 return None
450468
451469 @property
452- def _extra_query_params (self ) -> Dict :
470+ def _extra_query_params (self ) -> dict :
453471 return {}
454472
455473 @overload
@@ -458,8 +476,8 @@ async def process_login(
458476 code : str ,
459477 request : Request ,
460478 * ,
461- params : Optional [Dict [str , Any ]] = None ,
462- additional_headers : Optional [Dict [str , Any ]] = None ,
479+ params : Optional [dict [str , Any ]] = None ,
480+ additional_headers : Optional [dict [str , Any ]] = None ,
463481 redirect_uri : Optional [str ] = None ,
464482 pkce_code_verifier : Optional [str ] = None ,
465483 convert_response : Literal [True ] = True ,
@@ -471,33 +489,33 @@ async def process_login(
471489 code : str ,
472490 request : Request ,
473491 * ,
474- params : Optional [Dict [str , Any ]] = None ,
475- additional_headers : Optional [Dict [str , Any ]] = None ,
492+ params : Optional [dict [str , Any ]] = None ,
493+ additional_headers : Optional [dict [str , Any ]] = None ,
476494 redirect_uri : Optional [str ] = None ,
477495 pkce_code_verifier : Optional [str ] = None ,
478496 convert_response : Literal [False ],
479- ) -> Optional [Dict [str , Any ]]: ...
497+ ) -> Optional [dict [str , Any ]]: ...
480498
481499 @requires_async_context
482500 async def process_login (
483501 self ,
484502 code : str ,
485503 request : Request ,
486504 * ,
487- params : Optional [Dict [str , Any ]] = None ,
488- additional_headers : Optional [Dict [str , Any ]] = None ,
505+ params : Optional [dict [str , Any ]] = None ,
506+ additional_headers : Optional [dict [str , Any ]] = None ,
489507 redirect_uri : Optional [str ] = None ,
490508 pkce_code_verifier : Optional [str ] = None ,
491509 convert_response : Union [Literal [True ], Literal [False ]] = True ,
492- ) -> Union [Optional [OpenID ], Optional [Dict [str , Any ]]]:
510+ ) -> Union [Optional [OpenID ], Optional [dict [str , Any ]]]:
493511 """Processes login from the callback endpoint to verify the user and request user info endpoint.
494512 It's a lower-level method, typically, you should use `verify_and_process` instead.
495513
496514 Args:
497515 code (str): The authorization code.
498516 request (Request): FastAPI or Starlette request object.
499- params (Optional[Dict [str, Any]]): Additional query parameters to pass to the provider.
500- additional_headers (Optional[Dict [str, Any]]): Additional headers to be added to all requests.
517+ params (Optional[dict [str, Any]]): Additional query parameters to pass to the provider.
518+ additional_headers (Optional[dict [str, Any]]): Additional headers to be added to all requests.
501519 redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
502520 pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
503521 convert_response (bool): If True, userinfo response is converted to OpenID object.
@@ -507,7 +525,7 @@ async def process_login(
507525
508526 Returns:
509527 Optional[OpenID]: User information in OpenID format if the login was successful (convert_response == True).
510- Optional[Dict [str, Any]]: Original userinfo API endpoint response.
528+ Optional[dict [str, Any]]: Original userinfo API endpoint response.
511529 """
512530 if self ._oauth_client is not None : # pragma: no cover
513531 self ._oauth_client = None
@@ -565,5 +583,9 @@ async def process_login(
565583 response = await session .get (uri )
566584 content = response .json ()
567585 if convert_response :
586+ if self .use_id_token_for_user_info :
587+ if not self ._id_token :
588+ raise SSOLoginError (401 , f"Provider { self .provider !r} did not return id token." )
589+ return await self .openid_from_token (_decode_id_token (self ._id_token ), session )
568590 return await self .openid_from_response (content , session )
569591 return content
0 commit comments