1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/service/headless_jwks_cache.go
chan 31d107ff2e feat(user): support fixed UUID registration and enhance bulk import results
- Added support for fixed UUIDs during bulk registration (Search-first + ExternalID mapping)
- Implemented idempotency and visibility restoration for soft-deleted users
- Enhanced bulk upload UI to show 'New/Updated/Unchanged' status and modified fields
- Added logic to reclaim identifiers (login_id) from colliding records
- Added frontend E2E and backend unit tests for UUID integrity and conflict handling
- Fixed i18n, formatting, and mock tests to satisfy code-check
- Applied 'go fix' for 'omitzero' tags and general Go standards
2026-06-01 15:34:08 +09:00

546 lines
17 KiB
Go

package service
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"time"
"github.com/go-jose/go-jose/v4"
)
const (
headlessJWKSCacheKeyPrefix = "headless_jwks_cache:"
)
type HeadlessJWKSCacheService struct {
Redis domain.RedisRepository
HTTPClient *http.Client
TTL time.Duration
PrefetchWindow time.Duration
RequestTimeout time.Duration
FailureThreshold int
FailureBackoff time.Duration
}
type headlessJWKSCacheStateStore struct {
ClientID string `json:"clientId"`
JWKSURI string `json:"jwksUri"`
CachedAt *time.Time `json:"cachedAt,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
NextRetryAt *time.Time `json:"nextRetryAt,omitempty"`
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
LastError string `json:"lastError,omitempty"`
ConsecutiveFailures int `json:"consecutiveFailures,omitempty"`
CachedKids []string `json:"cachedKids,omitempty"`
ETag string `json:"etag,omitempty"`
LastModified string `json:"lastModified,omitempty"`
RawJWKS string `json:"rawJwks,omitempty"`
}
type HeadlessJWKSCacheWorker struct {
Hydra *HydraAdminService
Cache *HeadlessJWKSCacheService
Interval time.Duration
PageSize int
}
func NewHeadlessJWKSCacheService(redis domain.RedisRepository, httpClient *http.Client) *HeadlessJWKSCacheService {
ttlSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_CACHE_TTL_SECONDS", "1800")))
if ttlSeconds <= 0 {
ttlSeconds = 1800
}
prefetchSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_PREFETCH_WINDOW_SECONDS", "600")))
if prefetchSeconds <= 0 {
prefetchSeconds = 600
}
timeoutSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FETCH_TIMEOUT_SECONDS", "2")))
if timeoutSeconds <= 0 {
timeoutSeconds = 2
}
failureThreshold, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_THRESHOLD", "3")))
if failureThreshold <= 0 {
failureThreshold = 3
}
backoffSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_BACKOFF_SECONDS", "1800")))
if backoffSeconds <= 0 {
backoffSeconds = 1800
}
return &HeadlessJWKSCacheService{
Redis: redis,
HTTPClient: httpClient,
TTL: time.Duration(ttlSeconds) * time.Second,
PrefetchWindow: time.Duration(prefetchSeconds) * time.Second,
RequestTimeout: time.Duration(timeoutSeconds) * time.Second,
FailureThreshold: failureThreshold,
FailureBackoff: time.Duration(backoffSeconds) * time.Second,
}
}
func NewHeadlessJWKSCacheWorker(hydra *HydraAdminService, cache *HeadlessJWKSCacheService) *HeadlessJWKSCacheWorker {
intervalSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_REFRESH_INTERVAL_SECONDS", "600")))
if intervalSeconds <= 0 {
intervalSeconds = 600
}
return &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cache,
Interval: time.Duration(intervalSeconds) * time.Second,
PageSize: 100,
}
}
func (s *HeadlessJWKSCacheService) httpClient() *http.Client {
if s.HTTPClient != nil {
return s.HTTPClient
}
timeout := s.RequestTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
return &http.Client{Timeout: timeout}
}
func (s *HeadlessJWKSCacheService) cacheKey(clientID string) string {
return headlessJWKSCacheKeyPrefix + strings.TrimSpace(clientID)
}
func (s *HeadlessJWKSCacheService) SaveState(clientID string, state domain.HeadlessJWKSCacheState) error {
if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" {
return nil
}
payload, err := json.Marshal(headlessJWKSCacheStateStore{
ClientID: state.ClientID,
JWKSURI: state.JWKSURI,
CachedAt: state.CachedAt,
ExpiresAt: state.ExpiresAt,
LastCheckedAt: state.LastCheckedAt,
NextRetryAt: state.NextRetryAt,
LastSuccessfulVerificationAt: state.LastSuccessfulVerificationAt,
LastRefreshStatus: state.LastRefreshStatus,
LastError: state.LastError,
ConsecutiveFailures: state.ConsecutiveFailures,
CachedKids: state.CachedKids,
ETag: state.ETag,
LastModified: state.LastModified,
RawJWKS: state.RawJWKS,
})
if err != nil {
return err
}
return s.Redis.Set(s.cacheKey(clientID), string(payload), 0)
}
func (s *HeadlessJWKSCacheService) GetState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" {
return nil, nil
}
raw, err := s.Redis.Get(s.cacheKey(clientID))
if err != nil || strings.TrimSpace(raw) == "" {
return nil, err
}
var stored headlessJWKSCacheStateStore
if err := json.Unmarshal([]byte(raw), &stored); err != nil {
return nil, err
}
return &domain.HeadlessJWKSCacheState{
ClientID: stored.ClientID,
JWKSURI: stored.JWKSURI,
CachedAt: stored.CachedAt,
ExpiresAt: stored.ExpiresAt,
LastCheckedAt: stored.LastCheckedAt,
NextRetryAt: stored.NextRetryAt,
LastSuccessfulVerificationAt: stored.LastSuccessfulVerificationAt,
LastRefreshStatus: stored.LastRefreshStatus,
LastError: stored.LastError,
ConsecutiveFailures: stored.ConsecutiveFailures,
CachedKids: stored.CachedKids,
ETag: stored.ETag,
LastModified: stored.LastModified,
RawJWKS: stored.RawJWKS,
}, nil
}
func (s *HeadlessJWKSCacheService) DeleteState(clientID string) error {
if s == nil || s.Redis == nil {
return nil
}
return s.Redis.Delete(s.cacheKey(clientID))
}
func (s *HeadlessJWKSCacheService) PublicState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
state, err := s.GetState(clientID)
if err != nil || state == nil {
return state, err
}
state.ParsedKeys = summarizeHeadlessJWKS(state.RawJWKS)
state.RawJWKS = ""
return state, nil
}
func (s *HeadlessJWKSCacheService) MarkVerificationSuccess(clientID string) error {
state, err := s.GetState(clientID)
if err != nil || state == nil {
return err
}
now := time.Now()
state.LastSuccessfulVerificationAt = &now
return s.SaveState(clientID, *state)
}
func (s *HeadlessJWKSCacheService) ShouldPrefetch(state *domain.HeadlessJWKSCacheState, now time.Time) bool {
if state == nil {
return true
}
if s.ShouldSkipRefresh(state, now) {
return false
}
if strings.TrimSpace(state.RawJWKS) == "" {
return true
}
if state.ExpiresAt == nil {
return true
}
return !state.ExpiresAt.After(now.Add(s.PrefetchWindow))
}
func (s *HeadlessJWKSCacheService) ShouldSkipRefresh(state *domain.HeadlessJWKSCacheState, now time.Time) bool {
if state == nil || state.NextRetryAt == nil {
return false
}
return state.NextRetryAt.After(now)
}
func (s *HeadlessJWKSCacheService) EnsureFreshKeySet(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
if s == nil {
return nil, nil, false, fmt.Errorf("headless jwks cache service is not configured")
}
jwksURI := strings.TrimSpace(client.HeadlessJWKSURI())
if jwksURI == "" {
return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
}
state, err := s.GetState(client.ClientID)
if err != nil {
slog.Warn("failed to load headless jwks cache state", "clientID", client.ClientID, "error", err)
}
now := time.Now()
switch {
case state == nil:
return s.refreshClient(ctx, client, nil, "cache_miss")
case strings.TrimSpace(state.JWKSURI) != jwksURI:
return s.refreshClient(ctx, client, state, "config_changed")
case strings.TrimSpace(state.RawJWKS) == "":
return s.refreshClient(ctx, client, state, "cache_empty")
case state.ExpiresAt == nil || !state.ExpiresAt.After(now):
return s.refreshClient(ctx, client, state, "ttl_expired")
case expectedKid != "" && !containsString(state.CachedKids, expectedKid):
return s.refreshClient(ctx, client, state, "kid_missing")
default:
keySet, err := decodeHeadlessJWKS(state.RawJWKS)
if err != nil {
return s.refreshClient(ctx, client, state, "cache_corrupt")
}
return keySet, state, false, nil
}
}
func (s *HeadlessJWKSCacheService) ForceRefresh(ctx context.Context, client domain.HydraClient, reason string) (*domain.HeadlessJWKSCacheState, error) {
_, state, err := s.ForceRefreshKeySet(ctx, client, reason)
return state, err
}
func (s *HeadlessJWKSCacheService) ForceRefreshKeySet(ctx context.Context, client domain.HydraClient, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, error) {
previous, err := s.GetState(client.ClientID)
if err != nil {
slog.Warn("failed to load headless jwks cache state before force refresh", "clientID", client.ClientID, "error", err)
}
keySet, state, _, err := s.refreshClient(ctx, client, previous, reason)
return keySet, state, err
}
func (s *HeadlessJWKSCacheService) refreshClient(ctx context.Context, client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
jwksURI := strings.TrimSpace(client.HeadlessJWKSURI())
if jwksURI == "" {
return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURI, nil)
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to build jwks request: %w", err)), false, err
}
if previous != nil {
if etag := strings.TrimSpace(previous.ETag); etag != "" {
req.Header.Set("If-None-Match", etag)
}
if lastModified := strings.TrimSpace(previous.LastModified); lastModified != "" {
req.Header.Set("If-Modified-Since", lastModified)
}
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to fetch jwksUri: %w", err)), false, err
}
defer resp.Body.Close()
now := time.Now()
if resp.StatusCode == http.StatusNotModified && previous != nil && strings.TrimSpace(previous.RawJWKS) != "" {
updated := *previous
updated.JWKSURI = jwksURI
updated.LastCheckedAt = &now
updated.ExpiresAt = new(now.Add(s.TTL))
updated.NextRetryAt = nil
updated.LastRefreshStatus = "success"
updated.LastError = ""
updated.ConsecutiveFailures = 0
_ = s.SaveState(client.ClientID, updated)
keySet, decodeErr := decodeHeadlessJWKS(updated.RawJWKS)
return keySet, &updated, true, decodeErr
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
err = fmt.Errorf("failed to fetch jwksUri status=%d body=%s", resp.StatusCode, string(body))
return nil, s.persistRefreshFailure(client, previous, err), false, err
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to read jwks response: %w", err)), false, err
}
keySet, err := decodeHeadlessJWKS(string(body))
if err != nil {
return nil, s.persistRefreshFailure(client, previous, err), false, err
}
state := domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: jwksURI,
CachedAt: &now,
ExpiresAt: new(now.Add(s.TTL)),
LastCheckedAt: &now,
NextRetryAt: nil,
LastSuccessfulVerificationAt: previousLastVerification(previous),
LastRefreshStatus: "success",
LastError: "",
ConsecutiveFailures: 0,
CachedKids: extractHeadlessKids(keySet),
ETag: strings.TrimSpace(resp.Header.Get("ETag")),
LastModified: strings.TrimSpace(resp.Header.Get("Last-Modified")),
RawJWKS: string(body),
}
if err := s.SaveState(client.ClientID, state); err != nil {
return nil, &state, false, err
}
slog.Info("headless jwks cache refreshed", "clientID", client.ClientID, "reason", reason, "keyCount", len(keySet.Keys))
return keySet, &state, true, nil
}
func (s *HeadlessJWKSCacheService) persistRefreshFailure(client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, refreshErr error) *domain.HeadlessJWKSCacheState {
now := time.Now()
state := domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: strings.TrimSpace(client.HeadlessJWKSURI()),
LastCheckedAt: &now,
LastRefreshStatus: "failure",
LastError: refreshErr.Error(),
ConsecutiveFailures: 1,
}
if previous != nil {
state.CachedAt = previous.CachedAt
state.ExpiresAt = previous.ExpiresAt
state.LastSuccessfulVerificationAt = previous.LastSuccessfulVerificationAt
state.CachedKids = previous.CachedKids
state.ETag = previous.ETag
state.LastModified = previous.LastModified
state.RawJWKS = previous.RawJWKS
state.ConsecutiveFailures = previous.ConsecutiveFailures + 1
}
if s.shouldBackoff(state.ConsecutiveFailures) {
state.NextRetryAt = new(now.Add(s.failureBackoffDuration()))
}
_ = s.SaveState(client.ClientID, state)
return &state
}
func (s *HeadlessJWKSCacheService) shouldBackoff(consecutiveFailures int) bool {
threshold := s.FailureThreshold
if threshold <= 0 {
threshold = 3
}
return consecutiveFailures >= threshold
}
func (s *HeadlessJWKSCacheService) failureBackoffDuration() time.Duration {
if s.FailureBackoff > 0 {
return s.FailureBackoff
}
return 30 * time.Minute
}
func decodeHeadlessJWKS(raw string) (*jose.JSONWebKeySet, error) {
var keySet jose.JSONWebKeySet
if err := json.Unmarshal([]byte(raw), &keySet); err != nil {
return nil, fmt.Errorf("failed to decode jwks from jwksUri: %w", err)
}
if len(keySet.Keys) == 0 {
return nil, fmt.Errorf("configured jwksUri returned no keys")
}
return &keySet, nil
}
type headlessJWKSPreviewDocument struct {
Keys []headlessJWKSPreviewKey `json:"keys"`
}
type headlessJWKSPreviewKey struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
}
func summarizeHeadlessJWKS(raw string) []domain.HeadlessJWKSParsedKey {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
var document headlessJWKSPreviewDocument
if err := json.Unmarshal([]byte(raw), &document); err != nil {
return nil
}
parsedKeys := make([]domain.HeadlessJWKSParsedKey, 0, len(document.Keys))
for _, key := range document.Keys {
parsedKeys = append(parsedKeys, domain.HeadlessJWKSParsedKey{
Kid: strings.TrimSpace(key.Kid),
Kty: strings.TrimSpace(key.Kty),
Use: strings.TrimSpace(key.Use),
Alg: strings.TrimSpace(key.Alg),
N: strings.TrimSpace(key.N),
})
}
return parsedKeys
}
func extractHeadlessKids(keySet *jose.JSONWebKeySet) []string {
if keySet == nil {
return nil
}
kids := make([]string, 0, len(keySet.Keys))
for _, key := range keySet.Keys {
if kid := strings.TrimSpace(key.KeyID); kid != "" {
kids = append(kids, kid)
}
}
return kids
}
func containsString(values []string, needle string) bool {
needle = strings.TrimSpace(needle)
if needle == "" {
return false
}
for _, value := range values {
if strings.TrimSpace(value) == needle {
return true
}
}
return false
}
func previousLastVerification(previous *domain.HeadlessJWKSCacheState) *time.Time {
if previous == nil {
return nil
}
return previous.LastSuccessfulVerificationAt
}
//go:fix inline
func ptrTime(value time.Time) *time.Time {
return new(value)
}
func (w *HeadlessJWKSCacheWorker) Start(ctx context.Context) {
if w == nil || w.Hydra == nil || w.Cache == nil {
return
}
w.runOnce(ctx)
ticker := time.NewTicker(w.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.runOnce(ctx)
}
}
}
func (w *HeadlessJWKSCacheWorker) runOnce(ctx context.Context) {
offset := 0
pageSize := w.PageSize
if pageSize <= 0 {
pageSize = 100
}
now := time.Now()
for {
clients, err := w.Hydra.ListClients(ctx, pageSize, offset)
if err != nil {
slog.Warn("headless jwks worker failed to list clients", "error", err)
return
}
if len(clients) == 0 {
return
}
for _, client := range clients {
if !client.IsHeadlessLoginEnabled() {
continue
}
state, err := w.Cache.GetState(client.ClientID)
if err != nil {
slog.Warn("headless jwks worker failed to load cache state", "clientID", client.ClientID, "error", err)
continue
}
if !w.Cache.ShouldPrefetch(state, now) {
continue
}
if _, err := w.Cache.ForceRefresh(ctx, client, "cron_prefetch"); err != nil {
slog.Warn("headless jwks worker refresh failed", "clientID", client.ClientID, "error", err)
}
}
if len(clients) < pageSize {
return
}
offset += len(clients)
}
}