package service import ( "baron-sso-backend/internal/domain" "context" "crypto/rand" "crypto/rsa" "encoding/json" "net/http" "testing" "time" "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type headlessJWKSCacheTestRedis struct { data map[string]string } func (m *headlessJWKSCacheTestRedis) Set(key string, value string, expiration time.Duration) error { if m.data == nil { m.data = map[string]string{} } m.data[key] = value return nil } func (m *headlessJWKSCacheTestRedis) Get(key string) (string, error) { if m.data == nil { return "", nil } return m.data[key], nil } func (m *headlessJWKSCacheTestRedis) Delete(key string) error { if m.data != nil { delete(m.data, key) } return nil } func (m *headlessJWKSCacheTestRedis) StoreVerificationCode(phone, code string) error { return nil } func (m *headlessJWKSCacheTestRedis) GetVerificationCode(phone string) (string, error) { return "", nil } func (m *headlessJWKSCacheTestRedis) DeleteVerificationCode(phone string) error { return nil } func TestHeadlessJWKSCacheService_EnsureFreshKeySet_UsesCachedJWKSWhenFresh(t *testing.T) { _, jwks := mustServiceHeadlessRSAJWK(t, "cached-key") raw, err := json.Marshal(jwks) require.NoError(t, err) redisRepo := &headlessJWKSCacheTestRedis{} cacheService := NewHeadlessJWKSCacheService(redisRepo, nil) now := time.Now() err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{ ClientID: "client-headless", JWKSURI: "https://rp.example.com/.well-known/jwks.json", RawJWKS: string(raw), CachedKids: []string{"cached-key"}, CachedAt: &now, LastCheckedAt: &now, ExpiresAt: ptrTestTime(now.Add(30 * time.Minute)), LastRefreshStatus: "success", ConsecutiveFailures: 0, }) require.NoError(t, err) cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatalf("unexpected network fetch: %s", r.URL.String()) })) keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{ ClientID: "client-headless", Metadata: map[string]any{ domain.MetadataHeadlessLoginEnabled: true, domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json", }, }, "cached-key") require.NoError(t, err) assert.False(t, refreshed) require.NotNil(t, keySet) assert.Len(t, keySet.Keys, 1) require.NotNil(t, state) assert.Equal(t, []string{"cached-key"}, state.CachedKids) } func TestHeadlessJWKSCacheService_EnsureFreshKeySet_RefreshesWhenKidMissing(t *testing.T) { _, staleJWKS := mustServiceHeadlessRSAJWK(t, "stale-key") staleRaw, err := json.Marshal(staleJWKS) require.NoError(t, err) _, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key") freshRaw, err := json.Marshal(freshJWKS) require.NoError(t, err) redisRepo := &headlessJWKSCacheTestRedis{} cacheService := NewHeadlessJWKSCacheService(redisRepo, nil) now := time.Now() err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{ ClientID: "client-headless", JWKSURI: "https://rp.example.com/.well-known/jwks.json", RawJWKS: string(staleRaw), CachedKids: []string{"stale-key"}, CachedAt: &now, LastCheckedAt: &now, ExpiresAt: ptrTestTime(now.Add(30 * time.Minute)), LastRefreshStatus: "success", }) require.NoError(t, err) cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", r.URL.String()) w.Header().Set("Content-Type", "application/json") _, _ = w.Write(freshRaw) })) keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{ ClientID: "client-headless", Metadata: map[string]any{ domain.MetadataHeadlessLoginEnabled: true, domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json", }, }, "fresh-key") require.NoError(t, err) assert.True(t, refreshed) require.NotNil(t, keySet) assert.Len(t, keySet.Keys, 1) require.NotNil(t, state) assert.Equal(t, []string{"fresh-key"}, state.CachedKids) stored, err := cacheService.GetState("client-headless") require.NoError(t, err) require.NotNil(t, stored) assert.Equal(t, []string{"fresh-key"}, stored.CachedKids) } func mustServiceHeadlessRSAJWK(t *testing.T, kid string) (*rsa.PrivateKey, jose.JSONWebKeySet) { t.Helper() privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) publicJWK := jose.JSONWebKey{ Key: &privateKey.PublicKey, KeyID: kid, Algorithm: string(jose.RS256), Use: "sig", } return privateKey, jose.JSONWebKeySet{Keys: []jose.JSONWebKey{publicJWK}} } func ptrTestTime(value time.Time) *time.Time { return &value }