Skip to content

Commit 13e3374

Browse files
committed
✨ Check if token is expired when checking if user is logged in
1 parent 0296882 commit 13e3374

File tree

5 files changed

+198
-6
lines changed

5 files changed

+198
-6
lines changed

src/fastapi_cloud_cli/commands/deploy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def _waitlist_form(toolkit: RichToolkit) -> None:
496496

497497
with contextlib.suppress(Exception):
498498
subprocess.run(
499-
["open", "raycast://confetti"],
499+
["open", "raycast://confetti?emojis=🐔⚡"],
500500
stdout=subprocess.DEVNULL,
501501
stderr=subprocess.DEVNULL,
502502
check=False,

src/fastapi_cloud_cli/utils/auth.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import base64
2+
import binascii
3+
import json
14
import logging
5+
import time
26
from typing import Optional
37

48
from pydantic import BaseModel
@@ -55,7 +59,64 @@ def get_auth_token() -> Optional[str]:
5559
return auth_data.access_token
5660

5761

62+
def is_token_expired(token: str) -> bool:
63+
try:
64+
parts = token.split(".")
65+
66+
if len(parts) != 3:
67+
logger.debug("Invalid JWT format: expected 3 parts, got %d", len(parts))
68+
return True
69+
70+
payload = parts[1]
71+
72+
# Add padding if needed (JWT uses base64url encoding without padding)
73+
if padding := len(payload) % 4:
74+
payload += "=" * (4 - padding)
75+
76+
payload = payload.replace("-", "+").replace("_", "/")
77+
decoded_bytes = base64.b64decode(payload)
78+
payload_data = json.loads(decoded_bytes)
79+
80+
exp = payload_data.get("exp")
81+
82+
if exp is None:
83+
logger.debug("No 'exp' claim found in token")
84+
85+
return False
86+
87+
if not isinstance(exp, int):
88+
logger.debug("Invalid 'exp' claim: expected int, got %s", type(exp))
89+
90+
return True
91+
92+
current_time = time.time()
93+
94+
is_expired = current_time >= exp
95+
96+
logger.debug(
97+
"Token expiration check: current=%d, exp=%d, expired=%s",
98+
current_time,
99+
exp,
100+
is_expired,
101+
)
102+
103+
return is_expired
104+
except (binascii.Error, json.JSONDecodeError) as e:
105+
logger.debug("Error parsing JWT token: %s", e)
106+
107+
return True
108+
109+
58110
def is_logged_in() -> bool:
59-
result = get_auth_token() is not None
60-
logger.debug("Login status: %s", result)
61-
return result
111+
token = get_auth_token()
112+
113+
if token is None:
114+
logger.debug("Login status: False (no token)")
115+
return False
116+
117+
if is_token_expired(token):
118+
logger.debug("Login status: False (token expired)")
119+
return False
120+
121+
logger.debug("Login status: True")
122+
return True

tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import pytest
88
from typer import rich_utils
99

10+
from .utils import create_jwt_token
11+
1012

1113
@pytest.fixture(autouse=True)
1214
def reset_syspath() -> Generator[None, None, None]:
@@ -26,7 +28,9 @@ def setup_terminal() -> None:
2628

2729
@pytest.fixture
2830
def logged_in_cli(temp_auth_config: Path) -> Generator[None, None, None]:
29-
temp_auth_config.write_text('{"access_token": "test_token_12345"}')
31+
valid_token = create_jwt_token({"sub": "test_user_12345"})
32+
33+
temp_auth_config.write_text(f'{{"access_token": "{valid_token}"}}')
3034

3135
yield
3236

tests/test_auth.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import base64
2+
import time
3+
from pathlib import Path
4+
5+
import pytest
6+
7+
from fastapi_cloud_cli.utils.auth import (
8+
AuthConfig,
9+
is_logged_in,
10+
is_token_expired,
11+
write_auth_config,
12+
)
13+
14+
from .utils import create_jwt_token
15+
16+
17+
def test_is_token_expired_with_valid_token() -> None:
18+
future_exp = int(time.time()) + 3600
19+
20+
token = create_jwt_token({"exp": future_exp, "sub": "test_user"})
21+
22+
assert not is_token_expired(token)
23+
24+
25+
def test_is_token_expired_with_expired_token() -> None:
26+
past_exp = int(time.time()) - 3600
27+
token = create_jwt_token({"exp": past_exp, "sub": "test_user"})
28+
29+
assert is_token_expired(token)
30+
31+
32+
def test_is_token_expired_with_no_exp_claim() -> None:
33+
token = create_jwt_token({"sub": "test_user"})
34+
35+
# Tokens without exp claim should be considered valid
36+
assert not is_token_expired(token)
37+
38+
39+
@pytest.mark.parametrize(
40+
"token",
41+
[
42+
"not.a.valid.jwt.token",
43+
"only.two",
44+
"invalid",
45+
"",
46+
"...",
47+
],
48+
)
49+
def test_is_token_expired_with_malformed_token(token: str) -> None:
50+
assert is_token_expired(token)
51+
52+
53+
def test_is_token_expired_with_invalid_base64() -> None:
54+
token = "header.!!!invalid_signature!!!.signature"
55+
assert is_token_expired(token)
56+
57+
58+
def test_is_token_expired_with_invalid_json() -> None:
59+
header_encoded = base64.urlsafe_b64encode(b'{"alg":"HS256"}').decode().rstrip("=")
60+
payload_encoded = base64.urlsafe_b64encode(b"{invalid json}").decode().rstrip("=")
61+
signature = base64.urlsafe_b64encode(b"signature").decode().rstrip("=")
62+
token = f"{header_encoded}.{payload_encoded}.{signature}"
63+
64+
assert is_token_expired(token)
65+
66+
67+
def test_is_logged_in_with_no_token(temp_auth_config: Path) -> None:
68+
assert not temp_auth_config.exists()
69+
assert not is_logged_in()
70+
71+
72+
def test_is_logged_in_with_valid_token(temp_auth_config: Path) -> None:
73+
future_exp = int(time.time()) + 3600
74+
token = create_jwt_token({"exp": future_exp, "sub": "test_user"})
75+
76+
write_auth_config(AuthConfig(access_token=token))
77+
78+
assert is_logged_in()
79+
80+
81+
def test_is_logged_in_with_expired_token(temp_auth_config: Path) -> None:
82+
past_exp = int(time.time()) - 3600
83+
token = create_jwt_token({"exp": past_exp, "sub": "test_user"})
84+
85+
write_auth_config(AuthConfig(access_token=token))
86+
87+
assert not is_logged_in()
88+
89+
90+
def test_is_logged_in_with_malformed_token(temp_auth_config: Path) -> None:
91+
write_auth_config(AuthConfig(access_token="not.a.valid.token"))
92+
93+
assert not is_logged_in()
94+
95+
96+
def test_is_token_expired_edge_case_exact_expiration() -> None:
97+
current_time = int(time.time())
98+
token = create_jwt_token({"exp": current_time, "sub": "test_user"})
99+
100+
assert is_token_expired(token)
101+
102+
103+
def test_is_token_expired_edge_case_one_second_before() -> None:
104+
current_time = int(time.time())
105+
token = create_jwt_token({"exp": current_time + 1, "sub": "test_user"})
106+
107+
assert not is_token_expired(token)

tests/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import base64
2+
import json
13
import os
24
from contextlib import contextmanager
35
from pathlib import Path
4-
from typing import Generator, Union
6+
from typing import Any, Dict, Generator, Union
57

68

79
@contextmanager
@@ -20,3 +22,21 @@ class Keys:
2022
ENTER = "\r"
2123
CTRL_C = "\x03"
2224
TAB = "\t"
25+
26+
27+
def create_jwt_token(payload: Dict[str, Any]) -> str:
28+
# Note: This creates a JWT with an invalid signature, but that's OK for our tests
29+
# since we only parse the payload, not verify the signature.
30+
31+
header = {"alg": "HS256", "typ": "JWT"}
32+
header_encoded = (
33+
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
34+
)
35+
36+
payload_encoded = (
37+
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
38+
)
39+
40+
signature = base64.urlsafe_b64encode(b"signature").decode().rstrip("=")
41+
42+
return f"{header_encoded}.{payload_encoded}.{signature}"

0 commit comments

Comments
 (0)