1
0
forked from baron/baron-sso
Files
baron-sso/backend/internal/service/headless_jwks_cache_test.go
chan 31d107ff2e feat(user): support fixed UUID registration and enhance bulk import results
- Added support for fixed UUIDs during bulk registration (Search-first + ExternalID mapping)
- Implemented idempotency and visibility restoration for soft-deleted users
- Enhanced bulk upload UI to show 'New/Updated/Unchanged' status and modified fields
- Added logic to reclaim identifiers (login_id) from colliding records
- Added frontend E2E and backend unit tests for UUID integrity and conflict handling
- Fixed i18n, formatting, and mock tests to satisfy code-check
- Applied 'go fix' for 'omitzero' tags and general Go standards
2026-06-01 15:34:08 +09:00

451 lines
14 KiB
Go

package service
import (
"baron-sso-backend/internal/domain"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type headlessJWKSCacheTestRedis struct {
data map[string]string
}
func (m *headlessJWKSCacheTestRedis) Set(key string, value string, expiration time.Duration) error {
if m.data == nil {
m.data = map[string]string{}
}
m.data[key] = value
return nil
}
func (m *headlessJWKSCacheTestRedis) Get(key string) (string, error) {
if m.data == nil {
return "", nil
}
return m.data[key], nil
}
func (m *headlessJWKSCacheTestRedis) Delete(key string) error {
if m.data != nil {
delete(m.data, key)
}
return nil
}
func (m *headlessJWKSCacheTestRedis) StoreVerificationCode(phone, code string) error {
return nil
}
func (m *headlessJWKSCacheTestRedis) GetVerificationCode(phone string) (string, error) {
return "", nil
}
func (m *headlessJWKSCacheTestRedis) DeleteVerificationCode(phone string) error {
return nil
}
func TestHeadlessJWKSCacheService_EnsureFreshKeySet_UsesCachedJWKSWhenFresh(t *testing.T) {
_, jwks := mustServiceHeadlessRSAJWK(t, "cached-key")
raw, err := json.Marshal(jwks)
require.NoError(t, err)
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
now := time.Now()
err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
RawJWKS: string(raw),
CachedKids: []string{"cached-key"},
CachedAt: &now,
LastCheckedAt: &now,
ExpiresAt: new(now.Add(30 * time.Minute)),
LastRefreshStatus: "success",
ConsecutiveFailures: 0,
})
require.NoError(t, err)
cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected network fetch: %s", r.URL.String())
}))
keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}, "cached-key")
require.NoError(t, err)
assert.False(t, refreshed)
require.NotNil(t, keySet)
assert.Len(t, keySet.Keys, 1)
require.NotNil(t, state)
assert.Equal(t, []string{"cached-key"}, state.CachedKids)
}
func TestHeadlessJWKSCacheService_EnsureFreshKeySet_RefreshesWhenKidMissing(t *testing.T) {
_, staleJWKS := mustServiceHeadlessRSAJWK(t, "stale-key")
staleRaw, err := json.Marshal(staleJWKS)
require.NoError(t, err)
_, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key")
freshRaw, err := json.Marshal(freshJWKS)
require.NoError(t, err)
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
now := time.Now()
err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
RawJWKS: string(staleRaw),
CachedKids: []string{"stale-key"},
CachedAt: &now,
LastCheckedAt: &now,
ExpiresAt: new(now.Add(30 * time.Minute)),
LastRefreshStatus: "success",
})
require.NoError(t, err)
cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", r.URL.String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(freshRaw)
}))
keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}, "fresh-key")
require.NoError(t, err)
assert.True(t, refreshed)
require.NotNil(t, keySet)
assert.Len(t, keySet.Keys, 1)
require.NotNil(t, state)
assert.Equal(t, []string{"fresh-key"}, state.CachedKids)
stored, err := cacheService.GetState("client-headless")
require.NoError(t, err)
require.NotNil(t, stored)
assert.Equal(t, []string{"fresh-key"}, stored.CachedKids)
}
func TestHeadlessJWKSCacheService_PersistRefreshFailure_SetsNextRetryAtAfterThreshold(t *testing.T) {
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
client := domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}
previous := &domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}
state := cacheService.persistRefreshFailure(client, previous, assert.AnError)
require.NotNil(t, state)
assert.Equal(t, 3, state.ConsecutiveFailures)
require.NotNil(t, state.NextRetryAt)
assert.WithinDuration(t, time.Now().Add(15*time.Minute), *state.NextRetryAt, 3*time.Second)
}
func TestHeadlessJWKSCacheService_ShouldPrefetch_SkipsUntilNextRetryAt(t *testing.T) {
cacheService := NewHeadlessJWKSCacheService(&headlessJWKSCacheTestRedis{}, nil)
now := time.Now()
state := &domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: new(now.Add(10 * time.Minute)),
}
assert.False(t, cacheService.ShouldPrefetch(state, now))
assert.True(t, cacheService.ShouldPrefetch(state, now.Add(11*time.Minute)))
}
func TestHeadlessJWKSCacheWorker_RunOnce_SkipsBackoffTargets(t *testing.T) {
clients := []domain.HydraClient{
newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json"),
newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json"),
}
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, clients)),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
now := time.Now()
require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{
ClientID: "client-fail",
JWKSURI: clients[0].HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}))
require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{
ClientID: "client-skip",
JWKSURI: clients[1].HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: new(now.Add(10 * time.Minute)),
}))
fetchCounts := map[string]int{}
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCounts[req.URL.Host]++
if req.URL.Host == "fail.example.com" {
return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
}
t.Fatalf("unexpected fetch for host %s", req.URL.Host)
return nil, nil
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCounts["fail.example.com"])
assert.Equal(t, 0, fetchCounts["skip.example.com"])
failedState, err := cacheService.GetState("client-fail")
require.NoError(t, err)
require.NotNil(t, failedState)
assert.Equal(t, 3, failedState.ConsecutiveFailures)
require.NotNil(t, failedState.NextRetryAt)
skippedState, err := cacheService.GetState("client-skip")
require.NoError(t, err)
require.NotNil(t, skippedState)
assert.Equal(t, 3, skippedState.ConsecutiveFailures)
require.NotNil(t, skippedState.NextRetryAt)
assert.WithinDuration(t, now.Add(10*time.Minute), *skippedState.NextRetryAt, time.Second)
}
func TestHeadlessJWKSCacheWorker_RunOnce_RetriesAfterBackoffAndClearsFailureStateOnSuccess(t *testing.T) {
_, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key")
freshRaw, err := json.Marshal(freshJWKS)
require.NoError(t, err)
client := newTestHeadlessClient("client-recover", "https://recover.example.com/.well-known/jwks.json")
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{client})),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
require.NoError(t, cacheService.SaveState("client-recover", domain.HeadlessJWKSCacheState{
ClientID: "client-recover",
JWKSURI: client.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
LastError: "previous failure",
ConsecutiveFailures: 3,
NextRetryAt: new(time.Now().Add(-time.Minute)),
}))
fetchCount := 0
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCount++
assert.Equal(t, "recover.example.com", req.URL.Host)
return jsonHTTPResponse(http.StatusOK, string(freshRaw)), nil
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCount)
recoveredState, err := cacheService.GetState("client-recover")
require.NoError(t, err)
require.NotNil(t, recoveredState)
assert.Equal(t, "success", recoveredState.LastRefreshStatus)
assert.Empty(t, recoveredState.LastError)
assert.Equal(t, 0, recoveredState.ConsecutiveFailures)
assert.Nil(t, recoveredState.NextRetryAt)
assert.Equal(t, []string{"fresh-key"}, recoveredState.CachedKids)
}
func TestHeadlessJWKSCacheWorker_RunOnce_MixedClients(t *testing.T) {
_, successJWKS := mustServiceHeadlessRSAJWK(t, "success-key")
successRaw, err := json.Marshal(successJWKS)
require.NoError(t, err)
successClient := newTestHeadlessClient("client-success", "https://success.example.com/.well-known/jwks.json")
failClient := newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json")
skipClient := newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json")
disabledClient := domain.HydraClient{
ClientID: "client-disabled",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: false,
domain.MetadataHeadlessJWKSURI: "https://disabled.example.com/.well-known/jwks.json",
domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt",
},
}
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{
successClient,
failClient,
skipClient,
disabledClient,
})),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 20 * time.Minute
require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{
ClientID: "client-fail",
JWKSURI: failClient.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}))
require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{
ClientID: "client-skip",
JWKSURI: skipClient.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: new(time.Now().Add(10 * time.Minute)),
}))
fetchCounts := map[string]int{}
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCounts[req.URL.Host]++
switch req.URL.Host {
case "success.example.com":
return jsonHTTPResponse(http.StatusOK, string(successRaw)), nil
case "fail.example.com":
return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
default:
t.Fatalf("unexpected fetch for host %s", req.URL.Host)
return nil, nil
}
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCounts["success.example.com"])
assert.Equal(t, 1, fetchCounts["fail.example.com"])
assert.Equal(t, 0, fetchCounts["skip.example.com"])
assert.Equal(t, 0, fetchCounts["disabled.example.com"])
successState, err := cacheService.GetState("client-success")
require.NoError(t, err)
require.NotNil(t, successState)
assert.Equal(t, "success", successState.LastRefreshStatus)
assert.Equal(t, 0, successState.ConsecutiveFailures)
assert.Nil(t, successState.NextRetryAt)
failState, err := cacheService.GetState("client-fail")
require.NoError(t, err)
require.NotNil(t, failState)
assert.Equal(t, "failure", failState.LastRefreshStatus)
assert.Equal(t, 3, failState.ConsecutiveFailures)
require.NotNil(t, failState.NextRetryAt)
}
func mustServiceHeadlessRSAJWK(t *testing.T, kid string) (*rsa.PrivateKey, jose.JSONWebKeySet) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicJWK := jose.JSONWebKey{
Key: &privateKey.PublicKey,
KeyID: kid,
Algorithm: string(jose.RS256),
Use: "sig",
}
return privateKey, jose.JSONWebKeySet{Keys: []jose.JSONWebKey{publicJWK}}
}
//go:fix inline
func ptrTestTime(value time.Time) *time.Time {
return new(value)
}
func newTestHeadlessClient(clientID, jwksURI string) domain.HydraClient {
return domain.HydraClient{
ClientID: clientID,
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: jwksURI,
domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt",
},
}
}
func jsonHandler(t *testing.T, payload any) http.HandlerFunc {
t.Helper()
return func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(payload))
}
}
func jsonHTTPResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}