diff --git a/src/fastapi_cloud_cli/commands/deploy.py b/src/fastapi_cloud_cli/commands/deploy.py index da26c03..e8f687f 100644 --- a/src/fastapi_cloud_cli/commands/deploy.py +++ b/src/fastapi_cloud_cli/commands/deploy.py @@ -496,7 +496,7 @@ def _waitlist_form(toolkit: RichToolkit) -> None: with contextlib.suppress(Exception): subprocess.run( - ["open", "raycast://confetti"], + ["open", "raycast://confetti?emojis=🐔⚡"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False, diff --git a/src/fastapi_cloud_cli/utils/auth.py b/src/fastapi_cloud_cli/utils/auth.py index 246721b..5a72d49 100644 --- a/src/fastapi_cloud_cli/utils/auth.py +++ b/src/fastapi_cloud_cli/utils/auth.py @@ -1,4 +1,8 @@ +import base64 +import binascii +import json import logging +import time from typing import Optional from pydantic import BaseModel @@ -55,7 +59,64 @@ def get_auth_token() -> Optional[str]: return auth_data.access_token +def is_token_expired(token: str) -> bool: + try: + parts = token.split(".") + + if len(parts) != 3: + logger.debug("Invalid JWT format: expected 3 parts, got %d", len(parts)) + return True + + payload = parts[1] + + # Add padding if needed (JWT uses base64url encoding without padding) + if padding := len(payload) % 4: + payload += "=" * (4 - padding) + + payload = payload.replace("-", "+").replace("_", "/") + decoded_bytes = base64.b64decode(payload) + payload_data = json.loads(decoded_bytes) + + exp = payload_data.get("exp") + + if exp is None: + logger.debug("No 'exp' claim found in token") + + return False + + if not isinstance(exp, int): # pragma: no cover + logger.debug("Invalid 'exp' claim: expected int, got %s", type(exp)) + + return True + + current_time = time.time() + + is_expired = current_time >= exp + + logger.debug( + "Token expiration check: current=%d, exp=%d, expired=%s", + current_time, + exp, + is_expired, + ) + + return is_expired + except (binascii.Error, json.JSONDecodeError) as e: + logger.debug("Error parsing JWT token: %s", e) + + return True + + def is_logged_in() -> bool: - result = get_auth_token() is not None - logger.debug("Login status: %s", result) - return result + token = get_auth_token() + + if token is None: + logger.debug("Login status: False (no token)") + return False + + if is_token_expired(token): + logger.debug("Login status: False (token expired)") + return False + + logger.debug("Login status: True") + return True diff --git a/tests/conftest.py b/tests/conftest.py index ebe5ca8..6de52c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,8 @@ import pytest from typer import rich_utils +from .utils import create_jwt_token + @pytest.fixture(autouse=True) def reset_syspath() -> Generator[None, None, None]: @@ -26,7 +28,9 @@ def setup_terminal() -> None: @pytest.fixture def logged_in_cli(temp_auth_config: Path) -> Generator[None, None, None]: - temp_auth_config.write_text('{"access_token": "test_token_12345"}') + valid_token = create_jwt_token({"sub": "test_user_12345"}) + + temp_auth_config.write_text(f'{{"access_token": "{valid_token}"}}') yield diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..96a4d21 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,107 @@ +import base64 +import time +from pathlib import Path + +import pytest + +from fastapi_cloud_cli.utils.auth import ( + AuthConfig, + is_logged_in, + is_token_expired, + write_auth_config, +) + +from .utils import create_jwt_token + + +def test_is_token_expired_with_valid_token() -> None: + future_exp = int(time.time()) + 3600 + + token = create_jwt_token({"exp": future_exp, "sub": "test_user"}) + + assert not is_token_expired(token) + + +def test_is_token_expired_with_expired_token() -> None: + past_exp = int(time.time()) - 3600 + token = create_jwt_token({"exp": past_exp, "sub": "test_user"}) + + assert is_token_expired(token) + + +def test_is_token_expired_with_no_exp_claim() -> None: + token = create_jwt_token({"sub": "test_user"}) + + # Tokens without exp claim should be considered valid + assert not is_token_expired(token) + + +@pytest.mark.parametrize( + "token", + [ + "not.a.valid.jwt.token", + "only.two", + "invalid", + "", + "...", + ], +) +def test_is_token_expired_with_malformed_token(token: str) -> None: + assert is_token_expired(token) + + +def test_is_token_expired_with_invalid_base64() -> None: + token = "header.!!!invalid_signature!!!.signature" + assert is_token_expired(token) + + +def test_is_token_expired_with_invalid_json() -> None: + header_encoded = base64.urlsafe_b64encode(b'{"alg":"HS256"}').decode().rstrip("=") + payload_encoded = base64.urlsafe_b64encode(b"{invalid json}").decode().rstrip("=") + signature = base64.urlsafe_b64encode(b"signature").decode().rstrip("=") + token = f"{header_encoded}.{payload_encoded}.{signature}" + + assert is_token_expired(token) + + +def test_is_logged_in_with_no_token(temp_auth_config: Path) -> None: + assert not temp_auth_config.exists() + assert not is_logged_in() + + +def test_is_logged_in_with_valid_token(temp_auth_config: Path) -> None: + future_exp = int(time.time()) + 3600 + token = create_jwt_token({"exp": future_exp, "sub": "test_user"}) + + write_auth_config(AuthConfig(access_token=token)) + + assert is_logged_in() + + +def test_is_logged_in_with_expired_token(temp_auth_config: Path) -> None: + past_exp = int(time.time()) - 3600 + token = create_jwt_token({"exp": past_exp, "sub": "test_user"}) + + write_auth_config(AuthConfig(access_token=token)) + + assert not is_logged_in() + + +def test_is_logged_in_with_malformed_token(temp_auth_config: Path) -> None: + write_auth_config(AuthConfig(access_token="not.a.valid.token")) + + assert not is_logged_in() + + +def test_is_token_expired_edge_case_exact_expiration() -> None: + current_time = int(time.time()) + token = create_jwt_token({"exp": current_time, "sub": "test_user"}) + + assert is_token_expired(token) + + +def test_is_token_expired_edge_case_one_second_before() -> None: + current_time = int(time.time()) + token = create_jwt_token({"exp": current_time + 1, "sub": "test_user"}) + + assert not is_token_expired(token) diff --git a/tests/utils.py b/tests/utils.py index 74aa97b..c0d2e32 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,9 @@ +import base64 +import json import os from contextlib import contextmanager from pathlib import Path -from typing import Generator, Union +from typing import Any, Dict, Generator, Union @contextmanager @@ -20,3 +22,21 @@ class Keys: ENTER = "\r" CTRL_C = "\x03" TAB = "\t" + + +def create_jwt_token(payload: Dict[str, Any]) -> str: + # Note: This creates a JWT with an invalid signature, but that's OK for our tests + # since we only parse the payload, not verify the signature. + + header = {"alg": "HS256", "typ": "JWT"} + header_encoded = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") + ) + + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) + + signature = base64.urlsafe_b64encode(b"signature").decode().rstrip("=") + + return f"{header_encoded}.{payload_encoded}.{signature}"