forked from baron/baron-sso
feat(headless-login): add jwks cache visibility and refresh flow
- replace inline headless jwks support with jwksUri-only validation - add cached jwks refresh worker, manual refresh/revoke endpoints, and parsed key summaries - expose allowed algorithms and key previews in DevFront with regression coverage
This commit is contained in:
@@ -87,6 +87,7 @@ type AuthHandler struct {
|
||||
SmsService domain.SmsService
|
||||
EmailService domain.EmailService
|
||||
RedisService domain.RedisRepository
|
||||
HeadlessJWKS *service.HeadlessJWKSCacheService
|
||||
KratosAdmin service.KratosAdminService
|
||||
IdpProvider domain.IdentityProvider
|
||||
AuditRepo domain.AuditRepository
|
||||
@@ -193,6 +194,7 @@ func NewAuthHandler(redisService domain.RedisRepository, idpProvider domain.Iden
|
||||
SmsService: service.NewSmsService(),
|
||||
EmailService: service.NewEmailService(),
|
||||
RedisService: redisService,
|
||||
HeadlessJWKS: service.NewHeadlessJWKSCacheService(redisService, nil),
|
||||
KratosAdmin: kratos,
|
||||
IdpProvider: idpProvider,
|
||||
AuditRepo: auditRepo,
|
||||
@@ -1740,47 +1742,15 @@ func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bo
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) {
|
||||
var raw []byte
|
||||
switch {
|
||||
case client.HeadlessJWKS() != nil:
|
||||
data, err := json.Marshal(client.HeadlessJWKS())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode jwks: %w", err)
|
||||
}
|
||||
raw = data
|
||||
case client.HeadlessJWKSURI() != "":
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.HeadlessJWKSURI(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build jwks request: %w", err)
|
||||
}
|
||||
client := &http.Client{Timeout: headlessJWKSFetchTTL}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch jwks: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read jwks response: %w", err)
|
||||
}
|
||||
raw = body
|
||||
default:
|
||||
return nil, fmt.Errorf("headless login public key is not configured")
|
||||
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, bool, error) {
|
||||
if h.HeadlessJWKS == nil {
|
||||
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.RedisService, nil)
|
||||
}
|
||||
|
||||
var keySet jose.JSONWebKeySet
|
||||
if err := json.Unmarshal(raw, &keySet); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode jwks: %w", err)
|
||||
keySet, _, refreshed, err := h.HeadlessJWKS.EnsureFreshKeySet(ctx, client, expectedKid)
|
||||
if err != nil {
|
||||
return nil, refreshed, err
|
||||
}
|
||||
if len(keySet.Keys) == 0 {
|
||||
return nil, fmt.Errorf("headless login jwks has no keys")
|
||||
}
|
||||
return &keySet, nil
|
||||
return keySet, refreshed, nil
|
||||
}
|
||||
|
||||
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
|
||||
@@ -1809,12 +1779,6 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
|
||||
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
|
||||
}
|
||||
|
||||
keySet, err := h.loadHeadlessJWKS(c.Context(), client)
|
||||
if err != nil {
|
||||
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
}
|
||||
|
||||
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
|
||||
jose.RS256, jose.RS384, jose.RS512,
|
||||
jose.PS256, jose.PS384, jose.PS512,
|
||||
@@ -1830,6 +1794,13 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
|
||||
expectedKid = strings.TrimSpace(token.Headers[0].KeyID)
|
||||
}
|
||||
|
||||
keySet, refreshed, err := h.loadHeadlessJWKS(c.Context(), client, expectedKid)
|
||||
if err != nil {
|
||||
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", headlessClientAssertionErrorMessage(err))
|
||||
}
|
||||
|
||||
matchingKidPresent := expectedKid != "" && containsHeadlessKeyID(keySet, expectedKid)
|
||||
for _, key := range keySet.Keys {
|
||||
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
|
||||
continue
|
||||
@@ -1840,12 +1811,65 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
|
||||
continue
|
||||
}
|
||||
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion claims")
|
||||
}
|
||||
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
|
||||
if matchingKidPresent && !refreshed && h.HeadlessJWKS != nil {
|
||||
refreshedKeySet, _, refreshErr := h.HeadlessJWKS.ForceRefreshKeySet(c.Context(), client, "signature_verification_failed")
|
||||
if refreshErr == nil && refreshedKeySet != nil {
|
||||
for _, key := range refreshedKeySet.Keys {
|
||||
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
|
||||
continue
|
||||
}
|
||||
|
||||
var claims headlessClientAssertionClaims
|
||||
if err := token.Claims(key.Key, &claims); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion claims")
|
||||
}
|
||||
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion signature with jwksUri")
|
||||
}
|
||||
|
||||
func headlessClientAssertionErrorMessage(err error) string {
|
||||
if err == nil {
|
||||
return "Failed to verify client assertion"
|
||||
}
|
||||
message := strings.TrimSpace(err.Error())
|
||||
switch {
|
||||
case strings.Contains(message, "requires jwksUri"):
|
||||
return "Headless login requires jwksUri. Inline jwks is not supported."
|
||||
case strings.Contains(message, "no keys"):
|
||||
return "Configured jwksUri returned no keys for headless login."
|
||||
case strings.Contains(message, "failed to fetch jwksUri"):
|
||||
return "Failed to refresh headless login jwks from jwksUri."
|
||||
case strings.Contains(message, "failed to decode jwks"):
|
||||
return "Configured jwksUri returned an invalid jwks document."
|
||||
default:
|
||||
return "Failed to verify client assertion"
|
||||
}
|
||||
}
|
||||
|
||||
func containsHeadlessKeyID(keySet *jose.JSONWebKeySet, expectedKid string) bool {
|
||||
if keySet == nil {
|
||||
return false
|
||||
}
|
||||
for _, key := range keySet.Keys {
|
||||
if strings.TrimSpace(key.KeyID) == strings.TrimSpace(expectedKid) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) {
|
||||
|
||||
Reference in New Issue
Block a user