Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ lint-api:

lint-ui:
@echo "Linting UI..."
npm install --prefix $(UI_DIR)
npm run --prefix $(UI_DIR) lint

dev-api:
Expand Down
110 changes: 99 additions & 11 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ func (s *server) setupRouter() {
r.Post("/auth/login", s.handleLogin)
r.Get("/auth/github", s.handleGitHubAuth)
r.Get("/auth/github/callback", s.handleGitHubCallback)
r.Post("/auth/exchange", s.handleExchangeCode)

// WebSocket (authentication handled in handler).
r.Get("/ws", s.handleWebSocket)
Expand Down Expand Up @@ -1248,7 +1249,7 @@ func (s *server) handleRefreshRunners(w http.ResponseWriter, _ *http.Request) {
}

func (s *server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
ServeWs(s.hub, s.auth, w, r)
ServeWs(s.hub, s.auth, s.cfg.Server.CORSOrigins, w, r)
}

// ============================================================================
Expand Down Expand Up @@ -1294,6 +1295,7 @@ func (s *server) handleLogin(w http.ResponseWriter, r *http.Request) {
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: s.isSecureRequest(r),
MaxAge: int(s.cfg.Auth.SessionTTL.Seconds()),
})

Expand Down Expand Up @@ -1330,6 +1332,8 @@ func (s *server) handleLogout(w http.ResponseWriter, r *http.Request) {
Value: "",
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: s.isSecureRequest(r),
MaxAge: -1,
})

Expand All @@ -1354,11 +1358,13 @@ func (s *server) handleGitHubAuth(w http.ResponseWriter, r *http.Request) {
return
}

// Generate state for CSRF protection.
// In production, store this in a session/cookie and validate on callback.
state := r.URL.Query().Get("state")
if state == "" {
state = "dispatchoor"
// Generate cryptographically secure state for CSRF protection.
state, err := s.auth.CreateOAuthState(r.Context())
if err != nil {
s.log.WithError(err).Error("Failed to create OAuth state")
s.writeError(w, http.StatusInternalServerError, "Failed to initiate OAuth flow")

return
}

authURL := s.auth.GetGitHubAuthURL(state)
Expand All @@ -1379,7 +1385,20 @@ func (s *server) handleGitHubCallback(w http.ResponseWriter, r *http.Request) {
return
}

// TODO: Validate state parameter for CSRF protection.
// Validate state parameter for CSRF protection.
state := r.URL.Query().Get("state")
if state == "" {
s.writeError(w, http.StatusBadRequest, "Missing state parameter")

return
}

if err := s.auth.ValidateOAuthState(r.Context(), state); err != nil {
s.log.WithError(err).Warn("Invalid OAuth state")
s.writeError(w, http.StatusBadRequest, "Invalid or expired state parameter")

return
}

user, token, err := s.auth.AuthenticateGitHub(r.Context(), code)
if err != nil {
Expand All @@ -1389,13 +1408,14 @@ func (s *server) handleGitHubCallback(w http.ResponseWriter, r *http.Request) {
return
}

// Set session cookie.
// Set session cookie (works for same-origin requests).
http.SetCookie(w, &http.Cookie{
Name: "session",
Value: token,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: s.isSecureRequest(r),
MaxAge: int(s.cfg.Auth.SessionTTL.Seconds()),
})

Expand All @@ -1419,16 +1439,84 @@ func (s *server) handleGitHubCallback(w http.ResponseWriter, r *http.Request) {
redirectURL = "/"
}

// Append token to redirect URL for the UI to capture.
// Generate one-time authorization code for cross-origin token exchange.
// This is more secure than putting the session token in the URL.
authCode, err := s.auth.CreateAuthCode(r.Context(), user.ID)
if err != nil {
s.log.WithError(err).Error("Failed to create auth code")
s.writeError(w, http.StatusInternalServerError, "Failed to complete authentication")

return
}

// Append auth code to redirect URL for the UI to exchange for a token.
if strings.Contains(redirectURL, "?") {
redirectURL += "&token=" + token
redirectURL += "&code=" + authCode
} else {
redirectURL += "?token=" + token
redirectURL += "?code=" + authCode
}

http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
}

type exchangeCodeRequest struct {
Code string `json:"code"`
}

func (s *server) handleExchangeCode(w http.ResponseWriter, r *http.Request) {
var req exchangeCodeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.writeError(w, http.StatusBadRequest, "Invalid request body")

return
}

if req.Code == "" {
s.writeError(w, http.StatusBadRequest, "Code is required")

return
}

user, token, err := s.auth.ExchangeAuthCode(r.Context(), req.Code)
if err != nil {
s.log.WithError(err).Warn("Code exchange failed")
s.writeError(w, http.StatusUnauthorized, "Invalid or expired code")

return
}

// Set session cookie.
http.SetCookie(w, &http.Cookie{
Name: "session",
Value: token,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: s.isSecureRequest(r),
MaxAge: int(s.cfg.Auth.SessionTTL.Seconds()),
})

s.writeJSON(w, http.StatusOK, loginResponse{
Token: token,
User: user,
})
}

// isSecureRequest checks if the request was made over HTTPS.
func (s *server) isSecureRequest(r *http.Request) bool {
// Check TLS directly.
if r.TLS != nil {
return true
}

// Check X-Forwarded-Proto header (common with reverse proxies).
if r.Header.Get("X-Forwarded-Proto") == "https" {
return true
}

return false
}

// SyncGroupsFromConfig synchronizes groups and job templates from configuration.
func SyncGroupsFromConfig(ctx context.Context, log logrus.FieldLogger, st store.Store, cfg *config.Config) error {
log.Info("Syncing groups from configuration")
Expand Down
40 changes: 32 additions & 8 deletions pkg/api/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,34 @@ const (
maxMessageSize = 512
)

var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// In production, validate the origin properly.
return true
},
// createUpgrader creates a WebSocket upgrader with origin validation.
func createUpgrader(allowedOrigins []string) websocket.Upgrader {
allowAll := len(allowedOrigins) == 1 && allowedOrigins[0] == "*"

originSet := make(map[string]bool, len(allowedOrigins))
for _, origin := range allowedOrigins {
originSet[origin] = true
}

return websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// If no origins configured, reject all cross-origin requests.
if len(allowedOrigins) == 0 {
return r.Header.Get("Origin") == ""
}

// Allow all origins if configured with "*".
if allowAll {
return true
}

// Check if origin is in allowed list.
origin := r.Header.Get("Origin")
return originSet[origin]
},
}
}

// MessageType represents the type of WebSocket message.
Expand Down Expand Up @@ -411,7 +432,7 @@ func (c *Client) handleMessage(msg *Message) {
}

// ServeWs handles WebSocket requests from the peer.
func ServeWs(hub *Hub, authSvc auth.Service, w http.ResponseWriter, r *http.Request) {
func ServeWs(hub *Hub, authSvc auth.Service, allowedOrigins []string, w http.ResponseWriter, r *http.Request) {
// Authenticate the user.
token := r.URL.Query().Get("token")
if token == "" {
Expand All @@ -438,6 +459,9 @@ func ServeWs(hub *Hub, authSvc auth.Service, w http.ResponseWriter, r *http.Requ
return
}

// Create upgrader with origin validation.
upgrader := createUpgrader(allowedOrigins)

// Upgrade to WebSocket.
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
Expand Down
Loading
Loading