1
0
forked from baron/baron-sso

feat(headless-login): add jwks cache visibility and refresh flow

- replace inline headless jwks support with jwksUri-only validation
- add cached jwks refresh worker, manual refresh/revoke endpoints, and parsed key summaries
- expose allowed algorithms and key previews in DevFront with regression coverage
This commit is contained in:
Lectom C Han
2026-04-01 18:33:22 +09:00
parent f51cdba51a
commit 9facd24a00
20 changed files with 2393 additions and 499 deletions

View File

@@ -0,0 +1,29 @@
package domain
import "time"
type HeadlessJWKSParsedKey struct {
Kid string `json:"kid,omitempty"`
Kty string `json:"kty,omitempty"`
Use string `json:"use,omitempty"`
Alg string `json:"alg,omitempty"`
NPreview string `json:"nPreview,omitempty"`
}
// HeadlessJWKSCacheState는 headless login용 JWKS 캐시 상태와 최근 동기화 결과를 나타냅니다.
type HeadlessJWKSCacheState struct {
ClientID string `json:"clientId"`
JWKSURI string `json:"jwksUri"`
CachedAt *time.Time `json:"cachedAt,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
LastError string `json:"lastError,omitempty"`
ConsecutiveFailures int `json:"consecutiveFailures,omitempty"`
CachedKids []string `json:"cachedKids,omitempty"`
ParsedKeys []HeadlessJWKSParsedKey `json:"parsedKeys,omitempty"`
ETag string `json:"etag,omitempty"`
LastModified string `json:"lastModified,omitempty"`
RawJWKS string `json:"-"`
}

View File

@@ -28,9 +28,8 @@ type HydraClient struct {
}
func (c *HydraClient) SupportsHeadlessLogin() bool {
// A headless login client must have a public key registered (URI or Inline)
// and use private_key_jwt for token endpoint authentication.
hasPublicKey := c.HeadlessJWKSURI() != "" || c.HeadlessJWKS() != nil
// Headless login now supports jwksUri only.
hasPublicKey := c.HeadlessJWKSURI() != ""
isPrivateKeyJwt := c.HeadlessTokenEndpointAuthMethod() == "private_key_jwt"
return hasPublicKey && isPrivateKeyJwt
}

View File

@@ -9,11 +9,7 @@ func TestHydraClient_HeadlessLoginFlags(t *testing.T) {
Metadata: map[string]any{
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks": map[string]any{
"keys": []map[string]any{{
"kty": "RSA",
}},
},
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
}
@@ -25,7 +21,7 @@ func TestHydraClient_HeadlessLoginFlags(t *testing.T) {
}
})
t.Run("inline jwks with private_key_jwt and headless enabled", func(t *testing.T) {
t.Run("inline jwks without jwks uri does not support headless login", func(t *testing.T) {
client := HydraClient{
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: map[string]any{
@@ -38,11 +34,11 @@ func TestHydraClient_HeadlessLoginFlags(t *testing.T) {
},
}
if !client.SupportsHeadlessLogin() {
t.Fatalf("expected headless login client")
if client.SupportsHeadlessLogin() {
t.Fatalf("expected headless login prerequisites to be missing")
}
if !client.IsHeadlessLoginEnabled() {
t.Fatalf("expected headless login enabled")
if client.IsHeadlessLoginEnabled() {
t.Fatalf("expected headless login disabled without jwks uri")
}
})

View File

@@ -87,6 +87,7 @@ type AuthHandler struct {
SmsService domain.SmsService
EmailService domain.EmailService
RedisService domain.RedisRepository
HeadlessJWKS *service.HeadlessJWKSCacheService
KratosAdmin service.KratosAdminService
IdpProvider domain.IdentityProvider
AuditRepo domain.AuditRepository
@@ -193,6 +194,7 @@ func NewAuthHandler(redisService domain.RedisRepository, idpProvider domain.Iden
SmsService: service.NewSmsService(),
EmailService: service.NewEmailService(),
RedisService: redisService,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(redisService, nil),
KratosAdmin: kratos,
IdpProvider: idpProvider,
AuditRepo: auditRepo,
@@ -1740,47 +1742,15 @@ func containsHeadlessAudience(expected []string, actual headlessAssertionAud) bo
return false
}
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient) (*jose.JSONWebKeySet, error) {
var raw []byte
switch {
case client.HeadlessJWKS() != nil:
data, err := json.Marshal(client.HeadlessJWKS())
if err != nil {
return nil, fmt.Errorf("failed to encode jwks: %w", err)
}
raw = data
case client.HeadlessJWKSURI() != "":
req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.HeadlessJWKSURI(), nil)
if err != nil {
return nil, fmt.Errorf("failed to build jwks request: %w", err)
}
client := &http.Client{Timeout: headlessJWKSFetchTTL}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch jwks: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("failed to fetch jwks status=%d body=%s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return nil, fmt.Errorf("failed to read jwks response: %w", err)
}
raw = body
default:
return nil, fmt.Errorf("headless login public key is not configured")
func (h *AuthHandler) loadHeadlessJWKS(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, bool, error) {
if h.HeadlessJWKS == nil {
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.RedisService, nil)
}
var keySet jose.JSONWebKeySet
if err := json.Unmarshal(raw, &keySet); err != nil {
return nil, fmt.Errorf("failed to decode jwks: %w", err)
keySet, _, refreshed, err := h.HeadlessJWKS.EnsureFreshKeySet(ctx, client, expectedKid)
if err != nil {
return nil, refreshed, err
}
if len(keySet.Keys) == 0 {
return nil, fmt.Errorf("headless login jwks has no keys")
}
return &keySet, nil
return keySet, refreshed, nil
}
func validateHeadlessClientAssertionClaims(c *fiber.Ctx, claims headlessClientAssertionClaims, clientID string) error {
@@ -1809,12 +1779,6 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
return errorJSONCode(c, fiber.StatusBadRequest, "bad_request", "client_assertion is required")
}
keySet, err := h.loadHeadlessJWKS(c.Context(), client)
if err != nil {
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
}
token, err := josejwt.ParseSigned(assertion, []jose.SignatureAlgorithm{
jose.RS256, jose.RS384, jose.RS512,
jose.PS256, jose.PS384, jose.PS512,
@@ -1830,6 +1794,13 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
expectedKid = strings.TrimSpace(token.Headers[0].KeyID)
}
keySet, refreshed, err := h.loadHeadlessJWKS(c.Context(), client, expectedKid)
if err != nil {
slog.Error("failed to load jwks for headless client assertion", "clientID", clientID, "error", err)
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", headlessClientAssertionErrorMessage(err))
}
matchingKidPresent := expectedKid != "" && containsHeadlessKeyID(keySet, expectedKid)
for _, key := range keySet.Keys {
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
continue
@@ -1840,12 +1811,65 @@ func (h *AuthHandler) verifyHeadlessClientAssertion(c *fiber.Ctx, client domain.
continue
}
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion claims")
}
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
return nil
}
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion")
if matchingKidPresent && !refreshed && h.HeadlessJWKS != nil {
refreshedKeySet, _, refreshErr := h.HeadlessJWKS.ForceRefreshKeySet(c.Context(), client, "signature_verification_failed")
if refreshErr == nil && refreshedKeySet != nil {
for _, key := range refreshedKeySet.Keys {
if expectedKid != "" && key.KeyID != "" && key.KeyID != expectedKid {
continue
}
var claims headlessClientAssertionClaims
if err := token.Claims(key.Key, &claims); err != nil {
continue
}
if err := validateHeadlessClientAssertionClaims(c, claims, clientID); err != nil {
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion claims")
}
_ = h.HeadlessJWKS.MarkVerificationSuccess(clientID)
return nil
}
}
}
return errorJSONCode(c, fiber.StatusUnauthorized, "invalid_client_assertion", "Failed to verify client assertion signature with jwksUri")
}
func headlessClientAssertionErrorMessage(err error) string {
if err == nil {
return "Failed to verify client assertion"
}
message := strings.TrimSpace(err.Error())
switch {
case strings.Contains(message, "requires jwksUri"):
return "Headless login requires jwksUri. Inline jwks is not supported."
case strings.Contains(message, "no keys"):
return "Configured jwksUri returned no keys for headless login."
case strings.Contains(message, "failed to fetch jwksUri"):
return "Failed to refresh headless login jwks from jwksUri."
case strings.Contains(message, "failed to decode jwks"):
return "Configured jwksUri returned an invalid jwks document."
default:
return "Failed to verify client assertion"
}
}
func containsHeadlessKeyID(keySet *jose.JSONWebKeySet, expectedKid string) bool {
if keySet == nil {
return false
}
for _, key := range keySet.Keys {
if strings.TrimSpace(key.KeyID) == strings.TrimSpace(expectedKid) {
return true
}
}
return false
}
func (h *AuthHandler) storeHeadlessLinkState(pendingRef string, state headlessLinkState, ttl time.Duration) {

View File

@@ -175,6 +175,12 @@ func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
func TestHeadlessLinkInit_HeadlessLoginClientSuccess(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
defer jwksServer.Close()
idp := &mockIdpProvider{
userExists: true,
@@ -192,7 +198,7 @@ func TestHeadlessLinkInit_HeadlessLoginClientSuccess(t *testing.T) {
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks": jwks,
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
},
},
})
@@ -236,6 +242,12 @@ func TestHeadlessLinkInit_HeadlessLoginClientSuccess(t *testing.T) {
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
defer jwksServer.Close()
idp := &mockIdpProvider{
userExists: true,
@@ -254,7 +266,7 @@ func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks": jwks,
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
},
},
})

View File

@@ -10,6 +10,9 @@ import (
"baron-sso-backend/internal/service"
"bytes"
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/json"
@@ -199,6 +202,183 @@ func mustHeadlessClientAssertion(t *testing.T, privateKey *rsa.PrivateKey, clien
return raw
}
func mustHeadlessJWKForAlgorithm(t *testing.T, alg jose.SignatureAlgorithm) (any, map[string]any) {
t.Helper()
var privateKey any
var publicKey any
switch alg {
case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512:
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
privateKey = key
publicKey = &key.PublicKey
case jose.ES256:
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate ecdsa key: %v", err)
}
privateKey = key
publicKey = &key.PublicKey
case jose.ES384:
key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate ecdsa key: %v", err)
}
privateKey = key
publicKey = &key.PublicKey
case jose.ES512:
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate ecdsa key: %v", err)
}
privateKey = key
publicKey = &key.PublicKey
case jose.EdDSA:
_, key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("failed to generate ed25519 key: %v", err)
}
privateKey = key
publicKey = key.Public()
default:
t.Fatalf("unsupported test algorithm: %s", alg)
}
keySet := jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: publicKey,
KeyID: "test-kid",
Use: "sig",
Algorithm: string(alg),
},
},
}
raw, err := json.Marshal(keySet)
if err != nil {
t.Fatalf("failed to marshal jwks: %v", err)
}
var jwks map[string]any
if err := json.Unmarshal(raw, &jwks); err != nil {
t.Fatalf("failed to decode jwks map: %v", err)
}
return privateKey, jwks
}
func mustHeadlessClientAssertionWithAlgorithm(t *testing.T, privateKey any, alg jose.SignatureAlgorithm, clientID, audience string) string {
t.Helper()
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: alg,
Key: jose.JSONWebKey{
Key: privateKey,
KeyID: "test-kid",
Use: "sig",
Algorithm: string(alg),
},
}, nil)
if err != nil {
t.Fatalf("failed to create signer: %v", err)
}
now := time.Now()
raw, err := josejwt.Signed(signer).Claims(josejwt.Claims{
Issuer: clientID,
Subject: clientID,
Audience: josejwt.Audience{audience},
Expiry: josejwt.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: josejwt.NewNumericDate(now),
NotBefore: josejwt.NewNumericDate(now.Add(-1 * time.Minute)),
ID: "assertion-1",
}).Serialize()
if err != nil {
t.Fatalf("failed to sign client assertion: %v", err)
}
return raw
}
func runHeadlessPasswordLoginWithAssertion(t *testing.T, jwks map[string]any, clientAssertion string) *http.Response {
t.Helper()
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
jwksBody, err := json.Marshal(jwks)
if err != nil {
t.Fatalf("failed to marshal jwks body: %v", err)
}
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
t.Cleanup(jwksServer.Close)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "none",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": clientAssertion,
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
return resp
}
// mockHydraTransport simulates Hydra API responses
func mockHydraTransport(handler http.Handler) http.RoundTripper {
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
@@ -375,6 +555,206 @@ func TestHeadlessPasswordLogin_HeadlessLoginClientSuccess(t *testing.T) {
}
}
func TestHeadlessPasswordLogin_IgnoresInlineHeadlessJWKSWhenJWKSURIIsConfigured(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
defer jwksServer.Close()
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "none",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
"headless_jwks": map[string]any{
"keys": []map[string]any{},
},
},
},
})
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
default:
http.NotFound(w, r)
}
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
clientAssertion := mustHeadlessClientAssertion(
t,
privateKey,
"headless-login-client",
"http://example.com/api/v1/auth/headless/password/login",
)
body, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": clientAssertion,
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestHeadlessPasswordLogin_RefreshesJWKSWhenSignatureFailsForCachedKid(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt"},
Subject: "kratos-identity-id",
}, nil)
stalePrivateKey, staleJWKS := mustHeadlessRSAJWK(t)
staleRaw, err := json.Marshal(staleJWKS)
if err != nil {
t.Fatalf("failed to marshal stale jwks: %v", err)
}
freshPrivateKey, freshJWKS := mustHeadlessRSAJWK(t)
freshRaw, err := json.Marshal(freshJWKS)
if err != nil {
t.Fatalf("failed to marshal fresh jwks: %v", err)
}
fetchCount := 0
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fetchCount++
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(freshRaw)
}))
defer jwksServer.Close()
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "none",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
},
},
})
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
default:
http.NotFound(w, r)
}
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "employee001").Return("kratos-identity-id", nil)
redisRepo := &testRedisRepo{values: map[string]string{}}
cacheService := service.NewHeadlessJWKSCacheService(redisRepo, jwksServer.Client())
now := time.Now()
expiresAt := now.Add(30 * time.Minute)
if err := cacheService.SaveState("headless-login-client", domain.HeadlessJWKSCacheState{
ClientID: "headless-login-client",
JWKSURI: jwksServer.URL + "/.well-known/jwks.json",
RawJWKS: string(staleRaw),
CachedKids: []string{"test-kid"},
CachedAt: &now,
LastCheckedAt: &now,
ExpiresAt: &expiresAt,
LastRefreshStatus: "success",
}); err != nil {
t.Fatalf("failed to save cached jwks state: %v", err)
}
h := &AuthHandler{
IdpProvider: mockIdp,
KratosAdmin: mockKratos,
RedisService: redisRepo,
HeadlessJWKS: cacheService,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessPasswordLoginTestApp(h)
clientAssertion := mustHeadlessClientAssertion(
t,
freshPrivateKey,
"headless-login-client",
"http://example.com/api/v1/auth/headless/password/login",
)
body, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": clientAssertion,
"loginId": "employee001",
"password": "password",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/password/login", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d, body: %s", resp.StatusCode, string(bodyBytes))
}
if fetchCount != 1 {
t.Fatalf("expected exactly one jwks refresh fetch, got %d", fetchCount)
}
if stalePrivateKey == nil {
t.Fatalf("expected stale key to be generated")
}
}
func TestHeadlessPasswordLogin_MissingClientAssertionRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)
mockIdp.On("SignIn", "employee001", "password").Return(&domain.AuthInfo{
@@ -391,13 +771,12 @@ func TestHeadlessPasswordLogin_MissingClientAssertionRejected(t *testing.T) {
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: map[string]any{
"keys": []map[string]any{},
},
TokenEndpointAuthMethod: "none",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
},
})
@@ -453,6 +832,12 @@ func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) {
validKey, jwks := mustHeadlessRSAJWK(t)
invalidKey, _ := mustHeadlessRSAJWK(t)
_ = validKey
jwksBody, _ := json.Marshal(jwks)
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
defer jwksServer.Close()
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
@@ -461,11 +846,12 @@ func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) {
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: jwks,
TokenEndpointAuthMethod: "none",
Metadata: map[string]interface{}{
"status": "active",
"headless_login_enabled": true,
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
},
},
})
@@ -516,6 +902,42 @@ func TestHeadlessPasswordLogin_InvalidClientAssertionRejected(t *testing.T) {
}
}
func TestHeadlessPasswordLogin_AcceptsConfiguredClientAssertionAlgorithms(t *testing.T) {
algorithms := []jose.SignatureAlgorithm{
jose.RS256,
jose.RS384,
jose.RS512,
jose.PS256,
jose.PS384,
jose.PS512,
jose.ES256,
jose.ES384,
jose.ES512,
jose.EdDSA,
}
for _, algorithm := range algorithms {
t.Run(string(algorithm), func(t *testing.T) {
privateKey, jwks := mustHeadlessJWKForAlgorithm(t, algorithm)
clientAssertion := mustHeadlessClientAssertionWithAlgorithm(
t,
privateKey,
algorithm,
"headless-login-client",
"http://example.com/api/v1/auth/headless/password/login",
)
resp := runHeadlessPasswordLoginWithAssertion(t, jwks, clientAssertion)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200 for %s, got %d, body: %s", algorithm, resp.StatusCode, string(bodyBytes))
}
})
}
}
func TestHeadlessPasswordLogin_HeadlessDisabledRejected(t *testing.T) {
mockIdp := new(MockIdentityProvider)

View File

@@ -22,16 +22,17 @@ import (
)
type DevHandler struct {
Hydra *service.HydraAdminService
Redis domain.RedisRepository
SecretRepo domain.ClientSecretRepository
AuditRepo domain.AuditRepository
KratosAdmin service.KratosAdminService
ConsentRepo repository.ClientConsentRepository
Keto service.KetoService
RPSvc service.RelyingPartyService
TenantSvc service.TenantService
Auth interface {
Hydra *service.HydraAdminService
Redis domain.RedisRepository
HeadlessJWKS *service.HeadlessJWKSCacheService
SecretRepo domain.ClientSecretRepository
AuditRepo domain.AuditRepository
KratosAdmin service.KratosAdminService
ConsentRepo repository.ClientConsentRepository
Keto service.KetoService
RPSvc service.RelyingPartyService
TenantSvc service.TenantService
Auth interface {
GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error)
}
}
@@ -54,16 +55,17 @@ func NewDevHandler(
}
return &DevHandler{
Hydra: service.NewHydraAdminService(),
Redis: redis,
SecretRepo: secretRepo,
AuditRepo: nil,
KratosAdmin: service.NewKratosAdminService(),
ConsentRepo: consentRepo,
Keto: keto,
RPSvc: rpSvc,
TenantSvc: tenantSvc,
Auth: authProvider,
Hydra: service.NewHydraAdminService(),
Redis: redis,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(redis, nil),
SecretRepo: secretRepo,
AuditRepo: nil,
KratosAdmin: service.NewKratosAdminService(),
ConsentRepo: consentRepo,
Keto: keto,
RPSvc: rpSvc,
TenantSvc: tenantSvc,
Auth: authProvider,
}
}
@@ -102,8 +104,9 @@ type clientListResponse struct {
}
type clientDetailResponse struct {
Client clientSummary `json:"client"`
Endpoints clientEndpoints `json:"endpoints"`
Client clientSummary `json:"client"`
Endpoints clientEndpoints `json:"endpoints"`
HeadlessJWKSCache *domain.HeadlessJWKSCacheState `json:"headlessJwksCache,omitempty"`
}
type clientEndpoints struct {
@@ -697,8 +700,11 @@ func (h *DevHandler) GetClient(c *fiber.Ctx) error {
}
}
cacheState, _ := h.publicHeadlessJWKSCacheState(summary.ID)
return c.JSON(clientDetailResponse{
Client: summary,
Client: summary,
HeadlessJWKSCache: cacheState,
Endpoints: clientEndpoints{
Discovery: strings.TrimRight(h.Hydra.PublicURL, "/") + "/.well-known/openid-configuration",
Issuer: h.Hydra.PublicURL,
@@ -709,6 +715,32 @@ func (h *DevHandler) GetClient(c *fiber.Ctx) error {
})
}
func (h *DevHandler) publicHeadlessJWKSCacheState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
if h.HeadlessJWKS == nil {
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.Redis, nil)
}
if h.HeadlessJWKS == nil {
return nil, nil
}
return h.HeadlessJWKS.PublicState(clientID)
}
func (h *DevHandler) syncHeadlessJWKSCache(ctx context.Context, client domain.HydraClient, reason string) {
if h.HeadlessJWKS == nil {
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.Redis, nil)
}
if h.HeadlessJWKS == nil {
return
}
if !client.IsHeadlessLoginEnabled() {
_ = h.HeadlessJWKS.DeleteState(client.ClientID)
return
}
if _, err := h.HeadlessJWKS.ForceRefresh(ctx, client, reason); err != nil {
slog.Warn("failed to refresh headless jwks cache after client save", "clientID", client.ClientID, "reason", reason, "error", err)
}
}
func (h *DevHandler) UpdateClientStatus(c *fiber.Ctx) error {
tenantID := h.injectTenantContextFromHeader(c)
clientID := c.Params("id")
@@ -790,8 +822,10 @@ func (h *DevHandler) UpdateClientStatus(c *fiber.Ctx) error {
}
updatedSummary := h.mapClientSummary(*updated)
cacheState, _ := h.publicHeadlessJWKSCacheState(updatedSummary.ID)
return c.JSON(clientDetailResponse{
Client: updatedSummary,
Client: updatedSummary,
HeadlessJWKSCache: cacheState,
Endpoints: clientEndpoints{
Discovery: strings.TrimRight(h.Hydra.PublicURL, "/") + "/.well-known/openid-configuration",
Issuer: h.Hydra.PublicURL,
@@ -863,6 +897,9 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
if status != "active" && status != "inactive" {
return errorJSON(c, fiber.StatusBadRequest, "status must be active or inactive")
}
if requestIncludesInlineHeadlessJWKS(req) {
return errorJSON(c, fiber.StatusBadRequest, "headless login supports jwksUri only; inline jwks is not supported")
}
metadata := mergeMetadata(nil, req.Metadata)
if metadata == nil {
@@ -891,6 +928,9 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
tokenAuthMethod = "client_secret_basic"
}
}
if err := validateHeadlessClientInput(clientType, valueOr(req.JwksUri, ""), req.Jwks, metadata); err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
tokenAuthMethod, jwksURI, jwks, metadata := normalizeHeadlessClientConfig(
clientType,
tokenAuthMethod,
@@ -928,6 +968,7 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
h.syncHeadlessJWKSCache(c.Context(), *created, "client_create")
// Store secret in metadata for later retrieval
if created.ClientSecret != "" {
@@ -945,8 +986,10 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
h.setAuditDetailsExtra(c, map[string]any{"target_id": created.ClientID})
summary := h.mapClientSummary(*created)
cacheState, _ := h.publicHeadlessJWKSCacheState(summary.ID)
return c.Status(fiber.StatusCreated).JSON(clientDetailResponse{
Client: summary,
Client: summary,
HeadlessJWKSCache: cacheState,
Endpoints: clientEndpoints{
Discovery: strings.TrimRight(h.Hydra.PublicURL, "/") + "/.well-known/openid-configuration",
Issuer: h.Hydra.PublicURL,
@@ -1043,6 +1086,9 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error {
if req.RedirectURIs != nil && len(*req.RedirectURIs) == 0 {
return errorJSON(c, fiber.StatusBadRequest, "redirectUris cannot be empty")
}
if requestIncludesInlineHeadlessJWKS(req) {
return errorJSON(c, fiber.StatusBadRequest, "headless login supports jwksUri only; inline jwks is not supported")
}
metadata := mergeMetadata(current.Metadata, req.Metadata)
if status != "" {
@@ -1061,6 +1107,9 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error {
if req.Jwks == nil {
resolvedJWKS = current.JWKS
}
if err := validateHeadlessClientInput(resolvedClientType, resolvedJWKSURI, resolvedJWKS, metadata); err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
resolvedTokenAuthMethod, resolvedJWKSURI, resolvedJWKS, metadata = normalizeHeadlessClientConfig(
resolvedClientType,
resolvedTokenAuthMethod,
@@ -1105,6 +1154,7 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error {
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
h.syncHeadlessJWKSCache(c.Context(), *updatedClient, "client_update")
if updatedClient.ClientSecret != "" {
if h.SecretRepo != nil {
@@ -1116,8 +1166,10 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error {
}
summary := h.mapClientSummary(*updatedClient)
cacheState, _ := h.publicHeadlessJWKSCacheState(summary.ID)
return c.JSON(clientDetailResponse{
Client: summary,
Client: summary,
HeadlessJWKSCache: cacheState,
Endpoints: clientEndpoints{
Discovery: strings.TrimRight(h.Hydra.PublicURL, "/") + "/.well-known/openid-configuration",
Issuer: h.Hydra.PublicURL,
@@ -1451,9 +1503,11 @@ func (h *DevHandler) RotateClientSecret(c *fiber.Ctx) error {
// Return the new secret
updatedSummary := h.mapClientSummary(*updated)
updatedSummary.ClientSecret = newSecret
cacheState, _ := h.publicHeadlessJWKSCacheState(updatedSummary.ID)
return c.JSON(clientDetailResponse{
Client: updatedSummary,
Client: updatedSummary,
HeadlessJWKSCache: cacheState,
Endpoints: clientEndpoints{
Discovery: strings.TrimRight(h.Hydra.PublicURL, "/") + "/.well-known/openid-configuration",
Issuer: h.Hydra.PublicURL,
@@ -1464,6 +1518,134 @@ func (h *DevHandler) RotateClientSecret(c *fiber.Ctx) error {
})
}
func (h *DevHandler) RefreshHeadlessJWKSCache(c *fiber.Ctx) error {
tenantID := h.injectTenantContextFromHeader(c)
clientID := strings.TrimSpace(c.Params("id"))
if clientID == "" {
return errorJSON(c, fiber.StatusBadRequest, "client id is required")
}
current, err := h.Hydra.GetClient(c.Context(), clientID)
if err != nil {
if errors.Is(err, service.ErrHydraNotFound) {
return errorJSON(c, fiber.StatusNotFound, "client not found")
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
if isHiddenSystemClient(*current) {
return errorJSON(c, fiber.StatusForbidden, "forbidden: protected system client")
}
summary := h.mapClientSummary(*current)
profile := h.getCurrentProfile(c)
if profile == nil {
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: authentication required")
}
role := normalizeUserRole(profile.Role)
if !isDevConsoleRoleAllowed(role) {
return errorJSON(c, fiber.StatusForbidden, "forbidden")
}
isSuperAdmin := role == domain.RoleSuperAdmin
userTenantID := tenantIDFromProfile(profile)
if !isSuperAdmin {
clientTenantID := resolveClientTenantID(summary)
if clientTenantID != userTenantID {
return errorJSON(c, fiber.StatusForbidden, "forbidden: access denied to client in another tenant")
}
}
if !isRPAdminClientAllowed(profile, summary.ID) {
return errorJSON(c, fiber.StatusForbidden, "forbidden: rp_admin scope does not include this client")
}
if !current.IsHeadlessLoginEnabled() {
return errorJSON(c, fiber.StatusBadRequest, "headless login is not enabled for this client")
}
if h.HeadlessJWKS == nil {
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.Redis, nil)
}
if h.HeadlessJWKS == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "headless jwks cache service is unavailable")
}
if _, err := h.HeadlessJWKS.ForceRefresh(c.Context(), *current, "manual_refresh"); err != nil {
return errorJSON(c, fiber.StatusBadRequest, headlessClientAssertionErrorMessage(err))
}
h.setAuditDetailsExtra(c, map[string]any{
"action": "REFRESH_HEADLESS_JWKS_CACHE",
"target_id": clientID,
"tenant_id": tenantID,
})
cacheState, _ := h.publicHeadlessJWKSCacheState(clientID)
return c.JSON(clientDetailResponse{
Client: summary,
HeadlessJWKSCache: cacheState,
Endpoints: clientEndpoints{
Discovery: strings.TrimRight(h.Hydra.PublicURL, "/") + "/.well-known/openid-configuration",
Issuer: h.Hydra.PublicURL,
Authorization: strings.TrimRight(h.Hydra.PublicURL, "/") + "/oauth2/auth",
Token: strings.TrimRight(h.Hydra.PublicURL, "/") + "/oauth2/token",
UserInfo: strings.TrimRight(h.Hydra.PublicURL, "/") + "/userinfo",
},
})
}
func (h *DevHandler) RevokeHeadlessJWKSCache(c *fiber.Ctx) error {
tenantID := h.injectTenantContextFromHeader(c)
clientID := strings.TrimSpace(c.Params("id"))
if clientID == "" {
return errorJSON(c, fiber.StatusBadRequest, "client id is required")
}
current, err := h.Hydra.GetClient(c.Context(), clientID)
if err != nil {
if errors.Is(err, service.ErrHydraNotFound) {
return errorJSON(c, fiber.StatusNotFound, "client not found")
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
if isHiddenSystemClient(*current) {
return errorJSON(c, fiber.StatusForbidden, "forbidden: protected system client")
}
summary := h.mapClientSummary(*current)
profile := h.getCurrentProfile(c)
if profile == nil {
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: authentication required")
}
role := normalizeUserRole(profile.Role)
if !isDevConsoleRoleAllowed(role) {
return errorJSON(c, fiber.StatusForbidden, "forbidden")
}
isSuperAdmin := role == domain.RoleSuperAdmin
userTenantID := tenantIDFromProfile(profile)
if !isSuperAdmin {
clientTenantID := resolveClientTenantID(summary)
if clientTenantID != userTenantID {
return errorJSON(c, fiber.StatusForbidden, "forbidden: access denied to client in another tenant")
}
}
if !isRPAdminClientAllowed(profile, summary.ID) {
return errorJSON(c, fiber.StatusForbidden, "forbidden: rp_admin scope does not include this client")
}
if h.HeadlessJWKS == nil {
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.Redis, nil)
}
if h.HeadlessJWKS != nil {
_ = h.HeadlessJWKS.DeleteState(clientID)
}
h.setAuditDetailsExtra(c, map[string]any{
"action": "REVOKE_HEADLESS_JWKS_CACHE",
"target_id": clientID,
"tenant_id": tenantID,
})
return c.SendStatus(fiber.StatusNoContent)
}
func (h *DevHandler) ListAuditLogs(c *fiber.Ctx) error {
if h.AuditRepo == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
@@ -1739,9 +1921,9 @@ func normalizeHeadlessClientConfig(
}
metadata[domain.MetadataHeadlessTokenEndpointAuthMethod] = headlessTokenAuthMethod
headlessJWKSURI := readMetadataStringValue(metadata, domain.MetadataHeadlessJWKSURI)
if headlessJWKSURI == "" && strings.TrimSpace(jwksURI) != "" {
headlessJWKSURI = strings.TrimSpace(jwksURI)
headlessJWKSURI := strings.TrimSpace(jwksURI)
if headlessJWKSURI == "" {
headlessJWKSURI = readMetadataStringValue(metadata, domain.MetadataHeadlessJWKSURI)
}
if headlessJWKSURI != "" {
metadata[domain.MetadataHeadlessJWKSURI] = headlessJWKSURI
@@ -1749,12 +1931,7 @@ func normalizeHeadlessClientConfig(
delete(metadata, domain.MetadataHeadlessJWKSURI)
}
if _, ok := metadata[domain.MetadataHeadlessJWKS]; !ok && jwks != nil {
metadata[domain.MetadataHeadlessJWKS] = jwks
}
if metadata[domain.MetadataHeadlessJWKS] == nil {
delete(metadata, domain.MetadataHeadlessJWKS)
}
delete(metadata, domain.MetadataHeadlessJWKS)
return "none", "", nil, metadata
}
@@ -1765,6 +1942,36 @@ func normalizeHeadlessClientConfig(
return tokenAuthMethod, jwksURI, jwks, metadata
}
func validateHeadlessClientInput(clientType string, jwksURI string, jwks interface{}, metadata map[string]interface{}) error {
if clientType != "pkce" || !readMetadataBoolValue(metadata, domain.MetadataHeadlessLoginEnabled) {
return nil
}
if jwks != nil {
return fmt.Errorf("headless login supports jwksUri only; inline jwks is not supported")
}
resolvedURI := strings.TrimSpace(jwksURI)
if resolvedURI == "" {
resolvedURI = readMetadataStringValue(metadata, domain.MetadataHeadlessJWKSURI)
}
if resolvedURI == "" {
return fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
}
return nil
}
func requestIncludesInlineHeadlessJWKS(req clientUpsertRequest) bool {
if req.Jwks != nil {
return true
}
if req.Metadata == nil {
return false
}
value, ok := (*req.Metadata)[domain.MetadataHeadlessJWKS]
return ok && value != nil
}
func defaultClientScopes() []string {
return []string{"openid", "profile", "email"}
}

View File

@@ -10,6 +10,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -89,6 +90,32 @@ func (m *devEnhancedMockAuditRepo) CountActiveSessionsSince(ctx context.Context,
return m.countSessions, nil
}
func devTestJWKSFirstKeyString(t *testing.T, jwks map[string]any, field string) string {
t.Helper()
keys, ok := jwks["keys"].([]any)
if !ok || len(keys) == 0 {
t.Fatalf("expected jwks keys")
}
key, ok := keys[0].(map[string]any)
if !ok {
t.Fatalf("expected jwks key object")
}
value, ok := key[field].(string)
if !ok {
t.Fatalf("expected jwks field %s", field)
}
return value
}
func devTestPreviewValue(value string) string {
value = strings.TrimSpace(value)
if len(value) <= 24 {
return value
}
return value[:12] + "..." + value[len(value)-12:]
}
// --- Tests ---
func TestListClients_Success(t *testing.T) {
@@ -652,6 +679,64 @@ func TestCreateClient_HeadlessLoginPayloadMapping(t *testing.T) {
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"name": "Headless Login App",
"type": "pkce",
"redirectUris": []string{"https://rp.example.com/callback"},
"scopes": []string{"openid", "profile"},
"tokenEndpointAuthMethod": "private_key_jwt",
"jwksUri": "https://rp.example.com/.well-known/jwks.json",
"metadata": map[string]any{
"headless_login_enabled": true,
"request_object_signing_alg": "RS256",
},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
assert.Equal(t, "none", captured.TokenEndpointAuthMethod)
assert.Nil(t, captured.JWKS)
assert.Equal(t, "private_key_jwt", captured.Metadata["headless_token_endpoint_auth_method"])
assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", captured.Metadata["headless_jwks_uri"])
assert.True(t, captured.IsHeadlessLoginEnabled())
assert.Equal(t, true, captured.Metadata["headless_login_enabled"])
assert.Equal(t, "RS256", captured.Metadata["request_object_signing_alg"])
}
func TestCreateClient_HeadlessLoginRejectsInlineJWKS(t *testing.T) {
var hydraCalled bool
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
PublicURL: "http://hydra.public",
HTTPClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
hydraCalled = true
return httpJSONAny(r, http.StatusCreated, map[string]any{
"client_id": "client-headless-login",
"client_name": "Headless Login App",
"redirect_uris": []string{"https://rp.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "none",
"metadata": map[string]any{
"headless_login_enabled": true,
},
}), nil
})},
},
Keto: new(devMockKetoService),
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "test-user", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"name": "Headless Login App",
"type": "pkce",
@@ -675,14 +760,12 @@ func TestCreateClient_HeadlessLoginPayloadMapping(t *testing.T) {
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
assert.Equal(t, "none", captured.TokenEndpointAuthMethod)
assert.Nil(t, captured.JWKS)
assert.Equal(t, "private_key_jwt", captured.Metadata["headless_token_endpoint_auth_method"])
assert.NotNil(t, captured.Metadata["headless_jwks"])
assert.True(t, captured.IsHeadlessLoginEnabled())
assert.Equal(t, true, captured.Metadata["headless_login_enabled"])
assert.Equal(t, "RS256", captured.Metadata["request_object_signing_alg"])
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
defer resp.Body.Close()
bodyBytes, _ := io.ReadAll(resp.Body)
assert.Contains(t, string(bodyBytes), "headless login supports jwksUri only")
assert.False(t, hydraCalled)
}
func TestUpdateClient_HeadlessLoginPayloadMapping(t *testing.T) {
@@ -699,7 +782,10 @@ func TestUpdateClient_HeadlessLoginPayloadMapping(t *testing.T) {
"scope": "openid profile",
"token_endpoint_auth_method": "none",
"metadata": map[string]any{
"status": "active",
"status": "active",
"headless_jwks": map[string]any{"keys": []map[string]any{}},
"headless_jwks_uri": "https://stale.example.com/old.json",
"headless_login_enabled": true,
},
}), nil
}
@@ -759,10 +845,128 @@ func TestUpdateClient_HeadlessLoginPayloadMapping(t *testing.T) {
assert.Equal(t, "", captured.JWKSUri)
assert.Equal(t, "private_key_jwt", captured.Metadata["headless_token_endpoint_auth_method"])
assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", captured.Metadata["headless_jwks_uri"])
_, hasInlineJWKS := captured.Metadata["headless_jwks"]
assert.False(t, hasInlineJWKS)
assert.True(t, captured.IsHeadlessLoginEnabled())
assert.Equal(t, true, captured.Metadata["headless_login_enabled"])
}
func TestRefreshHeadlessJWKSCache_ReturnsUpdatedCacheState(t *testing.T) {
privateKey, jwks := mustHeadlessRSAJWK(t)
_ = privateKey
jwksBody, _ := json.Marshal(jwks)
expectedNPreview := devTestPreviewValue(devTestJWKSFirstKeyString(t, jwks, "n"))
redisRepo := &devMockRedisRepo{data: map[string]string{}}
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
PublicURL: "http://hydra.public",
HTTPClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-headless-login" {
return httpJSONAny(r, http.StatusOK, domain.HydraClient{
ClientID: "client-headless-login",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})},
},
Redis: redisRepo,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(redisRepo, &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", r.URL.String())
var payload map[string]any
_ = json.Unmarshal(jwksBody, &payload)
return httpJSONAny(r, http.StatusOK, payload), nil
})}),
Keto: new(devMockKetoService),
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "test-user", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/dev/clients/:id/headless-jwks/refresh", h.RefreshHeadlessJWKSCache)
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients/client-headless-login/headless-jwks/refresh", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got clientDetailResponse
err := json.NewDecoder(resp.Body).Decode(&got)
assert.NoError(t, err)
if assert.NotNil(t, got.HeadlessJWKSCache) {
assert.Equal(t, "success", got.HeadlessJWKSCache.LastRefreshStatus)
assert.Equal(t, []string{"test-kid"}, got.HeadlessJWKSCache.CachedKids)
if assert.Len(t, got.HeadlessJWKSCache.ParsedKeys, 1) {
assert.Equal(t, "test-kid", got.HeadlessJWKSCache.ParsedKeys[0].Kid)
assert.Equal(t, "RSA", got.HeadlessJWKSCache.ParsedKeys[0].Kty)
assert.Equal(t, "sig", got.HeadlessJWKSCache.ParsedKeys[0].Use)
assert.Equal(t, "RS256", got.HeadlessJWKSCache.ParsedKeys[0].Alg)
assert.Equal(t, expectedNPreview, got.HeadlessJWKSCache.ParsedKeys[0].NPreview)
}
}
}
func TestRevokeHeadlessJWKSCache_DeletesCachedState(t *testing.T) {
redisRepo := &devMockRedisRepo{data: map[string]string{}}
cacheService := service.NewHeadlessJWKSCacheService(redisRepo, nil)
now := time.Now()
expiresAt := now.Add(30 * time.Minute)
err := cacheService.SaveState("client-headless-login", domain.HeadlessJWKSCacheState{
ClientID: "client-headless-login",
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
CachedAt: &now,
ExpiresAt: &expiresAt,
LastRefreshStatus: "success",
ConsecutiveFailures: 0,
RawJWKS: `{"keys":[{"kid":"cached-key","kty":"RSA","n":"AQIDBAUGBw","e":"AQAB"}]}`,
})
assert.NoError(t, err)
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
PublicURL: "http://hydra.public",
HTTPClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-headless-login" {
return httpJSONAny(r, http.StatusOK, domain.HydraClient{
ClientID: "client-headless-login",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})},
},
Redis: redisRepo,
HeadlessJWKS: cacheService,
Keto: new(devMockKetoService),
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "test-user", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Delete("/api/v1/dev/clients/:id/headless-jwks/cache", h.RevokeHeadlessJWKSCache)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/dev/clients/client-headless-login/headless-jwks/cache", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
stored, err := cacheService.GetState("client-headless-login")
assert.Error(t, err)
assert.Nil(t, stored)
}
func TestListAuditLogs_TenantMemberForbidden(t *testing.T) {
h := &DevHandler{
Hydra: &service.HydraAdminService{AdminURL: "http://hydra.test"},

View File

@@ -0,0 +1,505 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"time"
"github.com/go-jose/go-jose/v4"
)
const (
headlessJWKSCacheKeyPrefix = "headless_jwks_cache:"
)
type HeadlessJWKSCacheService struct {
Redis domain.RedisRepository
HTTPClient *http.Client
TTL time.Duration
PrefetchWindow time.Duration
RequestTimeout time.Duration
}
type headlessJWKSCacheStateStore struct {
ClientID string `json:"clientId"`
JWKSURI string `json:"jwksUri"`
CachedAt *time.Time `json:"cachedAt,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
LastError string `json:"lastError,omitempty"`
ConsecutiveFailures int `json:"consecutiveFailures,omitempty"`
CachedKids []string `json:"cachedKids,omitempty"`
ETag string `json:"etag,omitempty"`
LastModified string `json:"lastModified,omitempty"`
RawJWKS string `json:"rawJwks,omitempty"`
}
type HeadlessJWKSCacheWorker struct {
Hydra *HydraAdminService
Cache *HeadlessJWKSCacheService
Interval time.Duration
PageSize int
}
func NewHeadlessJWKSCacheService(redis domain.RedisRepository, httpClient *http.Client) *HeadlessJWKSCacheService {
ttlSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_CACHE_TTL_SECONDS", "1800")))
if ttlSeconds <= 0 {
ttlSeconds = 1800
}
prefetchSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_PREFETCH_WINDOW_SECONDS", "600")))
if prefetchSeconds <= 0 {
prefetchSeconds = 600
}
timeoutSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FETCH_TIMEOUT_SECONDS", "5")))
if timeoutSeconds <= 0 {
timeoutSeconds = 5
}
return &HeadlessJWKSCacheService{
Redis: redis,
HTTPClient: httpClient,
TTL: time.Duration(ttlSeconds) * time.Second,
PrefetchWindow: time.Duration(prefetchSeconds) * time.Second,
RequestTimeout: time.Duration(timeoutSeconds) * time.Second,
}
}
func NewHeadlessJWKSCacheWorker(hydra *HydraAdminService, cache *HeadlessJWKSCacheService) *HeadlessJWKSCacheWorker {
intervalSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_REFRESH_INTERVAL_SECONDS", "600")))
if intervalSeconds <= 0 {
intervalSeconds = 600
}
return &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cache,
Interval: time.Duration(intervalSeconds) * time.Second,
PageSize: 100,
}
}
func (s *HeadlessJWKSCacheService) httpClient() *http.Client {
if s.HTTPClient != nil {
return s.HTTPClient
}
timeout := s.RequestTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
return &http.Client{Timeout: timeout}
}
func (s *HeadlessJWKSCacheService) cacheKey(clientID string) string {
return headlessJWKSCacheKeyPrefix + strings.TrimSpace(clientID)
}
func (s *HeadlessJWKSCacheService) SaveState(clientID string, state domain.HeadlessJWKSCacheState) error {
if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" {
return nil
}
payload, err := json.Marshal(headlessJWKSCacheStateStore{
ClientID: state.ClientID,
JWKSURI: state.JWKSURI,
CachedAt: state.CachedAt,
ExpiresAt: state.ExpiresAt,
LastCheckedAt: state.LastCheckedAt,
LastSuccessfulVerificationAt: state.LastSuccessfulVerificationAt,
LastRefreshStatus: state.LastRefreshStatus,
LastError: state.LastError,
ConsecutiveFailures: state.ConsecutiveFailures,
CachedKids: state.CachedKids,
ETag: state.ETag,
LastModified: state.LastModified,
RawJWKS: state.RawJWKS,
})
if err != nil {
return err
}
return s.Redis.Set(s.cacheKey(clientID), string(payload), 0)
}
func (s *HeadlessJWKSCacheService) GetState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" {
return nil, nil
}
raw, err := s.Redis.Get(s.cacheKey(clientID))
if err != nil || strings.TrimSpace(raw) == "" {
return nil, err
}
var stored headlessJWKSCacheStateStore
if err := json.Unmarshal([]byte(raw), &stored); err != nil {
return nil, err
}
return &domain.HeadlessJWKSCacheState{
ClientID: stored.ClientID,
JWKSURI: stored.JWKSURI,
CachedAt: stored.CachedAt,
ExpiresAt: stored.ExpiresAt,
LastCheckedAt: stored.LastCheckedAt,
LastSuccessfulVerificationAt: stored.LastSuccessfulVerificationAt,
LastRefreshStatus: stored.LastRefreshStatus,
LastError: stored.LastError,
ConsecutiveFailures: stored.ConsecutiveFailures,
CachedKids: stored.CachedKids,
ETag: stored.ETag,
LastModified: stored.LastModified,
RawJWKS: stored.RawJWKS,
}, nil
}
func (s *HeadlessJWKSCacheService) DeleteState(clientID string) error {
if s == nil || s.Redis == nil {
return nil
}
return s.Redis.Delete(s.cacheKey(clientID))
}
func (s *HeadlessJWKSCacheService) PublicState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
state, err := s.GetState(clientID)
if err != nil || state == nil {
return state, err
}
state.ParsedKeys = summarizeHeadlessJWKS(state.RawJWKS)
state.RawJWKS = ""
return state, nil
}
func (s *HeadlessJWKSCacheService) MarkVerificationSuccess(clientID string) error {
state, err := s.GetState(clientID)
if err != nil || state == nil {
return err
}
now := time.Now()
state.LastSuccessfulVerificationAt = &now
return s.SaveState(clientID, *state)
}
func (s *HeadlessJWKSCacheService) ShouldPrefetch(state *domain.HeadlessJWKSCacheState, now time.Time) bool {
if state == nil {
return true
}
if strings.TrimSpace(state.RawJWKS) == "" {
return true
}
if state.ExpiresAt == nil {
return true
}
return !state.ExpiresAt.After(now.Add(s.PrefetchWindow))
}
func (s *HeadlessJWKSCacheService) EnsureFreshKeySet(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
if s == nil {
return nil, nil, false, fmt.Errorf("headless jwks cache service is not configured")
}
jwksURI := strings.TrimSpace(client.HeadlessJWKSURI())
if jwksURI == "" {
return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
}
state, err := s.GetState(client.ClientID)
if err != nil {
slog.Warn("failed to load headless jwks cache state", "clientID", client.ClientID, "error", err)
}
now := time.Now()
switch {
case state == nil:
return s.refreshClient(ctx, client, nil, "cache_miss")
case strings.TrimSpace(state.JWKSURI) != jwksURI:
return s.refreshClient(ctx, client, state, "config_changed")
case strings.TrimSpace(state.RawJWKS) == "":
return s.refreshClient(ctx, client, state, "cache_empty")
case state.ExpiresAt == nil || !state.ExpiresAt.After(now):
return s.refreshClient(ctx, client, state, "ttl_expired")
case expectedKid != "" && !containsString(state.CachedKids, expectedKid):
return s.refreshClient(ctx, client, state, "kid_missing")
default:
keySet, err := decodeHeadlessJWKS(state.RawJWKS)
if err != nil {
return s.refreshClient(ctx, client, state, "cache_corrupt")
}
return keySet, state, false, nil
}
}
func (s *HeadlessJWKSCacheService) ForceRefresh(ctx context.Context, client domain.HydraClient, reason string) (*domain.HeadlessJWKSCacheState, error) {
_, state, err := s.ForceRefreshKeySet(ctx, client, reason)
return state, err
}
func (s *HeadlessJWKSCacheService) ForceRefreshKeySet(ctx context.Context, client domain.HydraClient, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, error) {
previous, err := s.GetState(client.ClientID)
if err != nil {
slog.Warn("failed to load headless jwks cache state before force refresh", "clientID", client.ClientID, "error", err)
}
keySet, state, _, err := s.refreshClient(ctx, client, previous, reason)
return keySet, state, err
}
func (s *HeadlessJWKSCacheService) refreshClient(ctx context.Context, client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
jwksURI := strings.TrimSpace(client.HeadlessJWKSURI())
if jwksURI == "" {
return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURI, nil)
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to build jwks request: %w", err)), false, err
}
if previous != nil {
if etag := strings.TrimSpace(previous.ETag); etag != "" {
req.Header.Set("If-None-Match", etag)
}
if lastModified := strings.TrimSpace(previous.LastModified); lastModified != "" {
req.Header.Set("If-Modified-Since", lastModified)
}
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to fetch jwksUri: %w", err)), false, err
}
defer resp.Body.Close()
now := time.Now()
if resp.StatusCode == http.StatusNotModified && previous != nil && strings.TrimSpace(previous.RawJWKS) != "" {
updated := *previous
updated.JWKSURI = jwksURI
updated.LastCheckedAt = &now
updated.ExpiresAt = ptrTime(now.Add(s.TTL))
updated.LastRefreshStatus = "success"
updated.LastError = ""
updated.ConsecutiveFailures = 0
_ = s.SaveState(client.ClientID, updated)
keySet, decodeErr := decodeHeadlessJWKS(updated.RawJWKS)
return keySet, &updated, true, decodeErr
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
err = fmt.Errorf("failed to fetch jwksUri status=%d body=%s", resp.StatusCode, string(body))
return nil, s.persistRefreshFailure(client, previous, err), false, err
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to read jwks response: %w", err)), false, err
}
keySet, err := decodeHeadlessJWKS(string(body))
if err != nil {
return nil, s.persistRefreshFailure(client, previous, err), false, err
}
state := domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: jwksURI,
CachedAt: &now,
ExpiresAt: ptrTime(now.Add(s.TTL)),
LastCheckedAt: &now,
LastSuccessfulVerificationAt: previousLastVerification(previous),
LastRefreshStatus: "success",
LastError: "",
ConsecutiveFailures: 0,
CachedKids: extractHeadlessKids(keySet),
ETag: strings.TrimSpace(resp.Header.Get("ETag")),
LastModified: strings.TrimSpace(resp.Header.Get("Last-Modified")),
RawJWKS: string(body),
}
if err := s.SaveState(client.ClientID, state); err != nil {
return nil, &state, false, err
}
slog.Info("headless jwks cache refreshed", "clientID", client.ClientID, "reason", reason, "keyCount", len(keySet.Keys))
return keySet, &state, true, nil
}
func (s *HeadlessJWKSCacheService) persistRefreshFailure(client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, refreshErr error) *domain.HeadlessJWKSCacheState {
now := time.Now()
state := domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: strings.TrimSpace(client.HeadlessJWKSURI()),
LastCheckedAt: &now,
LastRefreshStatus: "failure",
LastError: refreshErr.Error(),
ConsecutiveFailures: 1,
}
if previous != nil {
state.CachedAt = previous.CachedAt
state.ExpiresAt = previous.ExpiresAt
state.LastSuccessfulVerificationAt = previous.LastSuccessfulVerificationAt
state.CachedKids = previous.CachedKids
state.ETag = previous.ETag
state.LastModified = previous.LastModified
state.RawJWKS = previous.RawJWKS
state.ConsecutiveFailures = previous.ConsecutiveFailures + 1
}
_ = s.SaveState(client.ClientID, state)
return &state
}
func decodeHeadlessJWKS(raw string) (*jose.JSONWebKeySet, error) {
var keySet jose.JSONWebKeySet
if err := json.Unmarshal([]byte(raw), &keySet); err != nil {
return nil, fmt.Errorf("failed to decode jwks from jwksUri: %w", err)
}
if len(keySet.Keys) == 0 {
return nil, fmt.Errorf("configured jwksUri returned no keys")
}
return &keySet, nil
}
type headlessJWKSPreviewDocument struct {
Keys []headlessJWKSPreviewKey `json:"keys"`
}
type headlessJWKSPreviewKey struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
}
func summarizeHeadlessJWKS(raw string) []domain.HeadlessJWKSParsedKey {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
var document headlessJWKSPreviewDocument
if err := json.Unmarshal([]byte(raw), &document); err != nil {
return nil
}
parsedKeys := make([]domain.HeadlessJWKSParsedKey, 0, len(document.Keys))
for _, key := range document.Keys {
parsedKeys = append(parsedKeys, domain.HeadlessJWKSParsedKey{
Kid: strings.TrimSpace(key.Kid),
Kty: strings.TrimSpace(key.Kty),
Use: strings.TrimSpace(key.Use),
Alg: strings.TrimSpace(key.Alg),
NPreview: previewHeadlessJWKValue(key.N),
})
}
return parsedKeys
}
func previewHeadlessJWKValue(value string) string {
value = strings.TrimSpace(value)
if len(value) <= 24 {
return value
}
return value[:12] + "..." + value[len(value)-12:]
}
func extractHeadlessKids(keySet *jose.JSONWebKeySet) []string {
if keySet == nil {
return nil
}
kids := make([]string, 0, len(keySet.Keys))
for _, key := range keySet.Keys {
if kid := strings.TrimSpace(key.KeyID); kid != "" {
kids = append(kids, kid)
}
}
return kids
}
func containsString(values []string, needle string) bool {
needle = strings.TrimSpace(needle)
if needle == "" {
return false
}
for _, value := range values {
if strings.TrimSpace(value) == needle {
return true
}
}
return false
}
func previousLastVerification(previous *domain.HeadlessJWKSCacheState) *time.Time {
if previous == nil {
return nil
}
return previous.LastSuccessfulVerificationAt
}
func ptrTime(value time.Time) *time.Time {
return &value
}
func (w *HeadlessJWKSCacheWorker) Start(ctx context.Context) {
if w == nil || w.Hydra == nil || w.Cache == nil {
return
}
w.runOnce(ctx)
ticker := time.NewTicker(w.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.runOnce(ctx)
}
}
}
func (w *HeadlessJWKSCacheWorker) runOnce(ctx context.Context) {
offset := 0
pageSize := w.PageSize
if pageSize <= 0 {
pageSize = 100
}
now := time.Now()
for {
clients, err := w.Hydra.ListClients(ctx, pageSize, offset)
if err != nil {
slog.Warn("headless jwks worker failed to list clients", "error", err)
return
}
if len(clients) == 0 {
return
}
for _, client := range clients {
if !client.IsHeadlessLoginEnabled() {
continue
}
state, err := w.Cache.GetState(client.ClientID)
if err != nil {
slog.Warn("headless jwks worker failed to load cache state", "clientID", client.ClientID, "error", err)
continue
}
if !w.Cache.ShouldPrefetch(state, now) {
continue
}
if _, err := w.Cache.ForceRefresh(ctx, client, "cron_prefetch"); err != nil {
slog.Warn("headless jwks worker refresh failed", "clientID", client.ClientID, "error", err)
}
}
if len(clients) < pageSize {
return
}
offset += len(clients)
}
}

View File

@@ -0,0 +1,164 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type headlessJWKSCacheTestRedis struct {
data map[string]string
}
func (m *headlessJWKSCacheTestRedis) Set(key string, value string, expiration time.Duration) error {
if m.data == nil {
m.data = map[string]string{}
}
m.data[key] = value
return nil
}
func (m *headlessJWKSCacheTestRedis) Get(key string) (string, error) {
if m.data == nil {
return "", nil
}
return m.data[key], nil
}
func (m *headlessJWKSCacheTestRedis) Delete(key string) error {
if m.data != nil {
delete(m.data, key)
}
return nil
}
func (m *headlessJWKSCacheTestRedis) StoreVerificationCode(phone, code string) error {
return nil
}
func (m *headlessJWKSCacheTestRedis) GetVerificationCode(phone string) (string, error) {
return "", nil
}
func (m *headlessJWKSCacheTestRedis) DeleteVerificationCode(phone string) error {
return nil
}
func TestHeadlessJWKSCacheService_EnsureFreshKeySet_UsesCachedJWKSWhenFresh(t *testing.T) {
_, jwks := mustServiceHeadlessRSAJWK(t, "cached-key")
raw, err := json.Marshal(jwks)
require.NoError(t, err)
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
now := time.Now()
err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
RawJWKS: string(raw),
CachedKids: []string{"cached-key"},
CachedAt: &now,
LastCheckedAt: &now,
ExpiresAt: ptrTestTime(now.Add(30 * time.Minute)),
LastRefreshStatus: "success",
ConsecutiveFailures: 0,
})
require.NoError(t, err)
cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected network fetch: %s", r.URL.String())
}))
keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}, "cached-key")
require.NoError(t, err)
assert.False(t, refreshed)
require.NotNil(t, keySet)
assert.Len(t, keySet.Keys, 1)
require.NotNil(t, state)
assert.Equal(t, []string{"cached-key"}, state.CachedKids)
}
func TestHeadlessJWKSCacheService_EnsureFreshKeySet_RefreshesWhenKidMissing(t *testing.T) {
_, staleJWKS := mustServiceHeadlessRSAJWK(t, "stale-key")
staleRaw, err := json.Marshal(staleJWKS)
require.NoError(t, err)
_, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key")
freshRaw, err := json.Marshal(freshJWKS)
require.NoError(t, err)
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
now := time.Now()
err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
RawJWKS: string(staleRaw),
CachedKids: []string{"stale-key"},
CachedAt: &now,
LastCheckedAt: &now,
ExpiresAt: ptrTestTime(now.Add(30 * time.Minute)),
LastRefreshStatus: "success",
})
require.NoError(t, err)
cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", r.URL.String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(freshRaw)
}))
keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}, "fresh-key")
require.NoError(t, err)
assert.True(t, refreshed)
require.NotNil(t, keySet)
assert.Len(t, keySet.Keys, 1)
require.NotNil(t, state)
assert.Equal(t, []string{"fresh-key"}, state.CachedKids)
stored, err := cacheService.GetState("client-headless")
require.NoError(t, err)
require.NotNil(t, stored)
assert.Equal(t, []string{"fresh-key"}, stored.CachedKids)
}
func mustServiceHeadlessRSAJWK(t *testing.T, kid string) (*rsa.PrivateKey, jose.JSONWebKeySet) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicJWK := jose.JSONWebKey{
Key: &privateKey.PublicKey,
KeyID: kid,
Algorithm: string(jose.RS256),
Use: "sig",
}
return privateKey, jose.JSONWebKeySet{Keys: []jose.JSONWebKey{publicJWK}}
}
func ptrTestTime(value time.Time) *time.Time {
return &value
}