1
0
forked from baron/baron-sso

Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
2026-05-20 18:16:03 +09:00
79 changed files with 6977 additions and 1099 deletions

View File

@@ -17,6 +17,7 @@ type HeadlessJWKSCacheState struct {
CachedAt *time.Time `json:"cachedAt,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
NextRetryAt *time.Time `json:"nextRetryAt,omitempty"`
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
LastError string `json:"lastError,omitempty"`

View File

@@ -166,9 +166,11 @@ type passwordLoginUserRepo struct {
func (r *passwordLoginUserRepo) Create(ctx context.Context, user *domain.User) error { return nil }
func (r *passwordLoginUserRepo) Update(ctx context.Context, user *domain.User) error { return nil }
func (r *passwordLoginUserRepo) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, errors.New("not found")
}
func (r *passwordLoginUserRepo) FindByID(ctx context.Context, id string) (*domain.User, error) {
if r != nil {
if user, ok := r.usersByID[id]; ok {
@@ -177,40 +179,53 @@ func (r *passwordLoginUserRepo) FindByID(ctx context.Context, id string) (*domai
}
return nil, errors.New("not found")
}
func (r *passwordLoginUserRepo) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
return nil, nil
}
func (r *passwordLoginUserRepo) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
return nil, nil
}
func (r *passwordLoginUserRepo) List(ctx context.Context, offset, limit int, search string, tenantSlug string) ([]domain.User, int64, error) {
return nil, 0, nil
}
func (r *passwordLoginUserRepo) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
return 0, nil
}
func (r *passwordLoginUserRepo) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
return nil, nil
}
func (r *passwordLoginUserRepo) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
return nil, nil
}
func (r *passwordLoginUserRepo) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
return nil, nil
}
func (r *passwordLoginUserRepo) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
return nil, nil
}
func (r *passwordLoginUserRepo) Delete(ctx context.Context, id string) error { return nil }
func (r *passwordLoginUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
return nil
}
func (r *passwordLoginUserRepo) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
return nil, nil
}
func (r *passwordLoginUserRepo) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
return false, nil
}
func (r *passwordLoginUserRepo) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
return "", nil
}

View File

@@ -20,11 +20,13 @@ const (
)
type HeadlessJWKSCacheService struct {
Redis domain.RedisRepository
HTTPClient *http.Client
TTL time.Duration
PrefetchWindow time.Duration
RequestTimeout time.Duration
Redis domain.RedisRepository
HTTPClient *http.Client
TTL time.Duration
PrefetchWindow time.Duration
RequestTimeout time.Duration
FailureThreshold int
FailureBackoff time.Duration
}
type headlessJWKSCacheStateStore struct {
@@ -33,6 +35,7 @@ type headlessJWKSCacheStateStore struct {
CachedAt *time.Time `json:"cachedAt,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
NextRetryAt *time.Time `json:"nextRetryAt,omitempty"`
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
LastError string `json:"lastError,omitempty"`
@@ -61,17 +64,29 @@ func NewHeadlessJWKSCacheService(redis domain.RedisRepository, httpClient *http.
prefetchSeconds = 600
}
timeoutSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FETCH_TIMEOUT_SECONDS", "5")))
timeoutSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FETCH_TIMEOUT_SECONDS", "2")))
if timeoutSeconds <= 0 {
timeoutSeconds = 5
timeoutSeconds = 2
}
failureThreshold, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_THRESHOLD", "3")))
if failureThreshold <= 0 {
failureThreshold = 3
}
backoffSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_BACKOFF_SECONDS", "1800")))
if backoffSeconds <= 0 {
backoffSeconds = 1800
}
return &HeadlessJWKSCacheService{
Redis: redis,
HTTPClient: httpClient,
TTL: time.Duration(ttlSeconds) * time.Second,
PrefetchWindow: time.Duration(prefetchSeconds) * time.Second,
RequestTimeout: time.Duration(timeoutSeconds) * time.Second,
Redis: redis,
HTTPClient: httpClient,
TTL: time.Duration(ttlSeconds) * time.Second,
PrefetchWindow: time.Duration(prefetchSeconds) * time.Second,
RequestTimeout: time.Duration(timeoutSeconds) * time.Second,
FailureThreshold: failureThreshold,
FailureBackoff: time.Duration(backoffSeconds) * time.Second,
}
}
@@ -115,6 +130,7 @@ func (s *HeadlessJWKSCacheService) SaveState(clientID string, state domain.Headl
CachedAt: state.CachedAt,
ExpiresAt: state.ExpiresAt,
LastCheckedAt: state.LastCheckedAt,
NextRetryAt: state.NextRetryAt,
LastSuccessfulVerificationAt: state.LastSuccessfulVerificationAt,
LastRefreshStatus: state.LastRefreshStatus,
LastError: state.LastError,
@@ -151,6 +167,7 @@ func (s *HeadlessJWKSCacheService) GetState(clientID string) (*domain.HeadlessJW
CachedAt: stored.CachedAt,
ExpiresAt: stored.ExpiresAt,
LastCheckedAt: stored.LastCheckedAt,
NextRetryAt: stored.NextRetryAt,
LastSuccessfulVerificationAt: stored.LastSuccessfulVerificationAt,
LastRefreshStatus: stored.LastRefreshStatus,
LastError: stored.LastError,
@@ -193,6 +210,9 @@ func (s *HeadlessJWKSCacheService) ShouldPrefetch(state *domain.HeadlessJWKSCach
if state == nil {
return true
}
if s.ShouldSkipRefresh(state, now) {
return false
}
if strings.TrimSpace(state.RawJWKS) == "" {
return true
}
@@ -202,6 +222,13 @@ func (s *HeadlessJWKSCacheService) ShouldPrefetch(state *domain.HeadlessJWKSCach
return !state.ExpiresAt.After(now.Add(s.PrefetchWindow))
}
func (s *HeadlessJWKSCacheService) ShouldSkipRefresh(state *domain.HeadlessJWKSCacheState, now time.Time) bool {
if state == nil || state.NextRetryAt == nil {
return false
}
return state.NextRetryAt.After(now)
}
func (s *HeadlessJWKSCacheService) EnsureFreshKeySet(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
if s == nil {
return nil, nil, false, fmt.Errorf("headless jwks cache service is not configured")
@@ -283,6 +310,7 @@ func (s *HeadlessJWKSCacheService) refreshClient(ctx context.Context, client dom
updated.JWKSURI = jwksURI
updated.LastCheckedAt = &now
updated.ExpiresAt = ptrTime(now.Add(s.TTL))
updated.NextRetryAt = nil
updated.LastRefreshStatus = "success"
updated.LastError = ""
updated.ConsecutiveFailures = 0
@@ -313,6 +341,7 @@ func (s *HeadlessJWKSCacheService) refreshClient(ctx context.Context, client dom
CachedAt: &now,
ExpiresAt: ptrTime(now.Add(s.TTL)),
LastCheckedAt: &now,
NextRetryAt: nil,
LastSuccessfulVerificationAt: previousLastVerification(previous),
LastRefreshStatus: "success",
LastError: "",
@@ -349,10 +378,28 @@ func (s *HeadlessJWKSCacheService) persistRefreshFailure(client domain.HydraClie
state.RawJWKS = previous.RawJWKS
state.ConsecutiveFailures = previous.ConsecutiveFailures + 1
}
if s.shouldBackoff(state.ConsecutiveFailures) {
state.NextRetryAt = ptrTime(now.Add(s.failureBackoffDuration()))
}
_ = s.SaveState(client.ClientID, state)
return &state
}
func (s *HeadlessJWKSCacheService) shouldBackoff(consecutiveFailures int) bool {
threshold := s.FailureThreshold
if threshold <= 0 {
threshold = 3
}
return consecutiveFailures >= threshold
}
func (s *HeadlessJWKSCacheService) failureBackoffDuration() time.Duration {
if s.FailureBackoff > 0 {
return s.FailureBackoff
}
return 30 * time.Minute
}
func decodeHeadlessJWKS(raw string) (*jose.JSONWebKeySet, error) {
var keySet jose.JSONWebKeySet
if err := json.Unmarshal([]byte(raw), &keySet); err != nil {

View File

@@ -6,7 +6,9 @@ import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
@@ -143,6 +145,261 @@ func TestHeadlessJWKSCacheService_EnsureFreshKeySet_RefreshesWhenKidMissing(t *t
assert.Equal(t, []string{"fresh-key"}, stored.CachedKids)
}
func TestHeadlessJWKSCacheService_PersistRefreshFailure_SetsNextRetryAtAfterThreshold(t *testing.T) {
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
client := domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}
previous := &domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}
state := cacheService.persistRefreshFailure(client, previous, assert.AnError)
require.NotNil(t, state)
assert.Equal(t, 3, state.ConsecutiveFailures)
require.NotNil(t, state.NextRetryAt)
assert.WithinDuration(t, time.Now().Add(15*time.Minute), *state.NextRetryAt, 3*time.Second)
}
func TestHeadlessJWKSCacheService_ShouldPrefetch_SkipsUntilNextRetryAt(t *testing.T) {
cacheService := NewHeadlessJWKSCacheService(&headlessJWKSCacheTestRedis{}, nil)
now := time.Now()
state := &domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: ptrTestTime(now.Add(10 * time.Minute)),
}
assert.False(t, cacheService.ShouldPrefetch(state, now))
assert.True(t, cacheService.ShouldPrefetch(state, now.Add(11*time.Minute)))
}
func TestHeadlessJWKSCacheWorker_RunOnce_SkipsBackoffTargets(t *testing.T) {
clients := []domain.HydraClient{
newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json"),
newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json"),
}
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, clients)),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
now := time.Now()
require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{
ClientID: "client-fail",
JWKSURI: clients[0].HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}))
require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{
ClientID: "client-skip",
JWKSURI: clients[1].HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: ptrTestTime(now.Add(10 * time.Minute)),
}))
fetchCounts := map[string]int{}
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCounts[req.URL.Host]++
if req.URL.Host == "fail.example.com" {
return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
}
t.Fatalf("unexpected fetch for host %s", req.URL.Host)
return nil, nil
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCounts["fail.example.com"])
assert.Equal(t, 0, fetchCounts["skip.example.com"])
failedState, err := cacheService.GetState("client-fail")
require.NoError(t, err)
require.NotNil(t, failedState)
assert.Equal(t, 3, failedState.ConsecutiveFailures)
require.NotNil(t, failedState.NextRetryAt)
skippedState, err := cacheService.GetState("client-skip")
require.NoError(t, err)
require.NotNil(t, skippedState)
assert.Equal(t, 3, skippedState.ConsecutiveFailures)
require.NotNil(t, skippedState.NextRetryAt)
assert.WithinDuration(t, now.Add(10*time.Minute), *skippedState.NextRetryAt, time.Second)
}
func TestHeadlessJWKSCacheWorker_RunOnce_RetriesAfterBackoffAndClearsFailureStateOnSuccess(t *testing.T) {
_, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key")
freshRaw, err := json.Marshal(freshJWKS)
require.NoError(t, err)
client := newTestHeadlessClient("client-recover", "https://recover.example.com/.well-known/jwks.json")
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{client})),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
require.NoError(t, cacheService.SaveState("client-recover", domain.HeadlessJWKSCacheState{
ClientID: "client-recover",
JWKSURI: client.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
LastError: "previous failure",
ConsecutiveFailures: 3,
NextRetryAt: ptrTestTime(time.Now().Add(-time.Minute)),
}))
fetchCount := 0
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCount++
assert.Equal(t, "recover.example.com", req.URL.Host)
return jsonHTTPResponse(http.StatusOK, string(freshRaw)), nil
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCount)
recoveredState, err := cacheService.GetState("client-recover")
require.NoError(t, err)
require.NotNil(t, recoveredState)
assert.Equal(t, "success", recoveredState.LastRefreshStatus)
assert.Empty(t, recoveredState.LastError)
assert.Equal(t, 0, recoveredState.ConsecutiveFailures)
assert.Nil(t, recoveredState.NextRetryAt)
assert.Equal(t, []string{"fresh-key"}, recoveredState.CachedKids)
}
func TestHeadlessJWKSCacheWorker_RunOnce_MixedClients(t *testing.T) {
_, successJWKS := mustServiceHeadlessRSAJWK(t, "success-key")
successRaw, err := json.Marshal(successJWKS)
require.NoError(t, err)
successClient := newTestHeadlessClient("client-success", "https://success.example.com/.well-known/jwks.json")
failClient := newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json")
skipClient := newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json")
disabledClient := domain.HydraClient{
ClientID: "client-disabled",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: false,
domain.MetadataHeadlessJWKSURI: "https://disabled.example.com/.well-known/jwks.json",
domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt",
},
}
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{
successClient,
failClient,
skipClient,
disabledClient,
})),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 20 * time.Minute
require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{
ClientID: "client-fail",
JWKSURI: failClient.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}))
require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{
ClientID: "client-skip",
JWKSURI: skipClient.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: ptrTestTime(time.Now().Add(10 * time.Minute)),
}))
fetchCounts := map[string]int{}
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCounts[req.URL.Host]++
switch req.URL.Host {
case "success.example.com":
return jsonHTTPResponse(http.StatusOK, string(successRaw)), nil
case "fail.example.com":
return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
default:
t.Fatalf("unexpected fetch for host %s", req.URL.Host)
return nil, nil
}
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCounts["success.example.com"])
assert.Equal(t, 1, fetchCounts["fail.example.com"])
assert.Equal(t, 0, fetchCounts["skip.example.com"])
assert.Equal(t, 0, fetchCounts["disabled.example.com"])
successState, err := cacheService.GetState("client-success")
require.NoError(t, err)
require.NotNil(t, successState)
assert.Equal(t, "success", successState.LastRefreshStatus)
assert.Equal(t, 0, successState.ConsecutiveFailures)
assert.Nil(t, successState.NextRetryAt)
failState, err := cacheService.GetState("client-fail")
require.NoError(t, err)
require.NotNil(t, failState)
assert.Equal(t, "failure", failState.LastRefreshStatus)
assert.Equal(t, 3, failState.ConsecutiveFailures)
require.NotNil(t, failState.NextRetryAt)
}
func mustServiceHeadlessRSAJWK(t *testing.T, kid string) (*rsa.PrivateKey, jose.JSONWebKeySet) {
t.Helper()
@@ -162,3 +419,31 @@ func mustServiceHeadlessRSAJWK(t *testing.T, kid string) (*rsa.PrivateKey, jose.
func ptrTestTime(value time.Time) *time.Time {
return &value
}
func newTestHeadlessClient(clientID, jwksURI string) domain.HydraClient {
return domain.HydraClient{
ClientID: clientID,
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: jwksURI,
domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt",
},
}
}
func jsonHandler(t *testing.T, payload any) http.HandlerFunc {
t.Helper()
return func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(payload))
}
}
func jsonHTTPResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}