Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions codecarbon/cli/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""
OIDC Authentication helpers for the CodeCarbon CLI.

Handles the full token lifecycle: browser-based login (Authorization Code +
PKCE), credential storage, JWKS validation, and transparent refresh.
"""

import json
import os
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from urllib.parse import parse_qs, urlparse

import requests
from authlib.common.security import generate_token
from authlib.integrations.requests_client import OAuth2Session
from authlib.jose import JsonWebKey
from authlib.jose import jwt as jose_jwt
from authlib.oauth2.rfc7636 import create_s256_code_challenge

AUTH_CLIENT_ID = os.environ.get(
"AUTH_CLIENT_ID",
"codecarbon-cli",
)
AUTH_SERVER_WELL_KNOWN = os.environ.get(
"AUTH_SERVER_WELL_KNOWN",
"https://authentication.codecarbon.io/realms/codecarbon/.well-known/openid-configuration",
)

_REDIRECT_PORT = 8090
_REDIRECT_URI = f"http://localhost:{_REDIRECT_PORT}/callback"
_CREDENTIALS_FILE = Path("./credentials.json")


# ── OAuth callback server ───────────────────────────────────────────


class _CallbackHandler(BaseHTTPRequestHandler):
"""HTTP handler that captures the OAuth2 authorization callback."""

callback_url = None
error = None

def do_GET(self):
_CallbackHandler.callback_url = f"http://localhost:{_REDIRECT_PORT}{self.path}"
parsed = urlparse(self.path)
params = parse_qs(parsed.query)

if "error" in params:
_CallbackHandler.error = params["error"][0]
self.send_response(400)
self.send_header("Content-Type", "text/html")
self.end_headers()
msg = params.get("error_description", [params["error"][0]])[0]
self.wfile.write(
f"<html><body><h1>Login failed</h1><p>{msg}</p></body></html>".encode()
)
else:
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(
b"<html><body><h1>Login successful!</h1>"
b"<p>You can close this window.</p></body></html>"
)

def log_message(self, format, *args):
pass


# ── OIDC discovery ──────────────────────────────────────────────────


def _discover_endpoints():
"""Fetch OpenID Connect discovery document."""
resp = requests.get(AUTH_SERVER_WELL_KNOWN)
resp.raise_for_status()
return resp.json()


# ── Credential storage ──────────────────────────────────────────────


def _save_credentials(tokens):
"""Save OAuth tokens to the local credentials file."""
with open(_CREDENTIALS_FILE, "w") as f:
json.dump(tokens, f)


def _load_credentials():
"""Load OAuth tokens from the local credentials file."""
with open(_CREDENTIALS_FILE, "r") as f:
return json.load(f)


# ── Token validation & refresh ──────────────────────────────────────


def _validate_access_token(access_token: str) -> bool:
"""Validate access token against the current OIDC provider's JWKS.

Returns False when the signature or expiry check fails (wrong provider,
expired, tampered). Returns True on network errors so the caller can
fall through to the API and let the server decide.
"""
try:
discovery = _discover_endpoints()
jwks_resp = requests.get(discovery["jwks_uri"])
jwks_resp.raise_for_status()
keyset = JsonWebKey.import_key_set(jwks_resp.json())
claims = jose_jwt.decode(access_token, keyset)
claims.validate()
return True
except requests.RequestException:
return True # Can't reach auth server — let the API handle it
except Exception:
return False


def _refresh_tokens(refresh_token: str) -> dict:
"""Exchange a refresh token for a new token set via the OIDC token endpoint."""
discovery = _discover_endpoints()
resp = requests.post(
discovery["token_endpoint"],
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": AUTH_CLIENT_ID,
},
)
resp.raise_for_status()
return resp.json()


# ── Public API ──────────────────────────────────────────────────────


def authorize():
"""Run the OAuth2 Authorization Code flow with PKCE."""
discovery = _discover_endpoints()

session = OAuth2Session(
client_id=AUTH_CLIENT_ID,
redirect_uri=_REDIRECT_URI,
scope="openid offline_access",
token_endpoint_auth_method="none",
)

code_verifier = generate_token(48)
code_challenge = create_s256_code_challenge(code_verifier)

uri, state = session.create_authorization_url(
discovery["authorization_endpoint"],
code_challenge=code_challenge,
code_challenge_method="S256",
)

_CallbackHandler.callback_url = None
_CallbackHandler.error = None

server = HTTPServer(("localhost", _REDIRECT_PORT), _CallbackHandler)

print("Opening browser for authentication...")
webbrowser.open(uri)

server.handle_request()
server.server_close()

if _CallbackHandler.error:
raise ValueError(f"Authorization failed: {_CallbackHandler.error}")

if not _CallbackHandler.callback_url:
raise ValueError("Authorization failed: no callback received")

token = session.fetch_token(
discovery["token_endpoint"],
authorization_response=_CallbackHandler.callback_url,
code_verifier=code_verifier,
)

_save_credentials(token)
return token


def get_access_token() -> str:
"""Return a valid access token, refreshing or failing with a clear message."""
try:
creds = _load_credentials()
except Exception as e:
raise ValueError(
"Not able to retrieve the access token, "
f"please run `codecarbon login` first! (error: {e})"
)

access_token = creds.get("access_token")
if not access_token:
raise ValueError("No access token found. Please run `codecarbon login` first.")

# Fast path: token is still valid for the current OIDC provider
if _validate_access_token(access_token):
return access_token

# Token is expired or was issued by a different provider — try refresh
refresh_token = creds.get("refresh_token")
if refresh_token:
try:
new_tokens = _refresh_tokens(refresh_token)
_save_credentials(new_tokens)
return new_tokens["access_token"]
except Exception:
pass

# Refresh failed — credentials are stale (e.g. auth provider migrated)
_CREDENTIALS_FILE.unlink(missing_ok=True)
raise ValueError(
"Your session has expired or the authentication provider has changed. "
"Please run `codecarbon login` again."
)


def get_id_token() -> str:
"""Return the stored OIDC id_token."""
creds = _load_credentials()
return creds["id_token"]
46 changes: 8 additions & 38 deletions codecarbon/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import questionary
import requests
import typer
from fief_client import Fief
from fief_client.integrations.cli import FiefAuth
from rich import print
from rich.prompt import Confirm
from typing_extensions import Annotated

from codecarbon import __app_name__, __version__
from codecarbon.cli.auth import authorize, get_access_token
from codecarbon.cli.cli_utils import (
create_new_config_file,
get_api_endpoint,
Expand All @@ -27,13 +26,6 @@
from codecarbon.core.schemas import ExperimentCreate, OrganizationCreate, ProjectCreate
from codecarbon.emissions_tracker import EmissionsTracker, OfflineEmissionsTracker

AUTH_CLIENT_ID = os.environ.get(
"AUTH_CLIENT_ID",
"jsUPWIcUECQFE_ouanUuVhXx52TTjEVcVNNtNGeyAtU",
)
AUTH_SERVER_URL = os.environ.get(
"AUTH_SERVER_URL", "https://auth.codecarbon.io/codecarbon"
)
API_URL = os.environ.get("API_URL", "https://dashboard.codecarbon.io/api")

DEFAULT_PROJECT_ID = "e60afa92-17b7-4720-91a0-1ae91e409ba1"
Expand Down Expand Up @@ -79,7 +71,7 @@ def show_config(path: Path = Path("./.codecarbon.config")) -> None:
d = get_config(path)
api_endpoint = get_api_endpoint(path)
api = ApiClient(endpoint_url=api_endpoint)
api.set_access_token(_get_access_token())
api.set_access_token(get_access_token())
print("Current configuration : \n")
print("Config file content : ")
print(d)
Expand Down Expand Up @@ -115,28 +107,6 @@ def show_config(path: Path = Path("./.codecarbon.config")) -> None:
)


def get_fief_auth():
fief = Fief(AUTH_SERVER_URL, AUTH_CLIENT_ID)
fief_auth = FiefAuth(fief, "./credentials.json")
return fief_auth


def _get_access_token():
try:
access_token_info = get_fief_auth().access_token_info()
access_token = access_token_info["access_token"]
return access_token
except Exception as e:
raise ValueError(
f"Not able to retrieve the access token, please run `codecarbon login` first! (error: {e})"
)


def _get_id_token():
id_token = get_fief_auth()._tokens["id_token"]
return id_token


@codecarbon.command(
"test-api", short_help="Make an authenticated GET request to an API endpoint"
)
Expand All @@ -145,16 +115,16 @@ def api_get():
ex: test-api
"""
api = ApiClient(endpoint_url=API_URL) # TODO: get endpoint from config
api.set_access_token(_get_access_token())
api.set_access_token(get_access_token())
organizations = api.get_list_organizations()
print(organizations)


@codecarbon.command("login", short_help="Login to CodeCarbon")
def login():
get_fief_auth().authorize()
authorize()
api = ApiClient(endpoint_url=API_URL) # TODO: get endpoint from config
access_token = _get_access_token()
access_token = get_access_token()
api.set_access_token(access_token)
api.check_auth()

Expand All @@ -167,7 +137,7 @@ def get_api_key(project_id: str):
"name": "api token",
"x_token": "???",
},
headers={"Authorization": f"Bearer {_get_access_token()}"},
headers={"Authorization": f"Bearer {get_access_token()}"},
)
api_key = req.json()["token"]
return api_key
Expand All @@ -176,7 +146,7 @@ def get_api_key(project_id: str):
@codecarbon.command("get-token", short_help="Get project token")
def get_token(project_id: str):
# api = ApiClient(endpoint_url=API_URL) # TODO: get endpoint from config
# api.set_access_token(_get_access_token())
# api.set_access_token(get_access_token())
token = get_api_key(project_id)
print("Your token: " + token)
print("Add it to the api_key field in your configuration file")
Expand Down Expand Up @@ -224,7 +194,7 @@ def config():
)
overwrite_local_config("api_endpoint", api_endpoint, path=file_path)
api = ApiClient(endpoint_url=api_endpoint)
api.set_access_token(_get_access_token())
api.set_access_token(get_access_token())
organizations = api.get_list_organizations()
org = questionary_prompt(
"Pick existing organization from list or Create new organization ?",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ classifiers = [
]
dependencies = [
"arrow",
"authlib>=1.2.1",
"click",
"fief-client[cli]",
"pandas>=2.3.3;python_version>='3.14'",
"pandas;python_version<'3.14'",
"prometheus_client",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py → tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_app(self, MockApiClient):
@patch("codecarbon.cli.main.Path.exists")
@patch("codecarbon.cli.main.Confirm.ask")
@patch("codecarbon.cli.main.questionary_prompt")
@patch("codecarbon.cli.main._get_access_token")
@patch("codecarbon.cli.main.get_access_token")
@patch("typer.prompt")
def test_config_no_local_new_all(
self,
Expand Down Expand Up @@ -147,7 +147,7 @@ def side_effect_wrapper(*args, **kwargs):
except OSError:
pass

@patch("codecarbon.cli.main._get_access_token")
@patch("codecarbon.cli.main.get_access_token")
@patch("codecarbon.cli.main.Path.exists")
@patch("codecarbon.cli.main.get_config")
@patch("codecarbon.cli.main.questionary_prompt")
Expand Down
Loading
Loading