Skip to content

Commit fd23647

Browse files
authored
feat: add TwitterSSO to support Twitter (X) login (#139)
* feat: add TwitterSSO to support Twitter (X) login
1 parent 2857ab7 commit fd23647

21 files changed

+347
-51
lines changed

docs/generate_reference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING
55

66
if TYPE_CHECKING:
7-
import mkdocs.config.defaults
7+
import mkdocs.config.defaults # pragma: no cover
88

99

1010
SKIPPED_MODULES = ("fastapi_sso.sso", "fastapi_sso")

examples/twitter.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Twitter (X) Login Example
2+
"""
3+
4+
import os
5+
import uvicorn
6+
from fastapi import FastAPI, Request
7+
from fastapi_sso.sso.twitter import TwitterSSO
8+
9+
CLIENT_ID = os.environ["CLIENT_ID"]
10+
CLIENT_SECRET = os.environ["CLIENT_SECRET"]
11+
12+
app = FastAPI()
13+
14+
sso = TwitterSSO(
15+
client_id=CLIENT_ID,
16+
client_secret=CLIENT_SECRET,
17+
redirect_uri="http://127.0.0.1:5000/auth/callback",
18+
allow_insecure_http=True,
19+
)
20+
21+
22+
@app.get("/auth/login")
23+
async def auth_init():
24+
"""Initialize auth and redirect"""
25+
with sso:
26+
return await sso.get_login_redirect()
27+
28+
29+
@app.get("/auth/callback")
30+
async def auth_callback(request: Request):
31+
"""Verify login"""
32+
with sso:
33+
user = await sso.verify_and_process(request)
34+
return user
35+
36+
37+
if __name__ == "__main__":
38+
uvicorn.run(app="examples.twitter:app", host="127.0.0.1", port=5000)

fastapi_sso/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from .sso.naver import NaverSSO
1717
from .sso.notion import NotionSSO
1818
from .sso.spotify import SpotifySSO
19+
from .sso.twitter import TwitterSSO

fastapi_sso/pkce.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""PKCE-related helper functions"""
2+
3+
import base64
4+
import hashlib
5+
import os
6+
from typing import Tuple
7+
8+
9+
def get_code_verifier(length: int = 96) -> str:
10+
"""Get code verifier for PKCE challenge"""
11+
length = max(43, min(length, 128))
12+
bytes_length = int(length * 3 / 4)
13+
return base64.urlsafe_b64encode(os.urandom(bytes_length)).decode("utf-8").replace("=", "")[:length]
14+
15+
16+
def get_pkce_challenge_pair(verifier_length: int = 96) -> Tuple[str, str]:
17+
"""Get tuple of (verifier, challenge) for PKCE challenge."""
18+
code_verifier = get_code_verifier(verifier_length)
19+
code_challenge = (
20+
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
21+
.decode("utf-8")
22+
.replace("=", "")
23+
)
24+
25+
return (code_verifier, code_challenge)

fastapi_sso/sso/base.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from starlette.requests import Request
1818
from starlette.responses import RedirectResponse
1919

20+
from fastapi_sso.pkce import get_pkce_challenge_pair
21+
from fastapi_sso.state import generate_random_state
22+
2023
if sys.version_info >= (3, 8):
2124
from typing import TypedDict
2225
else:
@@ -63,6 +66,10 @@ class SSOBase:
6366
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = NotImplemented
6467
scope: List[str] = NotImplemented
6568
additional_headers: Optional[Dict[str, Any]] = None
69+
uses_pkce: bool = False
70+
requires_state: bool = False
71+
72+
_pkce_challenge_length: int = 96
6673

6774
def __init__(
6875
self,
@@ -79,6 +86,7 @@ def __init__(
7986
self.redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = redirect_uri
8087
self.allow_insecure_http: bool = allow_insecure_http
8188
self._oauth_client: Optional[WebApplicationClient] = None
89+
self._generated_state: Optional[str] = None
8290

8391
if self.allow_insecure_http:
8492
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
@@ -96,6 +104,9 @@ def __init__(
96104
self._refresh_token: Optional[str] = None
97105
self._id_token: Optional[str] = None
98106
self._state: Optional[str] = None
107+
self._pkce_code_challenge: Optional[str] = None
108+
self._pkce_code_verifier: Optional[str] = None
109+
self._pkce_challenge_method = "S256"
99110

100111
@property
101112
def state(self) -> Optional[str]:
@@ -236,8 +247,26 @@ async def get_login_url(
236247
redirect_uri = redirect_uri or self.redirect_uri
237248
if redirect_uri is None:
238249
raise ValueError("redirect_uri must be provided, either at construction or request time")
250+
if self.uses_pkce and not all((self._pkce_code_verifier, self._pkce_code_challenge)):
251+
warnings.warn(
252+
f"{self.__class__.__name__!r} uses PKCE and no code was generated yet. "
253+
"Use SSO class as a context manager to get rid of this warning and possible errors."
254+
)
255+
if self.requires_state and not state:
256+
if self._generated_state is None:
257+
warnings.warn(
258+
f"{self.__class__.__name__!r} requires state in the request but none was provided nor "
259+
"generated automatically. Use SSO as a context manager. The login process will most probably fail."
260+
)
261+
state = self._generated_state
239262
request_uri = self.oauth_client.prepare_request_uri(
240-
await self.authorization_endpoint, redirect_uri=redirect_uri, state=state, scope=self.scope, **params
263+
await self.authorization_endpoint,
264+
redirect_uri=redirect_uri,
265+
state=state,
266+
scope=self.scope,
267+
code_challenge=self._pkce_code_challenge,
268+
code_challenge_method=self._pkce_challenge_method,
269+
**params,
241270
)
242271
return request_uri
243272

@@ -259,8 +288,12 @@ async def get_login_redirect(
259288
Returns:
260289
RedirectResponse: A Starlette response directing to the login page of the OAuth SSO provider.
261290
"""
291+
if self.requires_state and not state:
292+
state = self._generated_state
262293
login_uri = await self.get_login_url(redirect_uri=redirect_uri, params=params, state=state)
263294
response = RedirectResponse(login_uri, 303)
295+
if self.uses_pkce:
296+
response.set_cookie("pkce_code_verifier", str(self._pkce_code_verifier))
264297
return response
265298

266299
async def verify_and_process(
@@ -291,14 +324,31 @@ async def verify_and_process(
291324
if code is None:
292325
raise SSOLoginError(400, "'code' parameter was not found in callback request")
293326
self._state = request.query_params.get("state")
327+
pkce_code_verifier: Optional[str] = None
328+
if self.uses_pkce:
329+
pkce_code_verifier = request.cookies.get("pkce_code_verifier")
330+
if pkce_code_verifier is None:
331+
warnings.warn(
332+
"PKCE code verifier was not found in the request Cookie. This will probably lead to a login error."
333+
)
294334
return await self.process_login(
295-
code, request, params=params, additional_headers=headers, redirect_uri=redirect_uri
335+
code,
336+
request,
337+
params=params,
338+
additional_headers=headers,
339+
redirect_uri=redirect_uri,
340+
pkce_code_verifier=pkce_code_verifier,
296341
)
297342

298343
def __enter__(self) -> "SSOBase":
299344
self._oauth_client = None
300345
self._refresh_token = None
301346
self._id_token = None
347+
self._state = None
348+
if self.requires_state:
349+
self._generated_state = generate_random_state()
350+
if self.uses_pkce:
351+
self._pkce_code_verifier, self._pkce_code_challenge = get_pkce_challenge_pair(self._pkce_challenge_length)
302352
return self
303353

304354
def __exit__(
@@ -321,6 +371,7 @@ async def process_login(
321371
params: Optional[Dict[str, Any]] = None,
322372
additional_headers: Optional[Dict[str, Any]] = None,
323373
redirect_uri: Optional[str] = None,
374+
pkce_code_verifier: Optional[str] = None,
324375
) -> Optional[OpenID]:
325376
"""
326377
Processes login from the callback endpoint to verify the user and request user info endpoint.
@@ -332,6 +383,7 @@ async def process_login(
332383
params (Optional[Dict[str, Any]]): Additional query parameters to pass to the provider.
333384
additional_headers (Optional[Dict[str, Any]]): Additional headers to be added to all requests.
334385
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
386+
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
335387
336388
Raises:
337389
ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
@@ -379,8 +431,12 @@ async def process_login(
379431
headers.update(additional_headers)
380432

381433
auth = httpx.BasicAuth(self.client_id, self.client_secret)
434+
435+
if pkce_code_verifier:
436+
params.update({"code_verifier": pkce_code_verifier})
437+
382438
async with httpx.AsyncClient() as session:
383-
response = await session.post(token_url, headers=headers, content=body, auth=auth)
439+
response = await session.post(token_url, headers=headers, content=body, auth=auth, params=params)
384440
content = response.json()
385441
self._refresh_token = content.get("refresh_token")
386442
self._id_token = content.get("id_token")

fastapi_sso/sso/facebook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase
77

88
if TYPE_CHECKING:
9-
import httpx
9+
import httpx # pragma: no cover
1010

1111

1212
class FacebookSSO(SSOBase):

fastapi_sso/sso/fitbit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase, SSOLoginError
77

88
if TYPE_CHECKING:
9-
import httpx
9+
import httpx # pragma: no cover
1010

1111

1212
class FitbitSSO(SSOBase):

fastapi_sso/sso/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase
99

1010
if TYPE_CHECKING:
11-
import httpx
11+
import httpx # pragma: no cover
1212

1313
logger = logging.getLogger(__name__)
1414

fastapi_sso/sso/github.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase
66

77
if TYPE_CHECKING:
8-
import httpx
8+
import httpx # pragma: no cover
99

1010

1111
class GithubSSO(SSOBase):

fastapi_sso/sso/gitlab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase
66

77
if TYPE_CHECKING:
8-
import httpx
8+
import httpx # pragma: no cover
99

1010

1111
class GitlabSSO(SSOBase):

0 commit comments

Comments
 (0)