1
0
forked from baron/baron-sso

feat(headless-login): add jwks cache visibility and refresh flow

- replace inline headless jwks support with jwksUri-only validation
- add cached jwks refresh worker, manual refresh/revoke endpoints, and parsed key summaries
- expose allowed algorithms and key previews in DevFront with regression coverage
This commit is contained in:
Lectom C Han
2026-04-01 18:33:22 +09:00
parent f51cdba51a
commit 9facd24a00
20 changed files with 2393 additions and 499 deletions

View File

@@ -87,6 +87,7 @@ type AuthHandler struct {
SmsService domain.SmsService
EmailService domain.EmailService
RedisService domain.RedisRepository
HeadlessJWKS *service.HeadlessJWKSCacheService
KratosAdmin service.KratosAdminService
IdpProvider domain.IdentityProvider
AuditRepo domain.AuditRepository
@@ -193,6 +194,7 @@ func NewAuthHandler(redisService domain.RedisRepository, idpProvider domain.Iden
SmsService: service.NewSmsService(),
EmailService: service.NewEmailService(),
RedisService: redisService,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(redisService, nil),
KratosAdmin: kratos,
IdpProvider: idpProvider,
AuditRepo: auditRepo,
@@ -1740,47 +1742,15 @@ func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bo
return false
}
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) {
var raw []byte
switch {
case client.HeadlessJWKS() != nil:
data, err := json.Marshal(client.HeadlessJWKS())
if err != nil {
return nil, fmt.Errorf("failed to encode jwks: %w", err)
}
raw = data
case client.HeadlessJWKSURI() != "":
req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.HeadlessJWKSURI(), nil)
if err != nil {
return nil, fmt.Errorf("failed to build jwks request: %w", err)
}
client := &http.Client{Timeout: headlessJWKSFetchTTL}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch jwks: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return nil, fmt.Errorf("failed to read jwks response: %w", err)
}
raw = body
default:
return nil, fmt.Errorf("headless login public key is not configured")
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, bool, error) {
if h.HeadlessJWKS == nil {
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.RedisService, nil)
}
var keySet jose.JSONWebKeySet
if err := json.Unmarshal(raw, &keySet); err != nil {
return nil, fmt.Errorf("failed to decode jwks: %w", err)
keySet, _, refreshed, err := h.HeadlessJWKS.EnsureFreshKeySet(ctx, client, expectedKid)
if err != nil {
return nil, refreshed, err
}
if len(keySet.Keys) == 0 {
return nil, fmt.Errorf("headless login jwks has no keys")
}
return &keySet, nil
return keySet, refreshed, nil
}
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
@@ -1809,12 +1779,6 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
}
keySet, err := h.loadHeadlessJWKS(c.Context(), client)
if err != nil {
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
jose.RS256, jose.RS384, jose.RS512,
jose.PS256, jose.PS384, jose.PS512,
@@ -1830,6 +1794,13 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
expectedKid = strings.TrimSpace(token.Headers[0].KeyID)
}
keySet, refreshed, err := h.loadHeadlessJWKS(c.Context(), client, expectedKid)
if err != nil {
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", headlessClientAssertionErrorMessage(err))
}
matchingKidPresent := expectedKid != "" && containsHeadlessKeyID(keySet, expectedKid)
for _, key := range keySet.Keys {
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
continue
@@ -1840,12 +1811,65 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
continue
}
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion claims")
}
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
return nil
}
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
if matchingKidPresent && !refreshed && h.HeadlessJWKS != nil {
refreshedKeySet, _, refreshErr := h.HeadlessJWKS.ForceRefreshKeySet(c.Context(), client, "signature_verification_failed")
if refreshErr == nil && refreshedKeySet != nil {
for _, key := range refreshedKeySet.Keys {
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
continue
}
var claims headlessClientAssertionClaims
if err := token.Claims(key.Key, &claims); err != nil {
continue
}
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion claims")
}
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
return nil
}
}
}
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion signature with jwksUri")
}
func headlessClientAssertionErrorMessage(err error) string {
if err == nil {
return "Failed to verify client assertion"
}
message := strings.TrimSpace(err.Error())
switch {
case strings.Contains(message, "requires jwksUri"):
return "Headless login requires jwksUri. Inline jwks is not supported."
case strings.Contains(message, "no keys"):
return "Configured jwksUri returned no keys for headless login."
case strings.Contains(message, "failed to fetch jwksUri"):
return "Failed to refresh headless login jwks from jwksUri."
case strings.Contains(message, "failed to decode jwks"):
return "Configured jwksUri returned an invalid jwks document."
default:
return "Failed to verify client assertion"
}
}
func containsHeadlessKeyID(keySet *jose.JSONWebKeySet, expectedKid string) bool {
if keySet == nil {
return false
}
for _, key := range keySet.Keys {
if strings.TrimSpace(key.KeyID) == strings.TrimSpace(expectedKid) {
return true
}
}
return false
}
func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) {