Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions auth/clerk_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/flanksource/commons/logger"
Expand Down Expand Up @@ -42,6 +43,33 @@ type ClerkHandler struct {
tokenCache *cache.Cache
accessTokenCache *cache.Cache
userCache *cache.Cache
jwks *jwksCache
}

// jwksCache lazily fetches and caches the Clerk JWKS keyfunc. The underlying
// keyfunc.Get performs a network fetch and spawns a background-refresh goroutine,
// so it must be built once and shared rather than rebuilt per request. The cache
// is held behind a pointer so copies of ClerkHandler (value receivers) share the
// same instance and lock. A failed fetch is not cached, so the next request retries.
type jwksCache struct {
url string
mu sync.Mutex
fn jwt.Keyfunc
}

func (c *jwksCache) keyfunc() (jwt.Keyfunc, error) {
c.mu.Lock()
defer c.mu.Unlock()

if c.fn == nil {
fn, err := newClerkKeyfunc(c.url)
if err != nil {
return nil, err
}
c.fn = fn
}

return c.fn, nil
}

func NewClerkHandler() (*ClerkHandler, error) {
Expand All @@ -58,19 +86,25 @@ func NewClerkHandler() (*ClerkHandler, error) {
tokenCache: cache.New(3*24*time.Hour, 12*time.Hour),
accessTokenCache: cache.New(3*24*time.Hour, 12*time.Hour),
userCache: cache.New(3*24*time.Hour, 12*time.Hour),
jwks: &jwksCache{url: ClerkJwksUrl},
}, nil
}

func (h ClerkHandler) parseJWTToken(token string) (jwt.MapClaims, error) {
keyfunc, err := h.jwks.keyfunc()
if err != nil {
return nil, err
}

claims := jwt.MapClaims{}
jt, err := jwt.ParseWithClaims(token, claims, getJWTKeyFunc(h.jwksURL))
jt, err := jwt.ParseWithClaims(token, claims, keyfunc)
if err != nil {
return claims, err
}
if !jt.Valid {
return claims, fmt.Errorf("jwt token not valid")
}
return claims, err
return claims, nil
}

type AuthResult struct {
Expand Down
11 changes: 7 additions & 4 deletions auth/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ func newPostgRESTJWT(config api.PostgrestConfig, claims jwt.MapClaims) (string,
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(config.JWTSecret))
}

func getJWTKeyFunc(jwksURL string) jwt.Keyfunc {
// newClerkKeyfunc fetches the JWKS from the given URL and returns a jwt.Keyfunc
// backed by it. keyfunc.Get performs a synchronous network fetch and spawns a
// background-refresh goroutine, so the result must be created once and reused
// rather than rebuilt per request.
func newClerkKeyfunc(jwksURL string) (jwt.Keyfunc, error) {
// Create the keyfunc options. Use an error handler that logs. Refresh the JWKS when a JWT signed by an unknown KID
// is found or at the specified interval. Rate limit these refreshes. Timeout the initial JWKS refresh request after
// 10 seconds. This timeout is also used to create the initial context.Context for keyfunc.Get.
Expand All @@ -127,10 +131,9 @@ func getJWTKeyFunc(jwksURL string) jwt.Keyfunc {
// Create the JWKS from the resource at the given URL.
jwks, err := keyfunc.Get(jwksURL, options)
if err != nil {
logger.Fatalf("Failed to create JWKS from resource at the given URL.\nError: %s", err.Error())
// TODO Handle
return nil, fmt.Errorf("failed to fetch JWKS from %q: %w", jwksURL, err)
}
return jwks.Keyfunc
return jwks.Keyfunc, nil
}

func getAccessToken(ctx context.Context, token string) (*models.AccessToken, error) {
Expand Down