diff --git a/codecarbon/cli/auth.py b/codecarbon/cli/auth.py new file mode 100644 index 000000000..9b6fd0656 --- /dev/null +++ b/codecarbon/cli/auth.py @@ -0,0 +1,225 @@ +""" +OIDC Authentication helpers for the CodeCarbon CLI. + +Handles the full token lifecycle: browser-based login (Authorization Code + +PKCE), credential storage, JWKS validation, and transparent refresh. +""" + +import json +import os +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from urllib.parse import parse_qs, urlparse + +import requests +from authlib.common.security import generate_token +from authlib.integrations.requests_client import OAuth2Session +from authlib.jose import JsonWebKey +from authlib.jose import jwt as jose_jwt +from authlib.oauth2.rfc7636 import create_s256_code_challenge + +AUTH_CLIENT_ID = os.environ.get( + "AUTH_CLIENT_ID", + "codecarbon-cli", +) +AUTH_SERVER_WELL_KNOWN = os.environ.get( + "AUTH_SERVER_WELL_KNOWN", + "https://authentication.codecarbon.io/realms/codecarbon/.well-known/openid-configuration", +) + +_REDIRECT_PORT = 8090 +_REDIRECT_URI = f"http://localhost:{_REDIRECT_PORT}/callback" +_CREDENTIALS_FILE = Path("./credentials.json") + + +# ── OAuth callback server ─────────────────────────────────────────── + + +class _CallbackHandler(BaseHTTPRequestHandler): + """HTTP handler that captures the OAuth2 authorization callback.""" + + callback_url = None + error = None + + def do_GET(self): + _CallbackHandler.callback_url = f"http://localhost:{_REDIRECT_PORT}{self.path}" + parsed = urlparse(self.path) + params = parse_qs(parsed.query) + + if "error" in params: + _CallbackHandler.error = params["error"][0] + self.send_response(400) + self.send_header("Content-Type", "text/html") + self.end_headers() + msg = params.get("error_description", [params["error"][0]])[0] + self.wfile.write( + f"

Login failed

{msg}

".encode() + ) + else: + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"

Login successful!

" + b"

You can close this window.

" + ) + + def log_message(self, format, *args): + pass + + +# ── OIDC discovery ────────────────────────────────────────────────── + + +def _discover_endpoints(): + """Fetch OpenID Connect discovery document.""" + resp = requests.get(AUTH_SERVER_WELL_KNOWN) + resp.raise_for_status() + return resp.json() + + +# ── Credential storage ────────────────────────────────────────────── + + +def _save_credentials(tokens): + """Save OAuth tokens to the local credentials file.""" + with open(_CREDENTIALS_FILE, "w") as f: + json.dump(tokens, f) + + +def _load_credentials(): + """Load OAuth tokens from the local credentials file.""" + with open(_CREDENTIALS_FILE, "r") as f: + return json.load(f) + + +# ── Token validation & refresh ────────────────────────────────────── + + +def _validate_access_token(access_token: str) -> bool: + """Validate access token against the current OIDC provider's JWKS. + + Returns False when the signature or expiry check fails (wrong provider, + expired, tampered). Returns True on network errors so the caller can + fall through to the API and let the server decide. + """ + try: + discovery = _discover_endpoints() + jwks_resp = requests.get(discovery["jwks_uri"]) + jwks_resp.raise_for_status() + keyset = JsonWebKey.import_key_set(jwks_resp.json()) + claims = jose_jwt.decode(access_token, keyset) + claims.validate() + return True + except requests.RequestException: + return True # Can't reach auth server — let the API handle it + except Exception: + return False + + +def _refresh_tokens(refresh_token: str) -> dict: + """Exchange a refresh token for a new token set via the OIDC token endpoint.""" + discovery = _discover_endpoints() + resp = requests.post( + discovery["token_endpoint"], + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": AUTH_CLIENT_ID, + }, + ) + resp.raise_for_status() + return resp.json() + + +# ── Public API ────────────────────────────────────────────────────── + + +def authorize(): + """Run the OAuth2 Authorization Code flow with PKCE.""" + discovery = _discover_endpoints() + + session = OAuth2Session( + client_id=AUTH_CLIENT_ID, + redirect_uri=_REDIRECT_URI, + scope="openid offline_access", + token_endpoint_auth_method="none", + ) + + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + + uri, state = session.create_authorization_url( + discovery["authorization_endpoint"], + code_challenge=code_challenge, + code_challenge_method="S256", + ) + + _CallbackHandler.callback_url = None + _CallbackHandler.error = None + + server = HTTPServer(("localhost", _REDIRECT_PORT), _CallbackHandler) + + print("Opening browser for authentication...") + webbrowser.open(uri) + + server.handle_request() + server.server_close() + + if _CallbackHandler.error: + raise ValueError(f"Authorization failed: {_CallbackHandler.error}") + + if not _CallbackHandler.callback_url: + raise ValueError("Authorization failed: no callback received") + + token = session.fetch_token( + discovery["token_endpoint"], + authorization_response=_CallbackHandler.callback_url, + code_verifier=code_verifier, + ) + + _save_credentials(token) + return token + + +def get_access_token() -> str: + """Return a valid access token, refreshing or failing with a clear message.""" + try: + creds = _load_credentials() + except Exception as e: + raise ValueError( + "Not able to retrieve the access token, " + f"please run `codecarbon login` first! (error: {e})" + ) + + access_token = creds.get("access_token") + if not access_token: + raise ValueError("No access token found. Please run `codecarbon login` first.") + + # Fast path: token is still valid for the current OIDC provider + if _validate_access_token(access_token): + return access_token + + # Token is expired or was issued by a different provider — try refresh + refresh_token = creds.get("refresh_token") + if refresh_token: + try: + new_tokens = _refresh_tokens(refresh_token) + _save_credentials(new_tokens) + return new_tokens["access_token"] + except Exception: + pass + + # Refresh failed — credentials are stale (e.g. auth provider migrated) + _CREDENTIALS_FILE.unlink(missing_ok=True) + raise ValueError( + "Your session has expired or the authentication provider has changed. " + "Please run `codecarbon login` again." + ) + + +def get_id_token() -> str: + """Return the stored OIDC id_token.""" + creds = _load_credentials() + return creds["id_token"] diff --git a/codecarbon/cli/main.py b/codecarbon/cli/main.py index 6f28c2309..36c654ee3 100644 --- a/codecarbon/cli/main.py +++ b/codecarbon/cli/main.py @@ -8,13 +8,12 @@ import questionary import requests import typer -from fief_client import Fief -from fief_client.integrations.cli import FiefAuth from rich import print from rich.prompt import Confirm from typing_extensions import Annotated from codecarbon import __app_name__, __version__ +from codecarbon.cli.auth import authorize, get_access_token from codecarbon.cli.cli_utils import ( create_new_config_file, get_api_endpoint, @@ -27,13 +26,6 @@ from codecarbon.core.schemas import ExperimentCreate, OrganizationCreate, ProjectCreate from codecarbon.emissions_tracker import EmissionsTracker, OfflineEmissionsTracker -AUTH_CLIENT_ID = os.environ.get( - "AUTH_CLIENT_ID", - "jsUPWIcUECQFE_ouanUuVhXx52TTjEVcVNNtNGeyAtU", -) -AUTH_SERVER_URL = os.environ.get( - "AUTH_SERVER_URL", "https://auth.codecarbon.io/codecarbon" -) API_URL = os.environ.get("API_URL", "https://dashboard.codecarbon.io/api") DEFAULT_PROJECT_ID = "e60afa92-17b7-4720-91a0-1ae91e409ba1" @@ -79,7 +71,7 @@ def show_config(path: Path = Path("./.codecarbon.config")) -> None: d = get_config(path) api_endpoint = get_api_endpoint(path) api = ApiClient(endpoint_url=api_endpoint) - api.set_access_token(_get_access_token()) + api.set_access_token(get_access_token()) print("Current configuration : \n") print("Config file content : ") print(d) @@ -115,28 +107,6 @@ def show_config(path: Path = Path("./.codecarbon.config")) -> None: ) -def get_fief_auth(): - fief = Fief(AUTH_SERVER_URL, AUTH_CLIENT_ID) - fief_auth = FiefAuth(fief, "./credentials.json") - return fief_auth - - -def _get_access_token(): - try: - access_token_info = get_fief_auth().access_token_info() - access_token = access_token_info["access_token"] - return access_token - except Exception as e: - raise ValueError( - f"Not able to retrieve the access token, please run `codecarbon login` first! (error: {e})" - ) - - -def _get_id_token(): - id_token = get_fief_auth()._tokens["id_token"] - return id_token - - @codecarbon.command( "test-api", short_help="Make an authenticated GET request to an API endpoint" ) @@ -145,16 +115,16 @@ def api_get(): ex: test-api """ api = ApiClient(endpoint_url=API_URL) # TODO: get endpoint from config - api.set_access_token(_get_access_token()) + api.set_access_token(get_access_token()) organizations = api.get_list_organizations() print(organizations) @codecarbon.command("login", short_help="Login to CodeCarbon") def login(): - get_fief_auth().authorize() + authorize() api = ApiClient(endpoint_url=API_URL) # TODO: get endpoint from config - access_token = _get_access_token() + access_token = get_access_token() api.set_access_token(access_token) api.check_auth() @@ -167,7 +137,7 @@ def get_api_key(project_id: str): "name": "api token", "x_token": "???", }, - headers={"Authorization": f"Bearer {_get_access_token()}"}, + headers={"Authorization": f"Bearer {get_access_token()}"}, ) api_key = req.json()["token"] return api_key @@ -176,7 +146,7 @@ def get_api_key(project_id: str): @codecarbon.command("get-token", short_help="Get project token") def get_token(project_id: str): # api = ApiClient(endpoint_url=API_URL) # TODO: get endpoint from config - # api.set_access_token(_get_access_token()) + # api.set_access_token(get_access_token()) token = get_api_key(project_id) print("Your token: " + token) print("Add it to the api_key field in your configuration file") @@ -224,7 +194,7 @@ def config(): ) overwrite_local_config("api_endpoint", api_endpoint, path=file_path) api = ApiClient(endpoint_url=api_endpoint) - api.set_access_token(_get_access_token()) + api.set_access_token(get_access_token()) organizations = api.get_list_organizations() org = questionary_prompt( "Pick existing organization from list or Create new organization ?", diff --git a/pyproject.toml b/pyproject.toml index 34d4e5d73..f74617486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,8 @@ classifiers = [ ] dependencies = [ "arrow", + "authlib>=1.2.1", "click", - "fief-client[cli]", "pandas>=2.3.3;python_version>='3.14'", "pandas;python_version<'3.14'", "prometheus_client", diff --git a/tests/test_cli.py b/tests/cli/test_cli.py similarity index 98% rename from tests/test_cli.py rename to tests/cli/test_cli.py index 35e7a4e22..0935cc069 100644 --- a/tests/test_cli.py +++ b/tests/cli/test_cli.py @@ -57,7 +57,7 @@ def test_app(self, MockApiClient): @patch("codecarbon.cli.main.Path.exists") @patch("codecarbon.cli.main.Confirm.ask") @patch("codecarbon.cli.main.questionary_prompt") - @patch("codecarbon.cli.main._get_access_token") + @patch("codecarbon.cli.main.get_access_token") @patch("typer.prompt") def test_config_no_local_new_all( self, @@ -147,7 +147,7 @@ def side_effect_wrapper(*args, **kwargs): except OSError: pass - @patch("codecarbon.cli.main._get_access_token") + @patch("codecarbon.cli.main.get_access_token") @patch("codecarbon.cli.main.Path.exists") @patch("codecarbon.cli.main.get_config") @patch("codecarbon.cli.main.questionary_prompt") diff --git a/tests/cli/test_cli_auth.py b/tests/cli/test_cli_auth.py new file mode 100644 index 000000000..048beaa23 --- /dev/null +++ b/tests/cli/test_cli_auth.py @@ -0,0 +1,143 @@ +import io +import json +import unittest +from unittest.mock import MagicMock, patch + +from codecarbon.cli import auth +from codecarbon.cli.auth import _CallbackHandler + + +class TestCallbackHandler(unittest.TestCase): + def setUp(self): + self.handler = _CallbackHandler + self.handler.callback_url = None + self.handler.error = None + + def _make_handler(self, path): + # Simulate BaseHTTPRequestHandler environment + request = MagicMock() + request.makefile.return_value = io.BytesIO() + server = MagicMock() + handler = _CallbackHandler(request, ("127.0.0.1", 12345), server) + handler.path = path + handler.wfile = io.BytesIO() + handler.send_response = MagicMock() + handler.send_header = MagicMock() + handler.end_headers = MagicMock() + handler.log_message = MagicMock() + return handler + + def test_do_get_success(self): + handler = self._make_handler("/callback?code=abc123&state=xyz") + handler.do_GET() + # Should set callback_url and not error + self.assertIsNone(_CallbackHandler.error) + self.assertTrue( + _CallbackHandler.callback_url.endswith("/callback?code=abc123&state=xyz") + ) + handler.send_response.assert_called_with(200) + handler.send_header.assert_called_with("Content-Type", "text/html") + handler.end_headers.assert_called() + output = handler.wfile.getvalue().decode() + self.assertIn("Login successful", output) + + def test_do_get_error(self): + handler = self._make_handler( + "/callback?error=access_denied&error_description=User+denied+access" + ) + handler.do_GET() + # Should set error and not callback_url + self.assertEqual(_CallbackHandler.error, "access_denied") + handler.send_response.assert_called_with(400) + handler.send_header.assert_called_with("Content-Type", "text/html") + handler.end_headers.assert_called() + output = handler.wfile.getvalue().decode() + self.assertIn("Login failed", output) + self.assertIn("User denied access", output) + + +class TestAuthMethods(unittest.TestCase): + @patch("codecarbon.cli.auth.requests.get") + def test_discover_endpoints(self, mock_get): + mock_get.return_value.json.return_value = { + "token_endpoint": "url", + "jwks_uri": "jwks", + } + mock_get.return_value.raise_for_status.return_value = None + result = auth._discover_endpoints() + self.assertIn("token_endpoint", result) + self.assertIn("jwks_uri", result) + + @patch("builtins.open") + def test_save_and_load_credentials(self, mock_open): + # Save + mock_file = MagicMock() + mock_open.return_value.__enter__.return_value = mock_file + tokens = {"access_token": "a", "refresh_token": "r", "id_token": "i"} + auth._save_credentials(tokens) + mock_file.write.assert_called() + # Load + mock_file.read.return_value = json.dumps(tokens) + mock_open.return_value.__enter__.return_value = mock_file + mock_file.__iter__.return_value = iter([json.dumps(tokens)]) + mock_file.read.return_value = json.dumps(tokens) + with patch("json.load", return_value=tokens): + loaded = auth._load_credentials() + self.assertEqual(loaded, tokens) + + @patch("codecarbon.cli.auth.requests.get") + @patch("codecarbon.cli.auth.JsonWebKey.import_key_set") + @patch("codecarbon.cli.auth.jose_jwt.decode") + def test_validate_access_token_valid( + self, mock_decode, mock_import_key_set, mock_get + ): + mock_get.return_value.json.return_value = {"jwks_uri": "jwks"} + mock_get.return_value.raise_for_status.return_value = None + mock_import_key_set.return_value = "keyset" + mock_decode.return_value.validate.return_value = None + with patch( + "codecarbon.cli.auth._discover_endpoints", return_value={"jwks_uri": "jwks"} + ): + self.assertTrue(auth._validate_access_token("token")) + + @patch("codecarbon.cli.auth.requests.post") + @patch("codecarbon.cli.auth._discover_endpoints") + def test_refresh_tokens(self, mock_discover, mock_post): + mock_discover.return_value = {"token_endpoint": "url"} + mock_post.return_value.raise_for_status.return_value = None + mock_post.return_value.json.return_value = { + "access_token": "a", + "refresh_token": "r", + } + result = auth._refresh_tokens("refresh") + self.assertIn("access_token", result) + self.assertIn("refresh_token", result) + + @patch("codecarbon.cli.auth._load_credentials") + @patch("codecarbon.cli.auth._validate_access_token") + def test_get_access_token_valid(self, mock_validate, mock_load): + mock_load.return_value = {"access_token": "a", "refresh_token": "r"} + mock_validate.return_value = True + self.assertEqual(auth.get_access_token(), "a") + + @patch("codecarbon.cli.auth._load_credentials") + @patch("codecarbon.cli.auth._validate_access_token") + @patch("codecarbon.cli.auth._refresh_tokens") + @patch("codecarbon.cli.auth._save_credentials") + def test_get_access_token_refresh( + self, mock_save, mock_refresh, mock_validate, mock_load + ): + mock_load.return_value = {"access_token": "a", "refresh_token": "r"} + mock_validate.return_value = False + mock_refresh.return_value = {"access_token": "b", "refresh_token": "r"} + self.assertEqual(auth.get_access_token(), "b") + mock_save.assert_called() + + @patch("codecarbon.cli.auth._load_credentials") + def test_get_id_token(self, mock_load): + mock_load.return_value = {"id_token": "i"} + self.assertEqual(auth.get_id_token(), "i") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cli_main.py b/tests/cli/test_cli_main.py similarity index 96% rename from tests/test_cli_main.py rename to tests/cli/test_cli_main.py index 0796b4218..705ab9295 100644 --- a/tests/test_cli_main.py +++ b/tests/cli/test_cli_main.py @@ -32,7 +32,7 @@ def test_version_flag(): def test_api_get_calls_api_and_prints(monkeypatch): runner = CliRunner() monkeypatch.setattr(cli_main, "ApiClient", FakeApiClient) - monkeypatch.setattr(cli_main, "_get_access_token", fake_get_access_token) + monkeypatch.setattr(cli_main, "get_access_token", fake_get_access_token) result = runner.invoke(cli_main.codecarbon, ["test-api"]) assert result.exit_code == 0