Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/fastapi_cloud_cli/commands/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def _waitlist_form(toolkit: RichToolkit) -> None:

with contextlib.suppress(Exception):
subprocess.run(
["open", "raycast://confetti"],
["open", "raycast://confetti?emojis=🐔⚡"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=False,
Expand Down
67 changes: 64 additions & 3 deletions src/fastapi_cloud_cli/utils/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import base64
import binascii
import json
import logging
import time
from typing import Optional

from pydantic import BaseModel
Expand Down Expand Up @@ -55,7 +59,64 @@ def get_auth_token() -> Optional[str]:
return auth_data.access_token


def is_token_expired(token: str) -> bool:
try:
parts = token.split(".")

if len(parts) != 3:
logger.debug("Invalid JWT format: expected 3 parts, got %d", len(parts))
return True

payload = parts[1]

# Add padding if needed (JWT uses base64url encoding without padding)
if padding := len(payload) % 4:
payload += "=" * (4 - padding)

payload = payload.replace("-", "+").replace("_", "/")
decoded_bytes = base64.b64decode(payload)
payload_data = json.loads(decoded_bytes)

exp = payload_data.get("exp")

if exp is None:
logger.debug("No 'exp' claim found in token")

return False

if not isinstance(exp, int): # pragma: no cover
logger.debug("Invalid 'exp' claim: expected int, got %s", type(exp))

return True

current_time = time.time()

is_expired = current_time >= exp

logger.debug(
"Token expiration check: current=%d, exp=%d, expired=%s",
current_time,
exp,
is_expired,
)

return is_expired
except (binascii.Error, json.JSONDecodeError) as e:
logger.debug("Error parsing JWT token: %s", e)

return True


def is_logged_in() -> bool:
result = get_auth_token() is not None
logger.debug("Login status: %s", result)
return result
token = get_auth_token()

if token is None:
logger.debug("Login status: False (no token)")
return False

if is_token_expired(token):
logger.debug("Login status: False (token expired)")
return False

logger.debug("Login status: True")
return True
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pytest
from typer import rich_utils

from .utils import create_jwt_token


@pytest.fixture(autouse=True)
def reset_syspath() -> Generator[None, None, None]:
Expand All @@ -26,7 +28,9 @@ def setup_terminal() -> None:

@pytest.fixture
def logged_in_cli(temp_auth_config: Path) -> Generator[None, None, None]:
temp_auth_config.write_text('{"access_token": "test_token_12345"}')
valid_token = create_jwt_token({"sub": "test_user_12345"})

temp_auth_config.write_text(f'{{"access_token": "{valid_token}"}}')

yield

Expand Down
107 changes: 107 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import base64
import time
from pathlib import Path

import pytest

from fastapi_cloud_cli.utils.auth import (
AuthConfig,
is_logged_in,
is_token_expired,
write_auth_config,
)

from .utils import create_jwt_token


def test_is_token_expired_with_valid_token() -> None:
future_exp = int(time.time()) + 3600

token = create_jwt_token({"exp": future_exp, "sub": "test_user"})

assert not is_token_expired(token)


def test_is_token_expired_with_expired_token() -> None:
past_exp = int(time.time()) - 3600
token = create_jwt_token({"exp": past_exp, "sub": "test_user"})

assert is_token_expired(token)


def test_is_token_expired_with_no_exp_claim() -> None:
token = create_jwt_token({"sub": "test_user"})

# Tokens without exp claim should be considered valid
assert not is_token_expired(token)


@pytest.mark.parametrize(
"token",
[
"not.a.valid.jwt.token",
"only.two",
"invalid",
"",
"...",
],
)
def test_is_token_expired_with_malformed_token(token: str) -> None:
assert is_token_expired(token)


def test_is_token_expired_with_invalid_base64() -> None:
token = "header.!!!invalid_signature!!!.signature"
assert is_token_expired(token)


def test_is_token_expired_with_invalid_json() -> None:
header_encoded = base64.urlsafe_b64encode(b'{"alg":"HS256"}').decode().rstrip("=")
payload_encoded = base64.urlsafe_b64encode(b"{invalid json}").decode().rstrip("=")
signature = base64.urlsafe_b64encode(b"signature").decode().rstrip("=")
token = f"{header_encoded}.{payload_encoded}.{signature}"

assert is_token_expired(token)


def test_is_logged_in_with_no_token(temp_auth_config: Path) -> None:
assert not temp_auth_config.exists()
assert not is_logged_in()


def test_is_logged_in_with_valid_token(temp_auth_config: Path) -> None:
future_exp = int(time.time()) + 3600
token = create_jwt_token({"exp": future_exp, "sub": "test_user"})

write_auth_config(AuthConfig(access_token=token))

assert is_logged_in()


def test_is_logged_in_with_expired_token(temp_auth_config: Path) -> None:
past_exp = int(time.time()) - 3600
token = create_jwt_token({"exp": past_exp, "sub": "test_user"})

write_auth_config(AuthConfig(access_token=token))

assert not is_logged_in()


def test_is_logged_in_with_malformed_token(temp_auth_config: Path) -> None:
write_auth_config(AuthConfig(access_token="not.a.valid.token"))

assert not is_logged_in()


def test_is_token_expired_edge_case_exact_expiration() -> None:
current_time = int(time.time())
token = create_jwt_token({"exp": current_time, "sub": "test_user"})

assert is_token_expired(token)


def test_is_token_expired_edge_case_one_second_before() -> None:
current_time = int(time.time())
token = create_jwt_token({"exp": current_time + 1, "sub": "test_user"})

assert not is_token_expired(token)
22 changes: 21 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import json
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Union
from typing import Any, Dict, Generator, Union


@contextmanager
Expand All @@ -20,3 +22,21 @@ class Keys:
ENTER = "\r"
CTRL_C = "\x03"
TAB = "\t"


def create_jwt_token(payload: Dict[str, Any]) -> str:
# Note: This creates a JWT with an invalid signature, but that's OK for our tests
# since we only parse the payload, not verify the signature.

header = {"alg": "HS256", "typ": "JWT"}
header_encoded = (
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
)

payload_encoded = (
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
)

signature = base64.urlsafe_b64encode(b"signature").decode().rstrip("=")

return f"{header_encoded}.{payload_encoded}.{signature}"