package service import ( "baron-sso-backend/internal/domain" "context" "crypto/rand" "crypto/rsa" "encoding/json" "io" "net/http" "strings" "testing" "time" "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type headlessJWKSCacheTestRedis struct { data map[string]string } func (m *headlessJWKSCacheTestRedis) Set(key string, value string, expiration time.Duration) error { if m.data == nil { m.data = map[string]string{} } m.data[key] = value return nil } func (m *headlessJWKSCacheTestRedis) Get(key string) (string, error) { if m.data == nil { return "", nil } return m.data[key], nil } func (m *headlessJWKSCacheTestRedis) Delete(key string) error { if m.data != nil { delete(m.data, key) } return nil } func (m *headlessJWKSCacheTestRedis) StoreVerificationCode(phone, code string) error { return nil } func (m *headlessJWKSCacheTestRedis) GetVerificationCode(phone string) (string, error) { return "", nil } func (m *headlessJWKSCacheTestRedis) DeleteVerificationCode(phone string) error { return nil } func TestHeadlessJWKSCacheService_EnsureFreshKeySet_UsesCachedJWKSWhenFresh(t *testing.T) { _, jwks := mustServiceHeadlessRSAJWK(t, "cached-key") raw, err := json.Marshal(jwks) require.NoError(t, err) redisRepo := &headlessJWKSCacheTestRedis{} cacheService := NewHeadlessJWKSCacheService(redisRepo, nil) now := time.Now() err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{ ClientID: "client-headless", JWKSURI: "https://rp.example.com/.well-known/jwks.json", RawJWKS: string(raw), CachedKids: []string{"cached-key"}, CachedAt: &now, LastCheckedAt: &now, ExpiresAt: new(now.Add(30 * time.Minute)), LastRefreshStatus: "success", ConsecutiveFailures: 0, }) require.NoError(t, err) cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatalf("unexpected network fetch: %s", r.URL.String()) })) keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{ ClientID: "client-headless", Metadata: map[string]any{ domain.MetadataHeadlessLoginEnabled: true, domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json", }, }, "cached-key") require.NoError(t, err) assert.False(t, refreshed) require.NotNil(t, keySet) assert.Len(t, keySet.Keys, 1) require.NotNil(t, state) assert.Equal(t, []string{"cached-key"}, state.CachedKids) } func TestHeadlessJWKSCacheService_EnsureFreshKeySet_RefreshesWhenKidMissing(t *testing.T) { _, staleJWKS := mustServiceHeadlessRSAJWK(t, "stale-key") staleRaw, err := json.Marshal(staleJWKS) require.NoError(t, err) _, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key") freshRaw, err := json.Marshal(freshJWKS) require.NoError(t, err) redisRepo := &headlessJWKSCacheTestRedis{} cacheService := NewHeadlessJWKSCacheService(redisRepo, nil) now := time.Now() err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{ ClientID: "client-headless", JWKSURI: "https://rp.example.com/.well-known/jwks.json", RawJWKS: string(staleRaw), CachedKids: []string{"stale-key"}, CachedAt: &now, LastCheckedAt: &now, ExpiresAt: new(now.Add(30 * time.Minute)), LastRefreshStatus: "success", }) require.NoError(t, err) cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", r.URL.String()) w.Header().Set("Content-Type", "application/json") _, _ = w.Write(freshRaw) })) keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{ ClientID: "client-headless", Metadata: map[string]any{ domain.MetadataHeadlessLoginEnabled: true, domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json", }, }, "fresh-key") require.NoError(t, err) assert.True(t, refreshed) require.NotNil(t, keySet) assert.Len(t, keySet.Keys, 1) require.NotNil(t, state) assert.Equal(t, []string{"fresh-key"}, state.CachedKids) stored, err := cacheService.GetState("client-headless") require.NoError(t, err) require.NotNil(t, stored) assert.Equal(t, []string{"fresh-key"}, stored.CachedKids) } func TestHeadlessJWKSCacheService_PersistRefreshFailure_SetsNextRetryAtAfterThreshold(t *testing.T) { redisRepo := &headlessJWKSCacheTestRedis{} cacheService := NewHeadlessJWKSCacheService(redisRepo, nil) cacheService.FailureThreshold = 3 cacheService.FailureBackoff = 15 * time.Minute client := domain.HydraClient{ ClientID: "client-headless", Metadata: map[string]any{ domain.MetadataHeadlessLoginEnabled: true, domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json", }, } previous := &domain.HeadlessJWKSCacheState{ ClientID: client.ClientID, JWKSURI: "https://rp.example.com/.well-known/jwks.json", LastRefreshStatus: "failure", ConsecutiveFailures: 2, } state := cacheService.persistRefreshFailure(client, previous, assert.AnError) require.NotNil(t, state) assert.Equal(t, 3, state.ConsecutiveFailures) require.NotNil(t, state.NextRetryAt) assert.WithinDuration(t, time.Now().Add(15*time.Minute), *state.NextRetryAt, 3*time.Second) } func TestHeadlessJWKSCacheService_ShouldPrefetch_SkipsUntilNextRetryAt(t *testing.T) { cacheService := NewHeadlessJWKSCacheService(&headlessJWKSCacheTestRedis{}, nil) now := time.Now() state := &domain.HeadlessJWKSCacheState{ ClientID: "client-headless", LastRefreshStatus: "failure", ConsecutiveFailures: 3, NextRetryAt: new(now.Add(10 * time.Minute)), } assert.False(t, cacheService.ShouldPrefetch(state, now)) assert.True(t, cacheService.ShouldPrefetch(state, now.Add(11*time.Minute))) } func TestHeadlessJWKSCacheWorker_RunOnce_SkipsBackoffTargets(t *testing.T) { clients := []domain.HydraClient{ newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json"), newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json"), } hydra := &HydraAdminService{ AdminURL: "http://hydra.test", HTTPClient: clientForHandler(jsonHandler(t, clients)), } redisRepo := &headlessJWKSCacheTestRedis{} cacheService := NewHeadlessJWKSCacheService(redisRepo, nil) cacheService.FailureThreshold = 3 cacheService.FailureBackoff = 15 * time.Minute now := time.Now() require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{ ClientID: "client-fail", JWKSURI: clients[0].HeadlessJWKSURI(), LastRefreshStatus: "failure", ConsecutiveFailures: 2, })) require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{ ClientID: "client-skip", JWKSURI: clients[1].HeadlessJWKSURI(), LastRefreshStatus: "failure", ConsecutiveFailures: 3, NextRetryAt: new(now.Add(10 * time.Minute)), })) fetchCounts := map[string]int{} cacheService.HTTPClient = &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { fetchCounts[req.URL.Host]++ if req.URL.Host == "fail.example.com" { return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil } t.Fatalf("unexpected fetch for host %s", req.URL.Host) return nil, nil }), } worker := &HeadlessJWKSCacheWorker{ Hydra: hydra, Cache: cacheService, PageSize: 100, } worker.runOnce(context.Background()) assert.Equal(t, 1, fetchCounts["fail.example.com"]) assert.Equal(t, 0, fetchCounts["skip.example.com"]) failedState, err := cacheService.GetState("client-fail") require.NoError(t, err) require.NotNil(t, failedState) assert.Equal(t, 3, failedState.ConsecutiveFailures) require.NotNil(t, failedState.NextRetryAt) skippedState, err := cacheService.GetState("client-skip") require.NoError(t, err) require.NotNil(t, skippedState) assert.Equal(t, 3, skippedState.ConsecutiveFailures) require.NotNil(t, skippedState.NextRetryAt) assert.WithinDuration(t, now.Add(10*time.Minute), *skippedState.NextRetryAt, time.Second) } func TestHeadlessJWKSCacheWorker_RunOnce_RetriesAfterBackoffAndClearsFailureStateOnSuccess(t *testing.T) { _, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key") freshRaw, err := json.Marshal(freshJWKS) require.NoError(t, err) client := newTestHeadlessClient("client-recover", "https://recover.example.com/.well-known/jwks.json") hydra := &HydraAdminService{ AdminURL: "http://hydra.test", HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{client})), } redisRepo := &headlessJWKSCacheTestRedis{} cacheService := NewHeadlessJWKSCacheService(redisRepo, nil) cacheService.FailureThreshold = 3 cacheService.FailureBackoff = 15 * time.Minute require.NoError(t, cacheService.SaveState("client-recover", domain.HeadlessJWKSCacheState{ ClientID: "client-recover", JWKSURI: client.HeadlessJWKSURI(), LastRefreshStatus: "failure", LastError: "previous failure", ConsecutiveFailures: 3, NextRetryAt: new(time.Now().Add(-time.Minute)), })) fetchCount := 0 cacheService.HTTPClient = &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { fetchCount++ assert.Equal(t, "recover.example.com", req.URL.Host) return jsonHTTPResponse(http.StatusOK, string(freshRaw)), nil }), } worker := &HeadlessJWKSCacheWorker{ Hydra: hydra, Cache: cacheService, PageSize: 100, } worker.runOnce(context.Background()) assert.Equal(t, 1, fetchCount) recoveredState, err := cacheService.GetState("client-recover") require.NoError(t, err) require.NotNil(t, recoveredState) assert.Equal(t, "success", recoveredState.LastRefreshStatus) assert.Empty(t, recoveredState.LastError) assert.Equal(t, 0, recoveredState.ConsecutiveFailures) assert.Nil(t, recoveredState.NextRetryAt) assert.Equal(t, []string{"fresh-key"}, recoveredState.CachedKids) } func TestHeadlessJWKSCacheWorker_RunOnce_MixedClients(t *testing.T) { _, successJWKS := mustServiceHeadlessRSAJWK(t, "success-key") successRaw, err := json.Marshal(successJWKS) require.NoError(t, err) successClient := newTestHeadlessClient("client-success", "https://success.example.com/.well-known/jwks.json") failClient := newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json") skipClient := newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json") disabledClient := domain.HydraClient{ ClientID: "client-disabled", Metadata: map[string]any{ domain.MetadataHeadlessLoginEnabled: false, domain.MetadataHeadlessJWKSURI: "https://disabled.example.com/.well-known/jwks.json", domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt", }, } hydra := &HydraAdminService{ AdminURL: "http://hydra.test", HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{ successClient, failClient, skipClient, disabledClient, })), } redisRepo := &headlessJWKSCacheTestRedis{} cacheService := NewHeadlessJWKSCacheService(redisRepo, nil) cacheService.FailureThreshold = 3 cacheService.FailureBackoff = 20 * time.Minute require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{ ClientID: "client-fail", JWKSURI: failClient.HeadlessJWKSURI(), LastRefreshStatus: "failure", ConsecutiveFailures: 2, })) require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{ ClientID: "client-skip", JWKSURI: skipClient.HeadlessJWKSURI(), LastRefreshStatus: "failure", ConsecutiveFailures: 3, NextRetryAt: new(time.Now().Add(10 * time.Minute)), })) fetchCounts := map[string]int{} cacheService.HTTPClient = &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { fetchCounts[req.URL.Host]++ switch req.URL.Host { case "success.example.com": return jsonHTTPResponse(http.StatusOK, string(successRaw)), nil case "fail.example.com": return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil default: t.Fatalf("unexpected fetch for host %s", req.URL.Host) return nil, nil } }), } worker := &HeadlessJWKSCacheWorker{ Hydra: hydra, Cache: cacheService, PageSize: 100, } worker.runOnce(context.Background()) assert.Equal(t, 1, fetchCounts["success.example.com"]) assert.Equal(t, 1, fetchCounts["fail.example.com"]) assert.Equal(t, 0, fetchCounts["skip.example.com"]) assert.Equal(t, 0, fetchCounts["disabled.example.com"]) successState, err := cacheService.GetState("client-success") require.NoError(t, err) require.NotNil(t, successState) assert.Equal(t, "success", successState.LastRefreshStatus) assert.Equal(t, 0, successState.ConsecutiveFailures) assert.Nil(t, successState.NextRetryAt) failState, err := cacheService.GetState("client-fail") require.NoError(t, err) require.NotNil(t, failState) assert.Equal(t, "failure", failState.LastRefreshStatus) assert.Equal(t, 3, failState.ConsecutiveFailures) require.NotNil(t, failState.NextRetryAt) } func mustServiceHeadlessRSAJWK(t *testing.T, kid string) (*rsa.PrivateKey, jose.JSONWebKeySet) { t.Helper() privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) publicJWK := jose.JSONWebKey{ Key: &privateKey.PublicKey, KeyID: kid, Algorithm: string(jose.RS256), Use: "sig", } return privateKey, jose.JSONWebKeySet{Keys: []jose.JSONWebKey{publicJWK}} } //go:fix inline func ptrTestTime(value time.Time) *time.Time { return new(value) } func newTestHeadlessClient(clientID, jwksURI string) domain.HydraClient { return domain.HydraClient{ ClientID: clientID, Metadata: map[string]any{ domain.MetadataHeadlessLoginEnabled: true, domain.MetadataHeadlessJWKSURI: jwksURI, domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt", }, } } func jsonHandler(t *testing.T, payload any) http.HandlerFunc { t.Helper() return func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/clients", r.URL.Path) w.Header().Set("Content-Type", "application/json") require.NoError(t, json.NewEncoder(w).Encode(payload)) } } func jsonHTTPResponse(status int, body string) *http.Response { return &http.Response{ StatusCode: status, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(body)), } }