package service import ( "baron-sso-backend/internal/domain" "context" "encoding/json" "fmt" "io" "log/slog" "net/http" "strconv" "strings" "time" "github.com/go-jose/go-jose/v4" ) const ( headlessJWKSCacheKeyPrefix = "headless_jwks_cache:" ) type HeadlessJWKSCacheService struct { Redis domain.RedisRepository HTTPClient *http.Client TTL time.Duration PrefetchWindow time.Duration RequestTimeout time.Duration FailureThreshold int FailureBackoff time.Duration } type headlessJWKSCacheStateStore struct { ClientID string `json:"clientId"` JWKSURI string `json:"jwksUri"` CachedAt *time.Time `json:"cachedAt,omitempty"` ExpiresAt *time.Time `json:"expiresAt,omitempty"` LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"` NextRetryAt *time.Time `json:"nextRetryAt,omitempty"` LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"` LastRefreshStatus string `json:"lastRefreshStatus,omitempty"` LastError string `json:"lastError,omitempty"` ConsecutiveFailures int `json:"consecutiveFailures,omitempty"` CachedKids []string `json:"cachedKids,omitempty"` ETag string `json:"etag,omitempty"` LastModified string `json:"lastModified,omitempty"` RawJWKS string `json:"rawJwks,omitempty"` } type HeadlessJWKSCacheWorker struct { Hydra *HydraAdminService Cache *HeadlessJWKSCacheService Interval time.Duration PageSize int } func NewHeadlessJWKSCacheService(redis domain.RedisRepository, httpClient *http.Client) *HeadlessJWKSCacheService { ttlSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_CACHE_TTL_SECONDS", "1800"))) if ttlSeconds <= 0 { ttlSeconds = 1800 } prefetchSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_PREFETCH_WINDOW_SECONDS", "600"))) if prefetchSeconds <= 0 { prefetchSeconds = 600 } timeoutSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FETCH_TIMEOUT_SECONDS", "2"))) if timeoutSeconds <= 0 { timeoutSeconds = 2 } failureThreshold, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_THRESHOLD", "3"))) if failureThreshold <= 0 { failureThreshold = 3 } backoffSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_BACKOFF_SECONDS", "1800"))) if backoffSeconds <= 0 { backoffSeconds = 1800 } return &HeadlessJWKSCacheService{ Redis: redis, HTTPClient: httpClient, TTL: time.Duration(ttlSeconds) * time.Second, PrefetchWindow: time.Duration(prefetchSeconds) * time.Second, RequestTimeout: time.Duration(timeoutSeconds) * time.Second, FailureThreshold: failureThreshold, FailureBackoff: time.Duration(backoffSeconds) * time.Second, } } func NewHeadlessJWKSCacheWorker(hydra *HydraAdminService, cache *HeadlessJWKSCacheService) *HeadlessJWKSCacheWorker { intervalSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_REFRESH_INTERVAL_SECONDS", "600"))) if intervalSeconds <= 0 { intervalSeconds = 600 } return &HeadlessJWKSCacheWorker{ Hydra: hydra, Cache: cache, Interval: time.Duration(intervalSeconds) * time.Second, PageSize: 100, } } func (s *HeadlessJWKSCacheService) httpClient() *http.Client { if s.HTTPClient != nil { return s.HTTPClient } timeout := s.RequestTimeout if timeout <= 0 { timeout = 5 * time.Second } return &http.Client{Timeout: timeout} } func (s *HeadlessJWKSCacheService) cacheKey(clientID string) string { return headlessJWKSCacheKeyPrefix + strings.TrimSpace(clientID) } func (s *HeadlessJWKSCacheService) SaveState(clientID string, state domain.HeadlessJWKSCacheState) error { if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" { return nil } payload, err := json.Marshal(headlessJWKSCacheStateStore{ ClientID: state.ClientID, JWKSURI: state.JWKSURI, CachedAt: state.CachedAt, ExpiresAt: state.ExpiresAt, LastCheckedAt: state.LastCheckedAt, NextRetryAt: state.NextRetryAt, LastSuccessfulVerificationAt: state.LastSuccessfulVerificationAt, LastRefreshStatus: state.LastRefreshStatus, LastError: state.LastError, ConsecutiveFailures: state.ConsecutiveFailures, CachedKids: state.CachedKids, ETag: state.ETag, LastModified: state.LastModified, RawJWKS: state.RawJWKS, }) if err != nil { return err } return s.Redis.Set(s.cacheKey(clientID), string(payload), 0) } func (s *HeadlessJWKSCacheService) GetState(clientID string) (*domain.HeadlessJWKSCacheState, error) { if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" { return nil, nil } raw, err := s.Redis.Get(s.cacheKey(clientID)) if err != nil || strings.TrimSpace(raw) == "" { return nil, err } var stored headlessJWKSCacheStateStore if err := json.Unmarshal([]byte(raw), &stored); err != nil { return nil, err } return &domain.HeadlessJWKSCacheState{ ClientID: stored.ClientID, JWKSURI: stored.JWKSURI, CachedAt: stored.CachedAt, ExpiresAt: stored.ExpiresAt, LastCheckedAt: stored.LastCheckedAt, NextRetryAt: stored.NextRetryAt, LastSuccessfulVerificationAt: stored.LastSuccessfulVerificationAt, LastRefreshStatus: stored.LastRefreshStatus, LastError: stored.LastError, ConsecutiveFailures: stored.ConsecutiveFailures, CachedKids: stored.CachedKids, ETag: stored.ETag, LastModified: stored.LastModified, RawJWKS: stored.RawJWKS, }, nil } func (s *HeadlessJWKSCacheService) DeleteState(clientID string) error { if s == nil || s.Redis == nil { return nil } return s.Redis.Delete(s.cacheKey(clientID)) } func (s *HeadlessJWKSCacheService) PublicState(clientID string) (*domain.HeadlessJWKSCacheState, error) { state, err := s.GetState(clientID) if err != nil || state == nil { return state, err } state.ParsedKeys = summarizeHeadlessJWKS(state.RawJWKS) state.RawJWKS = "" return state, nil } func (s *HeadlessJWKSCacheService) MarkVerificationSuccess(clientID string) error { state, err := s.GetState(clientID) if err != nil || state == nil { return err } now := time.Now() state.LastSuccessfulVerificationAt = &now return s.SaveState(clientID, *state) } func (s *HeadlessJWKSCacheService) ShouldPrefetch(state *domain.HeadlessJWKSCacheState, now time.Time) bool { if state == nil { return true } if s.ShouldSkipRefresh(state, now) { return false } if strings.TrimSpace(state.RawJWKS) == "" { return true } if state.ExpiresAt == nil { return true } return !state.ExpiresAt.After(now.Add(s.PrefetchWindow)) } func (s *HeadlessJWKSCacheService) ShouldSkipRefresh(state *domain.HeadlessJWKSCacheState, now time.Time) bool { if state == nil || state.NextRetryAt == nil { return false } return state.NextRetryAt.After(now) } func (s *HeadlessJWKSCacheService) EnsureFreshKeySet(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) { if s == nil { return nil, nil, false, fmt.Errorf("headless jwks cache service is not configured") } jwksURI := strings.TrimSpace(client.HeadlessJWKSURI()) if jwksURI == "" { return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported") } state, err := s.GetState(client.ClientID) if err != nil { slog.Warn("failed to load headless jwks cache state", "clientID", client.ClientID, "error", err) } now := time.Now() switch { case state == nil: return s.refreshClient(ctx, client, nil, "cache_miss") case strings.TrimSpace(state.JWKSURI) != jwksURI: return s.refreshClient(ctx, client, state, "config_changed") case strings.TrimSpace(state.RawJWKS) == "": return s.refreshClient(ctx, client, state, "cache_empty") case state.ExpiresAt == nil || !state.ExpiresAt.After(now): return s.refreshClient(ctx, client, state, "ttl_expired") case expectedKid != "" && !containsString(state.CachedKids, expectedKid): return s.refreshClient(ctx, client, state, "kid_missing") default: keySet, err := decodeHeadlessJWKS(state.RawJWKS) if err != nil { return s.refreshClient(ctx, client, state, "cache_corrupt") } return keySet, state, false, nil } } func (s *HeadlessJWKSCacheService) ForceRefresh(ctx context.Context, client domain.HydraClient, reason string) (*domain.HeadlessJWKSCacheState, error) { _, state, err := s.ForceRefreshKeySet(ctx, client, reason) return state, err } func (s *HeadlessJWKSCacheService) ForceRefreshKeySet(ctx context.Context, client domain.HydraClient, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, error) { previous, err := s.GetState(client.ClientID) if err != nil { slog.Warn("failed to load headless jwks cache state before force refresh", "clientID", client.ClientID, "error", err) } keySet, state, _, err := s.refreshClient(ctx, client, previous, reason) return keySet, state, err } func (s *HeadlessJWKSCacheService) refreshClient(ctx context.Context, client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) { jwksURI := strings.TrimSpace(client.HeadlessJWKSURI()) if jwksURI == "" { return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported") } req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURI, nil) if err != nil { return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to build jwks request: %w", err)), false, err } if previous != nil { if etag := strings.TrimSpace(previous.ETag); etag != "" { req.Header.Set("If-None-Match", etag) } if lastModified := strings.TrimSpace(previous.LastModified); lastModified != "" { req.Header.Set("If-Modified-Since", lastModified) } } resp, err := s.httpClient().Do(req) if err != nil { return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to fetch jwksUri: %w", err)), false, err } defer resp.Body.Close() now := time.Now() if resp.StatusCode == http.StatusNotModified && previous != nil && strings.TrimSpace(previous.RawJWKS) != "" { updated := *previous updated.JWKSURI = jwksURI updated.LastCheckedAt = &now updated.ExpiresAt = new(now.Add(s.TTL)) updated.NextRetryAt = nil updated.LastRefreshStatus = "success" updated.LastError = "" updated.ConsecutiveFailures = 0 _ = s.SaveState(client.ClientID, updated) keySet, decodeErr := decodeHeadlessJWKS(updated.RawJWKS) return keySet, &updated, true, decodeErr } if resp.StatusCode >= 300 { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) err = fmt.Errorf("failed to fetch jwksUri status=%d body=%s", resp.StatusCode, string(body)) return nil, s.persistRefreshFailure(client, previous, err), false, err } body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) if err != nil { return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to read jwks response: %w", err)), false, err } keySet, err := decodeHeadlessJWKS(string(body)) if err != nil { return nil, s.persistRefreshFailure(client, previous, err), false, err } state := domain.HeadlessJWKSCacheState{ ClientID: client.ClientID, JWKSURI: jwksURI, CachedAt: &now, ExpiresAt: new(now.Add(s.TTL)), LastCheckedAt: &now, NextRetryAt: nil, LastSuccessfulVerificationAt: previousLastVerification(previous), LastRefreshStatus: "success", LastError: "", ConsecutiveFailures: 0, CachedKids: extractHeadlessKids(keySet), ETag: strings.TrimSpace(resp.Header.Get("ETag")), LastModified: strings.TrimSpace(resp.Header.Get("Last-Modified")), RawJWKS: string(body), } if err := s.SaveState(client.ClientID, state); err != nil { return nil, &state, false, err } slog.Info("headless jwks cache refreshed", "clientID", client.ClientID, "reason", reason, "keyCount", len(keySet.Keys)) return keySet, &state, true, nil } func (s *HeadlessJWKSCacheService) persistRefreshFailure(client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, refreshErr error) *domain.HeadlessJWKSCacheState { now := time.Now() state := domain.HeadlessJWKSCacheState{ ClientID: client.ClientID, JWKSURI: strings.TrimSpace(client.HeadlessJWKSURI()), LastCheckedAt: &now, LastRefreshStatus: "failure", LastError: refreshErr.Error(), ConsecutiveFailures: 1, } if previous != nil { state.CachedAt = previous.CachedAt state.ExpiresAt = previous.ExpiresAt state.LastSuccessfulVerificationAt = previous.LastSuccessfulVerificationAt state.CachedKids = previous.CachedKids state.ETag = previous.ETag state.LastModified = previous.LastModified state.RawJWKS = previous.RawJWKS state.ConsecutiveFailures = previous.ConsecutiveFailures + 1 } if s.shouldBackoff(state.ConsecutiveFailures) { state.NextRetryAt = new(now.Add(s.failureBackoffDuration())) } _ = s.SaveState(client.ClientID, state) return &state } func (s *HeadlessJWKSCacheService) shouldBackoff(consecutiveFailures int) bool { threshold := s.FailureThreshold if threshold <= 0 { threshold = 3 } return consecutiveFailures >= threshold } func (s *HeadlessJWKSCacheService) failureBackoffDuration() time.Duration { if s.FailureBackoff > 0 { return s.FailureBackoff } return 30 * time.Minute } func decodeHeadlessJWKS(raw string) (*jose.JSONWebKeySet, error) { var keySet jose.JSONWebKeySet if err := json.Unmarshal([]byte(raw), &keySet); err != nil { return nil, fmt.Errorf("failed to decode jwks from jwksUri: %w", err) } if len(keySet.Keys) == 0 { return nil, fmt.Errorf("configured jwksUri returned no keys") } return &keySet, nil } type headlessJWKSPreviewDocument struct { Keys []headlessJWKSPreviewKey `json:"keys"` } type headlessJWKSPreviewKey struct { Kid string `json:"kid"` Kty string `json:"kty"` Use string `json:"use"` Alg string `json:"alg"` N string `json:"n"` } func summarizeHeadlessJWKS(raw string) []domain.HeadlessJWKSParsedKey { raw = strings.TrimSpace(raw) if raw == "" { return nil } var document headlessJWKSPreviewDocument if err := json.Unmarshal([]byte(raw), &document); err != nil { return nil } parsedKeys := make([]domain.HeadlessJWKSParsedKey, 0, len(document.Keys)) for _, key := range document.Keys { parsedKeys = append(parsedKeys, domain.HeadlessJWKSParsedKey{ Kid: strings.TrimSpace(key.Kid), Kty: strings.TrimSpace(key.Kty), Use: strings.TrimSpace(key.Use), Alg: strings.TrimSpace(key.Alg), N: strings.TrimSpace(key.N), }) } return parsedKeys } func extractHeadlessKids(keySet *jose.JSONWebKeySet) []string { if keySet == nil { return nil } kids := make([]string, 0, len(keySet.Keys)) for _, key := range keySet.Keys { if kid := strings.TrimSpace(key.KeyID); kid != "" { kids = append(kids, kid) } } return kids } func containsString(values []string, needle string) bool { needle = strings.TrimSpace(needle) if needle == "" { return false } for _, value := range values { if strings.TrimSpace(value) == needle { return true } } return false } func previousLastVerification(previous *domain.HeadlessJWKSCacheState) *time.Time { if previous == nil { return nil } return previous.LastSuccessfulVerificationAt } //go:fix inline func ptrTime(value time.Time) *time.Time { return new(value) } func (w *HeadlessJWKSCacheWorker) Start(ctx context.Context) { if w == nil || w.Hydra == nil || w.Cache == nil { return } w.runOnce(ctx) ticker := time.NewTicker(w.Interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: w.runOnce(ctx) } } } func (w *HeadlessJWKSCacheWorker) runOnce(ctx context.Context) { offset := 0 pageSize := w.PageSize if pageSize <= 0 { pageSize = 100 } now := time.Now() for { clients, err := w.Hydra.ListClients(ctx, pageSize, offset) if err != nil { slog.Warn("headless jwks worker failed to list clients", "error", err) return } if len(clients) == 0 { return } for _, client := range clients { if !client.IsHeadlessLoginEnabled() { continue } state, err := w.Cache.GetState(client.ClientID) if err != nil { slog.Warn("headless jwks worker failed to load cache state", "clientID", client.ClientID, "error", err) continue } if !w.Cache.ShouldPrefetch(state, now) { continue } if _, err := w.Cache.ForceRefresh(ctx, client, "cron_prefetch"); err != nil { slog.Warn("headless jwks worker refresh failed", "clientID", client.ClientID, "error", err) } } if len(clients) < pageSize { return } offset += len(clients) } }