1717from starlette .requests import Request
1818from starlette .responses import RedirectResponse
1919
20+ from fastapi_sso .pkce import get_pkce_challenge_pair
21+ from fastapi_sso .state import generate_random_state
22+
2023if sys .version_info >= (3 , 8 ):
2124 from typing import TypedDict
2225else :
@@ -63,6 +66,10 @@ class SSOBase:
6366 redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = NotImplemented
6467 scope : List [str ] = NotImplemented
6568 additional_headers : Optional [Dict [str , Any ]] = None
69+ uses_pkce : bool = False
70+ requires_state : bool = False
71+
72+ _pkce_challenge_length : int = 96
6673
6774 def __init__ (
6875 self ,
@@ -79,6 +86,7 @@ def __init__(
7986 self .redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = redirect_uri
8087 self .allow_insecure_http : bool = allow_insecure_http
8188 self ._oauth_client : Optional [WebApplicationClient ] = None
89+ self ._generated_state : Optional [str ] = None
8290
8391 if self .allow_insecure_http :
8492 os .environ ["OAUTHLIB_INSECURE_TRANSPORT" ] = "1"
@@ -96,6 +104,9 @@ def __init__(
96104 self ._refresh_token : Optional [str ] = None
97105 self ._id_token : Optional [str ] = None
98106 self ._state : Optional [str ] = None
107+ self ._pkce_code_challenge : Optional [str ] = None
108+ self ._pkce_code_verifier : Optional [str ] = None
109+ self ._pkce_challenge_method = "S256"
99110
100111 @property
101112 def state (self ) -> Optional [str ]:
@@ -236,8 +247,26 @@ async def get_login_url(
236247 redirect_uri = redirect_uri or self .redirect_uri
237248 if redirect_uri is None :
238249 raise ValueError ("redirect_uri must be provided, either at construction or request time" )
250+ if self .uses_pkce and not all ((self ._pkce_code_verifier , self ._pkce_code_challenge )):
251+ warnings .warn (
252+ f"{ self .__class__ .__name__ !r} uses PKCE and no code was generated yet. "
253+ "Use SSO class as a context manager to get rid of this warning and possible errors."
254+ )
255+ if self .requires_state and not state :
256+ if self ._generated_state is None :
257+ warnings .warn (
258+ f"{ self .__class__ .__name__ !r} requires state in the request but none was provided nor "
259+ "generated automatically. Use SSO as a context manager. The login process will most probably fail."
260+ )
261+ state = self ._generated_state
239262 request_uri = self .oauth_client .prepare_request_uri (
240- await self .authorization_endpoint , redirect_uri = redirect_uri , state = state , scope = self .scope , ** params
263+ await self .authorization_endpoint ,
264+ redirect_uri = redirect_uri ,
265+ state = state ,
266+ scope = self .scope ,
267+ code_challenge = self ._pkce_code_challenge ,
268+ code_challenge_method = self ._pkce_challenge_method ,
269+ ** params ,
241270 )
242271 return request_uri
243272
@@ -259,8 +288,12 @@ async def get_login_redirect(
259288 Returns:
260289 RedirectResponse: A Starlette response directing to the login page of the OAuth SSO provider.
261290 """
291+ if self .requires_state and not state :
292+ state = self ._generated_state
262293 login_uri = await self .get_login_url (redirect_uri = redirect_uri , params = params , state = state )
263294 response = RedirectResponse (login_uri , 303 )
295+ if self .uses_pkce :
296+ response .set_cookie ("pkce_code_verifier" , str (self ._pkce_code_verifier ))
264297 return response
265298
266299 async def verify_and_process (
@@ -291,14 +324,31 @@ async def verify_and_process(
291324 if code is None :
292325 raise SSOLoginError (400 , "'code' parameter was not found in callback request" )
293326 self ._state = request .query_params .get ("state" )
327+ pkce_code_verifier : Optional [str ] = None
328+ if self .uses_pkce :
329+ pkce_code_verifier = request .cookies .get ("pkce_code_verifier" )
330+ if pkce_code_verifier is None :
331+ warnings .warn (
332+ "PKCE code verifier was not found in the request Cookie. This will probably lead to a login error."
333+ )
294334 return await self .process_login (
295- code , request , params = params , additional_headers = headers , redirect_uri = redirect_uri
335+ code ,
336+ request ,
337+ params = params ,
338+ additional_headers = headers ,
339+ redirect_uri = redirect_uri ,
340+ pkce_code_verifier = pkce_code_verifier ,
296341 )
297342
298343 def __enter__ (self ) -> "SSOBase" :
299344 self ._oauth_client = None
300345 self ._refresh_token = None
301346 self ._id_token = None
347+ self ._state = None
348+ if self .requires_state :
349+ self ._generated_state = generate_random_state ()
350+ if self .uses_pkce :
351+ self ._pkce_code_verifier , self ._pkce_code_challenge = get_pkce_challenge_pair (self ._pkce_challenge_length )
302352 return self
303353
304354 def __exit__ (
@@ -321,6 +371,7 @@ async def process_login(
321371 params : Optional [Dict [str , Any ]] = None ,
322372 additional_headers : Optional [Dict [str , Any ]] = None ,
323373 redirect_uri : Optional [str ] = None ,
374+ pkce_code_verifier : Optional [str ] = None ,
324375 ) -> Optional [OpenID ]:
325376 """
326377 Processes login from the callback endpoint to verify the user and request user info endpoint.
@@ -332,6 +383,7 @@ async def process_login(
332383 params (Optional[Dict[str, Any]]): Additional query parameters to pass to the provider.
333384 additional_headers (Optional[Dict[str, Any]]): Additional headers to be added to all requests.
334385 redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
386+ pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
335387
336388 Raises:
337389 ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
@@ -379,8 +431,12 @@ async def process_login(
379431 headers .update (additional_headers )
380432
381433 auth = httpx .BasicAuth (self .client_id , self .client_secret )
434+
435+ if pkce_code_verifier :
436+ params .update ({"code_verifier" : pkce_code_verifier })
437+
382438 async with httpx .AsyncClient () as session :
383- response = await session .post (token_url , headers = headers , content = body , auth = auth )
439+ response = await session .post (token_url , headers = headers , content = body , auth = auth , params = params )
384440 content = response .json ()
385441 self ._refresh_token = content .get ("refresh_token" )
386442 self ._id_token = content .get ("id_token" )
0 commit comments