diff --git a/Makefile b/Makefile index c36f5fd..74a7879 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,7 @@ lint-api: lint-ui: @echo "Linting UI..." + npm install --prefix $(UI_DIR) npm run --prefix $(UI_DIR) lint dev-api: diff --git a/pkg/api/api.go b/pkg/api/api.go index f7d358a..039119e 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -159,6 +159,7 @@ func (s *server) setupRouter() { r.Post("/auth/login", s.handleLogin) r.Get("/auth/github", s.handleGitHubAuth) r.Get("/auth/github/callback", s.handleGitHubCallback) + r.Post("/auth/exchange", s.handleExchangeCode) // WebSocket (authentication handled in handler). r.Get("/ws", s.handleWebSocket) @@ -1248,7 +1249,7 @@ func (s *server) handleRefreshRunners(w http.ResponseWriter, _ *http.Request) { } func (s *server) handleWebSocket(w http.ResponseWriter, r *http.Request) { - ServeWs(s.hub, s.auth, w, r) + ServeWs(s.hub, s.auth, s.cfg.Server.CORSOrigins, w, r) } // ============================================================================ @@ -1294,6 +1295,7 @@ func (s *server) handleLogin(w http.ResponseWriter, r *http.Request) { Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, + Secure: s.isSecureRequest(r), MaxAge: int(s.cfg.Auth.SessionTTL.Seconds()), }) @@ -1330,6 +1332,8 @@ func (s *server) handleLogout(w http.ResponseWriter, r *http.Request) { Value: "", Path: "/", HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: s.isSecureRequest(r), MaxAge: -1, }) @@ -1354,11 +1358,13 @@ func (s *server) handleGitHubAuth(w http.ResponseWriter, r *http.Request) { return } - // Generate state for CSRF protection. - // In production, store this in a session/cookie and validate on callback. - state := r.URL.Query().Get("state") - if state == "" { - state = "dispatchoor" + // Generate cryptographically secure state for CSRF protection. + state, err := s.auth.CreateOAuthState(r.Context()) + if err != nil { + s.log.WithError(err).Error("Failed to create OAuth state") + s.writeError(w, http.StatusInternalServerError, "Failed to initiate OAuth flow") + + return } authURL := s.auth.GetGitHubAuthURL(state) @@ -1379,7 +1385,20 @@ func (s *server) handleGitHubCallback(w http.ResponseWriter, r *http.Request) { return } - // TODO: Validate state parameter for CSRF protection. + // Validate state parameter for CSRF protection. + state := r.URL.Query().Get("state") + if state == "" { + s.writeError(w, http.StatusBadRequest, "Missing state parameter") + + return + } + + if err := s.auth.ValidateOAuthState(r.Context(), state); err != nil { + s.log.WithError(err).Warn("Invalid OAuth state") + s.writeError(w, http.StatusBadRequest, "Invalid or expired state parameter") + + return + } user, token, err := s.auth.AuthenticateGitHub(r.Context(), code) if err != nil { @@ -1389,13 +1408,14 @@ func (s *server) handleGitHubCallback(w http.ResponseWriter, r *http.Request) { return } - // Set session cookie. + // Set session cookie (works for same-origin requests). http.SetCookie(w, &http.Cookie{ Name: "session", Value: token, Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, + Secure: s.isSecureRequest(r), MaxAge: int(s.cfg.Auth.SessionTTL.Seconds()), }) @@ -1419,16 +1439,84 @@ func (s *server) handleGitHubCallback(w http.ResponseWriter, r *http.Request) { redirectURL = "/" } - // Append token to redirect URL for the UI to capture. + // Generate one-time authorization code for cross-origin token exchange. + // This is more secure than putting the session token in the URL. + authCode, err := s.auth.CreateAuthCode(r.Context(), user.ID) + if err != nil { + s.log.WithError(err).Error("Failed to create auth code") + s.writeError(w, http.StatusInternalServerError, "Failed to complete authentication") + + return + } + + // Append auth code to redirect URL for the UI to exchange for a token. if strings.Contains(redirectURL, "?") { - redirectURL += "&token=" + token + redirectURL += "&code=" + authCode } else { - redirectURL += "?token=" + token + redirectURL += "?code=" + authCode } http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) } +type exchangeCodeRequest struct { + Code string `json:"code"` +} + +func (s *server) handleExchangeCode(w http.ResponseWriter, r *http.Request) { + var req exchangeCodeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.writeError(w, http.StatusBadRequest, "Invalid request body") + + return + } + + if req.Code == "" { + s.writeError(w, http.StatusBadRequest, "Code is required") + + return + } + + user, token, err := s.auth.ExchangeAuthCode(r.Context(), req.Code) + if err != nil { + s.log.WithError(err).Warn("Code exchange failed") + s.writeError(w, http.StatusUnauthorized, "Invalid or expired code") + + return + } + + // Set session cookie. + http.SetCookie(w, &http.Cookie{ + Name: "session", + Value: token, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: s.isSecureRequest(r), + MaxAge: int(s.cfg.Auth.SessionTTL.Seconds()), + }) + + s.writeJSON(w, http.StatusOK, loginResponse{ + Token: token, + User: user, + }) +} + +// isSecureRequest checks if the request was made over HTTPS. +func (s *server) isSecureRequest(r *http.Request) bool { + // Check TLS directly. + if r.TLS != nil { + return true + } + + // Check X-Forwarded-Proto header (common with reverse proxies). + if r.Header.Get("X-Forwarded-Proto") == "https" { + return true + } + + return false +} + // SyncGroupsFromConfig synchronizes groups and job templates from configuration. func SyncGroupsFromConfig(ctx context.Context, log logrus.FieldLogger, st store.Store, cfg *config.Config) error { log.Info("Syncing groups from configuration") diff --git a/pkg/api/websocket.go b/pkg/api/websocket.go index 14b747d..d69eae3 100644 --- a/pkg/api/websocket.go +++ b/pkg/api/websocket.go @@ -27,13 +27,34 @@ const ( maxMessageSize = 512 ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - // In production, validate the origin properly. - return true - }, +// createUpgrader creates a WebSocket upgrader with origin validation. +func createUpgrader(allowedOrigins []string) websocket.Upgrader { + allowAll := len(allowedOrigins) == 1 && allowedOrigins[0] == "*" + + originSet := make(map[string]bool, len(allowedOrigins)) + for _, origin := range allowedOrigins { + originSet[origin] = true + } + + return websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // If no origins configured, reject all cross-origin requests. + if len(allowedOrigins) == 0 { + return r.Header.Get("Origin") == "" + } + + // Allow all origins if configured with "*". + if allowAll { + return true + } + + // Check if origin is in allowed list. + origin := r.Header.Get("Origin") + return originSet[origin] + }, + } } // MessageType represents the type of WebSocket message. @@ -411,7 +432,7 @@ func (c *Client) handleMessage(msg *Message) { } // ServeWs handles WebSocket requests from the peer. -func ServeWs(hub *Hub, authSvc auth.Service, w http.ResponseWriter, r *http.Request) { +func ServeWs(hub *Hub, authSvc auth.Service, allowedOrigins []string, w http.ResponseWriter, r *http.Request) { // Authenticate the user. token := r.URL.Query().Get("token") if token == "" { @@ -438,6 +459,9 @@ func ServeWs(hub *Hub, authSvc auth.Service, w http.ResponseWriter, r *http.Requ return } + // Create upgrader with origin validation. + upgrader := createUpgrader(allowedOrigins) + // Upgrade to WebSocket. conn, err := upgrader.Upgrade(w, r, nil) if err != nil { diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index bd127a3..b25dc0a 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -34,6 +34,14 @@ type Service interface { // GitHub OAuth URL. GetGitHubAuthURL(state string) string + + // OAuth State (CSRF protection). + CreateOAuthState(ctx context.Context) (string, error) + ValidateOAuthState(ctx context.Context, state string) error + + // Auth Code (one-time exchange). + CreateAuthCode(ctx context.Context, userID string) (string, error) + ExchangeAuthCode(ctx context.Context, code string) (*store.User, string, error) } // service implements Service. @@ -380,7 +388,7 @@ func (s *service) createSession(ctx context.Context, user *store.User) (string, return token, nil } -// cleanupSessions periodically removes expired sessions. +// cleanupSessions periodically removes expired sessions, OAuth states, and auth codes. func (s *service) cleanupSessions(ctx context.Context) { ticker := time.NewTicker(time.Hour) defer ticker.Stop() @@ -393,6 +401,14 @@ func (s *service) cleanupSessions(ctx context.Context) { if err := s.store.DeleteExpiredSessions(ctx); err != nil { s.log.WithError(err).Error("Failed to cleanup expired sessions") } + + if err := s.store.DeleteExpiredOAuthStates(ctx); err != nil { + s.log.WithError(err).Error("Failed to cleanup expired oauth states") + } + + if err := s.store.DeleteExpiredAuthCodes(ctx); err != nil { + s.log.WithError(err).Error("Failed to cleanup expired auth codes") + } } } } @@ -420,3 +436,120 @@ type GitHubUser struct { ID string Login string } + +const ( + oauthStateTTL = 5 * time.Minute + authCodeTTL = 30 * time.Second +) + +// CreateOAuthState generates a random state token for CSRF protection. +func (s *service) CreateOAuthState(ctx context.Context) (string, error) { + stateBytes := make([]byte, 32) + + if _, err := rand.Read(stateBytes); err != nil { + return "", fmt.Errorf("generating state: %w", err) + } + + state := base64.URLEncoding.EncodeToString(stateBytes) + now := time.Now() + + oauthState := &store.OAuthState{ + State: state, + ExpiresAt: now.Add(oauthStateTTL), + CreatedAt: now, + } + + if err := s.store.CreateOAuthState(ctx, oauthState); err != nil { + return "", fmt.Errorf("storing oauth state: %w", err) + } + + return state, nil +} + +// ValidateOAuthState validates and consumes an OAuth state token. +func (s *service) ValidateOAuthState(ctx context.Context, state string) error { + oauthState, err := s.store.GetOAuthState(ctx, state) + if err != nil { + return fmt.Errorf("getting oauth state: %w", err) + } + + if oauthState == nil { + return fmt.Errorf("invalid oauth state") + } + + // Delete the state (single use). + if err := s.store.DeleteOAuthState(ctx, state); err != nil { + s.log.WithError(err).Error("Failed to delete oauth state") + } + + if time.Now().After(oauthState.ExpiresAt) { + return fmt.Errorf("oauth state expired") + } + + return nil +} + +// CreateAuthCode generates a one-time authorization code for token exchange. +func (s *service) CreateAuthCode(ctx context.Context, userID string) (string, error) { + codeBytes := make([]byte, 32) + + if _, err := rand.Read(codeBytes); err != nil { + return "", fmt.Errorf("generating code: %w", err) + } + + code := base64.URLEncoding.EncodeToString(codeBytes) + now := time.Now() + + authCode := &store.AuthCode{ + Code: code, + UserID: userID, + ExpiresAt: now.Add(authCodeTTL), + CreatedAt: now, + } + + if err := s.store.CreateAuthCode(ctx, authCode); err != nil { + return "", fmt.Errorf("storing auth code: %w", err) + } + + return code, nil +} + +// ExchangeAuthCode exchanges a one-time authorization code for a session token. +func (s *service) ExchangeAuthCode(ctx context.Context, code string) (*store.User, string, error) { + authCode, err := s.store.GetAuthCode(ctx, code) + if err != nil { + return nil, "", fmt.Errorf("getting auth code: %w", err) + } + + if authCode == nil { + return nil, "", fmt.Errorf("invalid authorization code") + } + + // Delete the code (single use). + if err := s.store.DeleteAuthCode(ctx, code); err != nil { + s.log.WithError(err).Error("Failed to delete auth code") + } + + if time.Now().After(authCode.ExpiresAt) { + return nil, "", fmt.Errorf("authorization code expired") + } + + user, err := s.store.GetUser(ctx, authCode.UserID) + if err != nil { + return nil, "", fmt.Errorf("getting user: %w", err) + } + + if user == nil { + return nil, "", fmt.Errorf("user not found") + } + + // Create session. + token, err := s.createSession(ctx, user) + if err != nil { + return nil, "", fmt.Errorf("creating session: %w", err) + } + + s.log.WithField("username", user.Username).Info("Auth code exchanged for session") + + return user, token, nil +} diff --git a/pkg/store/postgres.go b/pkg/store/postgres.go index d15c30e..bc3087b 100644 --- a/pkg/store/postgres.go +++ b/pkg/store/postgres.go @@ -249,6 +249,21 @@ func (s *PostgresStore) Migrate(ctx context.Context) error { END $$`, // Migration: Make template_id nullable for manual jobs. `ALTER TABLE jobs ALTER COLUMN template_id DROP NOT NULL`, + // OAuth states table (CSRF protection). + `CREATE TABLE IF NOT EXISTS oauth_states ( + state TEXT PRIMARY KEY, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE INDEX IF NOT EXISTS idx_oauth_states_expires ON oauth_states(expires_at)`, + // Auth codes table (one-time exchange codes). + `CREATE TABLE IF NOT EXISTS auth_codes ( + code TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE INDEX IF NOT EXISTS idx_auth_codes_expires ON auth_codes(expires_at)`, } for _, migration := range migrations { @@ -1724,3 +1739,119 @@ func (s *PostgresStore) ListAuditEntries( return entries, total, rows.Err() } + +// ============================================================================ +// OAuth States +// ============================================================================ + +// CreateOAuthState creates a new OAuth state for CSRF protection. +func (s *PostgresStore) CreateOAuthState(ctx context.Context, state *OAuthState) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO oauth_states (state, expires_at, created_at) + VALUES ($1, $2, $3) + `, state.State, state.ExpiresAt, state.CreatedAt) + + if err != nil { + return fmt.Errorf("inserting oauth_state: %w", err) + } + + return nil +} + +// GetOAuthState retrieves an OAuth state by its value. +func (s *PostgresStore) GetOAuthState(ctx context.Context, state string) (*OAuthState, error) { + var oauthState OAuthState + + err := s.db.QueryRowContext(ctx, ` + SELECT state, expires_at, created_at + FROM oauth_states WHERE state = $1 + `, state).Scan(&oauthState.State, &oauthState.ExpiresAt, &oauthState.CreatedAt) + + if err == sql.ErrNoRows { + return nil, nil + } + + if err != nil { + return nil, fmt.Errorf("querying oauth_state: %w", err) + } + + return &oauthState, nil +} + +// DeleteOAuthState deletes an OAuth state. +func (s *PostgresStore) DeleteOAuthState(ctx context.Context, state string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE state = $1`, state) + if err != nil { + return fmt.Errorf("deleting oauth_state: %w", err) + } + + return nil +} + +// DeleteExpiredOAuthStates deletes all expired OAuth states. +func (s *PostgresStore) DeleteExpiredOAuthStates(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE expires_at < $1`, time.Now()) + if err != nil { + return fmt.Errorf("deleting expired oauth_states: %w", err) + } + + return nil +} + +// ============================================================================ +// Auth Codes +// ============================================================================ + +// CreateAuthCode creates a new one-time authorization code. +func (s *PostgresStore) CreateAuthCode(ctx context.Context, code *AuthCode) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO auth_codes (code, user_id, expires_at, created_at) + VALUES ($1, $2, $3, $4) + `, code.Code, code.UserID, code.ExpiresAt, code.CreatedAt) + + if err != nil { + return fmt.Errorf("inserting auth_code: %w", err) + } + + return nil +} + +// GetAuthCode retrieves an authorization code by its value. +func (s *PostgresStore) GetAuthCode(ctx context.Context, code string) (*AuthCode, error) { + var authCode AuthCode + + err := s.db.QueryRowContext(ctx, ` + SELECT code, user_id, expires_at, created_at + FROM auth_codes WHERE code = $1 + `, code).Scan(&authCode.Code, &authCode.UserID, &authCode.ExpiresAt, &authCode.CreatedAt) + + if err == sql.ErrNoRows { + return nil, nil + } + + if err != nil { + return nil, fmt.Errorf("querying auth_code: %w", err) + } + + return &authCode, nil +} + +// DeleteAuthCode deletes an authorization code. +func (s *PostgresStore) DeleteAuthCode(ctx context.Context, code string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM auth_codes WHERE code = $1`, code) + if err != nil { + return fmt.Errorf("deleting auth_code: %w", err) + } + + return nil +} + +// DeleteExpiredAuthCodes deletes all expired authorization codes. +func (s *PostgresStore) DeleteExpiredAuthCodes(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM auth_codes WHERE expires_at < $1`, time.Now()) + if err != nil { + return fmt.Errorf("deleting expired auth_codes: %w", err) + } + + return nil +} diff --git a/pkg/store/sqlite.go b/pkg/store/sqlite.go index c95f555..a7b16f5 100644 --- a/pkg/store/sqlite.go +++ b/pkg/store/sqlite.go @@ -186,6 +186,21 @@ func (s *SQLiteStore) Migrate(ctx context.Context) error { `ALTER TABLE jobs ADD COLUMN workflow_id TEXT`, `ALTER TABLE jobs ADD COLUMN ref TEXT`, `ALTER TABLE jobs ADD COLUMN labels TEXT`, + // OAuth states table (CSRF protection). + `CREATE TABLE IF NOT EXISTS oauth_states ( + state TEXT PRIMARY KEY, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE INDEX IF NOT EXISTS idx_oauth_states_expires ON oauth_states(expires_at)`, + // Auth codes table (one-time exchange codes). + `CREATE TABLE IF NOT EXISTS auth_codes ( + code TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE INDEX IF NOT EXISTS idx_auth_codes_expires ON auth_codes(expires_at)`, } for _, migration := range migrations { @@ -1817,3 +1832,119 @@ func (s *SQLiteStore) ListAuditEntries( return entries, total, rows.Err() } + +// ============================================================================ +// OAuth States +// ============================================================================ + +// CreateOAuthState creates a new OAuth state for CSRF protection. +func (s *SQLiteStore) CreateOAuthState(ctx context.Context, state *OAuthState) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO oauth_states (state, expires_at, created_at) + VALUES (?, ?, ?) + `, state.State, state.ExpiresAt, state.CreatedAt) + + if err != nil { + return fmt.Errorf("inserting oauth_state: %w", err) + } + + return nil +} + +// GetOAuthState retrieves an OAuth state by its value. +func (s *SQLiteStore) GetOAuthState(ctx context.Context, state string) (*OAuthState, error) { + var oauthState OAuthState + + err := s.db.QueryRowContext(ctx, ` + SELECT state, expires_at, created_at + FROM oauth_states WHERE state = ? + `, state).Scan(&oauthState.State, &oauthState.ExpiresAt, &oauthState.CreatedAt) + + if err == sql.ErrNoRows { + return nil, nil + } + + if err != nil { + return nil, fmt.Errorf("querying oauth_state: %w", err) + } + + return &oauthState, nil +} + +// DeleteOAuthState deletes an OAuth state. +func (s *SQLiteStore) DeleteOAuthState(ctx context.Context, state string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE state = ?`, state) + if err != nil { + return fmt.Errorf("deleting oauth_state: %w", err) + } + + return nil +} + +// DeleteExpiredOAuthStates deletes all expired OAuth states. +func (s *SQLiteStore) DeleteExpiredOAuthStates(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE expires_at < ?`, time.Now()) + if err != nil { + return fmt.Errorf("deleting expired oauth_states: %w", err) + } + + return nil +} + +// ============================================================================ +// Auth Codes +// ============================================================================ + +// CreateAuthCode creates a new one-time auth code for token exchange. +func (s *SQLiteStore) CreateAuthCode(ctx context.Context, code *AuthCode) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO auth_codes (code, user_id, expires_at, created_at) + VALUES (?, ?, ?, ?) + `, code.Code, code.UserID, code.ExpiresAt, code.CreatedAt) + + if err != nil { + return fmt.Errorf("inserting auth_code: %w", err) + } + + return nil +} + +// GetAuthCode retrieves an auth code by its value. +func (s *SQLiteStore) GetAuthCode(ctx context.Context, code string) (*AuthCode, error) { + var authCode AuthCode + + err := s.db.QueryRowContext(ctx, ` + SELECT code, user_id, expires_at, created_at + FROM auth_codes WHERE code = ? + `, code).Scan(&authCode.Code, &authCode.UserID, &authCode.ExpiresAt, &authCode.CreatedAt) + + if err == sql.ErrNoRows { + return nil, nil + } + + if err != nil { + return nil, fmt.Errorf("querying auth_code: %w", err) + } + + return &authCode, nil +} + +// DeleteAuthCode deletes an auth code. +func (s *SQLiteStore) DeleteAuthCode(ctx context.Context, code string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM auth_codes WHERE code = ?`, code) + if err != nil { + return fmt.Errorf("deleting auth_code: %w", err) + } + + return nil +} + +// DeleteExpiredAuthCodes deletes all expired auth codes. +func (s *SQLiteStore) DeleteExpiredAuthCodes(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM auth_codes WHERE expires_at < ?`, time.Now()) + if err != nil { + return fmt.Errorf("deleting expired auth_codes: %w", err) + } + + return nil +} diff --git a/pkg/store/store.go b/pkg/store/store.go index 52afa57..0b2e559 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -71,6 +71,18 @@ type Store interface { DeleteExpiredSessions(ctx context.Context) error DeleteUserSessions(ctx context.Context, userID string) error + // OAuth States (CSRF protection). + CreateOAuthState(ctx context.Context, state *OAuthState) error + GetOAuthState(ctx context.Context, state string) (*OAuthState, error) + DeleteOAuthState(ctx context.Context, state string) error + DeleteExpiredOAuthStates(ctx context.Context) error + + // Auth Codes (one-time exchange codes). + CreateAuthCode(ctx context.Context, code *AuthCode) error + GetAuthCode(ctx context.Context, code string) (*AuthCode, error) + DeleteAuthCode(ctx context.Context, code string) error + DeleteExpiredAuthCodes(ctx context.Context) error + // Audit. CreateAuditEntry(ctx context.Context, entry *AuditEntry) error ListAuditEntries(ctx context.Context, opts AuditQueryOpts) ([]*AuditEntry, int, error) @@ -210,6 +222,21 @@ type Session struct { CreatedAt time.Time `json:"created_at"` } +// OAuthState represents a CSRF state token for OAuth flows. +type OAuthState struct { + State string `json:"state"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` +} + +// AuthCode represents a one-time authorization code for token exchange. +type AuthCode struct { + Code string `json:"code"` + UserID string `json:"user_id"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` +} + // AuditAction represents the type of action being audited. type AuditAction string diff --git a/ui/src/App.tsx b/ui/src/App.tsx index df9b1f5..6a7adc3 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -25,21 +25,31 @@ function AppRoutes() { const { checkAuth, isLoading } = useAuthStore(); useEffect(() => { - // Check for OAuth token in URL (from GitHub OAuth callback redirect). - const params = new URLSearchParams(window.location.search); - const token = params.get('token'); + const handleOAuthCallback = async () => { + // Check for OAuth code in URL (from GitHub OAuth callback redirect). + const params = new URLSearchParams(window.location.search); + const code = params.get('code'); - if (token) { - // Store the token and remove it from URL. - api.setToken(token); - params.delete('token'); - const newUrl = params.toString() - ? `${window.location.pathname}?${params.toString()}` - : window.location.pathname; - window.history.replaceState({}, '', newUrl); - } + if (code) { + // Exchange the one-time code for a session token. + try { + await api.exchangeCode(code); + } catch (error) { + console.error('Failed to exchange auth code:', error); + } + + // Remove code from URL. + params.delete('code'); + const newUrl = params.toString() + ? `${window.location.pathname}?${params.toString()}` + : window.location.pathname; + window.history.replaceState({}, '', newUrl); + } + + checkAuth(); + }; - checkAuth(); + handleOAuthCallback(); }, [checkAuth]); if (isLoading) { diff --git a/ui/src/api/client.ts b/ui/src/api/client.ts index 9d20f6e..41f843f 100644 --- a/ui/src/api/client.ts +++ b/ui/src/api/client.ts @@ -109,6 +109,15 @@ class ApiClient { return `${this.getApiBase()}/auth/github`; } + async exchangeCode(code: string): Promise<{ token: string; user: User }> { + const result = await this.request<{ token: string; user: User }>('/auth/exchange', { + method: 'POST', + body: JSON.stringify({ code }), + }); + this.setToken(result.token); + return result; + } + // Groups async getGroups(): Promise { return this.request('/groups');