diff --git a/backend/apps/oauth_app.py b/backend/apps/oauth_app.py index bda69f935..f05102d0c 100644 --- a/backend/apps/oauth_app.py +++ b/backend/apps/oauth_app.py @@ -1,27 +1,36 @@ import logging -from fastapi import APIRouter, Header, HTTPException +from fastapi import APIRouter, Header, HTTPException, Request from fastapi.responses import JSONResponse, RedirectResponse from http import HTTPStatus from typing import Optional +from pydantic import ValidationError as PydanticValidationError + +from consts.model import OAuthCompleteRequest from consts.exceptions import OAuthLinkError, OAuthProviderError, UnauthorizedError from consts.oauth_providers import get_all_provider_definitions from database.oauth_account_db import get_oauth_account_by_provider from services.oauth_service import ( + complete_pending_oauth_account, create_or_update_oauth_account, ensure_user_tenant_exists, exchange_code_for_provider_token, + find_supabase_user_id_by_email, + generate_pending_oauth_token, get_authorize_url, get_enabled_providers, + get_pending_oauth_info, get_provider_user_info, list_linked_accounts, - unlink_account, parse_state, + parse_state, + unlink_account, ) from utils.auth_utils import ( calculate_expires_at, generate_session_jwt, - get_current_user_id, get_supabase_admin_client, + get_current_user_id, + get_supabase_admin_client, ) logger = logging.getLogger(__name__) @@ -142,44 +151,37 @@ async def callback( if existing_binding: supabase_user_id = existing_binding["user_id"] else: - # No binding found, search/create user by email in Supabase - admin_client = get_supabase_admin_client() - if not admin_client: - raise RuntimeError("Supabase admin client not available") - supabase_user_id = None - page = 1 - while True: - users_resp = admin_client.auth.admin.list_users( - page=page, per_page=100 + if email: + admin_client = get_supabase_admin_client() + if not admin_client: + raise RuntimeError("Supabase admin client not available") + supabase_user_id = find_supabase_user_id_by_email( + admin_client, + email, ) - users = users_resp if len(users_resp) > 0 else [] - if not users: - break - for u in users: - if u.email and u.email.lower() == email.lower(): - supabase_user_id = u.id - break - if supabase_user_id: - break - if len(users) < 100: - break - page += 1 if not supabase_user_id: - if not email: - email = f"{provider}_{provider_user_id}@oauth.nexent" - create_resp = admin_client.auth.admin.create_user( - { - "email": email, - "email_confirm": True, - "user_metadata": { - "full_name": username, + pending_token = generate_pending_oauth_token( + provider=provider, + provider_user_id=provider_user_id, + provider_email=email, + provider_username=username, + ) + return JSONResponse( + status_code=HTTPStatus.OK, + content={ + "message": "OAuth account information required", + "data": { + "requires_account_completion": True, + "pending_token": pending_token, "provider": provider, + "provider_username": username, + "provider_email": email, + "email_required": not bool(email), }, - } + }, ) - supabase_user_id = create_resp.user.id ensure_user_tenant_exists(user_id=supabase_user_id, email=email) @@ -214,6 +216,18 @@ async def callback( }, ) + except OAuthLinkError as e: + logger.warning(f"OAuth callback link failed for provider={provider}: {e}") + return JSONResponse( + status_code=HTTPStatus.BAD_REQUEST, + content={ + "message": "OAuth account link failed", + "data": { + "oauth_error": "oauth_account_already_bound", + "oauth_error_description": "OAuth account is already bound to another user", + }, + }, + ) except Exception as e: logger.error(f"OAuth callback failed for provider={provider}: {e}") return JSONResponse( @@ -228,6 +242,67 @@ async def callback( ) +@router.get("/pending") +async def get_pending( + pending_token: Optional[str] = Header(None, alias="X-OAuth-Pending-Token"), +): + try: + pending = get_pending_oauth_info(pending_token or "") + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": pending}, + ) + except OAuthLinkError as e: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=str(e)) + except OAuthProviderError as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) + except Exception as e: + logger.error(f"Failed to get pending OAuth info: {e}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Failed to get pending OAuth info", + ) + + +@router.post("/complete") +async def complete( + request: Request, + pending_token: Optional[str] = Header(None, alias="X-OAuth-Pending-Token"), +): + try: + request_data = OAuthCompleteRequest(**(await request.json())) + result = await complete_pending_oauth_account( + pending_token=pending_token or "", + email=str(request_data.email) if request_data.email else None, + password=request_data.password, + invite_code=request_data.invite_code, + ) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "OAuth account completed", "data": result}, + ) + except OAuthLinkError as e: + status_code = ( + HTTPStatus.CONFLICT + if "Email already exists" in str(e) + else HTTPStatus.BAD_REQUEST + ) + raise HTTPException(status_code=status_code, detail=str(e)) + except PydanticValidationError as e: + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail=e.errors(), + ) + except OAuthProviderError as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) + except Exception as e: + logger.error(f"Failed to complete OAuth account: {e}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Failed to complete OAuth account", + ) + + @router.get("/accounts") async def get_accounts(authorization: Optional[str] = Header(None)): if not authorization: @@ -257,20 +332,7 @@ async def delete_account(provider: str, authorization: Optional[str] = Header(No try: user_id, _ = get_current_user_id(authorization) - - has_password_auth = False - - admin_client = get_supabase_admin_client() - if admin_client: - try: - user_resp = admin_client.auth.admin.get_user_by_id(user_id) - user_metadata = getattr(user_resp.user, "user_metadata", {}) or {} - signup_provider = user_metadata.get("provider", "email") - has_password_auth = signup_provider == "email" - except Exception as e: - logger.warning(f"Failed to check user identities for {user_id}: {e}") - - unlink_account(user_id, provider, has_password_auth=has_password_auth) + unlink_account(user_id, provider) return JSONResponse( status_code=HTTPStatus.OK, content={ diff --git a/backend/consts/model.py b/backend/consts/model.py index 2f1d7aae3..5e54835bb 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -40,6 +40,13 @@ class UserSignInRequest(BaseModel): password: str +class OAuthCompleteRequest(BaseModel): + """Complete a pending OAuth signup.""" + email: Optional[EmailStr] = None + password: str = Field(..., min_length=6) + invite_code: str = Field(..., min_length=1) + + class UserUpdateRequest(BaseModel): """User update request model""" username: Optional[str] = Field(None, min_length=1, max_length=50) diff --git a/backend/services/oauth_service.py b/backend/services/oauth_service.py index 0083ad9ec..7d2319bad 100644 --- a/backend/services/oauth_service.py +++ b/backend/services/oauth_service.py @@ -3,15 +3,20 @@ import os import secrets import ssl +import time import urllib.request from typing import Any, Dict, List, Optional from urllib.parse import urlencode, quote +import jwt +from pydantic import EmailStr, TypeAdapter, ValidationError as PydanticValidationError + from consts.const import ( DEFAULT_TENANT_ID, OAUTH_CALLBACK_BASE_URL, OAUTH_SSL_VERIFY, OAUTH_CA_BUNDLE, + SUPABASE_JWT_SECRET, ) from consts.exceptions import OAuthLinkError, OAuthProviderError from consts.oauth_providers import ( @@ -20,7 +25,6 @@ is_provider_enabled, ) from database.oauth_account_db import ( - count_oauth_accounts_by_user_id, delete_oauth_account, get_oauth_account_by_provider, get_soft_deleted_oauth_account, @@ -33,6 +37,10 @@ logger = logging.getLogger(__name__) +OAUTH_PENDING_EXPIRE_SECONDS = 10 * 60 +OAUTH_PENDING_PURPOSE = "oauth_account_completion" +_EMAIL_ADAPTER = TypeAdapter(EmailStr) + def _build_ssl_context() -> ssl.SSLContext: if OAUTH_CA_BUNDLE and os.path.isfile(OAUTH_CA_BUNDLE): @@ -246,12 +254,233 @@ def get_provider_user_info( except Exception: logger.warning(f"Failed to fetch {provider} user emails") - if result.get("email", "") == "": - result["email"] = f"{result['username']}@nexent.com" - return result +def generate_pending_oauth_token( + provider: str, + provider_user_id: str, + provider_email: Optional[str] = None, + provider_username: Optional[str] = None, + expires_in: int = OAUTH_PENDING_EXPIRE_SECONDS, +) -> str: + if not SUPABASE_JWT_SECRET: + raise OAuthProviderError("JWT verification is not configured") + + now = int(time.time()) + payload = { + "purpose": OAUTH_PENDING_PURPOSE, + "provider": provider, + "provider_user_id": provider_user_id, + "provider_email": provider_email or "", + "provider_username": provider_username or "", + "iat": now, + "exp": now + expires_in, + } + return jwt.encode(payload, SUPABASE_JWT_SECRET, algorithm="HS256") + + +def parse_pending_oauth_token(pending_token: str) -> Dict[str, str]: + if not pending_token: + raise OAuthLinkError("OAuth account completion session is missing") + if not SUPABASE_JWT_SECRET: + raise OAuthProviderError("JWT verification is not configured") + + try: + payload = jwt.decode( + pending_token, + SUPABASE_JWT_SECRET, + algorithms=["HS256"], + options={"verify_exp": True, "verify_aud": False}, + ) + except jwt.ExpiredSignatureError as exc: + raise OAuthLinkError("OAuth account completion session has expired") from exc + except jwt.InvalidTokenError as exc: + raise OAuthLinkError("OAuth account completion session is invalid") from exc + + if payload.get("purpose") != OAUTH_PENDING_PURPOSE: + raise OAuthLinkError("OAuth account completion session is invalid") + if not payload.get("provider") or not payload.get("provider_user_id"): + raise OAuthLinkError("OAuth account completion session is incomplete") + + return { + "provider": str(payload.get("provider", "")), + "provider_user_id": str(payload.get("provider_user_id", "")), + "provider_email": str(payload.get("provider_email", "")), + "provider_username": str(payload.get("provider_username", "")), + } + + +def get_pending_oauth_info(pending_token: str) -> Dict[str, Any]: + payload = parse_pending_oauth_token(pending_token) + provider_email = payload.get("provider_email") or "" + return { + "provider": payload["provider"], + "provider_username": payload.get("provider_username") or "", + "provider_email": provider_email, + "email_required": not bool(provider_email), + } + + +def _validate_email(email: Optional[str]) -> str: + if not email: + raise OAuthLinkError("Email is required") + try: + return str(_EMAIL_ADAPTER.validate_python(email)).lower() + except PydanticValidationError as exc: + raise OAuthLinkError("Invalid email address") from exc + + +def find_supabase_user_id_by_email( + admin_client: Any, email: Optional[str] +) -> Optional[str]: + if not email: + return None + + page = 1 + while True: + users_resp = admin_client.auth.admin.list_users(page=page, per_page=100) + users = getattr(users_resp, "users", users_resp) + if users is None: + users = [] + if not users: + return None + for user in users: + user_email = getattr(user, "email", "") + if user_email and user_email.lower() == email.lower(): + return user.id + if len(users) < 100: + return None + page += 1 + + +def _role_from_invitation_type(code_type: str) -> str: + if code_type == "ADMIN_INVITE": + return "ADMIN" + if code_type == "DEV_INVITE": + return "DEV" + return "USER" + + +async def complete_pending_oauth_account( + pending_token: str, + password: str, + invite_code: str, + email: Optional[str] = None, +) -> Dict[str, Any]: + from services.group_service import add_user_to_groups + from services.invitation_service import ( + check_invitation_available, + get_invitation_by_code, + use_invitation_code, + ) + from services.tool_configuration_service import init_tool_list_for_tenant + from services.user_management_service import generate_tts_stt_4_admin + from utils.auth_utils import calculate_expires_at, generate_session_jwt + + pending = parse_pending_oauth_token(pending_token) + provider = pending["provider"] + provider_user_id = pending["provider_user_id"] + provider_email = pending.get("provider_email") or "" + provider_username = pending.get("provider_username") or "" + + if len(password or "") < 6: + raise OAuthLinkError("Password must be at least 6 characters") + + final_email = _validate_email(provider_email or email) + normalized_invite_code = invite_code.upper() + + if get_oauth_account_by_provider(provider, provider_user_id): + raise OAuthLinkError(f"This {provider} account is already bound to another user") + + if not check_invitation_available(normalized_invite_code): + raise OAuthLinkError("Invitation code is invalid or unavailable") + + invitation_info = get_invitation_by_code(normalized_invite_code) + if not invitation_info: + raise OAuthLinkError("Invitation code is invalid or unavailable") + + admin_client = None + try: + from utils.auth_utils import get_supabase_admin_client + + admin_client = get_supabase_admin_client() + except Exception: + admin_client = None + if not admin_client: + raise RuntimeError("Supabase admin client not available") + + existing_user_id = find_supabase_user_id_by_email(admin_client, final_email) + if existing_user_id: + raise OAuthLinkError( + "Email already exists. Please log in with email and password, " + "then link this OAuth account in settings." + ) + + create_resp = admin_client.auth.admin.create_user( + { + "email": final_email, + "password": password, + "email_confirm": True, + "user_metadata": { + "full_name": provider_username, + "provider": provider, + }, + } + ) + supabase_user_id = create_resp.user.id + + tenant_id = invitation_info["tenant_id"] + user_role = _role_from_invitation_type(invitation_info.get("code_type", "USER_INVITE")) + + insert_user_tenant( + user_id=supabase_user_id, + tenant_id=tenant_id, + user_role=user_role, + user_email=final_email, + ) + + invitation_result = use_invitation_code(normalized_invite_code, supabase_user_id) + group_ids = invitation_result.get("group_ids", []) + if isinstance(group_ids, str): + from utils.str_utils import convert_string_to_list + + group_ids = convert_string_to_list(group_ids) + if group_ids: + add_user_to_groups(supabase_user_id, group_ids, supabase_user_id) + + if user_role == "ADMIN": + await generate_tts_stt_4_admin(tenant_id, supabase_user_id) + await init_tool_list_for_tenant(tenant_id, supabase_user_id) + + create_or_update_oauth_account( + user_id=supabase_user_id, + provider=provider, + provider_user_id=provider_user_id, + email=final_email, + username=provider_username, + tenant_id=tenant_id, + ) + + expiry_seconds = 3600 + jwt_token = generate_session_jwt(supabase_user_id, expires_in=expiry_seconds) + expires_at = calculate_expires_at(jwt_token) + + return { + "user": { + "id": str(supabase_user_id), + "email": final_email, + "role": user_role, + }, + "session": { + "access_token": jwt_token, + "refresh_token": "", + "expires_at": expires_at, + "expires_in_seconds": expiry_seconds, + }, + } + + def create_or_update_oauth_account( user_id: str, provider: str, @@ -330,13 +559,7 @@ def list_linked_accounts(user_id: str) -> List[Dict[str, Any]]: return result -def unlink_account( - user_id: str, provider: str, has_password_auth: bool = False -) -> bool: - oauth_count = count_oauth_accounts_by_user_id(user_id) - if oauth_count <= 1 and not has_password_auth: - raise OAuthLinkError("Cannot unlink the last authentication method") - +def unlink_account(user_id: str, provider: str) -> bool: success = delete_oauth_account(user_id, provider) if not success: raise OAuthLinkError(f"No linked {provider} account found") diff --git a/docker/.env.bak b/docker/.env.bak deleted file mode 100644 index 24b53751b..000000000 --- a/docker/.env.bak +++ /dev/null @@ -1,168 +0,0 @@ -# ===== Necessary Configs (Necessary till now, will be migrated to frontend page) ===== - -# Voice Service Config -APPID=app_id -TOKEN=token - -# ===== Non-essential Configs (Modify if you know what you are doing) ===== - -CLUSTER=volcano_tts -VOICE_TYPE=zh_male_jieshuonansheng_mars_bigtts -SPEED_RATIO=1.3 - -# ===== Proxy Configuration (Optional) ===== - -# HTTP_PROXY=http://proxy-server:port -# HTTPS_PROXY=http://proxy-server:port -# NO_PROXY=localhost,127.0.0.1 - -# ===== Backend Configuration (No need to modify at all) ===== - -# Model Path Config -CLIP_MODEL_PATH=/opt/models/clip-vit-base-patch32 -NLTK_DATA=/opt/models/nltk_data - -# Elasticsearch Service -ELASTICSEARCH_HOST=http://nexent-elasticsearch:9200 -ELASTIC_PASSWORD=nexent@2025 - -# Elasticsearch Memory Configuration -ES_JAVA_OPTS="-Xms2g -Xmx2g" - -# Elasticsearch Disk Watermark Configuration -ES_DISK_WATERMARK_LOW=85% -ES_DISK_WATERMARK_HIGH=90% -ES_DISK_WATERMARK_FLOOD_STAGE=95% - -# Main Services -# Config service (port 5010) - Main API service for config operations -CONFIG_SERVICE_URL=http://nexent-config:5010 -ELASTICSEARCH_SERVICE=http://nexent-config:5010/api - -# Runtime service (port 5014) - Runtime execution service for agent operations -RUNTIME_SERVICE_URL=http://nexent-runtime:5014 - -# MCP service (port 5011) - MCP protocol service -NEXENT_MCP_SERVER=http://nexent-mcp:5011 -MCP_MANAGEMENT_API=http://nexent-mcp:5015 - -# Data process service (port 5012) - Data processing service -DATA_PROCESS_SERVICE=http://nexent-data-process:5012/api - -# Northbound service (port 5013) - Northbound API service -NORTHBOUND_API_SERVER=http://nexent-northbound:5013/api - -# Postgres Config -POSTGRES_HOST=nexent-postgresql -POSTGRES_USER=root -NEXENT_POSTGRES_PASSWORD=nexent@4321 -POSTGRES_DB=nexent -POSTGRES_PORT=5432 - -# Minio Config -MINIO_ENDPOINT=http://nexent-minio:9000 -MINIO_ROOT_USER=nexent -MINIO_ROOT_PASSWORD=nexent@4321 -MINIO_REGION=cn-north-1 -MINIO_DEFAULT_BUCKET=nexent - -# Redis Config -REDIS_URL=redis://redis:6379/0 -REDIS_BACKEND_URL=redis://redis:6379/1 - -# Model Engine Config -MODEL_ENGINE_ENABLED=false - -# Supabase Config -DASHBOARD_USERNAME=supabase -DASHBOARD_PASSWORD=Huawei123 - -# Supabase db Config -SUPABASE_POSTGRES_PASSWORD=Huawei123 -SUPABASE_POSTGRES_HOST=db -SUPABASE_POSTGRES_DB=supabase -SUPABASE_POSTGRES_PORT=5436 - -# Supabase Auth Config -SITE_URL=http://localhost:3011 -SUPABASE_URL=http://supabase-kong-mini:8000 -API_EXTERNAL_URL=http://supabase-kong-mini:8000 -DISABLE_SIGNUP=false -JWT_EXPIRY=3600 -DEBUG_JWT_EXPIRE_SECONDS=0 - -# Supabase Configuration -ENABLE_EMAIL_SIGNUP=true -ENABLE_EMAIL_AUTOCONFIRM=true -ENABLE_ANONYMOUS_USERS=false - -# Supabase Phone Config -ENABLE_PHONE_SIGNUP=false -ENABLE_PHONE_AUTOCONFIRM=false - -MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify" -MAILER_URLPATHS_INVITE="/auth/v1/verify" -MAILER_URLPATHS_RECOVERY="/auth/v1/verify" -MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify" - -INVITE_CODE=nexent2025 - -# Terminal Tool SSH Key Path -SSH_PRIVATE_KEY_PATH=/path/to/openssh-server/ssh-keys/openssh_server_key - -# ===== Data Processing Service Configuration ===== - -# Redis Port -REDIS_PORT=6379 - -# Flower Monitoring -FLOWER_PORT=5555 - -# Ray Configuration -RAY_ACTOR_NUM_CPUS=2 -RAY_DASHBOARD_PORT=8265 -RAY_DASHBOARD_HOST=0.0.0.0 -RAY_NUM_CPUS=4 -RAY_OBJECT_STORE_MEMORY_GB=0.25 -RAY_TEMP_DIR=/tmp/ray -RAY_LOG_LEVEL=INFO - -# Service Control Flags -DISABLE_RAY_DASHBOARD=true -DISABLE_CELERY_FLOWER=true -DOCKER_ENVIRONMENT=false -ENABLE_UPLOAD_IMAGE=false - -# Celery Configuration -CELERY_WORKER_PREFETCH_MULTIPLIER=1 -CELERY_TASK_TIME_LIMIT=3600 -ELASTICSEARCH_REQUEST_TIMEOUT=30 - -# Worker Configuration -QUEUES=process_q,forward_q -WORKER_NAME= -WORKER_CONCURRENCY=4 - -# Skills Configuration -SKILLS_PATH=/mnt/nexent/skills - -# Telemetry and Monitoring Configuration -ENABLE_TELEMETRY=false -SERVICE_NAME=nexent-backend -JAEGER_ENDPOINT=http://localhost:14268/api/traces -PROMETHEUS_PORT=8000 -TELEMETRY_SAMPLE_RATE=1.0 -LLM_SLOW_REQUEST_THRESHOLD_SECONDS=5.0 -LLM_SLOW_TOKEN_RATE_THRESHOLD=10.0 - -# Market Backend Address -MARKET_BACKEND=http://60.204.251.153:8010 -DEPLOYMENT_VERSION="speed" -# Root dir -ROOT_DIR="/c/Users/18270/nexent-data" -TERMINAL_MOUNT_DIR="/opt/terminal" -SSH_USERNAME="root" -SSH_PASSWORD="731215" -NEXENT_MCP_DOCKER_IMAGE="ccr.ccs.tencentyun.com/nexent-hub/nexent-mcp:v2.0.1" -MINIO_ACCESS_KEY="72c31cb5b521511cea652723" -MINIO_SECRET_KEY="m5gcSuKzZnp84CqmG7z5VKnd2C+H5U3PSr7eoJeygmI=" diff --git a/docker/create-su.sh b/docker/create-su.sh old mode 100644 new mode 100755 diff --git a/frontend/app/[locale]/layout.client.tsx b/frontend/app/[locale]/layout.client.tsx index 5f8c7d5fa..619596213 100644 --- a/frontend/app/[locale]/layout.client.tsx +++ b/frontend/app/[locale]/layout.client.tsx @@ -32,7 +32,9 @@ export function ClientLayout({ children }: { children: ReactNode }) { const isChatPage = pathname?.includes("/chat"); // Home page does not require authorization - const isHomePage = getEffectiveRoutePath(pathname) === "/"; + const effectivePath = getEffectiveRoutePath(pathname); + const isHomePage = effectivePath === "/"; + const isOAuthCompletePage = effectivePath === "/oauth/complete"; // Sidebar collapse state const [collapsed, setCollapsed] = useState(false); @@ -146,7 +148,7 @@ export function ClientLayout({ children }: { children: ReactNode }) { {/* Don't render children until authorization is complete (except home page) */} - {isHomePage || isAuthorized ? ( + {isHomePage || isOAuthCompletePage || isAuthorized ? ( children ) : (
diff --git a/frontend/app/[locale]/layout.tsx b/frontend/app/[locale]/layout.tsx index 71e2f32c1..d28b52422 100644 --- a/frontend/app/[locale]/layout.tsx +++ b/frontend/app/[locale]/layout.tsx @@ -1,5 +1,4 @@ import type { Metadata } from "next"; -import { Inter } from "next/font/google"; import React, { ReactNode } from "react"; import { RootProvider } from "@/components/providers/rootProvider"; import { DeploymentProvider } from "@/components/providers/deploymentProvider"; @@ -14,8 +13,6 @@ import "katex/dist/katex.min.css"; import "react-pdf/dist/Page/TextLayer.css"; import "react-pdf/dist/Page/AnnotationLayer.css"; -const inter = Inter({ subsets: ["latin"] }); - export async function generateMetadata({ params, }: { @@ -45,7 +42,7 @@ export default async function RootLayout({ return ( - + (); + const locale = params?.locale === "en" ? "en" : "zh"; + const { t } = useTranslation("common"); + const { openRegisterModal } = useAuthenticationContext(); + const [status, setStatus] = useState<"loading" | "ready" | "expired">( + "loading" + ); + + useEffect(() => { + let mounted = true; + + oauthService.getPendingOAuth().then((pending) => { + if (!mounted) return; + + if (!pending) { + setStatus("expired"); + return; + } + + openRegisterModal({ + mode: "oauth_complete", + email: pending.provider_email || "", + emailReadOnly: !pending.email_required, + }); + setStatus("ready"); + }); + + return () => { + mounted = false; + }; + }, [openRegisterModal]); + + if (status === "expired") { + return ( +
+ + + + +
+ ); + } + + if (status === "ready") { + return null; + } + + return ( +
+ +
+ ); +} diff --git a/frontend/components/auth/loginModal.tsx b/frontend/components/auth/loginModal.tsx index ba7ea9ff2..3a4b94a90 100644 --- a/frontend/components/auth/loginModal.tsx +++ b/frontend/components/auth/loginModal.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useCallback, useState, useEffect } from "react"; import { useTranslation } from "react-i18next"; import { Modal, Form, Input, Button, Typography, Space, Divider, Alert } from "antd"; import { UserRound, LockKeyhole, Github, Link2 } from "lucide-react"; @@ -73,15 +73,27 @@ export function LoginModal() { const [emailError, setEmailError] = useState(""); const [passwordError, setPasswordError] = useState(false); const [oauthError, setOauthError] = useState(null); + const { t } = useTranslation("common"); + + const getOAuthLoginErrorMessage = useCallback( + (error: string) => { + const key = `auth.oauthErrors.${error}`; + const translated = t(key); + if (translated !== key) { + return translated; + } + return t("auth.oauthLoginFailedGeneric"); + }, + [t] + ); useEffect(() => { const error = searchParams.get("oauth_error"); - const description = searchParams.get("oauth_error_description"); if (error) { - setOauthError(description || error); + setOauthError(getOAuthLoginErrorMessage(error)); router.replace("/"); } - }, [searchParams, router]); + }, [searchParams, router, getOAuthLoginErrorMessage]); const resetForm = () => { setEmailError(""); @@ -108,9 +120,6 @@ export function LoginModal() { } }; - // Internationalization hook for multi-language support - const { t } = useTranslation("common"); - /** * Handles form submission for user login * @param values - Object containing email and password diff --git a/frontend/components/auth/registerModal.tsx b/frontend/components/auth/registerModal.tsx index 860b600d5..08b1b848e 100644 --- a/frontend/components/auth/registerModal.tsx +++ b/frontend/components/auth/registerModal.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import { useTranslation } from "react-i18next"; import { usePathname, useRouter } from "next/navigation"; import { @@ -10,7 +10,6 @@ import { Button, Typography, Space, - Switch, App, Popover, } from "antd"; @@ -19,15 +18,16 @@ import { LockKeyhole, ShieldCheck, KeyRound, - BookMarked, HelpCircle, Users, } from "lucide-react"; import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; import { useDeployment } from "@/components/providers/deploymentProvider"; -import { AuthFormValues } from "@/types/auth"; +import type { AuthFormValues } from "@/types/auth"; import { getEffectiveRoutePath } from "@/lib/auth"; +import { authEventUtils } from "@/lib/authEvents"; +import { oauthService } from "@/services/oauthService"; import log from "@/lib/logger"; const { Text } = Typography; @@ -35,6 +35,7 @@ const { Text } = Typography; export function RegisterModal() { const { isRegisterModalOpen, + registerModalOptions, isAuthenticated, closeRegisterModal, openLoginModal, @@ -54,6 +55,7 @@ export function RegisterModal() { }>({ target: "", message: "" }); const { t } = useTranslation("common"); const { message } = App.useApp(); + const isOAuthCompletion = registerModalOptions?.mode === "oauth_complete"; const validateEmail = (email: string): boolean => { if (!email) return false; @@ -74,6 +76,82 @@ export function RegisterModal() { form.resetFields(); }; + const setInviteCodeError = (errorMsg: string, value?: string) => { + message.error(errorMsg); + form.setFields([ + { + name: "inviteCode", + errors: [errorMsg], + value, + }, + ]); + }; + + const setPasswordFieldError = (errorMsg: string, value?: string) => { + message.error(errorMsg); + setPasswordError({ target: "password", message: errorMsg }); + form.setFields([ + { + name: "password", + errors: [errorMsg], + value, + }, + ]); + }; + + const setEmailFieldError = (errorMsg: string, value?: string) => { + message.error(errorMsg); + setEmailError(errorMsg); + form.setFields([ + { + name: "email", + errors: [errorMsg], + value, + }, + ]); + }; + + const handleOAuthCompleteError = ( + errorKey: string, + values: AuthFormValues + ) => { + const errorMsg = t(errorKey); + + if (errorKey === "auth.inviteCodeInvalid") { + setInviteCodeError(errorMsg, values.inviteCode); + return; + } + + if (errorKey === "auth.passwordMinLength") { + setPasswordFieldError(errorMsg, values.password); + return; + } + + if ( + errorKey === "auth.invalidEmailFormat" || + errorKey === "auth.emailRequired" || + errorKey === "auth.oauthEmailAlreadyExists" + ) { + setEmailFieldError(errorMsg, values.email); + return; + } + + message.error(errorMsg); + }; + + useEffect(() => { + if (!isRegisterModalOpen) return; + + setEmailError(""); + setPasswordError({ target: "", message: "" }); + form.resetFields(); + if (registerModalOptions?.email) { + form.setFieldsValue({ email: registerModalOptions.email }); + } else if (isOAuthCompletion) { + form.setFieldsValue({ email: "" }); + } + }, [form, isOAuthCompletion, isRegisterModalOpen, registerModalOptions]); + const handleSubmit = async (values: AuthFormValues) => { setIsLoading(true); setEmailError(""); // Reset error state @@ -103,6 +181,32 @@ export function RegisterModal() { } try { + if (isOAuthCompletion) { + const result = await oauthService.completeOAuth({ + email: registerModalOptions?.emailReadOnly ? undefined : values.email, + invite_code: values.inviteCode || "", + password: values.password, + }); + + if (result.error || !result.data) { + handleOAuthCompleteError( + result.errorKey || "auth.oauthCompleteFailed", + values + ); + setIsLoading(false); + return; + } + + resetForm(); + message.success(t("auth.oauthCompleteSuccess")); + authEventUtils.emitRegisterSuccess(); + authEventUtils.emitLoginSuccess(); + + const locale = pathname.split("/").find(Boolean) || "zh"; + window.location.href = `/${locale}`; + return; + } + await register( values.email, values.password, @@ -145,6 +249,12 @@ export function RegisterModal() { const httpStatusCode = error?.code; const errorType = error?.message; + if (isOAuthCompletion) { + handleOAuthCompleteError("auth.oauthCompleteFailed", values); + setIsLoading(false); + return; + } + // HTTP 409 Conflict if (httpStatusCode === 409 || errorType === "EMAIL_ALREADY_EXISTS") { const errorMsg = t("auth.emailAlreadyExists"); @@ -263,6 +373,12 @@ export function RegisterModal() { setPasswordError({ target: "", message: "" }); closeRegisterModal(); + if (isOAuthCompletion) { + const locale = pathname.split("/").find(Boolean) || "zh"; + router.push(`/${locale}`); + return; + } + // If user manually cancels registration from a protected page, // redirect back to home instead of keeping them on the restricted page if (!isAuthenticated && !isSpeedMode) { @@ -340,7 +456,9 @@ export function RegisterModal() { - {t("auth.registerTitle")} + {isOAuthCompletion + ? t("auth.oauthCompleteTitle") + : t("auth.registerTitle")}
} open={isRegisterModalOpen} @@ -383,6 +501,7 @@ export function RegisterModal() { prefix={} placeholder="your@email.com" size="large" + disabled={isOAuthCompletion && registerModalOptions?.emailReadOnly} onChange={handleEmailInputChange} /> @@ -576,20 +695,28 @@ export function RegisterModal() { block size="large" className="mt-2" - disabled={authServiceUnavailable} + disabled={!isOAuthCompletion && authServiceUnavailable} > - {isLoading? t("auth.registering"): t("auth.register")} + {isLoading + ? isOAuthCompletion + ? t("auth.oauthCompleting") + : t("auth.registering") + : isOAuthCompletion + ? t("auth.oauthCompleteSubmit") + : t("auth.register")} -
- - {t("auth.hasAccount")} - - -
+ {!isOAuthCompletion && ( +
+ + {t("auth.hasAccount")} + + +
+ )} diff --git a/frontend/const/auth.ts b/frontend/const/auth.ts index bf78490ee..009604ea5 100644 --- a/frontend/const/auth.ts +++ b/frontend/const/auth.ts @@ -33,6 +33,7 @@ export const COOKIE_NAMES = { ACCESS_TOKEN: "nexent_access_token", REFRESH_TOKEN: "nexent_refresh_token", EXPIRES_AT: "nexent_token_expires_at", + OAUTH_PENDING: "nexent_oauth_pending", } as const; // Type-safe authentication events (used with authEvents emitter) @@ -52,4 +53,3 @@ export const AUTHZ_EVENTS = { PERMISSIONS_READY: "authz:permissions-ready", PERMISSIONS_UPDATED: "authz:permissions-updated", } as const; - diff --git a/frontend/hooks/auth/useAuthentication.ts b/frontend/hooks/auth/useAuthentication.ts index b360d613e..2146349a4 100644 --- a/frontend/hooks/auth/useAuthentication.ts +++ b/frontend/hooks/auth/useAuthentication.ts @@ -36,6 +36,7 @@ export function useAuthentication(): AuthenticationContextType { // UI state isLoginModalOpen: authUI.isLoginModalOpen, isRegisterModalOpen: authUI.isRegisterModalOpen, + registerModalOptions: authUI.registerModalOptions, isAuthPromptModalOpen: authUI.isAuthPromptModalOpen, isSessionExpiredModalOpen: authUI.isSessionExpiredModalOpen, diff --git a/frontend/hooks/auth/useAuthenticationUI.ts b/frontend/hooks/auth/useAuthenticationUI.ts index cb0cbade0..d46f1c023 100644 --- a/frontend/hooks/auth/useAuthenticationUI.ts +++ b/frontend/hooks/auth/useAuthenticationUI.ts @@ -1,15 +1,15 @@ "use client"; -import { useState, useCallback, useRef, useEffect } from "react"; +import { useState, useCallback, useEffect } from "react"; import { useRouter, usePathname, useSearchParams } from "next/navigation"; +import { App } from "antd"; import { useTranslation } from "react-i18next"; import { useDeployment } from "@/components/providers/deploymentProvider"; import { AUTH_EVENTS } from "@/const/auth"; import { getEffectiveRoutePath } from "@/lib/auth"; import { authEvents, authEventUtils } from "@/lib/authEvents"; -import { AuthenticationUIReturn } from "@/types/auth"; -import log from "@/lib/logger"; +import { AuthenticationUIReturn, RegisterModalOptions } from "@/types/auth"; /** * Custom hook for authentication UI management @@ -28,12 +28,17 @@ export function useAuthenticationUI({ const router = useRouter(); const pathname = usePathname(); const searchParams = useSearchParams(); - const { t } = useTranslation("common"); const { isSpeedMode } = useDeployment(); + const { t } = useTranslation("common"); + const { message } = App.useApp(); + const effectivePath = pathname ? getEffectiveRoutePath(pathname) : "/"; + const isOAuthCompletePage = effectivePath === "/oauth/complete"; // UI state for modals - managed locally within the hook const [isLoginModalOpen, setIsLoginModalOpen] = useState(false); const [isRegisterModalOpen, setIsRegisterModalOpen] = useState(false); + const [registerModalOptions, setRegisterModalOptions] = + useState(null); const [isAuthPromptModalOpen, setIsAuthPromptModalOpen] = useState(false); const [isSessionExpiredModalOpen, setIsSessionExpiredModalOpen] = useState(false); @@ -44,8 +49,7 @@ export function useAuthenticationUI({ // Emit event to notify SideNavigation to reset selected key authEventUtils.emitBackToHome(); // Redirect to home page if not already there - const effectivePath = pathname ? getEffectiveRoutePath(pathname) : "/"; - if (effectivePath !== "/") { + if (effectivePath !== "/" && !isOAuthCompletePage) { router.push("/"); } } @@ -59,10 +63,14 @@ export function useAuthenticationUI({ handleUnauthenticatedModalClose(); }, [handleUnauthenticatedModalClose]); - const openRegisterModal = useCallback(() => setIsRegisterModalOpen(true), []); + const openRegisterModal = useCallback((options?: RegisterModalOptions) => { + setRegisterModalOptions(options || null); + setIsRegisterModalOpen(true); + }, []); const closeRegisterModal = useCallback(() => { setIsRegisterModalOpen(false); + setRegisterModalOptions(null); handleUnauthenticatedModalClose(); }, [handleUnauthenticatedModalClose]); @@ -81,6 +89,18 @@ export function useAuthenticationUI({ handleUnauthenticatedModalClose(); }, [handleUnauthenticatedModalClose]); + const getOAuthErrorMessage = useCallback( + (error: string) => { + const key = `auth.oauthErrors.${error}`; + const translated = t(key); + if (translated !== key) { + return translated; + } + return t("auth.oauthLoginFailedGeneric"); + }, + [t] + ); + useEffect(() => { if (isSpeedMode) return; @@ -90,6 +110,7 @@ export function useAuthenticationUI({ const handleRegisterSuccess = () => { setIsRegisterModalOpen(false); + setRegisterModalOptions(null); }; // Add event listener using type-safe auth events @@ -112,10 +133,12 @@ export function useAuthenticationUI({ // Auto-open login modal when returning from a failed OAuth redirect useEffect(() => { if (isSpeedMode) return; + if (isOAuthCompletePage) return; if (isAuthChecking) return; if (isAuthenticated) { const oauthError = searchParams.get("oauth_error"); if (oauthError) { + message.error(getOAuthErrorMessage(oauthError)); router.replace("/"); } return; @@ -125,11 +148,19 @@ export function useAuthenticationUI({ if (oauthError && !isLoginModalOpen) { setIsLoginModalOpen(true); } - }, [searchParams, isAuthChecking, isAuthenticated, isSpeedMode, isLoginModalOpen, router]); + }, [searchParams, isAuthChecking, isAuthenticated, isSpeedMode, isLoginModalOpen, router, isOAuthCompletePage, message, getOAuthErrorMessage]); + + useEffect(() => { + if (!isOAuthCompletePage) return; + setIsAuthPromptModalOpen(false); + setIsLoginModalOpen(false); + setIsSessionExpiredModalOpen(false); + }, [isOAuthCompletePage]); // Route guard for unauthenticated users - check when pathname changes useEffect(() => { if (isSpeedMode) return; + if (isOAuthCompletePage) return; // Skip while checking auth state if (isAuthChecking) return; // Skip if user is authenticated @@ -139,7 +170,7 @@ export function useAuthenticationUI({ if (isLoginModalOpen) return; if (isRegisterModalOpen) return; openAuthPromptModal(); - }, [pathname, isAuthenticated, isSpeedMode, isAuthChecking, isSessionExpiredModalOpen, openAuthPromptModal]); + }, [pathname, isAuthenticated, isSpeedMode, isAuthChecking, isSessionExpiredModalOpen, openAuthPromptModal, isOAuthCompletePage]); return { @@ -148,6 +179,7 @@ export function useAuthenticationUI({ openLoginModal, closeLoginModal, isRegisterModalOpen, + registerModalOptions, openRegisterModal, closeRegisterModal, diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 22c17c2ca..87bc0faae 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -993,6 +993,21 @@ "auth.oauthDivider": "or continue with", "auth.oauthLogin": "{{provider}} Login", "auth.oauthLoginFailed": "Third-party login failed: {{error}}", + "auth.oauthLoginFailedGeneric": "Third-party login failed. Please try again.", + "auth.oauthCompleteTitle": "Complete Account Setup", + "auth.oauthCompleting": "Submitting...", + "auth.oauthCompleteSubmit": "Complete and Sign In", + "auth.oauthCompleteSuccess": "Account setup completed", + "auth.oauthCompleteFailed": "Failed to complete OAuth account setup", + "auth.oauthPendingExpired": "OAuth account setup session is invalid or expired. Please sign in again.", + "auth.oauthBackHome": "Back to home", + "auth.oauthEmailAlreadyExists": "This email is already registered. Please log in with email and password, then link OAuth in settings.", + "auth.oauthAccountAlreadyBound": "This OAuth account is already linked to another user.", + "auth.oauthErrors.access_denied": "You cancelled third-party authorization.", + "auth.oauthErrors.no_code": "No third-party authorization code was received. Please try again.", + "auth.oauthErrors.unsupported_provider": "This third-party login provider is not supported.", + "auth.oauthErrors.callback_failed": "Third-party login callback failed. Please try again later.", + "auth.oauthErrors.oauth_account_already_bound": "This OAuth account is already linked to another user.", "auth.linkedAccounts": "Linked Accounts", "auth.unlinkAccount": "Unlink", "auth.unlinkConfirm": "Are you sure you want to unlink this {{provider}} account? You will need to use another login method.", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 1cc83a802..55cc59bec 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -994,6 +994,21 @@ "auth.oauthDivider": "或使用第三方登录", "auth.oauthLogin": "{{provider}} 登录", "auth.oauthLoginFailed": "第三方登录失败:{{error}}", + "auth.oauthLoginFailedGeneric": "第三方登录失败,请重试", + "auth.oauthCompleteTitle": "补充账号信息", + "auth.oauthCompleting": "提交中...", + "auth.oauthCompleteSubmit": "完成并登录", + "auth.oauthCompleteSuccess": "账号信息已补充完成", + "auth.oauthCompleteFailed": "OAuth 补充信息提交失败", + "auth.oauthPendingExpired": "OAuth 补充信息会话已失效,请重新登录", + "auth.oauthBackHome": "返回首页", + "auth.oauthEmailAlreadyExists": "该邮箱已被注册,请先使用邮箱密码登录后在个人设置中绑定 OAuth 账号", + "auth.oauthAccountAlreadyBound": "该 OAuth 账号已绑定其他用户", + "auth.oauthErrors.access_denied": "您已取消第三方授权", + "auth.oauthErrors.no_code": "未收到第三方授权码,请重新登录", + "auth.oauthErrors.unsupported_provider": "当前第三方登录方式暂不支持", + "auth.oauthErrors.callback_failed": "第三方登录回调失败,请稍后重试", + "auth.oauthErrors.oauth_account_already_bound": "该 OAuth 账号已绑定其他用户", "auth.linkedAccounts": "已绑定的账号", "auth.unlinkAccount": "解绑", "auth.unlinkConfirm": "确定要解绑此 {{provider}} 账号吗?您将需要使用其他登录方式。", diff --git a/frontend/server.js b/frontend/server.js index b6c38c72d..8a53f2d2b 100644 --- a/frontend/server.js +++ b/frontend/server.js @@ -40,6 +40,7 @@ const COOKIE_NAMES = { ACCESS_TOKEN: "nexent_access_token", REFRESH_TOKEN: "nexent_refresh_token", EXPIRES_AT: "nexent_token_expires_at", + OAUTH_PENDING: "nexent_oauth_pending", }; const isProduction = process.env.NODE_ENV === "production"; @@ -53,6 +54,12 @@ function buildCookieOptions(httpOnly) { }; } +function appendSetCookies(res, cookies) { + const existing = res.getHeader("Set-Cookie") || []; + const existingCookies = Array.isArray(existing) ? existing : [existing]; + res.setHeader("Set-Cookie", [...existingCookies, ...cookies].filter(Boolean)); +} + function setAuthCookies(res, session) { const cookies = []; @@ -92,7 +99,7 @@ function setAuthCookies(res, session) { } if (cookies.length > 0) { - res.setHeader("Set-Cookie", cookies); + appendSetCookies(res, cookies); } } @@ -102,9 +109,34 @@ function clearAuthCookies(res) { cookie.serialize(COOKIE_NAMES.ACCESS_TOKEN, "", { ...expired, httpOnly: true }), cookie.serialize(COOKIE_NAMES.REFRESH_TOKEN, "", { ...expired, httpOnly: true }), cookie.serialize(COOKIE_NAMES.EXPIRES_AT, "", expired), + cookie.serialize(COOKIE_NAMES.OAUTH_PENDING, "", { ...expired, httpOnly: true }), + ]); +} + +function setPendingOAuthCookie(res, pendingToken) { + appendSetCookies(res, [ + cookie.serialize(COOKIE_NAMES.OAUTH_PENDING, pendingToken, { + ...buildCookieOptions(true), + maxAge: 10 * 60, + }), + ]); +} + +function clearPendingOAuthCookie(res) { + appendSetCookies(res, [ + cookie.serialize(COOKIE_NAMES.OAUTH_PENDING, "", { + maxAge: 0, + path: "/", + httpOnly: true, + }), ]); } +function getPreferredLocale(cookies) { + const locale = cookies.NEXT_LOCALE; + return locale === "en" || locale === "zh" ? locale : "zh"; +} + function parseCookies(req) { return cookie.parse(req.headers.cookie || ""); } @@ -120,6 +152,8 @@ const AUTH_INTERCEPT_ENDPOINTS = new Set([ "/api/user/revoke", "/api/user/oauth/callback", "/api/user/oauth/link", + "/api/user/oauth/pending", + "/api/user/oauth/complete", ]); function collectRequestBody(req) { @@ -163,6 +197,14 @@ function forwardAuthRequest(req, res, targetUrl) { forwardHeaders["authorization"] = `Bearer ${cookies[COOKIE_NAMES.ACCESS_TOKEN]}`; } + if ( + cookies[COOKIE_NAMES.OAUTH_PENDING] && + (req.parsedPathname === "/api/user/oauth/pending" || + req.parsedPathname === "/api/user/oauth/complete") + ) { + forwardHeaders["x-oauth-pending-token"] = cookies[COOKIE_NAMES.OAUTH_PENDING]; + } + // Update content-length if body was modified if (body.length !== rawBody.length) { forwardHeaders["content-length"] = String(body.length); @@ -193,6 +235,17 @@ function forwardAuthRequest(req, res, targetUrl) { if (isLogout || isRevoke) { clearAuthCookies(res); + } else if ( + req.parsedPathname === "/api/user/oauth/callback" && + data.data && + data.data.requires_account_completion && + data.data.pending_token + ) { + setPendingOAuthCookie(res, data.data.pending_token); + const locale = getPreferredLocale(cookies); + res.writeHead(302, { Location: `/${locale}/oauth/complete` }); + res.end(); + return; } else if (data.data && data.data.session) { const session = data.data.session; setAuthCookies(res, session); @@ -204,6 +257,10 @@ function forwardAuthRequest(req, res, targetUrl) { return; } + if (req.parsedPathname === "/api/user/oauth/complete") { + clearPendingOAuthCookie(res); + } + const sanitized = { ...data }; sanitized.data = { ...data.data }; sanitized.data.session = { diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 34d359d0c..656ec0217 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -24,6 +24,8 @@ export const API_ENDPOINTS = { providers: `${API_BASE_URL}/user/oauth/providers`, authorize: `${API_BASE_URL}/user/oauth/authorize`, link: `${API_BASE_URL}/user/oauth/link`, + pending: `${API_BASE_URL}/user/oauth/pending`, + complete: `${API_BASE_URL}/user/oauth/complete`, accounts: `${API_BASE_URL}/user/oauth/accounts`, unlink: (provider: string) => `${API_BASE_URL}/user/oauth/accounts/${provider}`, }, diff --git a/frontend/services/oauthService.ts b/frontend/services/oauthService.ts index ba9b05bed..6bfed1f6b 100644 --- a/frontend/services/oauthService.ts +++ b/frontend/services/oauthService.ts @@ -16,6 +16,68 @@ export interface OAuthAccount { linked_at: string | null; } +export interface PendingOAuthInfo { + provider: string; + provider_username: string; + provider_email: string; + email_required: boolean; +} + +export interface CompleteOAuthRequest { + email?: string; + password: string; + invite_code: string; +} + +export interface CompleteOAuthResponse { + session: { + expires_at: number; + expires_in_seconds?: number; + }; +} + +export type OAuthErrorKey = + | "auth.oauthPendingExpired" + | "auth.oauthEmailAlreadyExists" + | "auth.oauthAccountAlreadyBound" + | "auth.invalidEmailFormat" + | "auth.emailRequired" + | "auth.passwordMinLength" + | "auth.inviteCodeInvalid" + | "auth.oauthCompleteFailed"; + +function getOAuthErrorKey(errorMessage: string, status?: number): OAuthErrorKey { + const normalized = errorMessage.toLowerCase(); + + if ( + status === 401 || + normalized.includes("completion session") || + normalized.includes("pending") + ) { + return "auth.oauthPendingExpired"; + } + if (normalized.includes("email already exists")) { + return "auth.oauthEmailAlreadyExists"; + } + if (normalized.includes("already bound")) { + return "auth.oauthAccountAlreadyBound"; + } + if (normalized.includes("invalid email")) { + return "auth.invalidEmailFormat"; + } + if (normalized.includes("email is required")) { + return "auth.emailRequired"; + } + if (normalized.includes("password")) { + return "auth.passwordMinLength"; + } + if (normalized.includes("invitation") || normalized.includes("invite")) { + return "auth.inviteCodeInvalid"; + } + + return "auth.oauthCompleteFailed"; +} + export const oauthService = { getEnabledProviders: async (): Promise => { try { @@ -40,6 +102,62 @@ export const oauthService = { window.location.href = `${API_ENDPOINTS.oauth.link}?provider=${provider}`; }, + getPendingOAuth: async (): Promise => { + try { + const response = await fetch(API_ENDPOINTS.oauth.pending); + if (!response.ok) { + log.warn("Failed to fetch pending OAuth info"); + return null; + } + const data = await response.json(); + return data.data || null; + } catch (error) { + log.error("Failed to fetch pending OAuth info:", error); + return null; + } + }, + + completeOAuth: async ( + payload: CompleteOAuthRequest + ): Promise<{ + data?: CompleteOAuthResponse; + error?: string; + errorKey?: OAuthErrorKey; + }> => { + try { + const response = await fetch(API_ENDPOINTS.oauth.complete, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(payload), + }); + const data = await response.json(); + if (!response.ok) { + const error = + data.detail || data.message || "Failed to complete OAuth account"; + return { + error, + errorKey: getOAuthErrorKey(error, response.status), + }; + } + return { + data: { + session: data.data.session, + }, + }; + } catch (error) { + log.error("Failed to complete OAuth account:", error); + return { + error: + error instanceof Error + ? error.message + : "Failed to complete OAuth account", + errorKey: "auth.oauthCompleteFailed", + }; + } + }, + getLinkedAccounts: async (): Promise => { try { const response = await fetchWithAuth(API_ENDPOINTS.oauth.accounts); diff --git a/frontend/types/auth.ts b/frontend/types/auth.ts index ed07a751a..09b9aedd2 100644 --- a/frontend/types/auth.ts +++ b/frontend/types/auth.ts @@ -37,6 +37,12 @@ export interface AuthFormValues { inviteCode?: string; } +export interface RegisterModalOptions { + mode?: "register" | "oauth_complete"; + email?: string; + emailReadOnly?: boolean; +} + // Authorization context type export interface AuthContextType { user: User | null; @@ -45,11 +51,12 @@ export interface AuthContextType { isLoading: boolean; isLoginModalOpen: boolean; isRegisterModalOpen: boolean; + registerModalOptions?: RegisterModalOptions | null; authServiceUnavailable: boolean; isAuthReady: boolean; openLoginModal: () => void; closeLoginModal: () => void; - openRegisterModal: () => void; + openRegisterModal: (options?: RegisterModalOptions) => void; closeRegisterModal: () => void; login: (email: string, password: string) => Promise; register: ( @@ -118,6 +125,7 @@ export interface AuthenticationContextType { // UI state isLoginModalOpen: boolean; isRegisterModalOpen: boolean; + registerModalOptions: RegisterModalOptions | null; authServiceUnavailable: boolean; // Methods @@ -138,7 +146,7 @@ export interface AuthenticationContextType { // UI methods openLoginModal: () => void; closeLoginModal: () => void; - openRegisterModal: () => void; + openRegisterModal: (options?: RegisterModalOptions) => void; closeRegisterModal: () => void; // Auth prompt modal (for side navigation pre-check) @@ -184,7 +192,8 @@ export interface AuthenticationUIReturn { openLoginModal: () => void; closeLoginModal: () => void; isRegisterModalOpen: boolean; - openRegisterModal: () => void; + registerModalOptions: RegisterModalOptions | null; + openRegisterModal: (options?: RegisterModalOptions) => void; closeRegisterModal: () => void; // Auth prompt modal (for side navigation pre-check) diff --git a/test/backend/app/test_oauth_app.py b/test/backend/app/test_oauth_app.py index 758ab75d2..c3920e407 100644 --- a/test/backend/app/test_oauth_app.py +++ b/test/backend/app/test_oauth_app.py @@ -1,7 +1,7 @@ import sys import os import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock test_dir = os.path.dirname(__file__) backend_dir = os.path.abspath(os.path.join(test_dir, "../../../backend")) @@ -20,7 +20,18 @@ sys.modules["consts"] = consts_mock sys.modules["consts.const"] = consts_mock.const -sys.modules["consts.model"] = MagicMock() +consts_model_mock = MagicMock() + + +class _OAuthCompleteRequest: + def __init__(self, **data): + self.email = data.get("email") + self.password = data.get("password") + self.invite_code = data.get("invite_code") + + +consts_model_mock.OAuthCompleteRequest = _OAuthCompleteRequest +sys.modules["consts.model"] = consts_model_mock oauth_providers_mock = MagicMock() oauth_providers_mock.get_all_provider_definitions.return_value = { @@ -75,6 +86,9 @@ class _UnauthorizedError(Exception): oauth_service_mock.parse_state = MagicMock( return_value={"provider": "github", "token": "tok", "link_user_id": ""} ) +oauth_service_mock.generate_pending_oauth_token = MagicMock(return_value="pending.jwt") +oauth_service_mock.find_supabase_user_id_by_email = MagicMock(return_value=None) +oauth_service_mock.complete_pending_oauth_account = AsyncMock() sys.modules["services"] = MagicMock() sys.modules["services.oauth_service"] = oauth_service_mock @@ -232,6 +246,9 @@ def test_returns_500_on_unexpected_error(self): class TestCallback(unittest.TestCase): + def setUp(self): + oauth_service_mock.find_supabase_user_id_by_email.return_value = None + def test_returns_error_when_provider_error(self): response = client.get( "/user/oauth/callback?provider=github&error=access_denied&error_description=User+cancelled" @@ -258,7 +275,11 @@ def test_returns_error_for_unsupported_provider(self): def test_success_returns_session_data(self): oauth_service_mock.reset_mock() oauth_service_mock.parse_state.return_value = {"provider": "github", "token": "tok", "link_user_id": ""} - database_oauth_mock.get_oauth_account_by_provider.return_value = None + database_oauth_mock.get_oauth_account_by_provider.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "user-uuid-123", + } database_oauth_mock.get_soft_deleted_oauth_account.return_value = None oauth_service_mock.exchange_code_for_provider_token.return_value = { "access_token": "ghu_provider_token_123", @@ -269,17 +290,6 @@ def test_success_returns_session_data(self): "username": "octocat", } - mock_existing_user = MagicMock() - mock_existing_user.id = "user-uuid-123" - mock_existing_user.email = "octocat@github.com" - - mock_users_resp = MagicMock() - mock_users_resp.users = [mock_existing_user] - - mock_admin_client = MagicMock() - mock_admin_client.auth.admin.list_users.return_value = mock_users_resp - - auth_utils_mock.get_supabase_admin_client.return_value = mock_admin_client auth_utils_mock.generate_session_jwt.return_value = "eyJ.mock.jwt.token" response = client.get("/user/oauth/callback?provider=github&code=valid_code") @@ -298,7 +308,7 @@ def test_success_returns_session_data(self): auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() - def test_success_creates_new_user_when_not_found(self): + def test_new_unbound_oauth_requires_account_completion(self): oauth_service_mock.reset_mock() oauth_service_mock.parse_state.return_value = {"provider": "github", "token": "tok", "link_user_id": ""} database_oauth_mock.get_oauth_account_by_provider.return_value = None @@ -333,8 +343,67 @@ def test_success_creates_new_user_when_not_found(self): print("Response:", response.json()) self.assertEqual(response.status_code, HTTPStatus.OK) data = response.json() - self.assertEqual(data["data"]["user"]["email"], "newuser@github.com") - mock_admin_client.auth.admin.create_user.assert_called_once() + self.assertTrue(data["data"]["requires_account_completion"]) + self.assertEqual(data["data"]["pending_token"], "pending.jwt") + self.assertEqual(data["data"]["provider_email"], "newuser@github.com") + oauth_service_mock.find_supabase_user_id_by_email.assert_called_once_with( + mock_admin_client, + "newuser@github.com", + ) + mock_admin_client.auth.admin.create_user.assert_not_called() + + auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + + def test_unbound_oauth_with_existing_email_links_existing_account(self): + oauth_service_mock.reset_mock() + oauth_service_mock.parse_state.return_value = {"provider": "github", "token": "tok", "link_user_id": ""} + database_oauth_mock.get_oauth_account_by_provider.return_value = None + database_oauth_mock.get_soft_deleted_oauth_account.return_value = None + oauth_service_mock.exchange_code_for_provider_token.return_value = { + "access_token": "ghu_provider_token_existing", + } + oauth_service_mock.get_provider_user_info.return_value = { + "id": "67891", + "email": "existing@example.com", + "username": "existing-user", + } + oauth_service_mock.find_supabase_user_id_by_email.return_value = "existing-user-id" + oauth_service_mock.ensure_user_tenant_exists.return_value = { + "user_id": "existing-user-id", + "tenant_id": "t-1", + } + oauth_service_mock.create_or_update_oauth_account.return_value = { + "provider": "github", + "provider_user_id": "67891", + "user_id": "existing-user-id", + } + mock_admin_client = MagicMock() + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin_client + auth_utils_mock.generate_session_jwt.return_value = "eyJ.existing.jwt" + + response = client.get("/user/oauth/callback?provider=github&code=existing_code") + + if response.status_code != HTTPStatus.OK: + print("Response:", response.json()) + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertNotIn("requires_account_completion", data["data"]) + self.assertEqual(data["data"]["user"]["id"], "existing-user-id") + self.assertEqual(data["data"]["user"]["email"], "existing@example.com") + self.assertEqual(data["data"]["session"]["access_token"], "eyJ.existing.jwt") + + oauth_service_mock.generate_pending_oauth_token.assert_not_called() + oauth_service_mock.find_supabase_user_id_by_email.assert_called_once_with( + mock_admin_client, + "existing@example.com", + ) + oauth_service_mock.create_or_update_oauth_account.assert_called_once_with( + user_id="existing-user-id", + provider="github", + provider_user_id="67891", + email="existing@example.com", + username="existing-user", + ) auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() @@ -415,6 +484,40 @@ def test_success_with_link_user_id_binding(self): username="octocat", ) + def test_link_user_id_binding_returns_specific_error_when_already_bound(self): + oauth_service_mock.reset_mock() + database_oauth_mock.reset_mock() + oauth_service_mock.parse_state.return_value = { + "provider": "github", + "token": "tok", + "link_user_id": "existing-user-uuid", + } + oauth_service_mock.exchange_code_for_provider_token.return_value = { + "access_token": "ghu_provider_token", + } + oauth_service_mock.get_provider_user_info.return_value = { + "id": "12345", + "email": "octocat@github.com", + "username": "octocat", + } + oauth_service_mock.create_or_update_oauth_account.side_effect = _OAuthLinkError( + "This github account is already bound to another user" + ) + + response = client.get( + "/user/oauth/callback?provider=github&code=bind_code&state=github:tok:existing-user-uuid" + ) + + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + data = response.json() + self.assertEqual(data["data"]["oauth_error"], "oauth_account_already_bound") + self.assertEqual( + data["data"]["oauth_error_description"], + "OAuth account is already bound to another user", + ) + + oauth_service_mock.create_or_update_oauth_account.side_effect = None + def test_success_with_already_bound_oauth_account(self): """Callback with existing binding should use that user_id without Supabase lookup.""" oauth_service_mock.reset_mock() @@ -506,21 +609,7 @@ def test_returns_401_for_invalid_token(self, mock_get_user): class TestDeleteAccount(unittest.TestCase): def setUp(self): - mock_identity = MagicMock() - mock_identity.provider = "email" - - mock_user = MagicMock() - mock_user.identities = [mock_identity] - mock_user.app_metadata = MagicMock() - mock_user.app_metadata.get = MagicMock(return_value="email") - - mock_user_resp = MagicMock() - mock_user_resp.user = mock_user - - mock_admin = MagicMock() - mock_admin.auth.admin.get_user_by_id.return_value = mock_user_resp - auth_utils_mock.get_supabase_admin_client.return_value = mock_admin - oauth_service_mock.count_oauth_accounts_by_user_id.return_value = 2 + oauth_service_mock.unlink_account.side_effect = None def test_unlinks_successfully(self): oauth_service_mock.unlink_account.reset_mock() @@ -534,7 +623,7 @@ def test_unlinks_successfully(self): self.assertEqual(response.status_code, HTTPStatus.OK) data = response.json() self.assertTrue(data["data"]["unlinked"]) - oauth_service_mock.unlink_account.assert_called_once() + oauth_service_mock.unlink_account.assert_called_once_with("user-1", "github") def test_returns_401_without_auth(self): response = client.delete("/user/oauth/accounts/github") @@ -542,10 +631,10 @@ def test_returns_401_without_auth(self): self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) @patch("apps.oauth_app.get_current_user_id") - def test_returns_400_when_last_account(self, mock_get_user): + def test_returns_400_when_account_not_found(self, mock_get_user): mock_get_user.return_value = ("user-1", "t-1") oauth_service_mock.unlink_account.side_effect = _OAuthLinkError( - "Cannot unlink last" + "No linked github account found" ) response = client.delete( @@ -559,6 +648,9 @@ def test_returns_400_when_last_account(self, mock_get_user): class TestCallbackPagination(unittest.TestCase): + def setUp(self): + oauth_service_mock.find_supabase_user_id_by_email.return_value = None + def test_finds_user_on_second_page(self): oauth_service_mock.reset_mock() database_oauth_mock.reset_mock() @@ -582,31 +674,21 @@ def test_finds_user_on_second_page(self): "provider_user_id": "12345", "user_id": "page2-uuid", } - - mock_page1_user = MagicMock() - mock_page1_user.id = "user-page1" - mock_page1_user.email = "other@github.com" - mock_page2_user = MagicMock() - mock_page2_user.id = "page2-uuid" - mock_page2_user.email = "page2user@github.com" - - mock_page1_resp = MagicMock() - mock_page1_resp.users = [mock_page1_user] - mock_page1_resp.__len__ = lambda self: 1 - - mock_page2_resp = MagicMock() - mock_page2_resp.users = [mock_page2_user] - mock_page2_resp.__len__ = lambda self: 1 + oauth_service_mock.find_supabase_user_id_by_email.return_value = "page2-uuid" mock_admin_client = MagicMock() - mock_admin_client.auth.admin.list_users.side_effect = [mock_page1_resp, mock_page2_resp] auth_utils_mock.get_supabase_admin_client.return_value = mock_admin_client response = client.get("/user/oauth/callback?provider=github&code=page2_code&state=github:tok") self.assertEqual(response.status_code, HTTPStatus.OK) data = response.json() + self.assertEqual(data["data"]["user"]["id"], "page2-uuid") self.assertEqual(data["data"]["user"]["email"], "page2user@github.com") + oauth_service_mock.find_supabase_user_id_by_email.assert_called_once_with( + mock_admin_client, + "page2user@github.com", + ) auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() @@ -649,7 +731,13 @@ def test_stops_pagination_when_less_than_100_users(self): response = client.get("/user/oauth/callback?provider=github&code=short_page_code&state=github:tok") self.assertEqual(response.status_code, HTTPStatus.OK) - mock_admin_client.auth.admin.list_users.assert_called_once() + data = response.json() + self.assertTrue(data["data"]["requires_account_completion"]) + oauth_service_mock.find_supabase_user_id_by_email.assert_called_once_with( + mock_admin_client, + "newuser@github.com", + ) + mock_admin_client.auth.admin.create_user.assert_not_called() auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() @@ -695,58 +783,95 @@ def test_creates_user_with_oauth_fallback_email(self): self.assertEqual(response.status_code, HTTPStatus.OK) data = response.json() - self.assertIn("@oauth.nexent", data["data"]["user"]["email"]) + self.assertTrue(data["data"]["requires_account_completion"]) + self.assertTrue(data["data"]["email_required"]) + self.assertEqual(data["data"]["provider_email"], "") + oauth_service_mock.find_supabase_user_id_by_email.assert_not_called() auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() -class TestDeleteAccountMetadata(unittest.TestCase): - def test_handles_get_user_exception_gracefully(self): - oauth_service_mock.reset_mock() - oauth_service_mock.count_oauth_accounts_by_user_id.return_value = 2 - oauth_service_mock.unlink_account.return_value = True - - mock_admin = MagicMock() - mock_admin.auth.admin.get_user_by_id.side_effect = Exception("User lookup failed") - auth_utils_mock.get_supabase_admin_client.return_value = mock_admin +class TestCompleteOAuth(unittest.TestCase): + def test_pending_returns_provider_info(self): + pending_info = { + "provider": "github", + "provider_username": "octocat", + "provider_email": "", + "email_required": True, + } - response = client.delete( - "/user/oauth/accounts/github", - headers={"Authorization": "Bearer valid_token"}, - ) + with patch("apps.oauth_app.get_pending_oauth_info", return_value=pending_info): + response = client.get( + "/user/oauth/pending", + headers={"X-OAuth-Pending-Token": "pending.jwt"}, + ) self.assertEqual(response.status_code, HTTPStatus.OK) + self.assertTrue(response.json()["data"]["email_required"]) - auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() - - def test_unlinks_with_password_auth_detected(self): - oauth_service_mock.reset_mock() - oauth_service_mock.count_oauth_accounts_by_user_id.return_value = 1 - oauth_service_mock.unlink_account.return_value = True + def test_pending_returns_401_when_missing_or_invalid(self): + with patch( + "apps.oauth_app.get_pending_oauth_info", + side_effect=_OAuthLinkError("expired"), + ): + response = client.get("/user/oauth/pending") - mock_identity = MagicMock() - mock_identity.provider = "email" - - mock_user = MagicMock() - mock_user.identities = [mock_identity] - mock_user.app_metadata = MagicMock() - mock_user.app_metadata.get = MagicMock(return_value="email") + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) - mock_user_resp = MagicMock() - mock_user_resp.user = mock_user + def test_complete_returns_session_data(self): + complete_mock = AsyncMock( + return_value={ + "user": {"id": "new-user", "email": "new@example.com", "role": "USER"}, + "session": { + "access_token": "jwt", + "refresh_token": "", + "expires_at": 1735689600, + "expires_in_seconds": 3600, + }, + } + ) - mock_admin = MagicMock() - mock_admin.auth.admin.get_user_by_id.return_value = mock_user_resp - auth_utils_mock.get_supabase_admin_client.return_value = mock_admin + with patch("apps.oauth_app.complete_pending_oauth_account", new=complete_mock): + response = client.post( + "/user/oauth/complete", + headers={"X-OAuth-Pending-Token": "pending.jwt"}, + json={ + "email": "new@example.com", + "password": "secret1", + "invite_code": "ABC123", + }, + ) - response = client.delete( - "/user/oauth/accounts/github", - headers={"Authorization": "Bearer valid_token"}, + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertEqual(data["data"]["user"]["id"], "new-user") + self.assertEqual(data["data"]["session"]["expires_in_seconds"], 3600) + complete_mock.assert_awaited_once_with( + pending_token="pending.jwt", + email="new@example.com", + password="secret1", + invite_code="ABC123", ) - self.assertEqual(response.status_code, HTTPStatus.OK) + def test_complete_returns_conflict_for_existing_email(self): + complete_mock = AsyncMock( + side_effect=_OAuthLinkError( + "Email already exists. Please log in with email and password." + ) + ) - auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + with patch("apps.oauth_app.complete_pending_oauth_account", new=complete_mock): + response = client.post( + "/user/oauth/complete", + headers={"X-OAuth-Pending-Token": "pending.jwt"}, + json={ + "email": "taken@example.com", + "password": "secret1", + "invite_code": "ABC123", + }, + ) + + self.assertEqual(response.status_code, HTTPStatus.CONFLICT) class TestGetAccounts(unittest.TestCase): diff --git a/test/backend/services/test_oauth_service.py b/test/backend/services/test_oauth_service.py index c974c1b5b..55595471d 100644 --- a/test/backend/services/test_oauth_service.py +++ b/test/backend/services/test_oauth_service.py @@ -217,6 +217,7 @@ def _get_all_provider_definitions(): create_or_update_oauth_account, ensure_user_tenant_exists, exchange_code_for_provider_token, + find_supabase_user_id_by_email, get_authorize_url, get_enabled_providers, get_provider_user_info, @@ -536,6 +537,75 @@ def test_creates_tenant_when_missing(self): } +class TestFindSupabaseUserIdByEmail(unittest.TestCase): + def test_returns_none_without_email(self): + admin_client = MagicMock() + + result = find_supabase_user_id_by_email(admin_client, "") + + self.assertIsNone(result) + admin_client.auth.admin.list_users.assert_not_called() + + def test_finds_user_from_supabase_users_response(self): + existing_user = MagicMock() + existing_user.id = "existing-user-id" + existing_user.email = "Existing@Example.com" + + response = MagicMock() + response.users = [existing_user] + + admin_client = MagicMock() + admin_client.auth.admin.list_users.return_value = response + + result = find_supabase_user_id_by_email(admin_client, "existing@example.com") + + self.assertEqual(result, "existing-user-id") + admin_client.auth.admin.list_users.assert_called_once_with(page=1, per_page=100) + + def test_finds_user_on_second_page(self): + page1_users = [] + for index in range(100): + user = MagicMock() + user.id = f"user-{index}" + user.email = f"user-{index}@example.com" + page1_users.append(user) + + target_user = MagicMock() + target_user.id = "target-user-id" + target_user.email = "target@example.com" + + page1 = MagicMock() + page1.users = page1_users + page2 = MagicMock() + page2.users = [target_user] + + admin_client = MagicMock() + admin_client.auth.admin.list_users.side_effect = [page1, page2] + + result = find_supabase_user_id_by_email(admin_client, "target@example.com") + + self.assertEqual(result, "target-user-id") + self.assertEqual(admin_client.auth.admin.list_users.call_count, 2) + admin_client.auth.admin.list_users.assert_any_call(page=1, per_page=100) + admin_client.auth.admin.list_users.assert_any_call(page=2, per_page=100) + + def test_stops_when_page_has_less_than_page_size(self): + other_user = MagicMock() + other_user.id = "other-user-id" + other_user.email = "other@example.com" + + response = MagicMock() + response.users = [other_user] + + admin_client = MagicMock() + admin_client.auth.admin.list_users.return_value = response + + result = find_supabase_user_id_by_email(admin_client, "missing@example.com") + + self.assertIsNone(result) + admin_client.auth.admin.list_users.assert_called_once_with(page=1, per_page=100) + + class TestListLinkedAccounts(unittest.TestCase): def test_transforms_db_results(self): oauth_account_db_mock.list_oauth_accounts_by_user_id.return_value = [ @@ -563,30 +633,14 @@ def test_returns_empty_list(self): class TestUnlinkAccount(unittest.TestCase): - def test_success_with_multiple_accounts(self): - oauth_account_db_mock.count_oauth_accounts_by_user_id.return_value = 2 + def test_success(self): oauth_account_db_mock.delete_oauth_account.return_value = True result = unlink_account("user-1", "github") self.assertTrue(result) - def test_raises_when_last_account_no_password(self): - oauth_account_db_mock.count_oauth_accounts_by_user_id.return_value = 1 - - with self.assertRaises(_OAuthLinkError): - unlink_account("user-1", "github") - - def test_allows_last_unlink_when_has_password(self): - oauth_account_db_mock.count_oauth_accounts_by_user_id.return_value = 1 - oauth_account_db_mock.delete_oauth_account.return_value = True - - result = unlink_account("user-1", "github", has_password_auth=True) - - self.assertTrue(result) - def test_raises_when_account_not_found(self): - oauth_account_db_mock.count_oauth_accounts_by_user_id.return_value = 2 oauth_account_db_mock.delete_oauth_account.return_value = False with self.assertRaises(_OAuthLinkError): @@ -707,7 +761,7 @@ def test_fallback_email_when_no_email_found(self): with patch.dict(os.environ, env, clear=False): result = get_provider_user_info("github", "test_token") - self.assertEqual(result["email"], "testuser@nexent.com") + self.assertEqual(result["email"], "") def test_wechat_does_not_fetch_emails(self): mock_user_resp = MagicMock() @@ -841,4 +895,4 @@ def test_includes_state_token(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()