첫 커밋: 로컬 프로젝트 업로드

This commit is contained in:
2026-06-10 15:51:34 +09:00
commit 6a8dbeb2e9
1211 changed files with 312864 additions and 0 deletions

View File

@@ -0,0 +1,192 @@
package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/go-jose/go-jose/v4"
josejwt "github.com/go-jose/go-jose/v4/jwt"
)
const backchannelLogoutEventURI = "http://schemas.openid.net/event/backchannel-logout"
type BackchannelLogoutService struct {
issuer string
keyID string
signer jose.Signer
publicJWK jose.JSONWebKey
client *http.Client
HTTPClient *http.Client
}
func NewBackchannelLogoutService() (*BackchannelLogoutService, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to generate backchannel logout key: %w", err)
}
keyID := randomBackchannelKeyID()
if keyID == "" {
keyID = fmt.Sprintf("bcl-%d", time.Now().UnixNano())
}
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: jose.JSONWebKey{
Key: privateKey,
KeyID: keyID,
Algorithm: string(jose.RS256),
Use: "sig",
},
}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
return nil, fmt.Errorf("failed to initialize backchannel logout signer: %w", err)
}
return &BackchannelLogoutService{
issuer: resolveBackchannelLogoutIssuer(),
keyID: keyID,
signer: signer,
publicJWK: jose.JSONWebKey{
Key: &privateKey.PublicKey,
KeyID: keyID,
Algorithm: string(jose.RS256),
Use: "sig",
},
client: &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 3 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 3 * time.Second,
},
},
}, nil
}
func randomBackchannelKeyID() string {
buf := make([]byte, 8)
if _, err := rand.Read(buf); err != nil {
return ""
}
return hex.EncodeToString(buf)
}
func resolveBackchannelLogoutIssuer() string {
if explicit := strings.TrimSpace(os.Getenv("BACKCHANNEL_LOGOUT_ISSUER")); explicit != "" {
return strings.TrimRight(explicit, "/")
}
if hydraPublic := strings.TrimSpace(os.Getenv("HYDRA_PUBLIC_URL")); hydraPublic != "" {
return strings.TrimRight(hydraPublic, "/")
}
if oathkeeperPublic := strings.TrimSpace(os.Getenv("OATHKEEPER_PUBLIC_URL")); oathkeeperPublic != "" {
return strings.TrimRight(oathkeeperPublic, "/") + "/oidc"
}
if userfrontURL := strings.TrimSpace(os.Getenv("USERFRONT_URL")); userfrontURL != "" {
return strings.TrimRight(userfrontURL, "/") + "/oidc"
}
return "http://localhost:5000/oidc"
}
func (s *BackchannelLogoutService) Issuer() string {
if s == nil {
return ""
}
return s.issuer
}
func (s *BackchannelLogoutService) PublicJWKS() map[string]any {
if s == nil {
return map[string]any{"keys": []any{}}
}
return map[string]any{
"keys": []jose.JSONWebKey{s.publicJWK.Public()},
}
}
func (s *BackchannelLogoutService) BuildLogoutToken(clientID, subject, sessionID string) (string, error) {
if s == nil || s.signer == nil {
return "", fmt.Errorf("backchannel logout service is unavailable")
}
clientID = strings.TrimSpace(clientID)
subject = strings.TrimSpace(subject)
sessionID = strings.TrimSpace(sessionID)
if clientID == "" {
return "", fmt.Errorf("client id is required")
}
if subject == "" && sessionID == "" {
return "", fmt.Errorf("subject or session id is required")
}
now := time.Now().UTC()
claims := josejwt.Claims{
Issuer: s.issuer,
Audience: josejwt.Audience{clientID},
IssuedAt: josejwt.NewNumericDate(now),
ID: fmt.Sprintf("%s-%d", s.keyID, now.UnixNano()),
}
if subject != "" {
claims.Subject = subject
}
extra := map[string]any{
"events": map[string]any{
backchannelLogoutEventURI: map[string]any{},
},
}
if sessionID != "" {
extra["sid"] = sessionID
}
return josejwt.Signed(s.signer).Claims(claims).Claims(extra).Serialize()
}
func (s *BackchannelLogoutService) SendLogoutToken(ctx context.Context, endpoint, logoutToken string) (int, error) {
if s == nil {
return 0, fmt.Errorf("backchannel logout service is unavailable")
}
form := url.Values{}
form.Set("logout_token", logoutToken)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := s.client
if s.HTTPClient != nil {
client = s.HTTPClient
}
resp, err := client.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return resp.StatusCode, fmt.Errorf("backchannel logout endpoint returned status %d", resp.StatusCode)
}
return resp.StatusCode, nil
}
func (s *BackchannelLogoutService) MarshalPublicJWKS() ([]byte, error) {
return json.Marshal(s.PublicJWKS())
}

View File

@@ -0,0 +1,85 @@
package service
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/go-jose/go-jose/v4"
josejwt "github.com/go-jose/go-jose/v4/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBackchannelLogoutService_BuildLogoutToken(t *testing.T) {
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
svc, err := NewBackchannelLogoutService()
require.NoError(t, err)
token, err := svc.BuildLogoutToken("client-1", "user-1", "sid-1")
require.NoError(t, err)
require.NotEmpty(t, token)
jwksRaw, err := svc.MarshalPublicJWKS()
require.NoError(t, err)
var jwks struct {
Keys []jose.JSONWebKey `json:"keys"`
}
require.NoError(t, json.Unmarshal(jwksRaw, &jwks))
require.Len(t, jwks.Keys, 1)
parsed, err := josejwt.ParseSigned(token, []jose.SignatureAlgorithm{jose.RS256})
require.NoError(t, err)
var claims struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Aud any `json:"aud"`
Iat int64 `json:"iat"`
Jti string `json:"jti"`
Sid string `json:"sid"`
Events map[string]any `json:"events"`
}
require.NoError(t, parsed.Claims(jwks.Keys[0].Key, &claims))
assert.Equal(t, "https://sso.example.com/oidc", claims.Issuer)
assert.Equal(t, "user-1", claims.Subject)
switch aud := claims.Aud.(type) {
case string:
assert.Equal(t, "client-1", aud)
case []any:
assert.Len(t, aud, 1)
assert.Equal(t, "client-1", aud[0])
default:
t.Fatalf("unexpected aud type: %T", claims.Aud)
}
assert.NotZero(t, claims.Iat)
assert.NotEmpty(t, claims.Jti)
assert.Equal(t, "sid-1", claims.Sid)
_, ok := claims.Events[backchannelLogoutEventURI]
assert.True(t, ok)
}
func TestBackchannelLogoutService_SendLogoutToken(t *testing.T) {
var body string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
raw, _ := io.ReadAll(r.Body)
body = string(raw)
w.WriteHeader(http.StatusNoContent)
})
svc, err := NewBackchannelLogoutService()
require.NoError(t, err)
svc.HTTPClient = clientForHandler(handler)
statusCode, err := svc.SendLogoutToken(context.Background(), "https://rp.example.com/backchannel-logout", "signed-token")
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, statusCode)
assert.Equal(t, "logout_token=signed-token", body)
}

View File

@@ -0,0 +1,86 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"gorm.io/gorm"
)
type DeveloperService struct {
db *gorm.DB
}
func NewDeveloperService(db *gorm.DB) *DeveloperService {
return &DeveloperService{db: db}
}
func (s *DeveloperService) RequestAccess(ctx context.Context, req domain.DeveloperRequest) error {
// Check if there is already a pending request
var existing domain.DeveloperRequest
err := s.db.WithContext(ctx).Where("user_id = ? AND tenant_id = ? AND status = ?", req.UserID, req.TenantID, domain.DeveloperRequestStatusPending).First(&existing).Error
if err == nil {
return nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
return s.db.WithContext(ctx).Create(&req).Error
}
func (s *DeveloperService) GetRequestStatus(ctx context.Context, userID, tenantID string) (*domain.DeveloperRequest, error) {
var req domain.DeveloperRequest
err := s.db.WithContext(ctx).Where("user_id = ? AND tenant_id = ?", userID, tenantID).Order("created_at DESC").First(&req).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &req, nil
}
func (s *DeveloperService) GetRequestByID(ctx context.Context, id uint) (*domain.DeveloperRequest, error) {
var req domain.DeveloperRequest
err := s.db.WithContext(ctx).First(&req, id).Error
if err != nil {
return nil, err
}
return &req, nil
}
func (s *DeveloperService) ListRequests(ctx context.Context, userID, status string) ([]domain.DeveloperRequest, error) {
var requests []domain.DeveloperRequest
query := s.db.WithContext(ctx)
if userID != "" {
query = query.Where("user_id = ?", userID)
}
if status != "" {
query = query.Where("status = ?", status)
}
err := query.Order("created_at DESC").Find(&requests).Error
return requests, err
}
func (s *DeveloperService) ApproveRequest(ctx context.Context, id uint, adminNotes string) error {
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.DeveloperRequestStatusApproved,
"admin_notes": adminNotes,
}).Error
}
func (s *DeveloperService) RejectRequest(ctx context.Context, id uint, adminNotes string) error {
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.DeveloperRequestStatusRejected,
"admin_notes": adminNotes,
}).Error
}
func (s *DeveloperService) CancelApprovedRequest(ctx context.Context, id uint, adminNotes string) error {
return s.db.WithContext(ctx).Model(&domain.DeveloperRequest{}).Where("id = ?", id).Updates(map[string]any{
"status": domain.DeveloperRequestStatusCancelled,
"admin_notes": adminNotes,
}).Error
}

View File

@@ -0,0 +1,19 @@
package service
import (
"baron-sso-backend/internal/logger"
"os"
"strings"
)
func IsProductionEnv() bool {
env := strings.ToLower(os.Getenv("APP_ENV"))
if env == "" {
env = strings.ToLower(os.Getenv("GO_ENV"))
}
return logger.IsProductionLikeEnv(env)
}
func IsDryRunAllowed() bool {
return !IsProductionEnv()
}

View File

@@ -0,0 +1,43 @@
package service
import (
"os"
"testing"
)
func TestIsProductionEnv_StageIsProductionLike(t *testing.T) {
t.Setenv("APP_ENV", "stage")
t.Setenv("GO_ENV", "")
if !IsProductionEnv() {
t.Fatalf("expected stage to be treated as production-like")
}
}
func TestIsDryRunAllowed_DisabledInStage(t *testing.T) {
t.Setenv("APP_ENV", "stage")
t.Setenv("GO_ENV", "")
if IsDryRunAllowed() {
t.Fatalf("expected dry-run to be disabled in stage")
}
}
func TestIsProductionEnv_FallsBackToGoEnv(t *testing.T) {
originalAppEnv, hadAppEnv := os.LookupEnv("APP_ENV")
if hadAppEnv {
t.Cleanup(func() {
_ = os.Setenv("APP_ENV", originalAppEnv)
})
} else {
t.Cleanup(func() {
_ = os.Unsetenv("APP_ENV")
})
}
_ = os.Unsetenv("APP_ENV")
t.Setenv("GO_ENV", "production")
if !IsProductionEnv() {
t.Fatalf("expected GO_ENV=production fallback to be production-like")
}
}

View File

@@ -0,0 +1,90 @@
package service
import (
"baron-sso-backend/internal/repository"
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)
type FederationService struct {
repo repository.FederationRepository
hydraSvc *HydraAdminService
redisSvc *RedisService
}
func NewFederationService(repo repository.FederationRepository, hydraSvc *HydraAdminService, redisSvc *RedisService) *FederationService {
return &FederationService{repo: repo, hydraSvc: hydraSvc, redisSvc: redisSvc}
}
func (s *FederationService) InitiateOIDCLogin(ctx context.Context, providerID, loginChallenge string) (string, error) {
provider, err := s.repo.FindProviderByID(ctx, providerID)
if err != nil {
return "", fmt.Errorf("failed to find provider: %w", err)
}
if provider == nil || provider.IssuerURL == nil || provider.OIDCClientID == nil || provider.OIDCClientSecret == nil || provider.Scopes == nil {
return "", fmt.Errorf("OIDC configuration for provider %s is incomplete", providerID)
}
oidcProvider, err := oidc.NewProvider(ctx, *provider.IssuerURL)
if err != nil {
return "", fmt.Errorf("failed to create OIDC provider: %w", err)
}
config := oauth2.Config{
ClientID: *provider.OIDCClientID,
ClientSecret: *provider.OIDCClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: "http://localhost:8080/api/v1/federation/oidc/callback", // This should be configurable
Scopes: []string{*provider.Scopes},
}
state, err := generateState()
if err != nil {
return "", fmt.Errorf("failed to generate state: %w", err)
}
// Store state and login_challenge in Redis
redisKey := fmt.Sprintf("oidc_state:%s", state)
if err := s.redisSvc.Set(redisKey, loginChallenge, 10*time.Minute); err != nil {
return "", fmt.Errorf("failed to save state to Redis: %w", err)
}
return config.AuthCodeURL(state), nil
}
func (s *FederationService) HandleOIDCCallback(ctx context.Context, code, state string) (string, error) {
// 1. Retrieve login_challenge from Redis
redisKey := fmt.Sprintf("oidc_state:%s", state)
loginChallenge, err := s.redisSvc.Get(redisKey)
if err != nil {
return "", fmt.Errorf("failed to get state from Redis or state expired: %w", err)
}
// Delete the state from Redis now that it's been used
s.redisSvc.Delete(redisKey)
// TODO: Finish the rest of the callback logic
// 2. Exchange code for token
// 3. Verify ID token
// 4. JIT Provisioning
// 5. Accept Hydra Login Request
fmt.Println("Login challenge found:", loginChallenge)
return "http://localhost:3000/login?login_successful=true", nil // Placeholder
}
func generateState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}

View File

@@ -0,0 +1,545 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"time"
"github.com/go-jose/go-jose/v4"
)
const (
headlessJWKSCacheKeyPrefix = "headless_jwks_cache:"
)
type HeadlessJWKSCacheService struct {
Redis domain.RedisRepository
HTTPClient *http.Client
TTL time.Duration
PrefetchWindow time.Duration
RequestTimeout time.Duration
FailureThreshold int
FailureBackoff time.Duration
}
type headlessJWKSCacheStateStore struct {
ClientID string `json:"clientId"`
JWKSURI string `json:"jwksUri"`
CachedAt *time.Time `json:"cachedAt,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
NextRetryAt *time.Time `json:"nextRetryAt,omitempty"`
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
LastError string `json:"lastError,omitempty"`
ConsecutiveFailures int `json:"consecutiveFailures,omitempty"`
CachedKids []string `json:"cachedKids,omitempty"`
ETag string `json:"etag,omitempty"`
LastModified string `json:"lastModified,omitempty"`
RawJWKS string `json:"rawJwks,omitempty"`
}
type HeadlessJWKSCacheWorker struct {
Hydra *HydraAdminService
Cache *HeadlessJWKSCacheService
Interval time.Duration
PageSize int
}
func NewHeadlessJWKSCacheService(redis domain.RedisRepository, httpClient *http.Client) *HeadlessJWKSCacheService {
ttlSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_CACHE_TTL_SECONDS", "1800")))
if ttlSeconds <= 0 {
ttlSeconds = 1800
}
prefetchSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_PREFETCH_WINDOW_SECONDS", "600")))
if prefetchSeconds <= 0 {
prefetchSeconds = 600
}
timeoutSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FETCH_TIMEOUT_SECONDS", "2")))
if timeoutSeconds <= 0 {
timeoutSeconds = 2
}
failureThreshold, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_THRESHOLD", "3")))
if failureThreshold <= 0 {
failureThreshold = 3
}
backoffSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_FAILURE_BACKOFF_SECONDS", "1800")))
if backoffSeconds <= 0 {
backoffSeconds = 1800
}
return &HeadlessJWKSCacheService{
Redis: redis,
HTTPClient: httpClient,
TTL: time.Duration(ttlSeconds) * time.Second,
PrefetchWindow: time.Duration(prefetchSeconds) * time.Second,
RequestTimeout: time.Duration(timeoutSeconds) * time.Second,
FailureThreshold: failureThreshold,
FailureBackoff: time.Duration(backoffSeconds) * time.Second,
}
}
func NewHeadlessJWKSCacheWorker(hydra *HydraAdminService, cache *HeadlessJWKSCacheService) *HeadlessJWKSCacheWorker {
intervalSeconds, _ := strconv.Atoi(strings.TrimSpace(getenv("HEADLESS_JWKS_REFRESH_INTERVAL_SECONDS", "600")))
if intervalSeconds <= 0 {
intervalSeconds = 600
}
return &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cache,
Interval: time.Duration(intervalSeconds) * time.Second,
PageSize: 100,
}
}
func (s *HeadlessJWKSCacheService) httpClient() *http.Client {
if s.HTTPClient != nil {
return s.HTTPClient
}
timeout := s.RequestTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
return &http.Client{Timeout: timeout}
}
func (s *HeadlessJWKSCacheService) cacheKey(clientID string) string {
return headlessJWKSCacheKeyPrefix + strings.TrimSpace(clientID)
}
func (s *HeadlessJWKSCacheService) SaveState(clientID string, state domain.HeadlessJWKSCacheState) error {
if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" {
return nil
}
payload, err := json.Marshal(headlessJWKSCacheStateStore{
ClientID: state.ClientID,
JWKSURI: state.JWKSURI,
CachedAt: state.CachedAt,
ExpiresAt: state.ExpiresAt,
LastCheckedAt: state.LastCheckedAt,
NextRetryAt: state.NextRetryAt,
LastSuccessfulVerificationAt: state.LastSuccessfulVerificationAt,
LastRefreshStatus: state.LastRefreshStatus,
LastError: state.LastError,
ConsecutiveFailures: state.ConsecutiveFailures,
CachedKids: state.CachedKids,
ETag: state.ETag,
LastModified: state.LastModified,
RawJWKS: state.RawJWKS,
})
if err != nil {
return err
}
return s.Redis.Set(s.cacheKey(clientID), string(payload), 0)
}
func (s *HeadlessJWKSCacheService) GetState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
if s == nil || s.Redis == nil || strings.TrimSpace(clientID) == "" {
return nil, nil
}
raw, err := s.Redis.Get(s.cacheKey(clientID))
if err != nil || strings.TrimSpace(raw) == "" {
return nil, err
}
var stored headlessJWKSCacheStateStore
if err := json.Unmarshal([]byte(raw), &stored); err != nil {
return nil, err
}
return &domain.HeadlessJWKSCacheState{
ClientID: stored.ClientID,
JWKSURI: stored.JWKSURI,
CachedAt: stored.CachedAt,
ExpiresAt: stored.ExpiresAt,
LastCheckedAt: stored.LastCheckedAt,
NextRetryAt: stored.NextRetryAt,
LastSuccessfulVerificationAt: stored.LastSuccessfulVerificationAt,
LastRefreshStatus: stored.LastRefreshStatus,
LastError: stored.LastError,
ConsecutiveFailures: stored.ConsecutiveFailures,
CachedKids: stored.CachedKids,
ETag: stored.ETag,
LastModified: stored.LastModified,
RawJWKS: stored.RawJWKS,
}, nil
}
func (s *HeadlessJWKSCacheService) DeleteState(clientID string) error {
if s == nil || s.Redis == nil {
return nil
}
return s.Redis.Delete(s.cacheKey(clientID))
}
func (s *HeadlessJWKSCacheService) PublicState(clientID string) (*domain.HeadlessJWKSCacheState, error) {
state, err := s.GetState(clientID)
if err != nil || state == nil {
return state, err
}
state.ParsedKeys = summarizeHeadlessJWKS(state.RawJWKS)
state.RawJWKS = ""
return state, nil
}
func (s *HeadlessJWKSCacheService) MarkVerificationSuccess(clientID string) error {
state, err := s.GetState(clientID)
if err != nil || state == nil {
return err
}
now := time.Now()
state.LastSuccessfulVerificationAt = &now
return s.SaveState(clientID, *state)
}
func (s *HeadlessJWKSCacheService) ShouldPrefetch(state *domain.HeadlessJWKSCacheState, now time.Time) bool {
if state == nil {
return true
}
if s.ShouldSkipRefresh(state, now) {
return false
}
if strings.TrimSpace(state.RawJWKS) == "" {
return true
}
if state.ExpiresAt == nil {
return true
}
return !state.ExpiresAt.After(now.Add(s.PrefetchWindow))
}
func (s *HeadlessJWKSCacheService) ShouldSkipRefresh(state *domain.HeadlessJWKSCacheState, now time.Time) bool {
if state == nil || state.NextRetryAt == nil {
return false
}
return state.NextRetryAt.After(now)
}
func (s *HeadlessJWKSCacheService) EnsureFreshKeySet(ctx context.Context, client domain.HydraClient, expectedKid string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
if s == nil {
return nil, nil, false, fmt.Errorf("headless jwks cache service is not configured")
}
jwksURI := strings.TrimSpace(client.HeadlessJWKSURI())
if jwksURI == "" {
return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
}
state, err := s.GetState(client.ClientID)
if err != nil {
slog.Warn("failed to load headless jwks cache state", "clientID", client.ClientID, "error", err)
}
now := time.Now()
switch {
case state == nil:
return s.refreshClient(ctx, client, nil, "cache_miss")
case strings.TrimSpace(state.JWKSURI) != jwksURI:
return s.refreshClient(ctx, client, state, "config_changed")
case strings.TrimSpace(state.RawJWKS) == "":
return s.refreshClient(ctx, client, state, "cache_empty")
case state.ExpiresAt == nil || !state.ExpiresAt.After(now):
return s.refreshClient(ctx, client, state, "ttl_expired")
case expectedKid != "" && !containsString(state.CachedKids, expectedKid):
return s.refreshClient(ctx, client, state, "kid_missing")
default:
keySet, err := decodeHeadlessJWKS(state.RawJWKS)
if err != nil {
return s.refreshClient(ctx, client, state, "cache_corrupt")
}
return keySet, state, false, nil
}
}
func (s *HeadlessJWKSCacheService) ForceRefresh(ctx context.Context, client domain.HydraClient, reason string) (*domain.HeadlessJWKSCacheState, error) {
_, state, err := s.ForceRefreshKeySet(ctx, client, reason)
return state, err
}
func (s *HeadlessJWKSCacheService) ForceRefreshKeySet(ctx context.Context, client domain.HydraClient, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, error) {
previous, err := s.GetState(client.ClientID)
if err != nil {
slog.Warn("failed to load headless jwks cache state before force refresh", "clientID", client.ClientID, "error", err)
}
keySet, state, _, err := s.refreshClient(ctx, client, previous, reason)
return keySet, state, err
}
func (s *HeadlessJWKSCacheService) refreshClient(ctx context.Context, client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, reason string) (*jose.JSONWebKeySet, *domain.HeadlessJWKSCacheState, bool, error) {
jwksURI := strings.TrimSpace(client.HeadlessJWKSURI())
if jwksURI == "" {
return nil, nil, false, fmt.Errorf("headless login requires jwksUri; inline jwks is not supported")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURI, nil)
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to build jwks request: %w", err)), false, err
}
if previous != nil {
if etag := strings.TrimSpace(previous.ETag); etag != "" {
req.Header.Set("If-None-Match", etag)
}
if lastModified := strings.TrimSpace(previous.LastModified); lastModified != "" {
req.Header.Set("If-Modified-Since", lastModified)
}
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to fetch jwksUri: %w", err)), false, err
}
defer resp.Body.Close()
now := time.Now()
if resp.StatusCode == http.StatusNotModified && previous != nil && strings.TrimSpace(previous.RawJWKS) != "" {
updated := *previous
updated.JWKSURI = jwksURI
updated.LastCheckedAt = &now
updated.ExpiresAt = new(now.Add(s.TTL))
updated.NextRetryAt = nil
updated.LastRefreshStatus = "success"
updated.LastError = ""
updated.ConsecutiveFailures = 0
_ = s.SaveState(client.ClientID, updated)
keySet, decodeErr := decodeHeadlessJWKS(updated.RawJWKS)
return keySet, &updated, true, decodeErr
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
err = fmt.Errorf("failed to fetch jwksUri status=%d body=%s", resp.StatusCode, string(body))
return nil, s.persistRefreshFailure(client, previous, err), false, err
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return nil, s.persistRefreshFailure(client, previous, fmt.Errorf("failed to read jwks response: %w", err)), false, err
}
keySet, err := decodeHeadlessJWKS(string(body))
if err != nil {
return nil, s.persistRefreshFailure(client, previous, err), false, err
}
state := domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: jwksURI,
CachedAt: &now,
ExpiresAt: new(now.Add(s.TTL)),
LastCheckedAt: &now,
NextRetryAt: nil,
LastSuccessfulVerificationAt: previousLastVerification(previous),
LastRefreshStatus: "success",
LastError: "",
ConsecutiveFailures: 0,
CachedKids: extractHeadlessKids(keySet),
ETag: strings.TrimSpace(resp.Header.Get("ETag")),
LastModified: strings.TrimSpace(resp.Header.Get("Last-Modified")),
RawJWKS: string(body),
}
if err := s.SaveState(client.ClientID, state); err != nil {
return nil, &state, false, err
}
slog.Info("headless jwks cache refreshed", "clientID", client.ClientID, "reason", reason, "keyCount", len(keySet.Keys))
return keySet, &state, true, nil
}
func (s *HeadlessJWKSCacheService) persistRefreshFailure(client domain.HydraClient, previous *domain.HeadlessJWKSCacheState, refreshErr error) *domain.HeadlessJWKSCacheState {
now := time.Now()
state := domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: strings.TrimSpace(client.HeadlessJWKSURI()),
LastCheckedAt: &now,
LastRefreshStatus: "failure",
LastError: refreshErr.Error(),
ConsecutiveFailures: 1,
}
if previous != nil {
state.CachedAt = previous.CachedAt
state.ExpiresAt = previous.ExpiresAt
state.LastSuccessfulVerificationAt = previous.LastSuccessfulVerificationAt
state.CachedKids = previous.CachedKids
state.ETag = previous.ETag
state.LastModified = previous.LastModified
state.RawJWKS = previous.RawJWKS
state.ConsecutiveFailures = previous.ConsecutiveFailures + 1
}
if s.shouldBackoff(state.ConsecutiveFailures) {
state.NextRetryAt = new(now.Add(s.failureBackoffDuration()))
}
_ = s.SaveState(client.ClientID, state)
return &state
}
func (s *HeadlessJWKSCacheService) shouldBackoff(consecutiveFailures int) bool {
threshold := s.FailureThreshold
if threshold <= 0 {
threshold = 3
}
return consecutiveFailures >= threshold
}
func (s *HeadlessJWKSCacheService) failureBackoffDuration() time.Duration {
if s.FailureBackoff > 0 {
return s.FailureBackoff
}
return 30 * time.Minute
}
func decodeHeadlessJWKS(raw string) (*jose.JSONWebKeySet, error) {
var keySet jose.JSONWebKeySet
if err := json.Unmarshal([]byte(raw), &keySet); err != nil {
return nil, fmt.Errorf("failed to decode jwks from jwksUri: %w", err)
}
if len(keySet.Keys) == 0 {
return nil, fmt.Errorf("configured jwksUri returned no keys")
}
return &keySet, nil
}
type headlessJWKSPreviewDocument struct {
Keys []headlessJWKSPreviewKey `json:"keys"`
}
type headlessJWKSPreviewKey struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
}
func summarizeHeadlessJWKS(raw string) []domain.HeadlessJWKSParsedKey {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
var document headlessJWKSPreviewDocument
if err := json.Unmarshal([]byte(raw), &document); err != nil {
return nil
}
parsedKeys := make([]domain.HeadlessJWKSParsedKey, 0, len(document.Keys))
for _, key := range document.Keys {
parsedKeys = append(parsedKeys, domain.HeadlessJWKSParsedKey{
Kid: strings.TrimSpace(key.Kid),
Kty: strings.TrimSpace(key.Kty),
Use: strings.TrimSpace(key.Use),
Alg: strings.TrimSpace(key.Alg),
N: strings.TrimSpace(key.N),
})
}
return parsedKeys
}
func extractHeadlessKids(keySet *jose.JSONWebKeySet) []string {
if keySet == nil {
return nil
}
kids := make([]string, 0, len(keySet.Keys))
for _, key := range keySet.Keys {
if kid := strings.TrimSpace(key.KeyID); kid != "" {
kids = append(kids, kid)
}
}
return kids
}
func containsString(values []string, needle string) bool {
needle = strings.TrimSpace(needle)
if needle == "" {
return false
}
for _, value := range values {
if strings.TrimSpace(value) == needle {
return true
}
}
return false
}
func previousLastVerification(previous *domain.HeadlessJWKSCacheState) *time.Time {
if previous == nil {
return nil
}
return previous.LastSuccessfulVerificationAt
}
//go:fix inline
func ptrTime(value time.Time) *time.Time {
return new(value)
}
func (w *HeadlessJWKSCacheWorker) Start(ctx context.Context) {
if w == nil || w.Hydra == nil || w.Cache == nil {
return
}
w.runOnce(ctx)
ticker := time.NewTicker(w.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.runOnce(ctx)
}
}
}
func (w *HeadlessJWKSCacheWorker) runOnce(ctx context.Context) {
offset := 0
pageSize := w.PageSize
if pageSize <= 0 {
pageSize = 100
}
now := time.Now()
for {
clients, err := w.Hydra.ListClients(ctx, pageSize, offset)
if err != nil {
slog.Warn("headless jwks worker failed to list clients", "error", err)
return
}
if len(clients) == 0 {
return
}
for _, client := range clients {
if !client.IsHeadlessLoginEnabled() {
continue
}
state, err := w.Cache.GetState(client.ClientID)
if err != nil {
slog.Warn("headless jwks worker failed to load cache state", "clientID", client.ClientID, "error", err)
continue
}
if !w.Cache.ShouldPrefetch(state, now) {
continue
}
if _, err := w.Cache.ForceRefresh(ctx, client, "cron_prefetch"); err != nil {
slog.Warn("headless jwks worker refresh failed", "clientID", client.ClientID, "error", err)
}
}
if len(clients) < pageSize {
return
}
offset += len(clients)
}
}

View File

@@ -0,0 +1,450 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type headlessJWKSCacheTestRedis struct {
data map[string]string
}
func (m *headlessJWKSCacheTestRedis) Set(key string, value string, expiration time.Duration) error {
if m.data == nil {
m.data = map[string]string{}
}
m.data[key] = value
return nil
}
func (m *headlessJWKSCacheTestRedis) Get(key string) (string, error) {
if m.data == nil {
return "", nil
}
return m.data[key], nil
}
func (m *headlessJWKSCacheTestRedis) Delete(key string) error {
if m.data != nil {
delete(m.data, key)
}
return nil
}
func (m *headlessJWKSCacheTestRedis) StoreVerificationCode(phone, code string) error {
return nil
}
func (m *headlessJWKSCacheTestRedis) GetVerificationCode(phone string) (string, error) {
return "", nil
}
func (m *headlessJWKSCacheTestRedis) DeleteVerificationCode(phone string) error {
return nil
}
func TestHeadlessJWKSCacheService_EnsureFreshKeySet_UsesCachedJWKSWhenFresh(t *testing.T) {
_, jwks := mustServiceHeadlessRSAJWK(t, "cached-key")
raw, err := json.Marshal(jwks)
require.NoError(t, err)
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
now := time.Now()
err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
RawJWKS: string(raw),
CachedKids: []string{"cached-key"},
CachedAt: &now,
LastCheckedAt: &now,
ExpiresAt: new(now.Add(30 * time.Minute)),
LastRefreshStatus: "success",
ConsecutiveFailures: 0,
})
require.NoError(t, err)
cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected network fetch: %s", r.URL.String())
}))
keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}, "cached-key")
require.NoError(t, err)
assert.False(t, refreshed)
require.NotNil(t, keySet)
assert.Len(t, keySet.Keys, 1)
require.NotNil(t, state)
assert.Equal(t, []string{"cached-key"}, state.CachedKids)
}
func TestHeadlessJWKSCacheService_EnsureFreshKeySet_RefreshesWhenKidMissing(t *testing.T) {
_, staleJWKS := mustServiceHeadlessRSAJWK(t, "stale-key")
staleRaw, err := json.Marshal(staleJWKS)
require.NoError(t, err)
_, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key")
freshRaw, err := json.Marshal(freshJWKS)
require.NoError(t, err)
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
now := time.Now()
err = cacheService.SaveState("client-headless", domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
RawJWKS: string(staleRaw),
CachedKids: []string{"stale-key"},
CachedAt: &now,
LastCheckedAt: &now,
ExpiresAt: new(now.Add(30 * time.Minute)),
LastRefreshStatus: "success",
})
require.NoError(t, err)
cacheService.HTTPClient = clientForHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "https://rp.example.com/.well-known/jwks.json", r.URL.String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(freshRaw)
}))
keySet, state, refreshed, err := cacheService.EnsureFreshKeySet(context.Background(), domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}, "fresh-key")
require.NoError(t, err)
assert.True(t, refreshed)
require.NotNil(t, keySet)
assert.Len(t, keySet.Keys, 1)
require.NotNil(t, state)
assert.Equal(t, []string{"fresh-key"}, state.CachedKids)
stored, err := cacheService.GetState("client-headless")
require.NoError(t, err)
require.NotNil(t, stored)
assert.Equal(t, []string{"fresh-key"}, stored.CachedKids)
}
func TestHeadlessJWKSCacheService_PersistRefreshFailure_SetsNextRetryAtAfterThreshold(t *testing.T) {
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
client := domain.HydraClient{
ClientID: "client-headless",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: "https://rp.example.com/.well-known/jwks.json",
},
}
previous := &domain.HeadlessJWKSCacheState{
ClientID: client.ClientID,
JWKSURI: "https://rp.example.com/.well-known/jwks.json",
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}
state := cacheService.persistRefreshFailure(client, previous, assert.AnError)
require.NotNil(t, state)
assert.Equal(t, 3, state.ConsecutiveFailures)
require.NotNil(t, state.NextRetryAt)
assert.WithinDuration(t, time.Now().Add(15*time.Minute), *state.NextRetryAt, 3*time.Second)
}
func TestHeadlessJWKSCacheService_ShouldPrefetch_SkipsUntilNextRetryAt(t *testing.T) {
cacheService := NewHeadlessJWKSCacheService(&headlessJWKSCacheTestRedis{}, nil)
now := time.Now()
state := &domain.HeadlessJWKSCacheState{
ClientID: "client-headless",
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: new(now.Add(10 * time.Minute)),
}
assert.False(t, cacheService.ShouldPrefetch(state, now))
assert.True(t, cacheService.ShouldPrefetch(state, now.Add(11*time.Minute)))
}
func TestHeadlessJWKSCacheWorker_RunOnce_SkipsBackoffTargets(t *testing.T) {
clients := []domain.HydraClient{
newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json"),
newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json"),
}
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, clients)),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
now := time.Now()
require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{
ClientID: "client-fail",
JWKSURI: clients[0].HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}))
require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{
ClientID: "client-skip",
JWKSURI: clients[1].HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: new(now.Add(10 * time.Minute)),
}))
fetchCounts := map[string]int{}
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCounts[req.URL.Host]++
if req.URL.Host == "fail.example.com" {
return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
}
t.Fatalf("unexpected fetch for host %s", req.URL.Host)
return nil, nil
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCounts["fail.example.com"])
assert.Equal(t, 0, fetchCounts["skip.example.com"])
failedState, err := cacheService.GetState("client-fail")
require.NoError(t, err)
require.NotNil(t, failedState)
assert.Equal(t, 3, failedState.ConsecutiveFailures)
require.NotNil(t, failedState.NextRetryAt)
skippedState, err := cacheService.GetState("client-skip")
require.NoError(t, err)
require.NotNil(t, skippedState)
assert.Equal(t, 3, skippedState.ConsecutiveFailures)
require.NotNil(t, skippedState.NextRetryAt)
assert.WithinDuration(t, now.Add(10*time.Minute), *skippedState.NextRetryAt, time.Second)
}
func TestHeadlessJWKSCacheWorker_RunOnce_RetriesAfterBackoffAndClearsFailureStateOnSuccess(t *testing.T) {
_, freshJWKS := mustServiceHeadlessRSAJWK(t, "fresh-key")
freshRaw, err := json.Marshal(freshJWKS)
require.NoError(t, err)
client := newTestHeadlessClient("client-recover", "https://recover.example.com/.well-known/jwks.json")
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{client})),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 15 * time.Minute
require.NoError(t, cacheService.SaveState("client-recover", domain.HeadlessJWKSCacheState{
ClientID: "client-recover",
JWKSURI: client.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
LastError: "previous failure",
ConsecutiveFailures: 3,
NextRetryAt: new(time.Now().Add(-time.Minute)),
}))
fetchCount := 0
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCount++
assert.Equal(t, "recover.example.com", req.URL.Host)
return jsonHTTPResponse(http.StatusOK, string(freshRaw)), nil
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCount)
recoveredState, err := cacheService.GetState("client-recover")
require.NoError(t, err)
require.NotNil(t, recoveredState)
assert.Equal(t, "success", recoveredState.LastRefreshStatus)
assert.Empty(t, recoveredState.LastError)
assert.Equal(t, 0, recoveredState.ConsecutiveFailures)
assert.Nil(t, recoveredState.NextRetryAt)
assert.Equal(t, []string{"fresh-key"}, recoveredState.CachedKids)
}
func TestHeadlessJWKSCacheWorker_RunOnce_MixedClients(t *testing.T) {
_, successJWKS := mustServiceHeadlessRSAJWK(t, "success-key")
successRaw, err := json.Marshal(successJWKS)
require.NoError(t, err)
successClient := newTestHeadlessClient("client-success", "https://success.example.com/.well-known/jwks.json")
failClient := newTestHeadlessClient("client-fail", "https://fail.example.com/.well-known/jwks.json")
skipClient := newTestHeadlessClient("client-skip", "https://skip.example.com/.well-known/jwks.json")
disabledClient := domain.HydraClient{
ClientID: "client-disabled",
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: false,
domain.MetadataHeadlessJWKSURI: "https://disabled.example.com/.well-known/jwks.json",
domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt",
},
}
hydra := &HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: clientForHandler(jsonHandler(t, []domain.HydraClient{
successClient,
failClient,
skipClient,
disabledClient,
})),
}
redisRepo := &headlessJWKSCacheTestRedis{}
cacheService := NewHeadlessJWKSCacheService(redisRepo, nil)
cacheService.FailureThreshold = 3
cacheService.FailureBackoff = 20 * time.Minute
require.NoError(t, cacheService.SaveState("client-fail", domain.HeadlessJWKSCacheState{
ClientID: "client-fail",
JWKSURI: failClient.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 2,
}))
require.NoError(t, cacheService.SaveState("client-skip", domain.HeadlessJWKSCacheState{
ClientID: "client-skip",
JWKSURI: skipClient.HeadlessJWKSURI(),
LastRefreshStatus: "failure",
ConsecutiveFailures: 3,
NextRetryAt: new(time.Now().Add(10 * time.Minute)),
}))
fetchCounts := map[string]int{}
cacheService.HTTPClient = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
fetchCounts[req.URL.Host]++
switch req.URL.Host {
case "success.example.com":
return jsonHTTPResponse(http.StatusOK, string(successRaw)), nil
case "fail.example.com":
return jsonHTTPResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
default:
t.Fatalf("unexpected fetch for host %s", req.URL.Host)
return nil, nil
}
}),
}
worker := &HeadlessJWKSCacheWorker{
Hydra: hydra,
Cache: cacheService,
PageSize: 100,
}
worker.runOnce(context.Background())
assert.Equal(t, 1, fetchCounts["success.example.com"])
assert.Equal(t, 1, fetchCounts["fail.example.com"])
assert.Equal(t, 0, fetchCounts["skip.example.com"])
assert.Equal(t, 0, fetchCounts["disabled.example.com"])
successState, err := cacheService.GetState("client-success")
require.NoError(t, err)
require.NotNil(t, successState)
assert.Equal(t, "success", successState.LastRefreshStatus)
assert.Equal(t, 0, successState.ConsecutiveFailures)
assert.Nil(t, successState.NextRetryAt)
failState, err := cacheService.GetState("client-fail")
require.NoError(t, err)
require.NotNil(t, failState)
assert.Equal(t, "failure", failState.LastRefreshStatus)
assert.Equal(t, 3, failState.ConsecutiveFailures)
require.NotNil(t, failState.NextRetryAt)
}
func mustServiceHeadlessRSAJWK(t *testing.T, kid string) (*rsa.PrivateKey, jose.JSONWebKeySet) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicJWK := jose.JSONWebKey{
Key: &privateKey.PublicKey,
KeyID: kid,
Algorithm: string(jose.RS256),
Use: "sig",
}
return privateKey, jose.JSONWebKeySet{Keys: []jose.JSONWebKey{publicJWK}}
}
//go:fix inline
func ptrTestTime(value time.Time) *time.Time {
return new(value)
}
func newTestHeadlessClient(clientID, jwksURI string) domain.HydraClient {
return domain.HydraClient{
ClientID: clientID,
Metadata: map[string]any{
domain.MetadataHeadlessLoginEnabled: true,
domain.MetadataHeadlessJWKSURI: jwksURI,
domain.MetadataHeadlessTokenEndpointAuthMethod: "private_key_jwt",
},
}
}
func jsonHandler(t *testing.T, payload any) http.HandlerFunc {
t.Helper()
return func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(payload))
}
}
func jsonHTTPResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}

View File

@@ -0,0 +1,639 @@
package service
import (
"baron-sso-backend/internal/domain"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
var ErrHydraNotFound = errors.New("hydra admin: resource not found")
// HydraAdminService는 Hydra Admin API 호출을 래핑합니다.
type HydraAdminService struct {
AdminURL string
PublicURL string
HTTPClient *http.Client
}
func NewHydraAdminService() *HydraAdminService {
return &HydraAdminService{
AdminURL: getenv("HYDRA_ADMIN_URL", "http://hydra:4445"),
PublicURL: getenv("HYDRA_PUBLIC_URL", "http://hydra:4444"),
}
}
func (s *HydraAdminService) ListClients(ctx context.Context, limit, offset int) ([]domain.HydraClient, error) {
endpoint, err := s.buildURL("/clients", map[string]int{
"limit": limit,
"offset": offset,
})
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, ErrHydraNotFound
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("hydra admin: list clients failed status=%d body=%s", resp.StatusCode, string(body))
}
var clients []domain.HydraClient
if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil {
return nil, fmt.Errorf("hydra admin: decode clients failed: %w", err)
}
return clients, nil
}
func (s *HydraAdminService) GetClient(ctx context.Context, clientID string) (*domain.HydraClient, error) {
endpoint := fmt.Sprintf("%s/clients/%s", strings.TrimRight(s.AdminURL, "/"), url.PathEscape(clientID))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, ErrHydraNotFound
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("hydra admin: get client failed status=%d body=%s", resp.StatusCode, string(body))
}
var client domain.HydraClient
if err := json.NewDecoder(resp.Body).Decode(&client); err != nil {
return nil, fmt.Errorf("hydra admin: decode client failed: %w", err)
}
return &client, nil
}
func (s *HydraAdminService) PatchClientStatus(ctx context.Context, clientID, status string) (*domain.HydraClient, error) {
// JSON Patch format
payload := []map[string]any{
{
"op": "replace",
"path": "/metadata/status",
"value": status,
},
}
body, _ := json.Marshal(payload)
endpoint := fmt.Sprintf("%s/clients/%s", strings.TrimRight(s.AdminURL, "/"), url.PathEscape(clientID))
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json-patch+json")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, ErrHydraNotFound
}
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("hydra admin: patch client failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var updated domain.HydraClient
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
return nil, fmt.Errorf("hydra admin: decode patched client failed: %w", err)
}
return &updated, nil
}
func (s *HydraAdminService) CreateClient(ctx context.Context, client domain.HydraClient) (*domain.HydraClient, error) {
body, _ := json.Marshal(client)
endpoint := fmt.Sprintf("%s/clients", strings.TrimRight(s.AdminURL, "/"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("hydra admin: create client failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var created domain.HydraClient
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
return nil, fmt.Errorf("hydra admin: decode created client failed: %w", err)
}
return &created, nil
}
func (s *HydraAdminService) UpdateClient(ctx context.Context, clientID string, client domain.HydraClient) (*domain.HydraClient, error) {
client.ClientID = clientID
body, _ := json.Marshal(client)
endpoint := fmt.Sprintf("%s/clients/%s", strings.TrimRight(s.AdminURL, "/"), url.PathEscape(clientID))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, ErrHydraNotFound
}
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("hydra admin: update client failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var updated domain.HydraClient
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
return nil, fmt.Errorf("hydra admin: decode updated client failed: %w", err)
}
return &updated, nil
}
func (s *HydraAdminService) DeleteClient(ctx context.Context, clientID string) error {
endpoint := fmt.Sprintf("%s/clients/%s", strings.TrimRight(s.AdminURL, "/"), url.PathEscape(clientID))
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
resp, err := s.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return ErrHydraNotFound
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("hydra admin: delete client failed status=%d body=%s", resp.StatusCode, string(body))
}
return nil
}
func (s *HydraAdminService) ListConsentSessions(ctx context.Context, subject, clientID string) ([]domain.HydraConsentSession, error) {
params := map[string]string{
"subject": subject,
}
if clientID != "" {
params["client"] = clientID
}
endpoint, err := s.buildURLWithParams("/oauth2/auth/sessions/consent", params)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNoContent {
return []domain.HydraConsentSession{}, nil
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if resp.StatusCode >= 300 {
return nil, fmt.Errorf("hydra admin: list consent sessions failed status=%d body=%s", resp.StatusCode, string(body))
}
if len(body) == 0 {
return []domain.HydraConsentSession{}, nil
}
var sessions []domain.HydraConsentSession
if err := json.Unmarshal(body, &sessions); err != nil {
return nil, fmt.Errorf("hydra admin: decode consent sessions failed: %w body=%s", err, string(body))
}
return sessions, nil
}
func (s *HydraAdminService) RevokeConsentSessions(ctx context.Context, subject, clientID string) error {
params := map[string]string{
"subject": subject,
}
if clientID != "" {
params["client"] = clientID
} else {
params["all"] = "true"
}
endpoint, err := s.buildURLWithParams("/oauth2/auth/sessions/consent", params)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
resp, err := s.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("hydra admin: revoke consent failed status=%d body=%s", resp.StatusCode, string(body))
}
return nil
}
func (s *HydraAdminService) httpClient() *http.Client {
if s.HTTPClient != nil {
return s.HTTPClient
}
return &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 5 * time.Second,
},
}
}
func (s *HydraAdminService) buildURL(path string, ints map[string]int) (string, error) {
base := strings.TrimRight(s.AdminURL, "/")
u, err := url.Parse(base + path)
if err != nil {
return "", err
}
q := u.Query()
for key, value := range ints {
if value > 0 {
q.Set(key, strconv.Itoa(value))
}
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func (s *HydraAdminService) buildURLWithParams(path string, params map[string]string) (string, error) {
base := strings.TrimRight(s.AdminURL, "/")
u, err := url.Parse(base + path)
if err != nil {
return "", err
}
q := u.Query()
for key, value := range params {
if value != "" {
q.Set(key, value)
}
}
u.RawQuery = q.Encode()
return u.String(), nil
}
type AcceptLoginRequestResponse struct {
RedirectTo string `json:"redirectTo"`
}
type AcceptConsentRequestResponse struct {
RedirectTo string `json:"redirectTo"`
}
type RejectConsentRequestResponse struct {
RedirectTo string `json:"redirectTo"`
}
type RejectLoginRequestResponse struct {
RedirectTo string `json:"redirectTo"`
}
func (s *HydraAdminService) GetConsentRequest(ctx context.Context, challenge string) (*domain.HydraConsentRequest, error) {
params := map[string]string{
"consent_challenge": challenge,
}
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/consent", params)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("hydra admin: create request for get consent failed: %w", err)
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("hydra admin: get consent request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("hydra admin: get consent failed status=%d body=%s", resp.StatusCode, string(body))
}
var consentReq domain.HydraConsentRequest
if err := json.Unmarshal(body, &consentReq); err != nil {
return nil, fmt.Errorf("hydra admin: decode get consent response failed: %w", err)
}
return &consentReq, nil
}
func (s *HydraAdminService) RejectConsentRequest(ctx context.Context, challenge string) (*RejectConsentRequestResponse, error) {
params := map[string]string{
"consent_challenge": challenge,
}
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/consent/reject", params)
if err != nil {
return nil, err
}
payload := map[string]any{
"error": "access_denied",
"error_description": "The user decided to reject the consent request.",
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("hydra admin: create request for reject consent failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("hydra admin: reject consent request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("hydra admin: reject consent failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var hydraResp struct {
RedirectTo string `json:"redirect_to"`
}
if err := json.Unmarshal(respBody, &hydraResp); err != nil {
return nil, fmt.Errorf("hydra admin: decode reject consent response failed: %w", err)
}
return &RejectConsentRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
}
func (s *HydraAdminService) RejectLoginRequest(ctx context.Context, challenge, error, errorDescription string) (*RejectLoginRequestResponse, error) {
params := map[string]string{
"login_challenge": challenge,
}
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/login/reject", params)
if err != nil {
return nil, err
}
payload := map[string]any{
"error": error,
"error_description": errorDescription,
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("hydra admin: create request for reject login failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("hydra admin: reject login request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("hydra admin: reject login failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var hydraResp struct {
RedirectTo string `json:"redirect_to"`
}
if err := json.Unmarshal(respBody, &hydraResp); err != nil {
return nil, fmt.Errorf("hydra admin: decode reject login response failed: %w", err)
}
return &RejectLoginRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
}
func (s *HydraAdminService) GetLoginRequest(ctx context.Context, challenge string) (*domain.HydraLoginRequest, error) {
params := map[string]string{
"login_challenge": challenge,
}
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/login", params)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("hydra admin: create request for get login failed: %w", err)
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("hydra admin: get login request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("hydra admin: get login failed status=%d body=%s", resp.StatusCode, string(body))
}
var loginReq domain.HydraLoginRequest
if err := json.Unmarshal(body, &loginReq); err != nil {
return nil, fmt.Errorf("hydra admin: decode get login response failed: %w", err)
}
return &loginReq, nil
}
func (s *HydraAdminService) AcceptConsentRequest(ctx context.Context, challenge string, grantInfo *domain.HydraConsentRequest, sessionClaims map[string]any) (*AcceptConsentRequestResponse, error) {
params := map[string]string{
"consent_challenge": challenge,
}
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/consent/accept", params)
if err != nil {
return nil, err
}
payload := map[string]any{
"grant_scope": grantInfo.RequestedScope,
"grant_audience": grantInfo.RequestedAudience,
"remember": true,
"remember_for": 2592000,
}
if len(sessionClaims) > 0 {
payload["session"] = map[string]any{
"id_token": sessionClaims,
"access_token": sessionClaims,
}
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("hydra admin: create request for accept consent failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("hydra admin: accept consent request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("hydra admin: accept consent failed status=%d body=%s", resp.StatusCode, string(respBody))
}
// Hydra 응답(redirect_to)을 읽어서 우리 응답(redirectTo)으로 변환
var hydraResp struct {
RedirectTo string `json:"redirect_to"`
}
if err := json.Unmarshal(respBody, &hydraResp); err != nil {
return nil, fmt.Errorf("hydra admin: decode accept consent response failed: %w", err)
}
return &AcceptConsentRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
}
func (s *HydraAdminService) AcceptLoginRequest(ctx context.Context, challenge string, subject string) (*AcceptLoginRequestResponse, error) {
params := map[string]string{
"login_challenge": challenge,
}
endpoint, err := s.buildURLWithParams("/oauth2/auth/requests/login/accept", params)
if err != nil {
return nil, err
}
payload := map[string]any{
"subject": subject,
"remember": true,
"remember_for": 2592000,
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("hydra admin: create request for accept login failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("hydra admin: accept login request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("hydra admin: accept login failed status=%d body=%s", resp.StatusCode, string(respBody))
}
// Hydra 응답(redirect_to)을 읽어서 우리 응답(redirectTo)으로 변환
var hydraResp struct {
RedirectTo string `json:"redirect_to"`
}
if err := json.Unmarshal(respBody, &hydraResp); err != nil {
return nil, fmt.Errorf("hydra admin: decode accept login response failed: %w", err)
}
return &AcceptLoginRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
}
type HydraIntrospectionResponse struct {
Active bool `json:"active"`
Subject string `json:"sub"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
ExpiresAt int64 `json:"exp"`
IssuedAt int64 `json:"iat"`
Ext map[string]any `json:"ext"`
}
func (s *HydraAdminService) IntrospectToken(ctx context.Context, token string) (*HydraIntrospectionResponse, error) {
endpoint := fmt.Sprintf("%s/oauth2/introspect", strings.TrimRight(s.AdminURL, "/"))
form := url.Values{}
form.Set("token", token)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("hydra admin: introspection failed status=%d body=%s", resp.StatusCode, string(body))
}
var res HydraIntrospectionResponse
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return nil, err
}
return &res, nil
}

View File

@@ -0,0 +1,331 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHydraAdminService_ListClients(t *testing.T) {
clients := []domain.HydraClient{
{ClientID: "client1", ClientName: "Client 1"},
{ClientID: "client2", ClientName: "Client 2"},
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients", r.URL.Path)
assert.Equal(t, "GET", r.Method)
assert.Equal(t, "10", r.URL.Query().Get("limit"))
assert.Equal(t, "5", r.URL.Query().Get("offset"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(clients)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
result, err := s.ListClients(context.Background(), 10, 5)
assert.NoError(t, err)
assert.Equal(t, clients, result)
}
func TestHydraAdminService_GetClient(t *testing.T) {
client := domain.HydraClient{ClientID: "test-client", ClientName: "Test Client"}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients/test-client", r.URL.Path)
assert.Equal(t, "GET", r.Method)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(client)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
result, err := s.GetClient(context.Background(), "test-client")
assert.NoError(t, err)
assert.Equal(t, &client, result)
}
func TestHydraAdminService_CreateClient(t *testing.T) {
client := domain.HydraClient{ClientName: "New Client"}
created := domain.HydraClient{ClientID: "new-id", ClientName: "New Client"}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients", r.URL.Path)
assert.Equal(t, "POST", r.Method)
var received domain.HydraClient
_ = json.NewDecoder(r.Body).Decode(&received)
assert.Equal(t, client.ClientName, received.ClientName)
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(created)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
result, err := s.CreateClient(context.Background(), client)
assert.NoError(t, err)
assert.Equal(t, &created, result)
}
func TestHydraAdminService_DeleteClient(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients/to-delete", r.URL.Path)
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(http.StatusNoContent)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
err := s.DeleteClient(context.Background(), "to-delete")
assert.NoError(t, err)
}
func TestHydraAdminService_GetConsentRequest(t *testing.T) {
challenge := "challenge123"
consentReq := domain.HydraConsentRequest{Challenge: challenge, Subject: "user1"}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/oauth2/auth/requests/consent", r.URL.Path)
assert.Equal(t, challenge, r.URL.Query().Get("consent_challenge"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(consentReq)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
result, err := s.GetConsentRequest(context.Background(), challenge)
assert.NoError(t, err)
assert.Equal(t, &consentReq, result)
}
func TestHydraAdminService_PatchClientStatus(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients/test-client", r.URL.Path)
assert.Equal(t, "PATCH", r.Method)
assert.Equal(t, "application/json-patch+json", r.Header.Get("Content-Type"))
var payload []map[string]any
_ = json.NewDecoder(r.Body).Decode(&payload)
assert.Equal(t, "replace", payload[0]["op"])
assert.Equal(t, "/metadata/status", payload[0]["path"])
assert.Equal(t, "inactive", payload[0]["value"])
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(domain.HydraClient{ClientID: "test-client"})
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
_, err := s.PatchClientStatus(context.Background(), "test-client", "inactive")
assert.NoError(t, err)
}
func TestHydraAdminService_UpdateClient(t *testing.T) {
client := domain.HydraClient{ClientName: "Updated Name"}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/clients/test-client", r.URL.Path)
assert.Equal(t, "PUT", r.Method)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(client)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
_, err := s.UpdateClient(context.Background(), "test-client", client)
assert.NoError(t, err)
}
func TestHydraAdminService_ListConsentSessions(t *testing.T) {
sessions := []domain.HydraConsentSession{{Subject: "user1"}}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/oauth2/auth/sessions/consent", r.URL.Path)
assert.Equal(t, "user1", r.URL.Query().Get("subject"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(sessions)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
result, err := s.ListConsentSessions(context.Background(), "user1", "")
assert.NoError(t, err)
assert.Equal(t, sessions, result)
}
func TestHydraAdminService_RevokeConsentSessions(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/oauth2/auth/sessions/consent", r.URL.Path)
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(http.StatusNoContent)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
err := s.RevokeConsentSessions(context.Background(), "user1", "")
assert.NoError(t, err)
}
func TestHydraAdminService_RejectConsentRequest(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/oauth2/auth/requests/consent/reject", r.URL.Path)
assert.Equal(t, "PUT", r.Method)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://reject"})
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
resp, err := s.RejectConsentRequest(context.Background(), "challenge")
assert.NoError(t, err)
assert.Equal(t, "http://reject", resp.RedirectTo)
}
func TestHydraAdminService_RejectLoginRequest(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/oauth2/auth/requests/login/reject", r.URL.Path)
assert.Equal(t, "PUT", r.Method)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://reject-login"})
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
resp, err := s.RejectLoginRequest(context.Background(), "challenge", "error", "desc")
assert.NoError(t, err)
assert.Equal(t, "http://reject-login", resp.RedirectTo)
}
func TestHydraAdminService_GetLoginRequest(t *testing.T) {
loginReq := domain.HydraLoginRequest{Challenge: "challenge"}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/oauth2/auth/requests/login", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(loginReq)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
result, err := s.GetLoginRequest(context.Background(), "challenge")
assert.NoError(t, err)
assert.Equal(t, &loginReq, result)
}
func TestHydraAdminService_AcceptConsentRequest(t *testing.T) {
grant := &domain.HydraConsentRequest{RequestedScope: []string{"openid"}}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/oauth2/auth/requests/consent/accept", r.URL.Path)
assert.Equal(t, "PUT", r.Method)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://accept"})
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
resp, err := s.AcceptConsentRequest(context.Background(), "challenge", grant, nil)
assert.NoError(t, err)
assert.Equal(t, "http://accept", resp.RedirectTo)
}
func TestHydraAdminService_AcceptLoginRequest(t *testing.T) {
challenge := "login_challenge"
subject := "user@example.com"
redirectTo := "http://hydra/auth/confirm"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/oauth2/auth/requests/login/accept", r.URL.Path)
assert.Equal(t, challenge, r.URL.Query().Get("login_challenge"))
var body map[string]any
_ = json.NewDecoder(r.Body).Decode(&body)
assert.Equal(t, subject, body["subject"])
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": redirectTo})
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
result, err := s.AcceptLoginRequest(context.Background(), challenge, subject)
assert.NoError(t, err)
assert.Equal(t, redirectTo, result.RedirectTo)
}
func TestHydraAdminService_ErrorHandling(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad request"))
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
_, err := s.GetClient(context.Background(), "invalid")
assert.Error(t, err)
assert.Contains(t, err.Error(), "status=400")
err = s.DeleteClient(context.Background(), "invalid")
assert.Error(t, err)
_, err = s.ListClients(context.Background(), 10, 0)
assert.Error(t, err)
_, err = s.PatchClientStatus(context.Background(), "invalid", "active")
assert.Error(t, err)
}
func TestHydraAdminService_NotFound(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
s := &HydraAdminService{
AdminURL: "http://hydra-admin.local",
HTTPClient: clientForHandler(handler),
}
_, err := s.GetClient(context.Background(), "none")
assert.Equal(t, ErrHydraNotFound, err)
}

View File

@@ -0,0 +1,78 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"log/slog"
"time"
)
type KetoRelayWorker interface {
Start(ctx context.Context)
}
type ketoRelayWorker struct {
outboxRepo repository.KetoOutboxRepository
ketoService KetoService
interval time.Duration
maxRetries int
}
func NewKetoRelayWorker(outboxRepo repository.KetoOutboxRepository, ketoService KetoService) KetoRelayWorker {
return &ketoRelayWorker{
outboxRepo: outboxRepo,
ketoService: ketoService,
interval: 5 * time.Second, // Poll every 5 seconds
maxRetries: 5,
}
}
func (w *ketoRelayWorker) Start(ctx context.Context) {
slog.Info("[KetoRelayWorker] Starting worker...")
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
slog.Info("[KetoRelayWorker] Stopping worker...")
return
case <-ticker.C:
w.processEntries(ctx)
}
}
}
func (w *ketoRelayWorker) processEntries(ctx context.Context) {
entries, err := w.outboxRepo.FindPending(ctx, 50) // Process up to 50 at once
if err != nil {
slog.Error("[KetoRelayWorker] Failed to fetch pending entries", "error", err)
return
}
for _, entry := range entries {
w.processEntry(ctx, entry)
}
}
func (w *ketoRelayWorker) processEntry(ctx context.Context, entry domain.KetoOutbox) {
var err error
if entry.Action == domain.KetoOutboxActionCreate {
err = w.ketoService.CreateRelation(ctx, entry.Namespace, entry.Object, entry.Relation, entry.Subject)
} else if entry.Action == domain.KetoOutboxActionDelete {
err = w.ketoService.DeleteRelation(ctx, entry.Namespace, entry.Object, entry.Relation, entry.Subject)
}
if err != nil {
slog.Error("[KetoRelayWorker] Failed to process entry", "id", entry.ID, "error", err)
newRetryCount := entry.RetryCount + 1
status := domain.KetoOutboxStatusPending
if newRetryCount >= w.maxRetries {
status = domain.KetoOutboxStatusFailed
}
_ = w.outboxRepo.UpdateStatus(ctx, entry.ID, status, newRetryCount, err.Error())
} else {
_ = w.outboxRepo.MarkProcessed(ctx, entry.ID)
}
}

View File

@@ -0,0 +1,267 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"time"
)
type KetoService interface {
CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error)
CreateRelation(ctx context.Context, namespace, object, relation, subject string) error
DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error
ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error)
ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error)
}
type ketoService struct {
readURL string
writeURL string
client *http.Client
}
func NewKetoService() KetoService {
readURL := os.Getenv("KETO_READ_URL")
if readURL == "" {
readURL = "http://keto:4466"
}
writeURL := os.Getenv("KETO_WRITE_URL")
if writeURL == "" {
writeURL = "http://keto:4467"
}
return &ketoService{
readURL: readURL,
writeURL: writeURL,
client: &http.Client{},
}
}
type RelationTuple struct {
Namespace string `json:"namespace"`
Object string `json:"object"`
Relation string `json:"relation"`
SubjectID string `json:"subject_id"`
}
type relationTuplesResponse struct {
RelationTuples []RelationTuple `json:"relation_tuples"`
NextPageToken string `json:"next_page_token"`
}
func (s *ketoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) {
u, _ := url.Parse(fmt.Sprintf("%s/relation-tuples", s.readURL))
q := u.Query()
if namespace != "" {
q.Set("namespace", namespace)
}
if object != "" {
q.Set("object", object)
}
if relation != "" {
q.Set("relation", relation)
}
if subject != "" {
q.Set("subject_id", subject)
}
u.RawQuery = q.Encode()
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(body))
}
var res relationTuplesResponse
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return nil, err
}
return res.RelationTuples, nil
}
type checkResponse struct {
Allowed bool `json:"allowed"`
}
func (s *ketoService) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
u, _ := url.Parse(fmt.Sprintf("%s/relation-tuples/check", s.readURL))
q := u.Query()
q.Set("namespace", namespace)
q.Set("object", object)
q.Set("relation", relation)
q.Set("subject_id", subject)
u.RawQuery = q.Encode()
var lastErr error
maxRetries := 5
backoff := 200 * time.Millisecond
for i := range maxRetries {
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
resp, err := s.client.Do(req)
if err == nil {
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
var res checkResponse
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return false, err
}
return res.Allowed, nil
}
if resp.StatusCode == http.StatusForbidden {
return false, nil
}
body, _ := io.ReadAll(resp.Body)
lastErr = fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(body))
} else {
lastErr = err
}
if i < maxRetries-1 {
slog.Debug("Retrying Keto CheckPermission...", "attempt", i+1, "error", lastErr)
time.Sleep(backoff)
backoff *= 2
}
}
return false, lastErr
}
func (s *ketoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
u := fmt.Sprintf("%s/admin/relation-tuples", s.writeURL)
payload := map[string]any{
"namespace": namespace,
"object": object,
"relation": relation,
"subject_id": subject,
}
body, _ := json.Marshal(payload)
// Exponential Backoff Retry Logic
var lastErr error
maxRetries := 5
backoff := 200 * time.Millisecond
for i := range maxRetries {
req, _ := http.NewRequestWithContext(ctx, "PUT", u, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err == nil {
defer resp.Body.Close()
if resp.StatusCode == http.StatusCreated || resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusOK {
slog.Debug("Keto relation created", "namespace", namespace, "object", object, "relation", relation, "subject", subject)
return nil
}
resBody, _ := io.ReadAll(resp.Body)
lastErr = fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(resBody))
} else {
lastErr = err
}
if i < maxRetries-1 {
slog.Debug("Retrying Keto CreateRelation...", "attempt", i+1, "error", lastErr)
time.Sleep(backoff)
backoff *= 2
}
}
slog.Error("Keto create relation failed after retries", "error", lastErr, "namespace", namespace, "object", object, "relation", relation, "subject", subject)
return lastErr
}
func (s *ketoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
u, _ := url.Parse(fmt.Sprintf("%s/admin/relation-tuples", s.writeURL))
q := u.Query()
q.Set("namespace", namespace)
q.Set("object", object)
q.Set("relation", relation)
q.Set("subject_id", subject)
u.RawQuery = q.Encode()
var lastErr error
maxRetries := 5
backoff := 200 * time.Millisecond
for i := range maxRetries {
req, _ := http.NewRequestWithContext(ctx, "DELETE", u.String(), nil)
resp, err := s.client.Do(req)
if err == nil {
defer resp.Body.Close()
if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusOK {
slog.Debug("Keto relation deleted", "namespace", namespace, "object", object, "relation", relation, "subject", subject)
return nil
}
resBody, _ := io.ReadAll(resp.Body)
lastErr = fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(resBody))
} else {
lastErr = err
}
if i < maxRetries-1 {
slog.Debug("Retrying Keto DeleteRelation...", "attempt", i+1, "error", lastErr)
time.Sleep(backoff)
backoff *= 2
}
}
slog.Error("Keto delete relation failed after retries", "error", lastErr)
return lastErr
}
func (s *ketoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
u, _ := url.Parse(fmt.Sprintf("%s/relation-tuples", s.readURL))
q := u.Query()
if namespace != "" {
q.Set("namespace", namespace)
}
if relation != "" {
q.Set("relation", relation)
}
if subject != "" {
q.Set("subject_id", subject)
}
u.RawQuery = q.Encode()
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(body))
}
var res relationTuplesResponse
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return nil, err
}
objects := make([]string, 0, len(res.RelationTuples))
seen := make(map[string]bool)
for _, rt := range res.RelationTuples {
if !seen[rt.Object] {
objects = append(objects, rt.Object)
seen[rt.Object] = true
}
}
return objects, nil
}

View File

@@ -0,0 +1,156 @@
package service
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestKetoService_CheckPermission(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/relation-tuples/check", r.URL.Path)
assert.Equal(t, "user1", r.URL.Query().Get("subject_id"))
assert.Equal(t, "tenants", r.URL.Query().Get("namespace"))
assert.Equal(t, "tenant1", r.URL.Query().Get("object"))
assert.Equal(t, "admin", r.URL.Query().Get("relation"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(checkResponse{Allowed: true})
})
s := &ketoService{
readURL: "http://keto-read.local",
client: clientForHandler(handler),
}
allowed, err := s.CheckPermission(context.Background(), "user1", "tenants", "tenant1", "admin")
assert.NoError(t, err)
assert.True(t, allowed)
}
func TestKetoService_CreateRelation(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/admin/relation-tuples", r.URL.Path)
assert.Equal(t, "PUT", r.Method)
var body map[string]any
_ = json.NewDecoder(r.Body).Decode(&body)
assert.Equal(t, "tenants", body["namespace"])
assert.Equal(t, "tenant1", body["object"])
assert.Equal(t, "admin", body["relation"])
assert.Equal(t, "user1", body["subject_id"])
w.WriteHeader(http.StatusCreated)
})
s := &ketoService{
writeURL: "http://keto-write.local",
client: clientForHandler(handler),
}
err := s.CreateRelation(context.Background(), "tenants", "tenant1", "admin", "user1")
assert.NoError(t, err)
}
func TestKetoService_DeleteRelation(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/admin/relation-tuples", r.URL.Path)
assert.Equal(t, "DELETE", r.Method)
assert.Equal(t, "user1", r.URL.Query().Get("subject_id"))
assert.Equal(t, "tenants", r.URL.Query().Get("namespace"))
assert.Equal(t, "tenant1", r.URL.Query().Get("object"))
assert.Equal(t, "admin", r.URL.Query().Get("relation"))
w.WriteHeader(http.StatusNoContent)
})
s := &ketoService{
writeURL: "http://keto-write.local",
client: clientForHandler(handler),
}
err := s.DeleteRelation(context.Background(), "tenants", "tenant1", "admin", "user1")
assert.NoError(t, err)
}
func TestKetoService_ListRelations(t *testing.T) {
tuples := []RelationTuple{
{Namespace: "tenants", Object: "tenant1", Relation: "admin", SubjectID: "user1"},
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/relation-tuples", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(relationTuplesResponse{RelationTuples: tuples})
})
s := &ketoService{
readURL: "http://keto-read.local",
client: clientForHandler(handler),
}
result, err := s.ListRelations(context.Background(), "tenants", "tenant1", "admin", "user1")
assert.NoError(t, err)
assert.Equal(t, tuples, result)
}
func TestKetoService_ErrorHandling(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("internal error"))
})
s := &ketoService{
readURL: "http://keto-read.local",
writeURL: "http://keto-write.local",
client: clientForHandler(handler),
}
_, err := s.CheckPermission(context.Background(), "u", "n", "o", "r")
assert.Error(t, err)
err = s.DeleteRelation(context.Background(), "n", "o", "r", "s")
assert.Error(t, err)
_, err = s.ListRelations(context.Background(), "n", "o", "r", "s")
assert.Error(t, err)
}
func TestKetoService_CheckPermission_Forbidden(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
})
s := &ketoService{
readURL: "http://keto-read.local",
client: clientForHandler(handler),
}
allowed, err := s.CheckPermission(context.Background(), "u", "n", "o", "r")
assert.NoError(t, err)
assert.False(t, allowed)
}
func TestKetoService_CreateRelation_Retry(t *testing.T) {
attempts := 0
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts < 2 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
})
s := &ketoService{
writeURL: "http://keto-write.local",
client: clientForHandler(handler),
}
err := s.CreateRelation(context.Background(), "n", "o", "r", "s")
assert.NoError(t, err)
assert.Equal(t, 2, attempts)
}

View File

@@ -0,0 +1,556 @@
package service
import (
"baron-sso-backend/internal/domain"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
)
type KratosIdentity struct {
ID string `json:"id"`
SchemaID string `json:"schema_id,omitempty"`
Traits map[string]any `json:"traits"`
State string `json:"state,omitempty"`
MetadataAdmin any `json:"metadata_admin,omitempty"`
MetadataPublic any `json:"metadata_public,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type KratosSessionDevice struct {
UserAgent string `json:"user_agent,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
}
type KratosSession struct {
ID string `json:"id"`
Active bool `json:"active"`
AuthenticatedAt time.Time `json:"authenticated_at"`
ExpiresAt time.Time `json:"expires_at"`
IssuedAt time.Time `json:"issued_at"`
Identity *KratosIdentity `json:"identity,omitempty"`
Devices []KratosSessionDevice `json:"devices,omitempty"`
}
type KratosAdminService interface {
ListIdentities(ctx context.Context) ([]KratosIdentity, error)
FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error)
GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error)
UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*KratosIdentity, error)
UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error
DeleteIdentity(ctx context.Context, identityID string) error
CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error)
ListIdentitySessions(ctx context.Context, identityID string) ([]KratosSession, error)
GetSession(ctx context.Context, sessionID string) (*KratosSession, error)
DeleteSession(ctx context.Context, sessionID string) error
}
type kratosAdminService struct {
AdminURL string
HTTPClient *http.Client
}
func NewKratosAdminService() KratosAdminService {
return &kratosAdminService{
AdminURL: getenvKratos("KRATOS_ADMIN_URL", "http://kratos:4434"),
}
}
func (s *kratosAdminService) ListIdentities(ctx context.Context) ([]KratosIdentity, error) {
endpoint := strings.TrimRight(s.AdminURL, "/") + "/admin/identities"
var identities []KratosIdentity
pageToken := ""
seenTokens := make(map[string]bool)
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
query := req.URL.Query()
query.Set("page_size", "250")
if pageToken != "" {
query.Set("page_token", pageToken)
}
req.URL.RawQuery = query.Encode()
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
_ = resp.Body.Close()
return nil, fmt.Errorf("kratos admin list identities failed status=%d body=%s", resp.StatusCode, string(body))
}
var page []KratosIdentity
if err := json.NewDecoder(resp.Body).Decode(&page); err != nil {
_ = resp.Body.Close()
return nil, err
}
_ = resp.Body.Close()
identities = append(identities, page...)
nextToken := kratosNextPageToken(resp.Header.Values("Link"))
if nextToken == "" {
return identities, nil
}
if seenTokens[nextToken] {
return nil, fmt.Errorf("kratos admin list identities pagination loop detected page_token=%s", nextToken)
}
seenTokens[nextToken] = true
pageToken = nextToken
}
}
func kratosNextPageToken(linkHeaders []string) string {
for _, header := range linkHeaders {
for _, part := range strings.Split(header, ",") {
part = strings.TrimSpace(part)
if !strings.Contains(part, `rel="next"`) && !strings.Contains(part, `rel=next`) {
continue
}
start := strings.Index(part, "<")
end := strings.Index(part, ">")
if start < 0 || end <= start+1 {
continue
}
rawURL := part[start+1 : end]
parsed, err := url.Parse(rawURL)
if err != nil {
continue
}
if token := strings.TrimSpace(parsed.Query().Get("page_token")); token != "" {
return token
}
}
}
return ""
}
func (s *kratosAdminService) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
identifier = strings.TrimSpace(identifier)
if identifier == "" {
return "", nil
}
endpoint := strings.TrimRight(s.AdminURL, "/") + "/admin/identities"
// 1. Try credentials_identifier (Email/LoginID/Phone)
id, err := s.searchIdentities(ctx, endpoint, "credentials_identifier", identifier)
if err == nil && id != "" {
// VERIFY: Kratos sometimes ignores unknown query params and returns the first identity.
if s.verifyIdentityMatch(ctx, id, identifier) {
return id, nil
}
}
identity, err := s.GetIdentity(ctx, identifier)
if err == nil && identity != nil {
return identity.ID, nil
}
return "", nil
}
func (s *kratosAdminService) verifyIdentityMatch(ctx context.Context, id, identifier string) bool {
identity, err := s.GetIdentity(ctx, id)
if err != nil || identity == nil {
return false
}
// Exact ID match
if strings.EqualFold(identity.ID, identifier) {
return true
}
// Check traits (Email, CustomLoginIDs)
if email, ok := identity.Traits["email"].(string); ok && strings.EqualFold(email, identifier) {
return true
}
if phone, ok := identity.Traits["phone_number"].(string); ok && strings.EqualFold(phone, identifier) {
return true
}
if lids, ok := identity.Traits["custom_login_ids"].([]any); ok {
for _, lid := range lids {
if s, ok := lid.(string); ok && strings.EqualFold(s, identifier) {
return true
}
}
} else if lids, ok := identity.Traits["custom_login_ids"].([]string); ok {
for _, lid := range lids {
if strings.EqualFold(lid, identifier) {
return true
}
}
}
return false
}
func (s *kratosAdminService) searchIdentities(ctx context.Context, endpoint, key, value string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return "", err
}
query := req.URL.Query()
query.Set(key, value)
req.URL.RawQuery = query.Encode()
resp, err := s.httpClient().Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return "", nil
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return "", fmt.Errorf("kratos admin search by %s failed status=%d body=%s", key, resp.StatusCode, string(body))
}
var identities []struct {
ID string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&identities); err != nil {
return "", err
}
if len(identities) == 0 {
return "", nil
}
return identities[0].ID, nil
}
func (s *kratosAdminService) GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error) {
endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("kratos admin get identity failed status=%d body=%s", resp.StatusCode, string(body))
}
var identity KratosIdentity
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
return nil, err
}
return &identity, nil
}
func (s *kratosAdminService) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*KratosIdentity, error) {
payload := map[string]any{
"schema_id": "default",
"traits": traits,
}
if strings.TrimSpace(state) != "" {
payload["state"] = strings.TrimSpace(state)
}
body, _ := json.Marshal(payload)
endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("kratos admin update identity failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var updated KratosIdentity
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
return nil, err
}
return &updated, nil
}
func (s *kratosAdminService) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
identity, err := s.GetIdentity(ctx, identityID)
if err != nil {
return err
}
if identity == nil {
return fmt.Errorf("kratos admin identity not found: %s", identityID)
}
hashedPassword, err := hashPasswordForKratosAdmin(newPassword)
if err != nil {
return err
}
payload := map[string]any{
"schema_id": identity.SchemaID,
"traits": identity.Traits,
"state": identity.State,
"credentials": map[string]any{
"password": map[string]any{
"config": map[string]string{
"hashed_password": hashedPassword,
},
},
},
}
if payload["schema_id"] == "" {
payload["schema_id"] = "default"
}
if payload["state"] == "" {
payload["state"] = "active"
}
if identity.MetadataAdmin != nil {
payload["metadata_admin"] = identity.MetadataAdmin
}
if identity.MetadataPublic != nil {
payload["metadata_public"] = identity.MetadataPublic
}
body, _ := json.Marshal(payload)
endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("kratos admin update password failed status=%d body=%s", resp.StatusCode, string(respBody))
}
return nil
}
func (s *kratosAdminService) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
if user == nil {
return "", fmt.Errorf("kratos admin: user payload is nil")
}
if strings.TrimSpace(user.ID) != "" {
return "", fmt.Errorf("kratos admin: requested identity id import is disabled; use backup/restore")
}
traits := map[string]any{
"email": user.Email,
"name": user.Name,
}
if user.PhoneNumber != "" {
traits["phone_number"] = user.PhoneNumber
}
for k, v := range user.Attributes {
if k == "id" || k == "email" {
continue
}
traits[k] = v
}
payload := map[string]any{
"schema_id": "default",
"traits": traits,
"credentials": map[string]any{
"password": map[string]any{
"config": map[string]string{
"password": password,
},
},
},
"state": "active",
}
body, _ := json.Marshal(payload)
endpoint := strings.TrimRight(s.AdminURL, "/") + "/admin/identities"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient().Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return "", fmt.Errorf("kratos admin create identity failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var created struct {
ID string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
return "", err
}
return created.ID, nil
}
func hashPasswordForKratosAdmin(password string) (string, error) {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashed), nil
}
func (s *kratosAdminService) DeleteIdentity(ctx context.Context, identityID string) error {
endpoint := fmt.Sprintf("%s/admin/identities/%s", strings.TrimRight(s.AdminURL, "/"), identityID)
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return err
}
resp, err := s.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("kratos admin delete identity failed status=%d body=%s", resp.StatusCode, string(respBody))
}
return nil
}
func (s *kratosAdminService) httpClient() *http.Client {
if s.HTTPClient != nil {
return s.HTTPClient
}
return &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 5 * time.Second,
},
}
}
func getenvKratos(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func (s *kratosAdminService) ListIdentitySessions(ctx context.Context, identityID string) ([]KratosSession, error) {
url := fmt.Sprintf("%s/admin/identities/%s/sessions", s.AdminURL, identityID)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
client := s.HTTPClient
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return []KratosSession{}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
}
var sessions []KratosSession
if err := json.NewDecoder(resp.Body).Decode(&sessions); err != nil {
return nil, err
}
return sessions, nil
}
func (s *kratosAdminService) GetSession(ctx context.Context, sessionID string) (*KratosSession, error) {
url := fmt.Sprintf("%s/admin/sessions/%s", s.AdminURL, sessionID)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
client := s.HTTPClient
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
}
var session KratosSession
if err := json.NewDecoder(resp.Body).Decode(&session); err != nil {
return nil, err
}
return &session, nil
}
func (s *kratosAdminService) DeleteSession(ctx context.Context, sessionID string) error {
url := fmt.Sprintf("%s/admin/sessions/%s", s.AdminURL, sessionID)
req, err := http.NewRequestWithContext(ctx, "DELETE", url, nil)
if err != nil {
return err
}
client := s.HTTPClient
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status: %d", resp.StatusCode)
}
return nil
}

View File

@@ -0,0 +1,63 @@
package service
import (
"bytes"
"context"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestKratosAdminService_ListIdentitiesFollowsNextPagination(t *testing.T) {
var requestedTokens []string
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
require.Equal(t, "/admin/identities", r.URL.Path)
token := r.URL.Query().Get("page_token")
requestedTokens = append(requestedTokens, token)
header := make(http.Header)
header.Set("Content-Type", "application/json")
status := http.StatusOK
body := "[]"
switch token {
case "":
header.Set(
"Link",
`</admin/identities?page_size=2&page_token=identity-2>; rel="next"`,
)
body = `[{"id":"identity-1","traits":{"email":"one@example.com"}},{"id":"identity-2","traits":{"email":"two@example.com"}}]`
case "identity-2":
body = `[{"id":"identity-3","traits":{"email":"three@example.com"}}]`
default:
t.Fatalf("unexpected page_token %q", token)
}
return &http.Response{
StatusCode: status,
Header: header,
Body: io.NopCloser(bytes.NewBufferString(body)),
Request: r,
}, nil
})}
service := &kratosAdminService{
AdminURL: "http://kratos.example",
HTTPClient: client,
}
identities, err := service.ListIdentities(context.Background())
require.NoError(t, err)
require.Equal(t, []string{"", "identity-2"}, requestedTokens)
require.Len(t, identities, 3)
require.Equal(t, "identity-1", identities[0].ID)
require.Equal(t, "identity-2", identities[1].ID)
require.Equal(t, "identity-3", identities[2].ID)
}

View File

@@ -0,0 +1,139 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)
// --- Shared Mocks for Service Tests ---
type MockKetoOutboxRepositoryShared struct {
mock.Mock
}
func (m *MockKetoOutboxRepositoryShared) Create(ctx context.Context, entry *domain.KetoOutbox) error {
return m.Called(ctx, entry).Error(0)
}
func (m *MockKetoOutboxRepositoryShared) CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error {
return m.Called(tx, entry).Error(0)
}
func (m *MockKetoOutboxRepositoryShared) FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error) {
args := m.Called(ctx, limit)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]domain.KetoOutbox), args.Error(1)
}
func (m *MockKetoOutboxRepositoryShared) ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error) {
args := m.Called(ctx, namespace, subject)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]domain.KetoOutbox), args.Error(1)
}
func (m *MockKetoOutboxRepositoryShared) UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error {
return m.Called(ctx, id, status, retryCount, lastError).Error(0)
}
func (m *MockKetoOutboxRepositoryShared) MarkProcessed(ctx context.Context, id string) error {
return m.Called(ctx, id).Error(0)
}
type MockKetoServiceShared struct {
mock.Mock
}
func (m *MockKetoServiceShared) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
args := m.Called(ctx, subject, namespace, object, relation)
return args.Bool(0), args.Error(1)
}
func (m *MockKetoServiceShared) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
args := m.Called(ctx, namespace, object, relation, subject)
return args.Error(0)
}
func (m *MockKetoServiceShared) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
args := m.Called(ctx, namespace, object, relation, subject)
return args.Error(0)
}
func (m *MockKetoServiceShared) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) {
args := m.Called(ctx, namespace, object, relation, subject)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]RelationTuple), args.Error(1)
}
func (m *MockKetoServiceShared) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
args := m.Called(ctx, namespace, relation, subject)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]string), args.Error(1)
}
type MockKratosAdminServiceShared struct {
mock.Mock
}
func (m *MockKratosAdminServiceShared) ListIdentities(ctx context.Context) ([]KratosIdentity, error) {
args := m.Called(ctx)
return args.Get(0).([]KratosIdentity), args.Error(1)
}
func (m *MockKratosAdminServiceShared) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
args := m.Called(ctx, identifier)
return args.String(0), args.Error(1)
}
func (m *MockKratosAdminServiceShared) GetIdentity(ctx context.Context, identityID string) (*KratosIdentity, error) {
args := m.Called(ctx, identityID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*KratosIdentity), args.Error(1)
}
func (m *MockKratosAdminServiceShared) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*KratosIdentity, error) {
args := m.Called(ctx, identityID, traits, state)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*KratosIdentity), args.Error(1)
}
func (m *MockKratosAdminServiceShared) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
args := m.Called(ctx, identityID, newPassword)
return args.Error(0)
}
func (m *MockKratosAdminServiceShared) DeleteIdentity(ctx context.Context, identityID string) error {
args := m.Called(ctx, identityID)
return args.Error(0)
}
func (m *MockKratosAdminServiceShared) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
args := m.Called(ctx, user, password)
return args.String(0), args.Error(1)
}
func (m *MockKratosAdminServiceShared) ListIdentitySessions(ctx context.Context, identityID string) ([]KratosSession, error) {
return nil, nil
}
func (m *MockKratosAdminServiceShared) GetSession(ctx context.Context, sessionID string) (*KratosSession, error) {
return nil, nil
}
func (m *MockKratosAdminServiceShared) DeleteSession(ctx context.Context, sessionID string) error {
return nil
}

View File

@@ -0,0 +1,967 @@
package service
import (
"baron-sso-backend/internal/domain"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
)
// OryProvider는 Kratos/Hydra를 기반으로 하는 IDP 어댑터의 최소 스켈레톤입니다.
// 지금은 스키마 메타데이터만 반환하며, 나머지 동작은 후속 작업에서 구현합니다.
type OryProvider struct {
KratosAdminURL string
KratosPublicURL string
HydraAdminURL string
HTTPClient *http.Client
}
func NewOryProvider() *OryProvider {
return &OryProvider{
KratosAdminURL: getenv("KRATOS_ADMIN_URL", "http://kratos:4434"),
KratosPublicURL: getenv("KRATOS_PUBLIC_URL", "http://kratos:4433"),
HydraAdminURL: getenv("HYDRA_ADMIN_URL", "http://hydra:4445"),
}
}
func (o *OryProvider) Name() string {
return "Ory (Kratos/Hydra)"
}
// GetMetadata는 BrokerUser가 요구하는 필드를 Kratos traits에 매핑 가능하다는 가정으로 반환합니다.
func (o *OryProvider) GetMetadata() (*domain.IDPMetadata, error) {
return &domain.IDPMetadata{
SupportedFields: []string{
"id", "custom_login_ids", "login_id", "email", "name", "phone_number",
"grade", "department", "affiliationType", "tenant_id",
},
}, nil
}
// CreateUser는 Kratos Admin API를 통해 identity를 생성합니다.
func (o *OryProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
if user == nil {
return "", fmt.Errorf("ory provider: user payload is nil")
}
if user.Email == "" || password == "" {
return "", fmt.Errorf("ory provider: email and password are required")
}
if strings.TrimSpace(user.ID) != "" {
return "", fmt.Errorf("ory provider: requested identity id import is disabled; use backup/restore")
}
existingID, err := o.findIdentityID(user.Email)
if err != nil {
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
}
if existingID != "" {
return "", fmt.Errorf("ory provider: identity already exists for email=%s", user.Email)
}
// [New] Check all custom login IDs for collisions
for _, lid := range user.CustomLoginIDs {
if lid == "" {
continue
}
existing, err := o.findIdentityID(lid)
if err != nil {
return "", fmt.Errorf("ory provider: search identity failed for %s: %w", lid, err)
}
if existing != "" {
return "", fmt.Errorf("ory provider: identifier %s already exists", lid)
}
}
// [Legacy] check single LoginID
if user.LoginID != "" {
existingLoginID, err := o.findIdentityID(user.LoginID)
if err != nil {
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
}
if existingLoginID != "" {
return "", fmt.Errorf("ory provider: identity already exists for login_id=%s", user.LoginID)
}
}
if user.PhoneNumber != "" {
existingPhoneID, err := o.findIdentityID(user.PhoneNumber)
if err != nil {
return "", fmt.Errorf("ory provider: search identity failed: %w", err)
}
if existingPhoneID != "" {
return "", fmt.Errorf("ory provider: identity already exists for phone=%s", user.PhoneNumber)
}
}
traits := map[string]any{
"email": user.Email,
"name": user.Name,
}
if len(user.CustomLoginIDs) > 0 {
traits["custom_login_ids"] = user.CustomLoginIDs
} else if user.LoginID != "" {
traits["custom_login_ids"] = []string{user.LoginID}
}
if user.PhoneNumber != "" {
traits["phone_number"] = user.PhoneNumber
}
for k, v := range user.Attributes {
// [SoT Fix] Don't let attributes overwrite core traits or use old 'id' trait
if k == "id" || k == "email" || k == "custom_login_ids" {
continue
}
traits[k] = v
}
payload := map[string]any{
"schema_id": "default",
"traits": traits,
"credentials": map[string]any{
"password": map[string]any{
"config": map[string]string{
"password": password,
},
},
},
}
verifiable := []map[string]any{
{
"value": user.Email,
"verified": true,
"via": "email",
},
}
if user.PhoneNumber != "" {
verifiable = append(verifiable, map[string]any{
"value": user.PhoneNumber,
"verified": true,
"via": "sms",
})
}
payload["verifiable_addresses"] = verifiable
payload["recovery_addresses"] = []map[string]any{
{
"value": user.Email,
"via": "email",
},
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fmt.Sprintf("%s/admin/identities", o.KratosAdminURL), bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("ory provider: build create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("ory provider: create identity request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return "", fmt.Errorf("ory provider: create identity failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var created struct {
ID string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
return "", fmt.Errorf("ory provider: decode create identity response failed: %w", err)
}
slog.Info("Ory identity created", "identity_id", created.ID, "email", user.Email)
return created.ID, nil
}
// SignIn은 Kratos Public API의 login API 플로우를 사용해 세션 토큰을 발급합니다.
func (o *OryProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
if loginID == "" || password == "" {
return nil, fmt.Errorf("ory provider: loginID and password are required")
}
flowID, err := o.startLoginFlow("")
if err != nil {
return nil, err
}
body, _ := json.Marshal(map[string]string{
"identifier": loginID,
"password": password,
"method": "password",
})
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("ory provider: build login request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("ory provider: login request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("ory provider: login failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var result struct {
SessionToken string `json:"session_token"`
SessionTokenExpiresAt time.Time `json:"session_token_expires_at"`
Session struct {
ID string `json:"id"`
Identity struct {
ID string `json:"id"`
} `json:"identity"`
} `json:"session"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("ory provider: decode login response failed: %w", err)
}
if result.SessionToken == "" {
return nil, fmt.Errorf("ory provider: empty session token returned")
}
slog.Info("Ory login successful",
"identity_id", result.Session.Identity.ID,
"loginID", loginID,
"expires_at", result.SessionTokenExpiresAt,
)
return &domain.AuthInfo{
SessionToken: &domain.Token{
JWT: result.SessionToken,
Expiration: result.SessionTokenExpiresAt,
SessionID: result.Session.ID,
},
Subject: result.Session.Identity.ID,
SetCookies: resp.Cookies(),
}, nil
}
// UserExists는 Kratos Admin API로 loginID 존재 여부를 확인합니다.
func (o *OryProvider) UserExists(loginID string) (bool, error) {
if loginID == "" {
return false, fmt.Errorf("ory provider: loginID is empty")
}
identityID, err := o.findIdentityID(loginID)
if err != nil {
return false, fmt.Errorf("ory provider: find identity failed: %w", err)
}
return identityID != "", nil
}
// IssueSession은 Ory에서 별도 세션 발급이 필요할 때 사용합니다. (현재 미지원)
func (o *OryProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
return nil, domain.ErrNotSupported
}
// InitiateLinkLogin은 Kratos Public API로 링크 로그인 플로우를 시작하고 이메일 전송을 트리거합니다.
func (o *OryProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
if loginID == "" {
return nil, fmt.Errorf("ory provider: loginID is required")
}
effectiveLoginID, err := o.resolveEffectiveLoginID(loginID)
if err != nil {
return nil, err
}
if err := o.ensureCodeLoginIdentifier(effectiveLoginID); err != nil {
return nil, err
}
init, err := o.submitLoginCodeInit(effectiveLoginID, returnTo)
if err == nil {
init.LoginID = effectiveLoginID
return init, nil
}
if shouldBootstrapCodeLogin(err) {
if ensureErr := o.ensureCodeLoginIdentifier(effectiveLoginID); ensureErr == nil {
init, initErr := o.submitLoginCodeInit(effectiveLoginID, returnTo)
if initErr == nil {
init.LoginID = effectiveLoginID
}
return init, initErr
} else {
slog.Warn("Ory code login bootstrap failed", "loginID", effectiveLoginID, "error", ensureErr)
}
}
return nil, err
}
func (o *OryProvider) resolveEffectiveLoginID(loginID string) (string, error) {
if strings.Contains(loginID, "@") {
return loginID, nil
}
identityID, err := o.findIdentityID(loginID)
if err != nil {
return "", err
}
if identityID == "" {
return "", fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
}
fullIdentity, err := o.fetchIdentityFull(identityID)
if err != nil {
return "", err
}
if fullIdentity != nil {
if emailRaw, ok := fullIdentity.Traits["email"]; ok {
if email, ok := emailRaw.(string); ok && email != "" {
return email, nil
}
}
}
return "", fmt.Errorf("ory provider: email trait missing for loginID=%s", loginID)
}
func (o *OryProvider) submitLoginCodeInit(loginID, returnTo string) (*domain.LinkLoginInit, error) {
flowID, err := o.startLoginFlow(returnTo)
if err != nil {
return nil, err
}
body, _ := json.Marshal(map[string]string{
"method": "code",
"identifier": loginID,
})
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("ory provider: build link login request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("ory provider: link login request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
if resp.StatusCode >= 300 {
init, ok := parseKratosLinkLoginResponse(flowID, respBody)
if ok {
slog.Info("Ory link login initiated with non-2xx response", "loginID", loginID, "flow_id", flowID, "status", resp.StatusCode)
return init, nil
}
return nil, fmt.Errorf("ory provider: link login failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var result struct {
ExpiresAt time.Time `json:"expires_at"`
}
_ = json.Unmarshal(respBody, &result)
slog.Info("Ory link login initiated", "loginID", loginID, "flow_id", flowID)
return &domain.LinkLoginInit{
FlowID: flowID,
ExpiresAt: result.ExpiresAt,
Mode: "link",
}, nil
}
func parseKratosLinkLoginResponse(flowID string, body []byte) (*domain.LinkLoginInit, bool) {
if len(body) == 0 {
return nil, false
}
var parsed struct {
ExpiresAt time.Time `json:"expires_at"`
State string `json:"state"`
Active string `json:"active"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, false
}
state := strings.ToLower(parsed.State)
active := strings.ToLower(parsed.Active)
if strings.Contains(state, "sent") || active == "code" {
return &domain.LinkLoginInit{
FlowID: flowID,
ExpiresAt: parsed.ExpiresAt,
Mode: "link",
}, true
}
return nil, false
}
func shouldBootstrapCodeLogin(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "has not setup sign in with code") ||
strings.Contains(msg, "4000035")
}
type kratosVerifiableAddress struct {
Value string `json:"value"`
Via string `json:"via"`
Verified bool `json:"verified"`
Status string `json:"status,omitempty"`
}
func (o *OryProvider) ensureCodeLoginIdentifier(loginID string) error {
identityID, err := o.findIdentityID(loginID)
if err != nil {
return fmt.Errorf("ory provider: find identity failed: %w", err)
}
if identityID == "" {
return fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
}
identity, err := o.fetchIdentity(identityID)
if err != nil {
return err
}
via := "sms"
if strings.Contains(loginID, "@") {
via = "email"
}
exists := false
existingIndex := -1
addresses := make([]kratosVerifiableAddress, 0, len(identity.VerifiableAddresses)+1)
for idx, addr := range identity.VerifiableAddresses {
addresses = append(addresses, kratosVerifiableAddress{
Value: addr.Value,
Via: addr.Via,
Verified: addr.Verified,
Status: addr.Status,
})
if addr.Value == loginID && addr.Via == via {
exists = true
existingIndex = idx
}
}
ops := make([]map[string]any, 0, 2)
if !exists {
ops = append(ops, map[string]any{
"op": "add",
"path": "/verifiable_addresses/-",
"value": map[string]any{
"value": loginID,
"via": via,
"verified": true,
"status": "completed",
},
})
} else {
addr := identity.VerifiableAddresses[existingIndex]
if !addr.Verified {
ops = append(ops, map[string]any{
"op": "replace",
"path": fmt.Sprintf("/verifiable_addresses/%d/verified", existingIndex),
"value": true,
})
}
if addr.Status != "" && addr.Status != "completed" {
ops = append(ops, map[string]any{
"op": "replace",
"path": fmt.Sprintf("/verifiable_addresses/%d/status", existingIndex),
"value": "completed",
})
}
}
if len(ops) == 0 {
slog.Info("Ory identity verifiable address already ready", "identity_id", identityID, "loginID", loginID, "via", via)
return nil
}
if err := o.patchIdentity(identityID, ops); err != nil {
slog.Warn("Ory identity patch failed, trying full update", "identity_id", identityID, "error", err)
}
fullIdentity, err := o.fetchIdentityFull(identityID)
if err != nil {
return err
}
addresses = make([]kratosVerifiableAddress, 0, len(fullIdentity.VerifiableAddresses)+1)
found := false
for _, addr := range fullIdentity.VerifiableAddresses {
addresses = append(addresses, kratosVerifiableAddress{
Value: addr.Value,
Via: addr.Via,
Verified: addr.Verified,
Status: addr.Status,
})
if addr.Value == loginID && addr.Via == via {
found = true
}
}
if !found {
addresses = append(addresses, kratosVerifiableAddress{
Value: loginID,
Via: via,
Verified: true,
Status: "completed",
})
}
payload := map[string]any{
"schema_id": fullIdentity.SchemaID,
"traits": fullIdentity.Traits,
"verifiable_addresses": addresses,
}
if len(fullIdentity.RecoveryAddresses) > 0 {
payload["recovery_addresses"] = fullIdentity.RecoveryAddresses
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPut, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
if err != nil {
return fmt.Errorf("ory provider: build identity update failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient().Do(req)
if err != nil {
return fmt.Errorf("ory provider: identity update failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("ory provider: identity update failed status=%d body=%s", resp.StatusCode, string(respBody))
}
slog.Info("Ory identity updated with verifiable address", "identity_id", identityID, "loginID", loginID, "via", via)
return nil
}
type kratosIdentity struct {
VerifiableAddresses []kratosVerifiableAddress `json:"verifiable_addresses"`
}
type kratosRecoveryAddress struct {
Value string `json:"value"`
Via string `json:"via"`
}
type kratosIdentityFull struct {
SchemaID string `json:"schema_id"`
Traits map[string]any `json:"traits"`
VerifiableAddresses []kratosVerifiableAddress `json:"verifiable_addresses"`
RecoveryAddresses []kratosRecoveryAddress `json:"recovery_addresses"`
}
func (o *OryProvider) patchIdentity(identityID string, ops []map[string]any) error {
body, _ := json.Marshal(ops)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
if err != nil {
return fmt.Errorf("ory provider: build identity patch failed: %w", err)
}
req.Header.Set("Content-Type", "application/json-patch+json")
resp, err := o.httpClient().Do(req)
if err != nil {
return fmt.Errorf("ory provider: identity patch failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("ory provider: identity patch failed status=%d body=%s", resp.StatusCode, string(respBody))
}
slog.Info("Ory identity patched", "identity_id", identityID, "ops", len(ops))
return nil
}
func (o *OryProvider) fetchIdentity(identityID string) (*kratosIdentity, error) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
if err != nil {
return nil, fmt.Errorf("ory provider: build identity get failed: %w", err)
}
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("ory provider: identity get failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("ory provider: identity get failed status=%d body=%s", resp.StatusCode, string(body))
}
var identity kratosIdentity
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
return nil, fmt.Errorf("ory provider: decode identity failed: %w", err)
}
return &identity, nil
}
func (o *OryProvider) fetchIdentityFull(identityID string) (*kratosIdentityFull, error) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
if err != nil {
return nil, fmt.Errorf("ory provider: build identity get failed: %w", err)
}
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("ory provider: identity get failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("ory provider: identity get failed status=%d body=%s", resp.StatusCode, string(body))
}
var identity kratosIdentityFull
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
return nil, fmt.Errorf("ory provider: decode identity failed: %w", err)
}
return &identity, nil
}
// VerifyLoginCode는 Kratos 로그인 코드 제출로 세션을 발급합니다.
func (o *OryProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
if loginID == "" || flowID == "" || code == "" {
return nil, fmt.Errorf("ory provider: loginID, flowID and code are required")
}
body, _ := json.Marshal(map[string]string{
"method": "code",
"identifier": loginID,
"code": code,
})
loginURL := fmt.Sprintf("%s/self-service/login?flow=%s", o.KratosPublicURL, flowID)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, loginURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("ory provider: build login code request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("ory provider: login code request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("ory provider: login code failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var result struct {
SessionToken string `json:"session_token"`
SessionTokenExpiresAt time.Time `json:"session_token_expires_at"`
Session struct {
ID string `json:"id"`
Identity struct {
ID string `json:"id"`
} `json:"identity"`
} `json:"session"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("ory provider: decode login code response failed: %w", err)
}
if result.SessionToken == "" {
return nil, fmt.Errorf("ory provider: empty session token returned")
}
slog.Info("Ory login code successful",
"identity_id", result.Session.Identity.ID,
"loginID", loginID,
"expires_at", result.SessionTokenExpiresAt,
)
return &domain.AuthInfo{
SessionToken: &domain.Token{
JWT: result.SessionToken,
Expiration: result.SessionTokenExpiresAt,
SessionID: result.Session.ID,
},
Subject: result.Session.Identity.ID,
SetCookies: resp.Cookies(),
}, nil
}
// GetPasswordPolicy는 Ory 환경에서 사용하는 기본 정책을 반환합니다.
func (o *OryProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return &domain.PasswordPolicy{
MinLength: 12,
Lowercase: true,
Uppercase: false,
Number: true,
NonAlphanumeric: true,
MinCharacterTypes: 0,
}, nil
}
// InitiatePasswordReset는 현재 내부 토큰/메일 흐름을 사용하고 있으므로 NO-OP로 둡니다.
func (o *OryProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
slog.Info("Ory InitiatePasswordReset bypassed (handled by app internal flow)", "loginID", loginID, "redirect", redirectUrl)
return nil
}
// VerifyPasswordResetToken는 내부 토큰 검증 흐름을 사용하므로 아직 구현하지 않습니다.
func (o *OryProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, fmt.Errorf("ory provider: VerifyPasswordResetToken not implemented (internal token flow expected)")
}
// UpdateUserPassword: Kratos Admin API를 통해 비밀번호를 갱신합니다.
func (o *OryProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
if loginID == "" || newPassword == "" {
return fmt.Errorf("ory provider: loginID or new password missing")
}
identityID, err := o.findIdentityID(loginID)
if err != nil {
return fmt.Errorf("ory provider: find identity failed: %w", err)
}
if identityID == "" {
return fmt.Errorf("ory provider: identity not found for loginID=%s", loginID)
}
identity, err := o.getIdentity(identityID)
if err != nil {
return fmt.Errorf("ory provider: load identity failed: %w", err)
}
if identity == nil {
return fmt.Errorf("ory provider: identity payload missing for loginID=%s", loginID)
}
hashedPassword, err := hashPasswordForKratos(newPassword)
if err != nil {
return fmt.Errorf("ory provider: hash password failed: %w", err)
}
payload := map[string]any{
"schema_id": identity.SchemaID,
"traits": identity.Traits,
"state": identity.State,
"credentials": map[string]any{
"password": map[string]any{
"config": map[string]string{
"hashed_password": hashedPassword,
},
},
},
}
if payload["schema_id"] == "" {
payload["schema_id"] = "default"
}
if payload["state"] == "" {
payload["state"] = "active"
}
if identity.MetadataAdmin != nil {
payload["metadata_admin"] = identity.MetadataAdmin
}
if identity.MetadataPublic != nil {
payload["metadata_public"] = identity.MetadataPublic
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPut, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), bytes.NewReader(body))
if err != nil {
return fmt.Errorf("ory provider: build request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := o.httpClient().Do(req)
if err != nil {
return fmt.Errorf("ory provider: request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("ory provider: password update failed status=%d body=%s", resp.StatusCode, string(respBody))
}
slog.Info("Ory password updated via Kratos admin", "identity_id", identityID, "loginID", loginID)
return nil
}
func getenv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
// findIdentityByID: Kratos Admin API에서 ID(UUID)로 직접 조회
func (o *OryProvider) findIdentityByID(id string) (string, error) {
identity, err := o.getIdentity(id)
if err != nil {
return "", err
}
if identity != nil {
return identity.ID, nil
}
return "", nil
}
// findIdentityID: Kratos Admin API에서 credentials_identifier로 검색 후 첫 번째 identity id 반환
func (o *OryProvider) findIdentityID(loginID string) (string, error) {
u, err := url.Parse(fmt.Sprintf("%s/admin/identities", o.KratosAdminURL))
if err != nil {
return "", err
}
query := u.Query()
query.Set("credentials_identifier", loginID)
u.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, u.String(), nil)
if err != nil {
return "", err
}
resp, err := o.httpClient().Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return "", nil
}
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return "", fmt.Errorf("kratos admin search failed status=%d body=%s", resp.StatusCode, string(body))
}
var identities []struct {
ID string `json:"id"`
Traits map[string]any `json:"traits"`
}
if err := json.NewDecoder(resp.Body).Decode(&identities); err != nil {
return "", fmt.Errorf("decode response failed: %w", err)
}
if len(identities) == 0 {
return "", nil
}
// VERIFY: Double check traits to avoid Kratos ignoring the query param
candidate := identities[0]
if email, ok := candidate.Traits["email"].(string); ok && strings.EqualFold(email, loginID) {
return candidate.ID, nil
}
if phone, ok := candidate.Traits["phone_number"].(string); ok && strings.EqualFold(phone, loginID) {
return candidate.ID, nil
}
if lids, ok := candidate.Traits["custom_login_ids"].([]any); ok {
for _, lid := range lids {
if s, ok := lid.(string); ok && strings.EqualFold(s, loginID) {
return candidate.ID, nil
}
}
} else if lids, ok := candidate.Traits["custom_login_ids"].([]string); ok {
for _, lid := range lids {
if strings.EqualFold(lid, loginID) {
return candidate.ID, nil
}
}
}
return "", nil
}
func (o *OryProvider) getIdentity(identityID string) (*KratosIdentity, error) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s/admin/identities/%s", o.KratosAdminURL, identityID), nil)
if err != nil {
return nil, err
}
resp, err := o.httpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("ory provider: get identity failed status=%d body=%s", resp.StatusCode, string(respBody))
}
var identity KratosIdentity
if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
return nil, err
}
return &identity, nil
}
func hashPasswordForKratos(password string) (string, error) {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashed), nil
}
func (o *OryProvider) httpClient() *http.Client {
if o.HTTPClient != nil {
return o.HTTPClient
}
return &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 5 * time.Second,
},
}
}
// startLoginFlow는 Kratos Public API에서 login flow ID를 발급받습니다.
func (o *OryProvider) startLoginFlow(returnTo string) (string, error) {
loginURL := fmt.Sprintf("%s/self-service/login/api", o.KratosPublicURL)
if returnTo != "" {
loginURL = loginURL + "?return_to=" + url.QueryEscape(returnTo)
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, loginURL, nil)
if err != nil {
return "", fmt.Errorf("ory provider: build login flow request failed: %w", err)
}
resp, err := o.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("ory provider: login flow request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return "", fmt.Errorf("ory provider: login flow failed status=%d body=%s", resp.StatusCode, string(body))
}
var result struct {
ID string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("ory provider: decode login flow failed: %w", err)
}
if result.ID == "" {
return "", fmt.Errorf("ory provider: empty login flow id")
}
return result.ID, nil
}

View File

@@ -0,0 +1,226 @@
package service
import (
"baron-sso-backend/internal/domain"
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
// clientForHandler returns an http.Client that routes requests to the given handler
// without real network sockets.
func clientForHandler(h http.Handler) *http.Client {
return &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
// Clone request body for handler
var bodyBytes []byte
if req.Body != nil {
bodyBytes, _ = io.ReadAll(req.Body)
}
r := httptest.NewRequest(req.Method, req.URL.String(), bytes.NewReader(bodyBytes))
r.Header = req.Header.Clone()
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
return w.Result(), nil
}),
}
}
type roundTripperFunc func(req *http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) }
func TestUpdateUserPassword_Success(t *testing.T) {
const (
loginID = "user@example.com"
identityID = "7f0dc8c3-9d5d-4f57-b3d1-123456789abc"
newPassword = "Sup3rStr0ng!Pass#2026"
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.HasPrefix(r.URL.Path, "/admin/identities") && r.Method == http.MethodGet:
if r.URL.Path == "/admin/identities" {
q := r.URL.Query()
if got := q.Get("credentials_identifier"); got != loginID {
t.Fatalf("expected credentials_identifier=%s, got=%s", loginID, got)
}
_ = json.NewEncoder(w).Encode([]map[string]any{
{
"id": identityID,
"traits": map[string]any{
"email": loginID,
},
},
})
return
}
if r.URL.Path != "/admin/identities/"+identityID {
t.Fatalf("unexpected identity lookup path: %s", r.URL.Path)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"id": identityID,
"schema_id": "default",
"state": "active",
"traits": map[string]any{
"email": loginID,
},
})
return
case r.URL.Path == "/admin/identities/"+identityID && r.Method == http.MethodPut:
body, _ := io.ReadAll(r.Body)
if !strings.Contains(string(body), "\"hashed_password\"") {
t.Fatalf("payload missing hashed_password, body=%s", string(body))
}
if strings.Contains(string(body), newPassword) {
t.Fatalf("payload must not contain plain password, body=%s", string(body))
}
if !strings.Contains(string(body), "\"schema_id\":\"default\"") {
t.Fatalf("payload missing schema_id, body=%s", string(body))
}
w.WriteHeader(http.StatusOK)
return
default:
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
}
})
provider := &OryProvider{
KratosAdminURL: "http://kratos-admin.local",
HTTPClient: clientForHandler(handler),
}
if err := provider.UpdateUserPassword(loginID, newPassword, nil); err != nil {
t.Fatalf("UpdateUserPassword returned error: %v", err)
}
}
func TestUpdateUserPassword_NotFound(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/admin/identities") && r.Method == http.MethodGet {
http.NotFound(w, r)
return
}
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
})
provider := &OryProvider{
KratosAdminURL: "http://kratos-admin.local",
HTTPClient: clientForHandler(handler),
}
err := provider.UpdateUserPassword("user@example.com", "Sup3rStr0ng!Pass#2026", nil)
if err == nil || !strings.Contains(err.Error(), "identity not found") {
t.Fatalf("expected identity not found error, got: %v", err)
}
}
func TestUpdateUserPassword_ServerError(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.HasPrefix(r.URL.Path, "/admin/identities") && r.Method == http.MethodGet:
if r.URL.Path == "/admin/identities" {
_ = json.NewEncoder(w).Encode([]map[string]any{
{
"id": "abc",
"traits": map[string]any{
"email": "user@example.com",
},
},
})
return
}
if r.URL.Path == "/admin/identities/abc" {
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "abc",
"schema_id": "default",
"state": "active",
"traits": map[string]any{
"email": "user@example.com",
},
})
return
}
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
case r.URL.Path == "/admin/identities/abc" && r.Method == http.MethodPut:
http.Error(w, "boom", http.StatusInternalServerError)
return
default:
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
}
})
provider := &OryProvider{
KratosAdminURL: "http://kratos-admin.local",
HTTPClient: clientForHandler(handler),
}
err := provider.UpdateUserPassword("user@example.com", "Sup3rStr0ng!Pass#2026", nil)
if err == nil || !strings.Contains(err.Error(), "password update failed") {
t.Fatalf("expected server error, got: %v", err)
}
}
func TestFindIdentityID_QueryEncoding(t *testing.T) {
loginID := "user+alias@example.com"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
values, _ := url.ParseQuery(r.URL.RawQuery)
if values.Get("credentials_identifier") != loginID {
t.Fatalf("expected credentials_identifier=%s, got=%s", loginID, values.Get("credentials_identifier"))
}
_ = json.NewEncoder(w).Encode([]map[string]any{
{
"id": "id-123",
"traits": map[string]any{
"email": loginID,
},
},
})
})
provider := &OryProvider{
KratosAdminURL: "http://kratos-admin.local",
HTTPClient: clientForHandler(handler),
}
id, err := provider.findIdentityID(loginID)
if err != nil {
t.Fatalf("findIdentityID returned error: %v", err)
}
if id != "id-123" {
t.Fatalf("expected id-123, got %s", id)
}
}
func TestOryProvider_CreateUser_RejectsRequestedIdentityID(t *testing.T) {
const (
email = "newuser@test.com"
name = "New User"
customUuid = "550e8400-e29b-41d4-a716-446655440000"
password = "secret123456"
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
})
provider := &OryProvider{
KratosAdminURL: "http://kratos-admin.local",
HTTPClient: clientForHandler(handler),
}
id, err := provider.CreateUser(&domain.BrokerUser{
ID: customUuid,
Email: email,
Name: name,
}, password)
if err == nil || !strings.Contains(err.Error(), "requested identity id import is disabled") {
t.Fatalf("expected requested identity id rejection, got id=%s err=%v", id, err)
}
}

View File

@@ -0,0 +1,238 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"os"
"time"
"github.com/go-redis/redis/v8"
)
var ctx = context.Background()
type RedisService struct {
Client *redis.Client
}
type identityMirrorStateStore struct {
Status string `json:"status"`
LastRefreshedAt *time.Time `json:"lastRefreshedAt,omitempty"`
LastError string `json:"lastError,omitempty"`
ObservedCount int64 `json:"observedCount,omitempty"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
}
// NewRedisService creates and returns a new RedisService
func NewRedisService() (*RedisService, error) {
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "localhost:6389" // Fallback for local dev without Docker
}
rdb := redis.NewClient(&redis.Options{
Addr: redisAddr,
})
// Ping the server to check the connection
if _, err := rdb.Ping(ctx).Result(); err != nil {
return nil, err
}
// [DEV-FIX] Disable stop-writes-on-bgsave-error to allow writes even if persistence fails
// This is common in dev docker environments with permission issues.
rdb.ConfigSet(ctx, "stop-writes-on-bgsave-error", "no")
return &RedisService{Client: rdb}, nil
}
func (s *RedisService) Ping(ctx context.Context) error {
if s.Client == nil {
return os.ErrInvalid
}
return s.Client.Ping(ctx).Err()
}
// StoreVerificationCode saves the SMS verification code with a 3-minute expiration
func (s *RedisService) StoreVerificationCode(phone, code string) error {
// Key format: "sms_verify:01012345678"
key := "sms_verify:" + phone
expiration := 3 * time.Minute
err := s.Client.Set(ctx, key, code, expiration).Err()
return err
}
// GetVerificationCode retrieves the SMS verification code
func (s *RedisService) GetVerificationCode(phone string) (string, error) {
key := "sms_verify:" + phone
code, err := s.Client.Get(ctx, key).Result()
if err == redis.Nil {
// Key does not exist (expired or incorrect phone number)
return "", nil
} else if err != nil {
return "", err
}
return code, nil
}
// DeleteVerificationCode removes the verification code after successful verification
func (s *RedisService) DeleteVerificationCode(phone string) error {
key := "sms_verify:" + phone
return s.Client.Del(ctx, key).Err()
}
// Set stores a key-value pair with expiration
func (s *RedisService) Set(key string, value string, expiration time.Duration) error {
return s.Client.Set(ctx, key, value, expiration).Err()
}
// Get retrieves a value by key
func (s *RedisService) Get(key string) (string, error) {
val, err := s.Client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil
}
return val, err
}
// Delete removes a key
func (s *RedisService) Delete(key string) error {
return s.Client.Del(ctx, key).Err()
}
func (s *RedisService) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
if s == nil || s.Client == nil {
return domain.IdentityCacheStatus{
Status: "unavailable",
RedisReady: false,
LastError: "redis service unavailable",
}, nil
}
if err := s.Client.Ping(ctx).Err(); err != nil {
return domain.IdentityCacheStatus{
Status: "failed",
RedisReady: false,
LastError: err.Error(),
}, nil
}
keyCount, err := s.countIdentityCacheKeys(ctx)
if err != nil {
return domain.IdentityCacheStatus{
Status: "failed",
RedisReady: true,
LastError: err.Error(),
}, nil
}
raw, err := s.Client.Get(ctx, "identity:mirror:state").Result()
if err == redis.Nil {
return domain.IdentityCacheStatus{
Status: "empty",
RedisReady: true,
KeyCount: keyCount,
}, nil
}
if err != nil {
return domain.IdentityCacheStatus{
Status: "failed",
RedisReady: true,
KeyCount: keyCount,
LastError: err.Error(),
}, nil
}
var stored identityMirrorStateStore
if err := json.Unmarshal([]byte(raw), &stored); err != nil {
return domain.IdentityCacheStatus{
Status: "failed",
RedisReady: true,
KeyCount: keyCount,
LastError: err.Error(),
}, nil
}
status := stored.Status
if status == "" {
status = "unknown"
}
return domain.IdentityCacheStatus{
Status: status,
RedisReady: true,
ObservedCount: stored.ObservedCount,
KeyCount: keyCount,
LastRefreshedAt: stored.LastRefreshedAt,
LastError: stored.LastError,
UpdatedAt: stored.UpdatedAt,
}, nil
}
func (s *RedisService) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
if s == nil || s.Client == nil {
return domain.IdentityCacheFlushResult{}, os.ErrInvalid
}
keys, err := s.identityCacheKeys(ctx)
if err != nil {
return domain.IdentityCacheFlushResult{}, err
}
var deleted int64
for len(keys) > 0 {
chunkSize := len(keys)
if chunkSize > 500 {
chunkSize = 500
}
chunk := keys[:chunkSize]
count, err := s.Client.Del(ctx, chunk...).Result()
if err != nil {
return domain.IdentityCacheFlushResult{}, err
}
deleted += count
keys = keys[chunkSize:]
}
return domain.IdentityCacheFlushResult{
Status: "success",
FlushedKeys: deleted,
UpdatedAt: time.Now().UTC(),
}, nil
}
func (s *RedisService) countIdentityCacheKeys(ctx context.Context) (int64, error) {
keys, err := s.identityCacheKeys(ctx)
if err != nil {
return 0, err
}
return int64(len(keys)), nil
}
func (s *RedisService) identityCacheKeys(ctx context.Context) ([]string, error) {
seen := make(map[string]bool)
patterns := []string{
"identity:mirror:*",
"identity:index:*",
}
for _, pattern := range patterns {
var cursor uint64
for {
keys, next, err := s.Client.Scan(ctx, cursor, pattern, 250).Result()
if err != nil {
return nil, err
}
for _, key := range keys {
seen[key] = true
}
cursor = next
if cursor == 0 {
break
}
}
}
keys := make([]string, 0, len(seen))
for key := range seen {
keys = append(keys, key)
}
return keys, nil
}

View File

@@ -0,0 +1,150 @@
package service
import (
"context"
"encoding/json"
"os"
"testing"
"time"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/require"
)
type redisCommandStub struct {
scans map[string][]string
stateValue string
deleted []string
}
func (h *redisCommandStub) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
return ctx, nil
}
func (h *redisCommandStub) AfterProcess(ctx context.Context, cmd redis.Cmder) error {
switch cmd.Name() {
case "ping":
if status, ok := cmd.(*redis.StatusCmd); ok {
status.SetVal("PONG")
}
case "scan":
if scan, ok := cmd.(*redis.ScanCmd); ok {
scan.SetVal(h.scans[scanPattern(cmd.Args())], 0)
}
case "get":
if str, ok := cmd.(*redis.StringCmd); ok {
if h.stateValue == "" {
str.SetErr(redis.Nil)
return nil
}
str.SetVal(h.stateValue)
}
case "del":
args := cmd.Args()
keys := make([]string, 0, len(args)-1)
for _, arg := range args[1:] {
keys = append(keys, arg.(string))
}
h.deleted = append(h.deleted, keys...)
if count, ok := cmd.(*redis.IntCmd); ok {
count.SetVal(int64(len(keys)))
}
}
cmd.SetErr(nil)
return nil
}
func (h *redisCommandStub) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
return ctx, nil
}
func (h *redisCommandStub) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error {
return nil
}
func scanPattern(args []interface{}) string {
for index := 0; index < len(args)-1; index++ {
value, ok := args[index].(string)
if ok && value == "match" {
if pattern, ok := args[index+1].(string); ok {
return pattern
}
}
}
return ""
}
func newStubbedRedisService(stub *redisCommandStub) *RedisService {
client := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
MaxRetries: -1,
})
client.AddHook(stub)
return &RedisService{Client: client}
}
func TestRedisServiceGetIdentityCacheStatusReadsStateAndCountsCacheKeys(t *testing.T) {
now := time.Date(2026, 6, 9, 3, 20, 0, 0, time.UTC)
state, err := json.Marshal(identityMirrorStateStore{
Status: "ready",
LastRefreshedAt: &now,
ObservedCount: 42,
UpdatedAt: &now,
})
require.NoError(t, err)
stub := &redisCommandStub{
stateValue: string(state),
scans: map[string][]string{
"identity:mirror:*": {"identity:mirror:state", "identity:mirror:user:1"},
"identity:index:*": {"identity:index:email:a", "identity:mirror:user:1"},
},
}
service := newStubbedRedisService(stub)
status, err := service.GetIdentityCacheStatus(context.Background())
require.NoError(t, err)
require.Equal(t, "ready", status.Status)
require.True(t, status.RedisReady)
require.Equal(t, int64(42), status.ObservedCount)
require.Equal(t, int64(3), status.KeyCount)
require.Equal(t, &now, status.LastRefreshedAt)
require.Equal(t, &now, status.UpdatedAt)
}
func TestRedisServiceFlushIdentityCacheDeletesOnlyIdentityMirrorAndIndexKeys(t *testing.T) {
stub := &redisCommandStub{
scans: map[string][]string{
"identity:mirror:*": {"identity:mirror:state", "identity:mirror:user:1"},
"identity:index:*": {"identity:index:email:a", "identity:mirror:user:1"},
},
}
service := newStubbedRedisService(stub)
result, err := service.FlushIdentityCache(context.Background())
require.NoError(t, err)
require.Equal(t, "success", result.Status)
require.Equal(t, int64(3), result.FlushedKeys)
require.ElementsMatch(t, []string{
"identity:mirror:state",
"identity:mirror:user:1",
"identity:index:email:a",
}, stub.deleted)
}
func TestRedisServiceGetIdentityCacheStatusReturnsUnavailableWithoutClient(t *testing.T) {
status, err := (*RedisService)(nil).GetIdentityCacheStatus(context.Background())
require.NoError(t, err)
require.Equal(t, "unavailable", status.Status)
require.False(t, status.RedisReady)
require.NotEmpty(t, status.LastError)
}
func TestRedisServiceFlushIdentityCacheFailsWithoutClient(t *testing.T) {
_, err := (*RedisService)(nil).FlushIdentityCache(context.Background())
require.ErrorIs(t, err, os.ErrInvalid)
}

View File

@@ -0,0 +1,215 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"fmt"
"log/slog"
"strings"
)
type RelyingPartyService interface {
Create(ctx context.Context, tenantID string, client domain.HydraClient) (*domain.RelyingParty, error)
Get(ctx context.Context, clientID string) (*domain.RelyingParty, *domain.HydraClient, error)
List(ctx context.Context, tenantID string) ([]domain.RelyingParty, error)
ListAll(ctx context.Context) ([]domain.RelyingParty, error)
ListByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.RelyingParty, error)
Update(ctx context.Context, clientID string, client domain.HydraClient) (*domain.RelyingParty, error)
Delete(ctx context.Context, clientID string) error
}
type relyingPartyService struct {
hydraService *HydraAdminService
ketoService KetoService
outboxRepo repository.KetoOutboxRepository
}
var defaultRelyingPartyOperatorRelations = []string{
"admins",
"creator",
"config_editor",
"secret_viewer",
"secret_rotator",
"jwks_viewer",
"jwks_operator",
"consent_viewer",
"consent_revoker",
"relationship_viewer",
"audit_viewer",
"status_operator",
}
func NewRelyingPartyService(
hydraService *HydraAdminService,
ketoService KetoService,
outboxRepo repository.KetoOutboxRepository,
) RelyingPartyService {
return &relyingPartyService{
hydraService: hydraService,
ketoService: ketoService,
outboxRepo: outboxRepo,
}
}
func extractRelyingPartyCreatorSubject(client *domain.HydraClient) string {
if client == nil || client.Metadata == nil {
return ""
}
raw, _ := client.Metadata["user_id"].(string)
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
return "User:" + raw
}
func (s *relyingPartyService) enqueueRelyingPartyTuple(ctx context.Context, action, object, relation, subject string) {
if s.outboxRepo == nil || strings.TrimSpace(object) == "" || strings.TrimSpace(relation) == "" || strings.TrimSpace(subject) == "" {
return
}
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "RelyingParty",
Object: object,
Relation: relation,
Subject: subject,
Action: action,
})
}
func (s *relyingPartyService) enqueueDefaultRelyingPartyRelations(ctx context.Context, action string, client *domain.HydraClient, tenantID string) {
if client == nil {
return
}
tenantID = strings.TrimSpace(tenantID)
if tenantID != "" {
s.enqueueRelyingPartyTuple(ctx, action, client.ClientID, "parents", "Tenant:"+tenantID)
}
creatorSubject := extractRelyingPartyCreatorSubject(client)
if creatorSubject == "" {
return
}
for _, relation := range defaultRelyingPartyOperatorRelations {
s.enqueueRelyingPartyTuple(ctx, action, client.ClientID, relation, creatorSubject)
}
}
func (s *relyingPartyService) Create(ctx context.Context, tenantID string, client domain.HydraClient) (*domain.RelyingParty, error) {
// 1. Create Client in Hydra
if client.Metadata == nil {
client.Metadata = make(map[string]any)
}
client.Metadata["tenant_id"] = tenantID
createdClient, err := s.hydraService.CreateClient(ctx, client)
if err != nil {
return nil, fmt.Errorf("failed to create hydra client: %w", err)
}
// 2. Create default relations in Keto via Outbox.
s.enqueueDefaultRelyingPartyRelations(ctx, domain.KetoOutboxActionCreate, createdClient, tenantID)
return s.mapHydraToDomain(createdClient), nil
}
func (s *relyingPartyService) Get(ctx context.Context, clientID string) (*domain.RelyingParty, *domain.HydraClient, error) {
hydraClient, err := s.hydraService.GetClient(ctx, clientID)
if err != nil {
return nil, nil, err
}
return s.mapHydraToDomain(hydraClient), hydraClient, nil
}
func (s *relyingPartyService) List(ctx context.Context, tenantID string) ([]domain.RelyingParty, error) {
// 1. Fetch ClientIDs from Keto
// Relation tuple: RelyingParty:cid # parents @ Tenant:tid
tuples, err := s.ketoService.ListRelations(ctx, "RelyingParty", "", "parents", "Tenant:"+tenantID)
if err != nil {
return nil, err
}
var rps []domain.RelyingParty
for _, t := range tuples {
clientID := t.Object
client, err := s.hydraService.GetClient(ctx, clientID)
if err != nil {
slog.Warn("Failed to fetch relying party from hydra", "client_id", clientID, "error", err)
continue
}
if rp := s.mapHydraToDomain(client); rp != nil {
rps = append(rps, *rp)
}
}
return rps, nil
}
func (s *relyingPartyService) ListAll(ctx context.Context) ([]domain.RelyingParty, error) {
return nil, fmt.Errorf("ListAll not implemented in SSOT mode yet")
}
func (s *relyingPartyService) ListByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.RelyingParty, error) {
var allRps []domain.RelyingParty
for _, tid := range tenantIDs {
rps, err := s.List(ctx, tid)
if err == nil {
allRps = append(allRps, rps...)
}
}
return allRps, nil
}
func (s *relyingPartyService) Update(ctx context.Context, clientID string, client domain.HydraClient) (*domain.RelyingParty, error) {
updatedClient, err := s.hydraService.UpdateClient(ctx, clientID, client)
if err != nil {
return nil, err
}
return s.mapHydraToDomain(updatedClient), nil
}
func (s *relyingPartyService) Delete(ctx context.Context, clientID string) error {
// 1. Get client to find tenantID (for Keto cleanup)
client, err := s.hydraService.GetClient(ctx, clientID)
if err != nil {
return err
}
tenantID := ""
if client.Metadata != nil {
if tid, ok := client.Metadata["tenant_id"].(string); ok {
tenantID = tid
}
}
// 2. Delete from Hydra
if err := s.hydraService.DeleteClient(ctx, clientID); err != nil {
return err
}
// 3. Delete default relations from Keto via Outbox.
s.enqueueDefaultRelyingPartyRelations(ctx, domain.KetoOutboxActionDelete, client, tenantID)
return nil
}
func (s *relyingPartyService) mapHydraToDomain(client *domain.HydraClient) *domain.RelyingParty {
if client == nil {
return nil
}
rp := &domain.RelyingParty{
ClientID: client.ClientID,
Name: client.ClientName,
}
if client.Metadata != nil {
if tid, ok := client.Metadata["tenant_id"].(string); ok {
rp.TenantID = tid
}
if desc, ok := client.Metadata["description"].(string); ok {
rp.Description = desc
}
}
return rp
}

View File

@@ -0,0 +1,217 @@
/*
이 테스트 파일은 RelyingPartyService의 기능을 검증하기 위한 유닛 테스트입니다.
RelyingPartyService는 HydraAdminService, KetoService와 협력하므로
각 의존성을 모킹(Mocking)하여 통합 로직을 검증합니다.
주요 테스트 항목:
1. Create: Hydra 클라이언트 생성 -> Keto 권한 설정
2. Get: Hydra에서 정보 조회
3. Update: Hydra 업데이트
4. Delete: Hydra 삭제 + Keto 권한 정리
*/
package service
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Test Helpers ---
type hydraRoundTripperFunc func(*http.Request) (*http.Response, error)
func (f hydraRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func mockHydraClient(handler http.Handler) *http.Client {
return &http.Client{
Transport: hydraRoundTripperFunc(func(req *http.Request) (*http.Response, error) {
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
return rec.Result(), nil
}),
}
}
// --- Tests ---
func TestRelyingPartyService_Create_Success(t *testing.T) {
mockKeto := new(MockKetoServiceShared)
mockOutbox := new(MockKetoOutboxRepositoryShared)
tenantID := "tenant-1"
inputClient := domain.HydraClient{
ClientName: "Test App",
Metadata: map[string]any{
"user_id": "creator-1",
},
}
// Hydra Mock
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/clients") {
var req domain.HydraClient
_ = json.NewDecoder(r.Body).Decode(&req)
// 메타데이터 tenant_id 주입 확인
if req.Metadata["tenant_id"] != tenantID {
t.Errorf("expected tenant_id in metadata")
}
req.ClientID = "generated-client-id"
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(req)
return
}
http.NotFound(w, r)
})
hydraSvc := &HydraAdminService{
AdminURL: "http://hydra:4445",
HTTPClient: mockHydraClient(hydraHandler),
}
// Keto sync via Outbox using 'parents' relation
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "RelyingParty" && e.Object == "generated-client-id" && e.Relation == "parents" && e.Subject == "Tenant:"+tenantID
})).Return(nil)
for _, relation := range defaultRelyingPartyOperatorRelations {
rel := relation
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "RelyingParty" && e.Object == "generated-client-id" && e.Relation == rel && e.Subject == "User:creator-1"
})).Return(nil)
}
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
rp, err := svc.Create(context.Background(), tenantID, inputClient)
assert.NoError(t, err)
assert.Equal(t, "generated-client-id", rp.ClientID)
assert.Equal(t, tenantID, rp.TenantID)
mockOutbox.AssertExpectations(t)
}
func TestRelyingPartyService_Create_HydraFail(t *testing.T) {
mockKeto := new(MockKetoServiceShared)
mockOutbox := new(MockKetoOutboxRepositoryShared)
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
hydraSvc := &HydraAdminService{
AdminURL: "http://hydra:4445",
HTTPClient: mockHydraClient(hydraHandler),
}
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
_, err := svc.Create(context.Background(), "tenant-1", domain.HydraClient{})
assert.Error(t, err)
}
func TestRelyingPartyService_Get_Success(t *testing.T) {
mockKeto := new(MockKetoServiceShared)
mockOutbox := new(MockKetoOutboxRepositoryShared)
clientID := "client-123"
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(domain.HydraClient{
ClientID: clientID,
ClientName: "Hydra Name",
Metadata: map[string]any{
"tenant_id": "tenant-1",
},
})
})
hydraSvc := &HydraAdminService{
AdminURL: "http://hydra:4445",
HTTPClient: mockHydraClient(hydraHandler),
}
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
rp, hc, err := svc.Get(context.Background(), clientID)
assert.NoError(t, err)
assert.Equal(t, "Hydra Name", rp.Name)
assert.Equal(t, "Hydra Name", hc.ClientName)
}
func TestRelyingPartyService_Update_Success(t *testing.T) {
mockKeto := new(MockKetoServiceShared)
mockOutbox := new(MockKetoOutboxRepositoryShared)
clientID := "client-123"
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPut {
var req domain.HydraClient
_ = json.NewDecoder(r.Body).Decode(&req)
_ = json.NewEncoder(w).Encode(req)
return
}
})
hydraSvc := &HydraAdminService{
AdminURL: "http://hydra:4445",
HTTPClient: mockHydraClient(hydraHandler),
}
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
updateReq := domain.HydraClient{ClientName: "New Name"}
rp, err := svc.Update(context.Background(), clientID, updateReq)
assert.NoError(t, err)
assert.Equal(t, "New Name", rp.Name)
}
func TestRelyingPartyService_Delete_Success(t *testing.T) {
mockKeto := new(MockKetoServiceShared)
mockOutbox := new(MockKetoOutboxRepositoryShared)
clientID := "client-123"
tenantID := "tenant-1"
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && strings.Contains(r.URL.Path, clientID) {
_ = json.NewEncoder(w).Encode(domain.HydraClient{
ClientID: clientID,
Metadata: map[string]any{
"tenant_id": tenantID,
"user_id": "creator-1",
},
})
return
}
if r.Method == http.MethodDelete && strings.Contains(r.URL.Path, clientID) {
w.WriteHeader(http.StatusNoContent)
return
}
http.NotFound(w, r)
})
hydraSvc := &HydraAdminService{
AdminURL: "http://hydra:4445",
HTTPClient: mockHydraClient(hydraHandler),
}
// Delete relation via Outbox using 'parents'
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "RelyingParty" && e.Object == clientID && e.Relation == "parents" && e.Subject == "Tenant:"+tenantID
})).Return(nil)
for _, relation := range defaultRelyingPartyOperatorRelations {
rel := relation
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "RelyingParty" && e.Object == clientID && e.Relation == rel && e.Subject == "User:creator-1"
})).Return(nil)
}
svc := NewRelyingPartyService(hydraSvc, mockKeto, mockOutbox)
err := svc.Delete(context.Background(), clientID)
assert.NoError(t, err)
mockOutbox.AssertExpectations(t)
}

View File

@@ -0,0 +1,67 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
)
type RPUsageEventEmitter struct {
repo repository.RPUsageOutboxRepository
}
func NewRPUsageEventEmitter(repo repository.RPUsageOutboxRepository) *RPUsageEventEmitter {
return &RPUsageEventEmitter{repo: repo}
}
func (e *RPUsageEventEmitter) EmitRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
if e == nil || e.repo == nil {
return nil
}
event.EventType = strings.TrimSpace(event.EventType)
event.Subject = strings.TrimSpace(event.Subject)
event.ClientID = strings.TrimSpace(event.ClientID)
event.Source = strings.TrimSpace(event.Source)
event.CorrelationID = strings.TrimSpace(event.CorrelationID)
if event.EventType == "" {
return fmt.Errorf("rp usage event type is required")
}
if event.Subject == "" {
return fmt.Errorf("rp usage subject is required")
}
if event.ClientID == "" {
return fmt.Errorf("rp usage client_id is required")
}
if event.Source == "" {
event.Source = "backend"
}
if event.OccurredAt.IsZero() {
event.OccurredAt = time.Now()
}
if event.DedupeKey == "" {
event.DedupeKey = buildRPUsageDedupeKey(event)
}
if event.Payload == nil {
event.Payload = domain.JSONMap{}
}
return e.repo.Create(ctx, &event)
}
func buildRPUsageDedupeKey(event domain.RPUsageEvent) string {
raw := strings.Join([]string{
event.EventType,
event.Subject,
event.ClientID,
event.SessionID,
event.Source,
event.CorrelationID,
event.OccurredAt.UTC().Format("2006-01-02T15:04:05.000Z"),
}, "|")
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}

View File

@@ -0,0 +1,132 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type fakeRPUsageOutboxRepo struct {
created []domain.RPUsageEvent
ready []domain.RPUsageEvent
processing []string
processed []string
failed []string
createErr error
projectErr error
}
func (f *fakeRPUsageOutboxRepo) Create(ctx context.Context, event *domain.RPUsageEvent) error {
if f.createErr != nil {
return f.createErr
}
f.created = append(f.created, *event)
return nil
}
func (f *fakeRPUsageOutboxRepo) ListReady(ctx context.Context, limit int) ([]domain.RPUsageEvent, error) {
return f.ready, nil
}
func (f *fakeRPUsageOutboxRepo) MarkProcessing(ctx context.Context, id string) error {
f.processing = append(f.processing, id)
return nil
}
func (f *fakeRPUsageOutboxRepo) MarkProcessed(ctx context.Context, id string) error {
f.processed = append(f.processed, id)
return nil
}
func (f *fakeRPUsageOutboxRepo) MarkFailed(ctx context.Context, id string, message string, nextAttemptAt time.Time) error {
f.failed = append(f.failed, id)
return nil
}
type fakeRPUsageProjectionRepo struct {
created []domain.RPUsageEvent
err error
}
func (f *fakeRPUsageProjectionRepo) CreateRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
if f.err != nil {
return f.err
}
f.created = append(f.created, event)
return nil
}
func TestRPUsageEventEmitterRequiresCanonicalFields(t *testing.T) {
repo := &fakeRPUsageOutboxRepo{}
emitter := NewRPUsageEventEmitter(repo)
err := emitter.EmitRPUsageEvent(context.Background(), domain.RPUsageEvent{
EventType: domain.RPUsageEventTypeAuthorizationGranted,
ClientID: "client-app",
})
require.Error(t, err)
require.Empty(t, repo.created)
}
func TestRPUsageEventEmitterCreatesPendingOutboxEvent(t *testing.T) {
repo := &fakeRPUsageOutboxRepo{}
emitter := NewRPUsageEventEmitter(repo)
err := emitter.EmitRPUsageEvent(context.Background(), domain.RPUsageEvent{
EventType: domain.RPUsageEventTypeAuthorizationGranted,
Subject: "user-123",
ClientID: "client-app",
Source: "hydra_consent",
CorrelationID: "challenge-1",
})
require.NoError(t, err)
require.Len(t, repo.created, 1)
require.NotEmpty(t, repo.created[0].DedupeKey)
require.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, repo.created[0].EventType)
require.Equal(t, "hydra_consent", repo.created[0].Source)
}
func TestRPUsageProjectorWorkerMarksProcessedAfterProjection(t *testing.T) {
outbox := &fakeRPUsageOutboxRepo{
ready: []domain.RPUsageEvent{{
ID: "event-1",
EventType: domain.RPUsageEventTypeAuthorizationGranted,
Subject: "user-123",
ClientID: "client-app",
}},
}
projection := &fakeRPUsageProjectionRepo{}
worker := NewRPUsageProjectorWorker(outbox, projection)
worker.processOnce(context.Background())
require.Equal(t, []string{"event-1"}, outbox.processing)
require.Equal(t, []string{"event-1"}, outbox.processed)
require.Empty(t, outbox.failed)
require.Len(t, projection.created, 1)
}
func TestRPUsageProjectorWorkerMarksFailedWhenProjectionFails(t *testing.T) {
outbox := &fakeRPUsageOutboxRepo{
ready: []domain.RPUsageEvent{{
ID: "event-1",
EventType: domain.RPUsageEventTypeAuthorizationGranted,
Subject: "user-123",
ClientID: "client-app",
}},
}
projection := &fakeRPUsageProjectionRepo{err: errors.New("clickhouse unavailable")}
worker := NewRPUsageProjectorWorker(outbox, projection)
worker.processOnce(context.Background())
require.Equal(t, []string{"event-1"}, outbox.processing)
require.Empty(t, outbox.processed)
require.Equal(t, []string{"event-1"}, outbox.failed)
}

View File

@@ -0,0 +1,82 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"log/slog"
"time"
)
type RPUsageProjectorWorker struct {
outbox repository.RPUsageOutboxRepository
projection domain.RPUsageProjectionRepository
interval time.Duration
batchSize int
}
func NewRPUsageProjectorWorker(outbox repository.RPUsageOutboxRepository, projection domain.RPUsageProjectionRepository) *RPUsageProjectorWorker {
return &RPUsageProjectorWorker{
outbox: outbox,
projection: projection,
interval: 5 * time.Second,
batchSize: 50,
}
}
func (w *RPUsageProjectorWorker) Start(ctx context.Context) {
if w == nil || w.outbox == nil || w.projection == nil {
return
}
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
default:
w.processOnce(ctx)
}
select {
case <-ctx.Done():
return
case <-ticker.C:
}
}
}
func (w *RPUsageProjectorWorker) processOnce(ctx context.Context) {
events, err := w.outbox.ListReady(ctx, w.batchSize)
if err != nil {
slog.Warn("failed to list rp usage outbox", "error", err)
return
}
for _, event := range events {
if err := w.outbox.MarkProcessing(ctx, event.ID); err != nil {
slog.Warn("failed to mark rp usage event processing", "event_id", event.ID, "error", err)
continue
}
if err := w.projection.CreateRPUsageEvent(ctx, event); err != nil {
nextAttempt := time.Now().Add(backoffDuration(event.RetryCount))
_ = w.outbox.MarkFailed(ctx, event.ID, err.Error(), nextAttempt)
slog.Warn("failed to project rp usage event", "event_id", event.ID, "error", err)
continue
}
if err := w.outbox.MarkProcessed(ctx, event.ID); err != nil {
slog.Warn("failed to mark rp usage event processed", "event_id", event.ID, "error", err)
}
}
}
func backoffDuration(retryCount int) time.Duration {
if retryCount < 0 {
retryCount = 0
}
delay := time.Duration(retryCount+1) * time.Minute
if delay > 30*time.Minute {
return 30 * time.Minute
}
return delay
}

View File

@@ -0,0 +1,78 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"fmt"
"log/slog"
"os"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ses"
"github.com/aws/aws-sdk-go-v2/service/ses/types"
)
type SesServiceImpl struct {
client *ses.Client
sender string
}
func NewEmailService() domain.EmailService {
region := os.Getenv("AWS_REGION")
accessKey := os.Getenv("AWS_ACCESS_KEY_ID")
secretKey := os.Getenv("AWS_SECRET_ACCESS_KEY")
sender := os.Getenv("AWS_SES_SENDER")
if region == "" || accessKey == "" || secretKey == "" {
slog.Warn("[EmailService] AWS configuration missing, email service will not work")
return nil
}
cfg, err := config.LoadDefaultConfig(context.TODO(),
config.WithRegion(region),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")),
)
if err != nil {
slog.Error("Failed to load AWS config", "error", err)
return nil
}
return &SesServiceImpl{
client: ses.NewFromConfig(cfg),
sender: sender,
}
}
func (s *SesServiceImpl) SendEmail(to, subject, body string) error {
if s == nil || s.client == nil {
return fmt.Errorf("email service not initialized")
}
input := &ses.SendEmailInput{
Destination: &types.Destination{
ToAddresses: []string{to},
},
Message: &types.Message{
Body: &types.Body{
Html: &types.Content{
Charset: aws.String("UTF-8"),
Data: aws.String(body),
},
},
Subject: &types.Content{
Charset: aws.String("UTF-8"),
Data: aws.String(subject),
},
},
Source: aws.String(s.sender),
}
_, err := s.client.SendEmail(context.TODO(), input)
if err != nil {
slog.Error("[EmailService] Failed to send email", "to", to, "error", err)
} else {
slog.Info("[EmailService] Email sent successfully", "to", to)
}
return err
}

View File

@@ -0,0 +1,63 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"errors"
"time"
)
type SharedLinkService interface {
CreateLink(ctx context.Context, tenantID, name, description string, expiresAt *time.Time) (*domain.SharedLink, error)
ValidateToken(ctx context.Context, token string) (*domain.SharedLink, error)
GetLinksByTenant(ctx context.Context, tenantID string) ([]domain.SharedLink, error)
DeactivateLink(ctx context.Context, id string) error
}
type sharedLinkService struct {
repo repository.SharedLinkRepository
}
func NewSharedLinkService(repo repository.SharedLinkRepository) SharedLinkService {
return &sharedLinkService{repo: repo}
}
func (s *sharedLinkService) CreateLink(ctx context.Context, tenantID, name, description string, expiresAt *time.Time) (*domain.SharedLink, error) {
link := &domain.SharedLink{
TenantID: tenantID,
Name: name,
Description: description,
ExpiresAt: expiresAt,
IsActive: true,
AccessLevel: "READ_ONLY",
}
if err := s.repo.Create(ctx, link); err != nil {
return nil, err
}
return link, nil
}
func (s *sharedLinkService) ValidateToken(ctx context.Context, token string) (*domain.SharedLink, error) {
link, err := s.repo.FindByToken(ctx, token)
if err != nil {
return nil, errors.New("invalid or expired share link")
}
if !link.IsValid() {
return nil, errors.New("share link has expired or is inactive")
}
return link, nil
}
func (s *sharedLinkService) GetLinksByTenant(ctx context.Context, tenantID string) ([]domain.SharedLink, error) {
return s.repo.FindByTenantID(ctx, tenantID)
}
func (s *sharedLinkService) DeactivateLink(ctx context.Context, id string) error {
// 실제 삭제 대신 비활성화 처리 (soft-delete와 유사)
// 하지만 여기서는 간단히 활성 플래그만 끔
return s.repo.Delete(ctx, id) // 리포지토리의 Delete는 GORM의 DeletedAt을 사용하여 soft-delete함
}

View File

@@ -0,0 +1,134 @@
package service
import (
"baron-sso-backend/internal/domain"
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"strings"
"time"
)
const naverSMSMaxBytes = 90
type SmsServiceImpl struct {
accessKey string
secretKey string
serviceID string
senderPhone string
}
func NewSmsService() domain.SmsService {
// Sanitize sender phone number right after reading from env
rawSenderPhone := os.Getenv("NAVER_SENDER_PHONE_NUMBER")
sanitizedSenderPhone := strings.ReplaceAll(rawSenderPhone, "-", "")
slog.Info("[서비스 초기화] 발신자 번호 처리", "원본", rawSenderPhone, "정제후", sanitizedSenderPhone)
return &SmsServiceImpl{
accessKey: os.Getenv("NAVER_CLOUD_ACCESS_KEY"),
secretKey: os.Getenv("NAVER_CLOUD_SECRET_KEY"),
serviceID: os.Getenv("NAVER_CLOUD_SERVICE_ID"),
senderPhone: sanitizedSenderPhone,
}
}
func (s *SmsServiceImpl) SendSms(to, content string) error {
timestamp := strconv.FormatInt(time.Now().UnixNano()/int64(time.Millisecond), 10)
apiURL := fmt.Sprintf("https://sens.apigw.ntruss.com/sms/v2/services/%s/messages", s.serviceID)
slog.Info("[SmsService] Requesting SENS API URL", "url", apiURL)
// Naver SENS API requires phone number without '+'
sanitizedTo := strings.Replace(to, "+", "", 1)
reqBody := buildNaverSmsRequest(s.senderPhone, sanitizedTo, content)
if reqBody.Type == "LMS" {
slog.Info("[SmsService] Upgrading message type to LMS due to content length",
"bytes", len([]byte(content)),
)
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("error marshalling request body: %w", err)
}
signature, err := s.makeSignature("POST", fmt.Sprintf("/sms/v2/services/%s/messages", s.serviceID), timestamp)
if err != nil {
return fmt.Errorf("error creating signature: %w", err)
}
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonBody))
if err != nil {
return fmt.Errorf("error creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-ncp-apigw-timestamp", timestamp)
req.Header.Set("x-ncp-iam-access-key", s.accessKey)
req.Header.Set("x-ncp-apigw-signature-v2", signature)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("error sending request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode >= 300 {
slog.Error("[SmsService] error response from naver cloud sms api", "body", string(respBody))
return fmt.Errorf("error sending sms: status code %d", resp.StatusCode)
}
slog.Info("[SmsService] sms sent successfully", "body", string(respBody))
return nil
}
func buildNaverSmsRequest(senderPhone, sanitizedTo, content string) domain.NaverSmsRequest {
requestType := "SMS"
subject := ""
if len([]byte(content)) > naverSMSMaxBytes {
requestType = "LMS"
subject = "[Baron 로그인]"
}
return domain.NaverSmsRequest{
Type: requestType,
ContentType: "COMM",
CountryCode: "82",
From: senderPhone,
Subject: subject,
Content: content,
Messages: []domain.SmsMessage{
{
To: sanitizedTo,
},
},
}
}
func (s *SmsServiceImpl) makeSignature(method, url, timestamp string) (string, error) {
space := " "
newLine := "\n"
message := method + space + url + newLine + timestamp + newLine + s.accessKey
h := hmac.New(sha256.New, []byte(s.secretKey))
_, err := h.Write([]byte(message))
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(h.Sum(nil)), nil
}

View File

@@ -0,0 +1,26 @@
package service
import "testing"
func TestBuildNaverSmsRequest_UsesSMSForShortContent(t *testing.T) {
req := buildNaverSmsRequest("0262857755", "821012345678", "123456")
if req.Type != "SMS" {
t.Fatalf("expected SMS, got %s", req.Type)
}
if req.Subject != "" {
t.Fatalf("expected empty subject for SMS, got %q", req.Subject)
}
}
func TestBuildNaverSmsRequest_UsesLMSForLongContent(t *testing.T) {
content := "[Baron 로그인] 비밀번호 재설정 링크: http://sso.example.test/api/v1/auth/password/reset/v/1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
req := buildNaverSmsRequest("0262857755", "821012345678", content)
if req.Type != "LMS" {
t.Fatalf("expected LMS, got %s", req.Type)
}
if req.Subject == "" {
t.Fatal("expected LMS subject to be set")
}
}

View File

@@ -0,0 +1,394 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/utils"
"context"
"errors"
"log/slog"
"strings"
"gorm.io/gorm"
)
type TenantService interface {
RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error)
RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error)
GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error)
GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error)
GetTenant(ctx context.Context, id string) (*domain.Tenant, error)
ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error)
ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error)
ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error)
IsDomainAllowed(ctx context.Context, domainName string) (bool, error)
ApproveTenant(ctx context.Context, id string) error
ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error)
SetKetoService(keto KetoService)
DeleteTenantsBulk(ctx context.Context, ids []string) error
}
type tenantService struct {
repo repository.TenantRepository
userRepo repository.UserRepository
userGroupRepo repository.UserGroupRepository
keto KetoService
outboxRepo repository.KetoOutboxRepository
}
func NewTenantService(repo repository.TenantRepository, userRepo repository.UserRepository, userGroupRepo repository.UserGroupRepository, outboxRepo repository.KetoOutboxRepository) TenantService {
return &tenantService{
repo: repo,
userRepo: userRepo,
userGroupRepo: userGroupRepo,
outboxRepo: outboxRepo,
}
}
func (s *tenantService) SetKetoService(keto KetoService) {
s.keto = keto
}
func (s *tenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
return s.repo.FindByID(ctx, id)
}
func (s *tenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
if s.keto == nil {
return nil, errors.New("keto service not initialized")
}
allIDs, err := s.keto.ListObjects(ctx, "Tenant", "manage", "User:"+userID)
if err != nil {
slog.Error("Failed to list manageable tenants from Keto", "userID", userID, "error", err)
return []domain.Tenant{}, nil
}
if len(allIDs) == 0 {
directAdminIDs, _ := s.keto.ListObjects(ctx, "Tenant", "admins", "User:"+userID)
directOwnerIDs, _ := s.keto.ListObjects(ctx, "Tenant", "owners", "User:"+userID)
idMap := make(map[string]bool)
for _, id := range directAdminIDs {
idMap[id] = true
}
for _, id := range directOwnerIDs {
idMap[id] = true
}
allIDs = make([]string, 0, len(idMap))
for id := range idMap {
allIDs = append(allIDs, id)
}
}
if len(allIDs) == 0 {
return []domain.Tenant{}, nil
}
return s.repo.FindByIDs(ctx, allIDs)
}
func (s *tenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
if s.keto == nil {
return nil, errors.New("keto service not initialized")
}
memberIDs, err := s.keto.ListObjects(ctx, "Tenant", "members", "User:"+userID)
if err != nil {
slog.Error("Failed to list joined tenants from Keto", "userID", userID, "error", err)
return []domain.Tenant{}, nil
}
ownerIDs, _ := s.keto.ListObjects(ctx, "Tenant", "owners", "User:"+userID)
adminIDs, _ := s.keto.ListObjects(ctx, "Tenant", "admins", "User:"+userID)
idMap := make(map[string]bool)
for _, id := range memberIDs {
idMap[id] = true
}
for _, id := range ownerIDs {
idMap[id] = true
}
for _, id := range adminIDs {
idMap[id] = true
}
allIDs := make([]string, 0, len(idMap))
for id := range idMap {
allIDs = append(allIDs, id)
}
if len(allIDs) == 0 {
return []domain.Tenant{}, nil
}
return s.repo.FindByIDs(ctx, allIDs)
}
func (s *tenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
if ok, msg := utils.ValidateSlug(slug); !ok {
return nil, errors.New(msg)
}
existing, err := s.repo.FindBySlug(ctx, slug)
if err == nil && existing != nil {
return nil, errors.New("tenant slug already exists")
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
tenant := &domain.Tenant{
Type: tenantType,
Name: name,
Slug: slug,
Description: description,
Status: domain.TenantStatusActive,
ParentID: parentID,
}
if err := s.repo.Create(ctx, tenant); err != nil {
return nil, err
}
if s.outboxRepo != nil {
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "admins",
Subject: "System:global#super_admins",
Action: domain.KetoOutboxActionCreate,
})
if tenant.ParentID != nil {
if err := s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "parents",
Subject: "Tenant:" + *tenant.ParentID,
Action: domain.KetoOutboxActionCreate,
}); err != nil {
slog.Error("Failed to create outbox entry for tenant hierarchy", "tenant", tenant.ID, "error", err)
}
}
if creatorID != "" {
slog.Info("Creating outbox entries for tenant creator", "tenant", tenant.ID, "creator", creatorID)
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "owners",
Subject: "User:" + creatorID,
Action: domain.KetoOutboxActionCreate,
})
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "admins",
Subject: "User:" + creatorID,
Action: domain.KetoOutboxActionCreate,
})
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "members",
Subject: "User:" + creatorID,
Action: domain.KetoOutboxActionCreate,
})
}
}
for _, d := range domains {
if err := s.repo.AddDomain(ctx, tenant.ID, d, true); err != nil {
slog.Error("Failed to add domain to tenant", "tenant", slug, "domain", d, "error", err)
}
}
return s.repo.FindBySlug(ctx, slug)
}
func (s *tenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
if ok, msg := utils.ValidateSlug(slug); !ok {
return nil, errors.New(msg)
}
parts := strings.Split(adminEmail, "@")
if len(parts) != 2 || parts[1] != domainName {
return nil, errors.New("admin email domain must match the tenant domain")
}
tenant := &domain.Tenant{
Type: domain.TenantTypeCompany,
Name: name,
Slug: slug,
Description: description,
Status: domain.TenantStatusPending,
Config: domain.JSONMap{"adminEmail": adminEmail},
}
if err := s.repo.Create(ctx, tenant); err != nil {
return nil, err
}
if s.outboxRepo != nil {
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "admins",
Subject: "System:global#super_admins",
Action: domain.KetoOutboxActionCreate,
})
}
if err := s.repo.AddDomain(ctx, tenant.ID, domainName, false); err != nil {
return nil, err
}
return tenant, nil
}
func (s *tenantService) ApproveTenant(ctx context.Context, id string) error {
tenant, err := s.repo.FindByID(ctx, id)
if err != nil {
return err
}
tenant.Status = domain.TenantStatusActive
if err := s.repo.Update(ctx, tenant); err != nil {
return err
}
if s.outboxRepo != nil {
if adminEmail, ok := tenant.Config["adminEmail"].(string); ok && adminEmail != "" {
slog.Info("Queueing tenant admin/owner sync to Keto", "tenant", tenant.Slug, "adminEmail", adminEmail)
if s.userRepo != nil {
user, err := s.userRepo.FindByEmail(ctx, adminEmail)
if err == nil && user != nil {
slog.Info("Queueing tenant ownership/membership sync to Keto", "tenant", tenant.Slug, "userID", user.ID)
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "owners",
Subject: "User:" + user.ID,
Action: domain.KetoOutboxActionCreate,
})
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "admins",
Subject: "User:" + user.ID,
Action: domain.KetoOutboxActionCreate,
})
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "members",
Subject: "User:" + user.ID,
Action: domain.KetoOutboxActionCreate,
})
} else {
slog.Info("Tenant admin user not found in local DB, will need manual sync or sync on signup", "email", adminEmail)
}
}
}
}
return nil
}
func (s *tenantService) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
tenant, err := s.repo.FindByDomain(ctx, emailDomain)
if err != nil {
return nil, err
}
if tenant.Status != domain.TenantStatusActive {
return nil, errors.New("tenant is not active")
}
return tenant, nil
}
func (s *tenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return s.repo.FindBySlug(ctx, slug)
}
func (s *tenantService) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
return s.repo.List(ctx, limit, offset, parentID, search)
}
func (s *tenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
tenant, err := s.repo.FindByDomain(ctx, domainName)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return false, err
}
return tenant != nil && tenant.Status == domain.TenantStatusActive, nil
}
func (s *tenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
groups, err := s.repo.ListByType(ctx, domain.TenantTypeCompanyGroup)
if err != nil {
return nil, err
}
for _, g := range groups {
rawConfig, ok := g.Config["autoProvisioning"].(map[string]any)
if !ok {
continue
}
enabled, _ := rawConfig["enabled"].(bool)
if !enabled {
continue
}
mapping, ok := rawConfig["mappingRules"].(map[string]any)
if !ok {
continue
}
rule, ok := mapping[domainName].(map[string]any)
if !ok {
continue
}
slug, _ := rule["slug"].(string)
name, _ := rule["name"].(string)
if slug == "" || name == "" {
continue
}
slog.Info("[Provisioning] Found rule for domain, creating sub-tenant", "domain", domainName, "parent", g.Slug, "newTenant", slug)
return s.RegisterTenant(ctx, name, slug, domain.TenantTypeCompany, "Automatically provisioned via group policy", []string{domainName}, &g.ID, "")
}
return nil, gorm.ErrRecordNotFound
}
func (s *tenantService) DeleteTenantsBulk(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return nil
}
if err := s.repo.DeleteBulk(ctx, ids); err != nil {
return err
}
if s.outboxRepo != nil {
for _, id := range ids {
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: id,
Relation: "parents",
Action: domain.KetoOutboxActionDelete,
})
}
}
return nil
}

View File

@@ -0,0 +1,114 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)
func TestTenantService_RegisterTenant_DuplicateSlug(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
svc := NewTenantService(mockRepo, nil, nil, nil)
ctx := context.Background()
slug := "duplicate-slug"
// Mock: slug already exists
mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: "existing-id", Slug: slug}, nil)
tenant, err := svc.RegisterTenant(ctx, "New Name", slug, domain.TenantTypeCompany, "", nil, nil, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "already exists")
assert.Nil(t, tenant)
}
func TestTenantService_RegisterTenant_InvalidSlug(t *testing.T) {
svc := NewTenantService(nil, nil, nil, nil)
ctx := context.Background()
// Case 1: Too short
_, err := svc.RegisterTenant(ctx, "Name", "a", domain.TenantTypeCompany, "", nil, nil, "")
assert.Error(t, err)
// Case 2: Invalid characters
_, err = svc.RegisterTenant(ctx, "Name", "Invalid Slug!", domain.TenantTypeCompany, "", nil, nil, "")
assert.Error(t, err)
}
func TestTenantService_RequestRegistration_EmailMismatch(t *testing.T) {
svc := NewTenantService(nil, nil, nil, nil)
ctx := context.Background()
// admin email domain (gmail.com) != tenant domain (company.com)
tenant, err := svc.RequestRegistration(ctx, "Name", "slug", "", "company.com", "admin@gmail.com")
assert.Error(t, err)
assert.Contains(t, err.Error(), "must match")
assert.Nil(t, tenant)
}
func TestTenantService_ApproveTenant_NotFound(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
svc := NewTenantService(mockRepo, nil, nil, nil)
ctx := context.Background()
id := "non-existent-id"
mockRepo.On("FindByID", ctx, id).Return(nil, gorm.ErrRecordNotFound)
err := svc.ApproveTenant(ctx, id)
assert.Error(t, err)
assert.True(t, errors.Is(err, gorm.ErrRecordNotFound))
}
func TestTenantService_GetTenantByDomain_Inactive(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
svc := NewTenantService(mockRepo, nil, nil, nil)
ctx := context.Background()
domainName := "inactive.com"
mockRepo.On("FindByDomain", ctx, domainName).Return(&domain.Tenant{
ID: "t1",
Status: domain.TenantStatusPending,
}, nil)
tenant, err := svc.GetTenantByDomain(ctx, domainName)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not active")
assert.Nil(t, tenant)
}
func TestTenantService_ApproveTenant_UserNotFound(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
mockUserRepo := new(MockUserRepoForTenant)
mockOutbox := new(MockKetoOutboxRepositoryShared)
svc := NewTenantService(mockRepo, mockUserRepo, nil, mockOutbox)
ctx := context.Background()
tenantID := "t1"
adminEmail := "notfound@tenant.com"
tenant := &domain.Tenant{
ID: tenantID,
Slug: "tenant-slug",
Config: domain.JSONMap{"adminEmail": adminEmail},
}
mockRepo.On("FindByID", ctx, tenantID).Return(tenant, nil)
mockRepo.On("Update", ctx, mock.Anything).Return(nil)
// User not found in DB
mockUserRepo.On("FindByEmail", adminEmail).Return(nil, gorm.ErrRecordNotFound)
// Outbox should not be called since user is not found
err := svc.ApproveTenant(ctx, tenantID)
assert.NoError(t, err) // Should succeed but just log that user is not found
mockRepo.AssertExpectations(t)
mockUserRepo.AssertExpectations(t)
mockOutbox.AssertNotCalled(t, "Create")
}

View File

@@ -0,0 +1,345 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)
// --- Local Mocks to avoid collisions ---
type MockTenantRepoForSvc struct {
mock.Mock
}
func (m *MockTenantRepoForSvc) Create(ctx context.Context, tenant *domain.Tenant) error {
return m.Called(ctx, tenant).Error(0)
}
func (m *MockTenantRepoForSvc) Update(ctx context.Context, tenant *domain.Tenant) error {
return m.Called(ctx, tenant).Error(0)
}
func (m *MockTenantRepoForSvc) FindByID(ctx context.Context, id string) (*domain.Tenant, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantRepoForSvc) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
args := m.Called(ctx, slug)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantRepoForSvc) FindByName(ctx context.Context, name string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantRepoForSvc) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
args := m.Called(ctx, domainName)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantRepoForSvc) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantRepoForSvc) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
return m.Called(ctx, tenantID, domainName, verified).Error(0)
}
func (m *MockTenantRepoForSvc) List(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
args := m.Called(ctx, limit, offset, parentID, search)
return args.Get(0).([]domain.Tenant), args.Get(1).(int64), args.Error(2)
}
func (m *MockTenantRepoForSvc) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
args := m.Called(ctx, tenantType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantRepoForSvc) DeleteBulk(ctx context.Context, ids []string) error {
args := m.Called(ctx, ids)
return args.Error(0)
}
type MockKetoSvcForTenant struct {
mock.Mock
}
func (m *MockKetoSvcForTenant) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
return m.Called(ctx, namespace, object, relation, subject).Error(0)
}
func (m *MockKetoSvcForTenant) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
return m.Called(ctx, namespace, object, relation, subject).Error(0)
}
func (m *MockKetoSvcForTenant) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) {
args := m.Called(ctx, namespace, object, relation, subject)
return args.Get(0).([]RelationTuple), args.Error(1)
}
func (m *MockKetoSvcForTenant) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
args := m.Called(ctx, namespace, relation, subject)
return args.Get(0).([]string), args.Error(1)
}
func (m *MockKetoSvcForTenant) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) {
args := m.Called(ctx, namespace, object, relation, subject)
return args.Bool(0), args.Error(1)
}
type MockUserRepoForTenant struct {
mock.Mock
}
func (m *MockUserRepoForTenant) Create(ctx context.Context, user *domain.User) error { return nil }
func (m *MockUserRepoForTenant) Update(ctx context.Context, user *domain.User) error { return nil }
func (m *MockUserRepoForTenant) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
args := m.Called(email)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.User), args.Error(1)
}
func (m *MockUserRepoForTenant) Delete(ctx context.Context, id string) error {
return m.Called(ctx, id).Error(0)
}
func (m *MockUserRepoForTenant) FindByID(ctx context.Context, id string) (*domain.User, error) {
return nil, nil
}
func (m *MockUserRepoForTenant) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
return nil, nil
}
func (m *MockUserRepoForTenant) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
return nil, nil
}
func (m *MockUserRepoForTenant) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) {
return nil, 0, "", nil
}
func (m *MockUserRepoForTenant) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
args := m.Called(tenantID)
return int64(args.Int(0)), args.Error(1)
}
func (m *MockUserRepoForTenant) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
args := m.Called(ctx, tenantIDs)
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *MockUserRepoForTenant) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
args := m.Called(tenantIDs)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(map[string]int64), args.Error(1)
}
func (m *MockUserRepoForTenant) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
args := m.Called(ctx, codes)
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *MockUserRepoForTenant) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
args := m.Called(ctx, codes)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(map[string]int64), args.Error(1)
}
func (m *MockUserRepoForTenant) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
return nil
}
func (m *MockUserRepoForTenant) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
return nil, nil
}
func (m *MockUserRepoForTenant) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
return false, nil
}
func (m *MockUserRepoForTenant) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
return "", nil
}
func (m *MockUserRepoForTenant) DB() *gorm.DB {
return nil
}
// --- Tests ---
func TestTenantService_RegisterTenant_AutoVerify(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
mockOutbox := new(MockKetoOutboxRepositoryShared)
svc := NewTenantService(mockRepo, nil, nil, mockOutbox)
ctx := context.Background()
name := "New Tenant"
slug := "new-tenant"
domains := []string{"example.com"}
// Use .Once() to ensure correct return values for sequential calls to FindBySlug
mockRepo.On("FindBySlug", ctx, slug).Return(nil, nil).Once()
mockRepo.On("Create", ctx, mock.Anything).Return(nil)
mockRepo.On("AddDomain", ctx, mock.Anything, "example.com", true).Return(nil)
mockOutbox.On("Create", ctx, mock.MatchedBy(func(k *domain.KetoOutbox) bool {
return k.Relation == "admins" && k.Subject == "System:global#super_admins"
})).Return(nil)
mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: "t1", Slug: slug}, nil).Once()
tenant, err := svc.RegisterTenant(ctx, name, slug, domain.TenantTypeCompany, "", domains, nil, "")
assert.NoError(t, err)
assert.NotNil(t, tenant)
assert.Equal(t, "t1", tenant.ID)
mockRepo.AssertExpectations(t)
}
func TestTenantService_RegisterTenant_WithCreator(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
mockOutbox := new(MockKetoOutboxRepositoryShared)
svc := NewTenantService(mockRepo, nil, nil, mockOutbox)
ctx := context.Background()
name := "Creator Tenant"
slug := "creator-tenant"
creatorID := "creator-uuid"
tenantID := "t-new"
mockRepo.On("FindBySlug", ctx, slug).Return(nil, nil).Once()
mockRepo.On("Create", ctx, mock.MatchedBy(func(t *domain.Tenant) bool {
return t.Slug == slug
})).Run(func(args mock.Arguments) {
t := args.Get(1).(*domain.Tenant)
t.ID = tenantID
}).Return(nil)
// Expect global super admin sync
mockOutbox.On("Create", ctx, mock.MatchedBy(func(k *domain.KetoOutbox) bool {
return k.Relation == "admins" && k.Subject == "System:global#super_admins"
})).Return(nil)
// Expect owners sync
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "owners" && e.Subject == "User:"+creatorID
})).Return(nil)
// Expect admins sync
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "admins" && e.Subject == "User:"+creatorID
})).Return(nil)
// Expect members sync
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+creatorID
})).Return(nil)
mockRepo.On("FindBySlug", ctx, slug).Return(&domain.Tenant{ID: tenantID, Slug: slug}, nil).Once()
tenant, err := svc.RegisterTenant(ctx, name, slug, domain.TenantTypeCompany, "", nil, nil, creatorID)
assert.NoError(t, err)
assert.NotNil(t, tenant)
mockRepo.AssertExpectations(t)
mockOutbox.AssertExpectations(t)
}
func TestTenantService_RequestRegistration_NoVerify(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
mockOutbox := new(MockKetoOutboxRepositoryShared)
svc := NewTenantService(mockRepo, nil, nil, mockOutbox)
ctx := context.Background()
name := "Public Tenant"
slug := "public-tenant"
domainName := "public.com"
adminEmail := "admin@public.com"
mockRepo.On("Create", ctx, mock.MatchedBy(func(tenant *domain.Tenant) bool {
return tenant.Status == domain.TenantStatusPending
})).Return(nil)
mockOutbox.On("Create", ctx, mock.MatchedBy(func(k *domain.KetoOutbox) bool {
return k.Relation == "admins" && k.Subject == "System:global#super_admins"
})).Return(nil)
mockRepo.On("AddDomain", ctx, mock.Anything, domainName, false).Return(nil)
tenant, err := svc.RequestRegistration(ctx, name, slug, "", domainName, adminEmail)
assert.NoError(t, err)
assert.NotNil(t, tenant)
mockRepo.AssertExpectations(t)
}
func TestTenantService_ApproveTenant_SyncAdmin(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
mockUserRepo := new(MockUserRepoForTenant)
mockKeto := new(MockKetoSvcForTenant)
mockOutbox := new(MockKetoOutboxRepositoryShared)
svc := NewTenantService(mockRepo, mockUserRepo, nil, mockOutbox)
svc.SetKetoService(mockKeto)
ctx := context.Background()
tenantID := "t1"
adminEmail := "admin@tenant.com"
userID := "user-uuid"
tenant := &domain.Tenant{
ID: tenantID,
Slug: "tenant-slug",
Config: domain.JSONMap{"adminEmail": adminEmail},
}
mockRepo.On("FindByID", ctx, tenantID).Return(tenant, nil)
mockRepo.On("Update", ctx, mock.Anything).Return(nil)
mockUserRepo.On("FindByEmail", adminEmail).Return(&domain.User{ID: userID, Email: adminEmail}, nil)
// Now using Outbox instead of direct Keto call
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "owners" && e.Subject == "User:"+userID
})).Return(nil)
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "admins" && e.Subject == "User:"+userID
})).Return(nil)
mockOutbox.On("Create", ctx, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil)
err := svc.ApproveTenant(ctx, tenantID)
assert.NoError(t, err)
mockRepo.AssertExpectations(t)
mockUserRepo.AssertExpectations(t)
mockOutbox.AssertExpectations(t)
}
func TestTenantService_ListTenants(t *testing.T) {
mockRepo := new(MockTenantRepoForSvc)
svc := NewTenantService(mockRepo, nil, nil, nil)
ctx := context.Background()
tenants := []domain.Tenant{{ID: "t1", Name: "Tenant 1"}}
mockRepo.On("List", ctx, 10, 0, "", "").Return(tenants, int64(1), nil)
result, total, err := svc.ListTenants(ctx, 10, 0, "", "")
assert.NoError(t, err)
assert.Equal(t, int64(1), total)
assert.Equal(t, tenants, result)
mockRepo.AssertExpectations(t)
}

View File

@@ -0,0 +1,466 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
)
type UserGroupService interface {
Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error)
Get(ctx context.Context, id string) (*domain.UserGroup, error)
List(ctx context.Context, tenantID string) ([]domain.UserGroup, error)
Delete(ctx context.Context, tenantID, groupID string) error
Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error)
SetWorksmobileSyncer(syncer WorksmobileSyncer)
// Member Management with Keto Sync
AddMember(ctx context.Context, groupID, userID string) error
RemoveMember(ctx context.Context, groupID, userID string) error
// Permission Management
ListRoles(ctx context.Context, groupID string) ([]domain.GroupRole, error)
AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error
RemoveRoleFromTenant(ctx context.Context, groupID, tenantID, relation string) error
}
type userGroupService struct {
repo repository.UserGroupRepository
userRepo repository.UserRepository
tenantRepo repository.TenantRepository
ketoService KetoService
outboxRepo repository.KetoOutboxRepository
kratos KratosAdminService
worksmobile WorksmobileSyncer
}
func NewUserGroupService(
repo repository.UserGroupRepository,
userRepo repository.UserRepository,
tenantRepo repository.TenantRepository,
keto KetoService,
outbox repository.KetoOutboxRepository,
kratos KratosAdminService,
) UserGroupService {
return &userGroupService{
repo: repo,
userRepo: userRepo,
tenantRepo: tenantRepo,
ketoService: keto,
outboxRepo: outbox,
kratos: kratos,
}
}
func (s *userGroupService) SetWorksmobileSyncer(syncer WorksmobileSyncer) {
s.worksmobile = syncer
}
func (s *userGroupService) Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error) {
// For Keto and Tenant hierarchy, if no parent group, the company tenant is the parent.
actualParentID := parentID
if actualParentID == nil || *actualParentID == "" {
actualParentID = &tenantID
}
// Validate parent tenant exists
if _, err := s.tenantRepo.FindByID(ctx, *actualParentID); err != nil {
return nil, fmt.Errorf("parent tenant not found or invalid: %w", err)
}
unitID := uuid.NewString()
// 1. Create Tenant (Type: ORGANIZATION)
groupTenant := &domain.Tenant{
ID: unitID,
Type: domain.TenantTypeOrganization,
ParentID: actualParentID,
Name: name,
Slug: fmt.Sprintf("ug-%s", unitID[:8]),
Description: description,
Status: domain.TenantStatusActive,
}
if err := s.tenantRepo.Create(ctx, groupTenant); err != nil {
slog.Error("Failed to create tenant record for user group", "error", err)
return nil, err
}
// 2. Create UserGroup metadata
// parent_id in user_groups refers to other groups, so use original parentID (which might be nil)
group := &domain.UserGroup{
ID: unitID,
TenantID: tenantID,
ParentID: parentID,
Name: name,
Description: description,
UnitType: unitType,
}
if err := s.repo.Create(ctx, group); err != nil {
// Rollback Tenant creation? Or handle via cleanup job. For now, just log.
slog.Error("Failed to create user group metadata after creating tenant", "tenantId", unitID, "error", err)
return nil, err
}
// 3. Keto Hierarchy via Outbox: Tenant:<child_id>#parents@Tenant:<parent_id>
if s.outboxRepo != nil {
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: unitID,
Relation: "parents",
Subject: "Tenant:" + *actualParentID,
Action: domain.KetoOutboxActionCreate,
})
}
return group, nil
}
func (s *userGroupService) Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error) {
// Implementation for Update
return nil, nil // Placeholder
}
func (s *userGroupService) Delete(ctx context.Context, tenantID, groupID string) error {
// Implementation for Delete
return nil // Placeholder
}
func (s *userGroupService) populateMembers(ctx context.Context, group *domain.UserGroup) {
tuples, err := s.ketoService.ListRelations(ctx, "Tenant", group.ID, "members", "")
if err != nil {
slog.Error("Failed to fetch group members from keto", "error", err, "group_id", group.ID)
group.Members = []domain.User{}
return
}
var userIDs []string
for _, t := range tuples {
sid := t.SubjectID
if len(sid) > 5 && sid[:5] == "User:" {
userIDs = append(userIDs, sid[5:])
} else {
userIDs = append(userIDs, sid)
}
}
if len(userIDs) > 0 {
members, err := s.userRepo.FindByIDs(ctx, userIDs)
if err != nil {
slog.Error("Failed to fetch member details from db", "error", err)
}
memberMap := make(map[string]domain.User)
for _, m := range members {
memberMap[m.ID] = m
}
var finalMembers []domain.User
for _, uid := range userIDs {
if m, ok := memberMap[uid]; ok {
finalMembers = append(finalMembers, m)
} else if s.kratos != nil {
identity, err := s.kratos.GetIdentity(ctx, uid)
if err == nil && identity != nil {
name, _ := identity.Traits["name"].(string)
email, _ := identity.Traits["email"].(string)
finalMembers = append(finalMembers, domain.User{
ID: uid,
Name: name,
Email: email,
})
}
}
}
group.Members = finalMembers
} else {
group.Members = []domain.User{}
}
}
func (s *userGroupService) Get(ctx context.Context, id string) (*domain.UserGroup, error) {
group, err := s.repo.FindByID(ctx, id)
if err != nil {
return nil, err
}
s.populateMembers(ctx, group)
return group, nil
}
func (s *userGroupService) List(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
groups, err := s.repo.ListByTenantID(ctx, tenantID)
if err != nil {
return nil, err
}
if s.ketoService == nil {
return groups, nil
}
for i := range groups {
s.populateMembers(ctx, &groups[i])
}
return groups, nil
}
func (s *userGroupService) AddMember(ctx context.Context, groupID, userID string) error {
// Validate group exists
group, err := s.repo.FindByID(ctx, groupID)
if err != nil {
return fmt.Errorf("user group not found: %w", err)
}
var tenant *domain.Tenant
if s.tenantRepo != nil {
tenant, _ = s.tenantRepo.FindByID(ctx, group.TenantID)
}
// Kratos는 identity SSOT이고 조직/부서 정보의 원장이 아니므로 AddMember에서 traits를 수정하지 않습니다.
if s.userRepo != nil && tenant != nil {
localUser, err := s.userRepo.FindByID(ctx, userID)
if err != nil || localUser == nil {
if s.kratos != nil {
identity, identityErr := s.kratos.GetIdentity(ctx, userID)
if identityErr == nil && identity != nil {
localUser = mapUserGroupKratosIdentityToLocalUser(*identity)
} else {
slog.Warn("Skipping local user sync during AddMember because identity read is unavailable", "user", userID, "error", identityErr)
}
} else {
slog.Warn("Skipping local user sync during AddMember because identity projection is unavailable", "user", userID, "error", err)
}
}
if localUser != nil {
localUser.TenantID = &tenant.ID
localUser.Department = group.Name
if err := s.userRepo.Update(ctx, localUser); err != nil {
slog.Error("Failed to sync local user during AddMember", "user", userID, "error", err)
} else if s.worksmobile != nil {
if err := s.worksmobile.EnqueueUserUpsertIfInScope(ctx, *localUser); err != nil {
slog.Warn("Failed to enqueue Worksmobile user sync during AddMember", "user", userID, "error", err)
}
}
}
}
// Keto via Outbox: Tenant:<groupID>#members@User:<userID>
if s.outboxRepo != nil {
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: groupID,
Relation: "members",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionCreate,
})
// Also add direct Tenant membership to Keto for member counting
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: group.TenantID,
Relation: "members",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionCreate,
})
}
return nil
}
func mapUserGroupKratosIdentityToLocalUser(identity KratosIdentity) *domain.User {
traits := identity.Traits
now := time.Now()
createdAt := identity.CreatedAt
if createdAt.IsZero() {
createdAt = now
}
updatedAt := identity.UpdatedAt
if updatedAt.IsZero() {
updatedAt = now
}
role, ok := domain.NormalizeRoleAlias(userGroupTraitString(traits, "role"))
if !ok {
role, ok = domain.NormalizeRoleAlias(userGroupTraitString(traits, "grade"))
if !ok {
role = domain.RoleUser
}
}
grade := userGroupTraitString(traits, "grade")
if _, ok := domain.NormalizeRoleAlias(grade); ok {
grade = ""
}
user := &domain.User{
ID: identity.ID,
Email: userGroupTraitString(traits, "email"),
Name: userGroupTraitString(traits, "name"),
Phone: domain.NormalizePhoneNumber(userGroupTraitString(traits, "phone_number")),
Role: role,
Status: userGroupIdentityStatus(identity.State),
Department: userGroupTraitString(traits, "department"),
Grade: grade,
Position: userGroupTraitString(traits, "position"),
JobTitle: userGroupTraitString(traits, "jobTitle"),
AffiliationType: userGroupTraitString(traits, "affiliationType"),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
Metadata: make(domain.JSONMap),
}
if tenantID := userGroupTraitString(traits, "tenant_id"); tenantID != "" {
user.TenantID = &tenantID
}
if relyingPartyID := userGroupTraitString(traits, "relying_party_id"); relyingPartyID != "" {
user.RelyingPartyID = &relyingPartyID
}
coreTraits := map[string]bool{
"email": true, "name": true, "phone_number": true,
"grade": true, "role": true, "companyCode": true, "company_code": true,
"companyCodes": true, "tenant_id": true, "department": true,
"position": true, "jobTitle": true, "affiliationType": true,
"relying_party_id": true, "custom_login_ids": true, "id": true,
}
for key, value := range traits {
if !coreTraits[key] {
user.Metadata[key] = value
}
}
return user
}
func userGroupTraitString(traits map[string]any, key string) string {
if traits == nil {
return ""
}
value, ok := traits[key]
if !ok || value == nil {
return ""
}
if str, ok := value.(string); ok {
return str
}
return fmt.Sprint(value)
}
func userGroupTraitStringArray(traits map[string]any, key string) []string {
if traits == nil {
return nil
}
switch value := traits[key].(type) {
case []string:
return value
case []any:
items := make([]string, 0, len(value))
for _, item := range value {
if str, ok := item.(string); ok && str != "" {
items = append(items, str)
}
}
return items
default:
return nil
}
}
func userGroupIdentityStatus(state string) string {
return domain.NormalizeUserStatus(state)
}
func (s *userGroupService) RemoveMember(ctx context.Context, groupID, userID string) error {
// Validate group exists
if _, err := s.repo.FindByID(ctx, groupID); err != nil {
return fmt.Errorf("user group not found: %w", err)
}
// Keto via Outbox: Delete relation
if s.outboxRepo != nil {
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: groupID,
Relation: "members",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionDelete,
})
}
return nil
}
func (s *userGroupService) ListRoles(ctx context.Context, groupID string) ([]domain.GroupRole, error) {
// Query: namespace=Tenant, subject=Tenant:groupID#members
subject := "Tenant:" + groupID + "#members"
tuples, err := s.ketoService.ListRelations(ctx, "Tenant", "", "", subject)
if err != nil {
slog.Error("Failed to fetch group roles from keto", "error", err, "group_id", groupID)
return nil, err
}
var roles []domain.GroupRole
tenantIDs := make([]string, 0, len(tuples))
for _, t := range tuples {
tenantIDs = append(tenantIDs, t.Object)
}
if len(tenantIDs) > 0 {
tenantList, err := s.tenantRepo.FindByIDs(ctx, tenantIDs)
if err != nil {
slog.Error("Failed to fetch tenant details for roles", "error", err)
}
tenantMap := make(map[string]string)
for _, t := range tenantList {
tenantMap[t.ID] = t.Name
}
for _, t := range tuples {
roles = append(roles, domain.GroupRole{
TenantID: t.Object,
TenantName: tenantMap[t.Object],
Relation: t.Relation,
})
}
}
return roles, nil
}
func (s *userGroupService) AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error {
// Validate group exists
if _, err := s.repo.FindByID(ctx, groupID); err != nil {
return fmt.Errorf("user group not found: %w", err)
}
// Keto via Outbox: Tenant:<tenantID>#<relation>@Tenant:<groupID>#members
if s.outboxRepo != nil {
subject := "Tenant:" + groupID + "#members"
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenantID,
Relation: relation,
Subject: subject,
Action: domain.KetoOutboxActionCreate,
})
}
return nil
}
func (s *userGroupService) RemoveRoleFromTenant(ctx context.Context, groupID, tenantID, relation string) error {
// Keto via Outbox: Delete relation
if s.outboxRepo != nil {
subject := "Tenant:" + groupID + "#members"
_ = s.outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenantID,
Relation: relation,
Subject: subject,
Action: domain.KetoOutboxActionDelete,
})
}
return nil
}

View File

@@ -0,0 +1,103 @@
package service
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)
func TestUserGroupService_Create_InvalidParentID(t *testing.T) {
mockRepo := new(MockUserGroupRepository)
mockTenantRepo := new(MockTenantRepository)
mockKeto := new(MockKetoServiceShared)
mockOutbox := new(MockKetoOutboxRepositoryShared)
svc := NewUserGroupService(mockRepo, nil, mockTenantRepo, mockKeto, mockOutbox, nil)
tenantID := "company-1"
invalidParentID := "invalid-uuid"
name := "Invalid Parent Group"
description := ""
unitType := "Team"
// Mock: TenantRepo returns record not found for invalidParentID
mockTenantRepo.On("FindByID", mock.Anything, invalidParentID).Return(nil, gorm.ErrRecordNotFound).Once()
// No Create calls should happen on any repo if parent is invalid
mockRepo.AssertNotCalled(t, "Create")
mockTenantRepo.AssertNotCalled(t, "Create")
mockOutbox.AssertNotCalled(t, "Create")
group, err := svc.Create(context.Background(), tenantID, &invalidParentID, name, description, unitType)
assert.Error(t, err)
assert.Contains(t, err.Error(), "parent tenant not found or invalid")
assert.Nil(t, group)
mockTenantRepo.AssertExpectations(t)
}
func TestUserGroupService_AddMember_GroupNotFound(t *testing.T) {
mockOutbox := new(MockKetoOutboxRepositoryShared)
mockUserGroupRepo := new(MockUserGroupRepository)
svc := NewUserGroupService(mockUserGroupRepo, nil, nil, nil, mockOutbox, nil)
groupID := "non-existent-group"
userID := "user-1"
// Mock: Group does not exist
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(nil, gorm.ErrRecordNotFound)
// No Outbox call should happen if group is not found
mockOutbox.AssertNotCalled(t, "Create")
err := svc.AddMember(context.Background(), groupID, userID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user group not found")
mockUserGroupRepo.AssertExpectations(t)
}
func TestUserGroupService_RemoveMember_GroupNotFound(t *testing.T) {
mockOutbox := new(MockKetoOutboxRepositoryShared)
mockUserGroupRepo := new(MockUserGroupRepository)
svc := NewUserGroupService(mockUserGroupRepo, nil, nil, nil, mockOutbox, nil)
groupID := "non-existent-group"
userID := "user-1"
// Mock: Group does not exist
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(nil, gorm.ErrRecordNotFound)
// No Outbox call should happen if group is not found
mockOutbox.AssertNotCalled(t, "Create")
err := svc.RemoveMember(context.Background(), groupID, userID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user group not found")
mockUserGroupRepo.AssertExpectations(t)
}
func TestUserGroupService_AssignRoleToTenant_GroupNotFound(t *testing.T) {
mockOutbox := new(MockKetoOutboxRepositoryShared)
mockUserGroupRepo := new(MockUserGroupRepository)
svc := NewUserGroupService(mockUserGroupRepo, nil, nil, nil, mockOutbox, nil)
groupID := "non-existent-group"
tenantID := "tenant-alpha"
relation := "manage"
// Mock: Group does not exist
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(nil, gorm.ErrRecordNotFound)
// No Outbox call should happen if group is not found
mockOutbox.AssertNotCalled(t, "Create")
err := svc.AssignRoleToTenant(context.Background(), groupID, tenantID, relation)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user group not found")
mockUserGroupRepo.AssertExpectations(t)
}

View File

@@ -0,0 +1,463 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)
// --- Mocks for Repositories ---
type MockUserGroupRepository struct {
mock.Mock
}
func (m *MockUserGroupRepository) Create(ctx context.Context, group *domain.UserGroup) error {
return m.Called(ctx, group).Error(0)
}
func (m *MockUserGroupRepository) Update(ctx context.Context, group *domain.UserGroup) error {
return m.Called(ctx, group).Error(0)
}
func (m *MockUserGroupRepository) Delete(ctx context.Context, id string) error {
return m.Called(ctx, id).Error(0)
}
func (m *MockUserGroupRepository) FindByID(ctx context.Context, id string) (*domain.UserGroup, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupRepository) ListByTenantID(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
args := m.Called(ctx, tenantID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]domain.UserGroup), args.Error(1)
}
type MockUserRepository struct {
mock.Mock
updatedUsers []domain.User
}
func (m *MockUserRepository) Create(ctx context.Context, user *domain.User) error { return nil }
func (m *MockUserRepository) Update(ctx context.Context, user *domain.User) error {
copied := *user
m.updatedUsers = append(m.updatedUsers, copied)
return nil
}
func (m *MockUserRepository) Delete(ctx context.Context, id string) error {
return m.Called(ctx, id).Error(0)
}
func (m *MockUserRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, nil
}
func (m *MockUserRepository) FindByID(ctx context.Context, id string) (*domain.User, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.User), args.Error(1)
}
func (m *MockUserRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
args := m.Called(ctx, ids)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *MockUserRepository) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
return nil, nil
}
func (m *MockUserRepository) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) {
return nil, 0, "", nil
}
func (m *MockUserRepository) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
args := m.Called(tenantID)
return int64(args.Int(0)), args.Error(1)
}
func (m *MockUserRepository) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
args := m.Called(ctx, tenantIDs)
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *MockUserRepository) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
args := m.Called(tenantIDs)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(map[string]int64), args.Error(1)
}
func (m *MockUserRepository) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
args := m.Called(ctx, codes)
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *MockUserRepository) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
args := m.Called(ctx, codes)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(map[string]int64), args.Error(1)
}
func (m *MockUserRepository) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
return nil
}
func (m *MockUserRepository) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
return nil, nil
}
func (m *MockUserRepository) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
return false, nil
}
func (m *MockUserRepository) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
return "", nil
}
func (m *MockUserRepository) DB() *gorm.DB {
return nil
}
type fakeUserGroupWorksmobileSyncer struct {
userUpserts []domain.User
}
func (f *fakeUserGroupWorksmobileSyncer) EnqueueTenantUpsertIfInScope(ctx context.Context, tenant domain.Tenant) error {
return nil
}
func (f *fakeUserGroupWorksmobileSyncer) EnqueueTenantDeleteIfInScope(ctx context.Context, tenant domain.Tenant) error {
return nil
}
func (f *fakeUserGroupWorksmobileSyncer) EnqueueUserUpsertIfInScope(ctx context.Context, user domain.User) error {
f.userUpserts = append(f.userUpserts, user)
return nil
}
func (f *fakeUserGroupWorksmobileSyncer) EnqueueUserDeleteIfInScope(ctx context.Context, user domain.User) error {
return nil
}
type MockKetoOutboxRepository struct {
mock.Mock
}
type MockTenantRepository struct {
mock.Mock
}
func (m *MockTenantRepository) Create(ctx context.Context, tenant *domain.Tenant) error {
return m.Called(ctx, tenant).Error(0)
}
func (m *MockTenantRepository) Update(ctx context.Context, tenant *domain.Tenant) error { return nil }
func (m *MockTenantRepository) FindByID(ctx context.Context, id string) (*domain.Tenant, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
args := m.Called(ctx, ids)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantRepository) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantRepository) FindByName(ctx context.Context, name string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantRepository) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantRepository) List(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
return nil, 0, nil
}
func (m *MockTenantRepository) ListByType(ctx context.Context, tenantType string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantRepository) AddDomain(ctx context.Context, tenantID string, domainName string, verified bool) error {
return nil
}
func (m *MockTenantRepository) DeleteBulk(ctx context.Context, ids []string) error {
return nil
}
func TestUserGroupService_Create(t *testing.T) {
mockRepo := new(MockUserGroupRepository)
mockTenantRepo := new(MockTenantRepository)
mockKeto := new(MockKetoServiceShared)
mockOutbox := new(MockKetoOutboxRepositoryShared)
svc := NewUserGroupService(mockRepo, nil, mockTenantRepo, mockKeto, mockOutbox, nil)
tenantID := "company-1"
parentID := "parent-group-id"
name := "Test Group"
description := "Group Description"
unitType := "Team"
// Mock Tenant FindByID for parent check
mockTenantRepo.On("FindByID", mock.Anything, parentID).Return(&domain.Tenant{ID: parentID}, nil)
// Mock Tenant creation (Polymorphic)
mockTenantRepo.On("Create", mock.Anything, mock.MatchedBy(func(ten *domain.Tenant) bool {
return ten.Type == domain.TenantTypeOrganization && ten.Name == name && *ten.ParentID == parentID
})).Return(nil)
// Mock UserGroup creation
mockRepo.On("Create", mock.Anything, mock.MatchedBy(func(g *domain.UserGroup) bool {
return g.Name == name && *g.ParentID == parentID && g.TenantID == tenantID
})).Return(nil)
// Mock Keto sync via Outbox
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Relation == "parents" && e.Subject == "Tenant:"+parentID
})).Return(nil)
group, err := svc.Create(context.Background(), tenantID, &parentID, name, description, unitType)
assert.NoError(t, err)
assert.NotNil(t, group)
mockTenantRepo.AssertExpectations(t)
mockRepo.AssertExpectations(t)
mockOutbox.AssertExpectations(t)
}
func TestUserGroupService_AddMember(t *testing.T) {
mockOutbox := new(MockKetoOutboxRepositoryShared)
mockUserGroupRepo := new(MockUserGroupRepository)
mockUserRepo := new(MockUserRepository)
mockTenantRepo := new(MockTenantRepository)
mockKratos := new(MockKratosAdminServiceShared)
svc := NewUserGroupService(mockUserGroupRepo, mockUserRepo, mockTenantRepo, nil, mockOutbox, mockKratos)
groupID := "group-1"
userID := "user-1"
tenantID := "tenant-1"
tenantSlug := "tenant-slug"
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID, TenantID: tenantID, Name: "Sales"}, nil)
mockUserRepo.On("FindByID", mock.Anything, userID).Return(&domain.User{ID: userID}, nil)
mockTenantRepo.On("FindByID", mock.Anything, tenantID).Return(&domain.Tenant{ID: tenantID, Slug: tenantSlug}, nil)
// Mock local user repo update (Ignored since Update is hardcoded to return nil without calling m.Called)
// mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
// return u.CompanyCode == tenantSlug && *u.TenantID == tenantID && u.Department == "Sales"
// })).Return(nil)
// First Outbox Create for Group
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == groupID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil).Once()
// Second Outbox Create for Tenant
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil).Once()
err := svc.AddMember(context.Background(), groupID, userID)
assert.NoError(t, err)
mockOutbox.AssertExpectations(t)
mockKratos.AssertExpectations(t)
mockKratos.AssertNotCalled(t, "GetIdentity", mock.Anything, userID)
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
// mockUserRepo.AssertExpectations(t)
}
func TestUserGroupService_AddMemberUpsertsLocalReadModelWhenMissing(t *testing.T) {
mockOutbox := new(MockKetoOutboxRepositoryShared)
mockUserGroupRepo := new(MockUserGroupRepository)
mockUserRepo := new(MockUserRepository)
mockTenantRepo := new(MockTenantRepository)
mockKratos := new(MockKratosAdminServiceShared)
svc := NewUserGroupService(mockUserGroupRepo, mockUserRepo, mockTenantRepo, nil, mockOutbox, mockKratos)
groupID := "group-1"
userID := "user-1"
tenantID := "tenant-1"
tenantSlug := "tenant-slug"
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID, TenantID: tenantID, Name: "Sales"}, nil)
mockUserRepo.On("FindByID", mock.Anything, userID).Return(nil, gorm.ErrRecordNotFound)
mockTenantRepo.On("FindByID", mock.Anything, tenantID).Return(&domain.Tenant{ID: tenantID, Slug: tenantSlug}, nil)
mockKratos.On("GetIdentity", mock.Anything, userID).Return(&KratosIdentity{
ID: userID,
Traits: map[string]any{
"email": "user@test.com",
"name": "User Test",
},
State: "active",
}, nil)
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == groupID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil).Once()
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil).Once()
err := svc.AddMember(context.Background(), groupID, userID)
assert.NoError(t, err)
assert.Len(t, mockUserRepo.updatedUsers, 1)
assert.Equal(t, userID, mockUserRepo.updatedUsers[0].ID)
assert.Empty(t, mockUserRepo.updatedUsers[0].CompanyCode)
assert.NotNil(t, mockUserRepo.updatedUsers[0].TenantID)
assert.Equal(t, tenantID, *mockUserRepo.updatedUsers[0].TenantID)
assert.Equal(t, "Sales", mockUserRepo.updatedUsers[0].Department)
mockOutbox.AssertExpectations(t)
mockKratos.AssertExpectations(t)
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
}
func TestUserGroupService_AddMemberEnqueuesWorksmobileUserSync(t *testing.T) {
mockOutbox := new(MockKetoOutboxRepositoryShared)
mockUserGroupRepo := new(MockUserGroupRepository)
mockUserRepo := new(MockUserRepository)
mockTenantRepo := new(MockTenantRepository)
mockKratos := new(MockKratosAdminServiceShared)
worksmobile := &fakeUserGroupWorksmobileSyncer{}
svc := NewUserGroupService(mockUserGroupRepo, mockUserRepo, mockTenantRepo, nil, mockOutbox, mockKratos)
svc.SetWorksmobileSyncer(worksmobile)
groupID := "group-1"
userID := "user-1"
tenantID := "tenant-1"
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID, TenantID: tenantID, Name: "Sales"}, nil)
mockUserRepo.On("FindByID", mock.Anything, userID).Return(&domain.User{
ID: userID,
Email: "user@test.com",
Name: "User Test",
Status: "active",
}, nil)
mockTenantRepo.On("FindByID", mock.Anything, tenantID).Return(&domain.Tenant{ID: tenantID, Slug: "tenant-slug"}, nil)
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == groupID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil).Once()
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == "members" && e.Subject == "User:"+userID
})).Return(nil).Once()
err := svc.AddMember(context.Background(), groupID, userID)
assert.NoError(t, err)
assert.Len(t, worksmobile.userUpserts, 1)
assert.Equal(t, userID, worksmobile.userUpserts[0].ID)
assert.NotNil(t, worksmobile.userUpserts[0].TenantID)
assert.Equal(t, tenantID, *worksmobile.userUpserts[0].TenantID)
assert.Equal(t, "Sales", worksmobile.userUpserts[0].Department)
mockOutbox.AssertExpectations(t)
mockKratos.AssertExpectations(t)
mockKratos.AssertNotCalled(t, "GetIdentity", mock.Anything, userID)
mockKratos.AssertNotCalled(t, "UpdateIdentity", mock.Anything, userID, mock.Anything, mock.Anything)
}
func TestUserGroupService_AssignRoleToTenant(t *testing.T) {
mockOutbox := new(MockKetoOutboxRepositoryShared)
mockUserGroupRepo := new(MockUserGroupRepository)
svc := NewUserGroupService(mockUserGroupRepo, nil, nil, nil, mockOutbox, nil)
groupID := "group-1"
tenantID := "tenant-alpha"
relation := "manage"
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID}, nil)
expectedSubject := "Tenant:" + groupID + "#members"
mockOutbox.On("Create", mock.Anything, mock.MatchedBy(func(e *domain.KetoOutbox) bool {
return e.Namespace == "Tenant" && e.Object == tenantID && e.Relation == relation && e.Subject == expectedSubject
})).Return(nil)
err := svc.AssignRoleToTenant(context.Background(), groupID, tenantID, relation)
assert.NoError(t, err)
mockOutbox.AssertExpectations(t)
}
func TestUserGroupService_ListRoles(t *testing.T) {
mockKeto := new(MockKetoServiceShared)
mockTenantRepo := new(MockTenantRepository)
mockUserGroupRepo := new(MockUserGroupRepository)
svc := NewUserGroupService(mockUserGroupRepo, nil, mockTenantRepo, mockKeto, nil, nil)
groupID := "group-1"
subject := "Tenant:" + groupID + "#members"
mockUserGroupRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID}, nil)
tuples := []RelationTuple{
{Object: "t1", Relation: "manage", SubjectID: subject},
{Object: "t2", Relation: "view", SubjectID: subject},
}
mockKeto.On("ListRelations", mock.Anything, "Tenant", "", "", subject).Return(tuples, nil)
tenants := []domain.Tenant{
{ID: "t1", Name: "Tenant One"},
{ID: "t2", Name: "Tenant Two"},
}
mockTenantRepo.On("FindByIDs", mock.Anything, []string{"t1", "t2"}).Return(tenants, nil)
roles, err := svc.ListRoles(context.Background(), groupID)
assert.NoError(t, err)
assert.Len(t, roles, 2)
}
func TestUserGroupService_Get_WithKratosFallback(t *testing.T) {
mockRepo := new(MockUserGroupRepository)
mockKeto := new(MockKetoServiceShared)
mockUserRepo := new(MockUserRepository)
mockKratos := new(MockKratosAdminServiceShared)
svc := NewUserGroupService(mockRepo, mockUserRepo, nil, mockKeto, nil, mockKratos)
groupID := "group-1"
mockRepo.On("FindByID", mock.Anything, groupID).Return(&domain.UserGroup{ID: groupID, Name: "Test"}, nil)
tuples := []RelationTuple{
{Object: groupID, Relation: "members", SubjectID: "User:u1"},
}
mockKeto.On("ListRelations", mock.Anything, "Tenant", groupID, "members", "").Return(tuples, nil)
mockUserRepo.On("FindByIDs", mock.Anything, []string{"u1"}).Return([]domain.User{}, nil)
mockKratos.On("GetIdentity", mock.Anything, "u1").Return(&KratosIdentity{
ID: "u1",
Traits: map[string]any{"name": "User One", "email": "user1@example.com"},
}, nil)
group, err := svc.Get(context.Background(), groupID)
assert.NoError(t, err)
assert.NotNil(t, group)
assert.Len(t, group.Members, 1)
assert.Equal(t, "User One", group.Members[0].Name)
}

View File

@@ -0,0 +1,153 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"fmt"
"strings"
"time"
)
type UserProjectionSyncService struct {
kratos KratosAdminService
repo repository.UserProjectionRepository
}
type UserProjectionReconciler interface {
Reconcile(ctx context.Context) (int, error)
}
func NewUserProjectionSyncService(kratos KratosAdminService, repo repository.UserProjectionRepository) *UserProjectionSyncService {
return &UserProjectionSyncService{
kratos: kratos,
repo: repo,
}
}
func (s *UserProjectionSyncService) Reconcile(ctx context.Context) (int, error) {
if s == nil || s.kratos == nil || s.repo == nil {
return 0, fmt.Errorf("user projection sync dependencies are not configured")
}
identities, err := s.kratos.ListIdentities(ctx)
if err != nil {
_ = s.repo.MarkFailed(ctx, err)
return 0, err
}
users := make([]domain.User, 0, len(identities))
for _, identity := range identities {
users = append(users, MapKratosIdentityToLocalUser(identity))
}
if err := s.repo.ReplaceAllFromKratos(ctx, users); err != nil {
_ = s.repo.MarkFailed(ctx, err)
return 0, err
}
return len(users), nil
}
func MapKratosIdentityToLocalUser(identity KratosIdentity) domain.User {
traits := identity.Traits
now := time.Now()
createdAt := identity.CreatedAt
if createdAt.IsZero() {
createdAt = now
}
updatedAt := identity.UpdatedAt
if updatedAt.IsZero() {
updatedAt = now
}
role, ok := domain.NormalizeRoleAlias(kratosProjectionTraitString(traits, "role"))
if !ok {
role, ok = domain.NormalizeRoleAlias(kratosProjectionTraitString(traits, "grade"))
if !ok {
role = domain.RoleUser
}
}
grade := kratosProjectionTraitString(traits, "grade")
if _, ok := domain.NormalizeRoleAlias(grade); ok {
grade = ""
}
user := domain.User{
ID: identity.ID,
Email: kratosProjectionTraitString(traits, "email"),
Name: kratosProjectionTraitString(traits, "name"),
Phone: domain.NormalizePhoneNumber(kratosProjectionTraitString(traits, "phone_number")),
Role: role,
Status: normalizeProjectionStatus(identity.State),
Department: kratosProjectionTraitString(traits, "department"),
Grade: grade,
Position: kratosProjectionTraitString(traits, "position"),
JobTitle: kratosProjectionTraitString(traits, "jobTitle"),
AffiliationType: kratosProjectionTraitString(traits, "affiliationType"),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
Metadata: make(domain.JSONMap),
}
if tenantID := kratosProjectionTraitString(traits, "tenant_id"); tenantID != "" {
user.TenantID = &tenantID
}
if relyingPartyID := kratosProjectionTraitString(traits, "relying_party_id"); relyingPartyID != "" {
user.RelyingPartyID = &relyingPartyID
}
coreTraits := map[string]bool{
"email": true, "name": true, "phone_number": true,
"grade": true, "role": true,
"companyCode": true, "company_code": true, "companyCodes": true,
"tenant_id": true, "department": true,
"position": true, "jobTitle": true, "affiliationType": true,
"relying_party_id": true, "custom_login_ids": true, "id": true,
}
for key, value := range traits {
if !coreTraits[key] {
user.Metadata[key] = value
}
}
return user
}
func kratosProjectionTraitString(traits map[string]any, key string) string {
if traits == nil {
return ""
}
value, ok := traits[key]
if !ok || value == nil {
return ""
}
if str, ok := value.(string); ok {
return str
}
return fmt.Sprint(value)
}
func kratosProjectionTraitStringArray(traits map[string]any, key string) []string {
if traits == nil {
return nil
}
switch value := traits[key].(type) {
case []string:
return value
case []any:
items := make([]string, 0, len(value))
for _, item := range value {
if str, ok := item.(string); ok && strings.TrimSpace(str) != "" {
items = append(items, str)
}
}
return items
default:
return nil
}
}
func normalizeProjectionStatus(state string) string {
normalized := domain.NormalizeUserStatus(state)
if normalized == "" {
return domain.UserStatusActive
}
return normalized
}

View File

@@ -0,0 +1,142 @@
package service
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeUserProjectionRepo struct {
replacedUsers []domain.User
failedErr error
replaceErr error
}
func (f *fakeUserProjectionRepo) IsReady(ctx context.Context) (bool, error) {
return false, nil
}
func (f *fakeUserProjectionRepo) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
return domain.UserProjectionStatus{}, nil
}
func (f *fakeUserProjectionRepo) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
return nil, nil
}
func (f *fakeUserProjectionRepo) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
return nil, nil
}
func (f *fakeUserProjectionRepo) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
f.replacedUsers = append([]domain.User(nil), users...)
return f.replaceErr
}
func (f *fakeUserProjectionRepo) MarkFailed(ctx context.Context, syncErr error) error {
f.failedErr = syncErr
return nil
}
func TestUserProjectionSyncService_ReconcileReplacesProjectionFromKratos(t *testing.T) {
ctx := context.Background()
kratos := new(MockKratosAdminServiceShared)
repo := &fakeUserProjectionRepo{}
svc := NewUserProjectionSyncService(kratos, repo)
tenantID := "00000000-0000-0000-0000-000000000001"
kratos.On("ListIdentities", ctx).Return([]KratosIdentity{
{
ID: "00000000-0000-0000-0000-000000000101",
Traits: map[string]any{
"email": "one@example.com",
"name": "One",
"phone_number": "+821012345678",
"companyCode": "saman",
"companyCodes": []any{"saman", "group-a"},
"tenant_id": tenantID,
"department": "DX",
"customAttr": "kept",
},
State: "active",
},
}, nil).Once()
count, err := svc.Reconcile(ctx)
require.NoError(t, err)
assert.Equal(t, 1, count)
require.Len(t, repo.replacedUsers, 1)
assert.Equal(t, "one@example.com", repo.replacedUsers[0].Email)
assert.Equal(t, "One", repo.replacedUsers[0].Name)
assert.Equal(t, "+821012345678", repo.replacedUsers[0].Phone)
assert.Empty(t, repo.replacedUsers[0].CompanyCode)
assert.Empty(t, repo.replacedUsers[0].CompanyCodes)
require.NotNil(t, repo.replacedUsers[0].TenantID)
assert.Equal(t, tenantID, *repo.replacedUsers[0].TenantID)
assert.Equal(t, "kept", repo.replacedUsers[0].Metadata["customAttr"])
assert.NoError(t, repo.failedErr)
kratos.AssertExpectations(t)
}
func TestUserProjectionSyncService_ReconcileDeduplicatesKoreanCountryCodePhone(t *testing.T) {
ctx := context.Background()
kratos := new(MockKratosAdminServiceShared)
repo := &fakeUserProjectionRepo{}
svc := NewUserProjectionSyncService(kratos, repo)
kratos.On("ListIdentities", ctx).Return([]KratosIdentity{
{
ID: "00000000-0000-0000-0000-000000000102",
Traits: map[string]any{
"email": "two@example.com",
"name": "Two",
"phone_number": "+82 +821091917771",
},
State: "active",
},
}, nil).Once()
count, err := svc.Reconcile(ctx)
require.NoError(t, err)
assert.Equal(t, 1, count)
require.Len(t, repo.replacedUsers, 1)
assert.Equal(t, "+821091917771", repo.replacedUsers[0].Phone)
kratos.AssertExpectations(t)
}
func TestUserProjectionSyncService_ReconcileMarksFailedWhenKratosFails(t *testing.T) {
ctx := context.Background()
kratos := new(MockKratosAdminServiceShared)
repo := &fakeUserProjectionRepo{}
svc := NewUserProjectionSyncService(kratos, repo)
expectedErr := errors.New("kratos down")
kratos.On("ListIdentities", ctx).Return([]KratosIdentity{}, expectedErr).Once()
count, err := svc.Reconcile(ctx)
assert.Equal(t, 0, count)
assert.ErrorIs(t, err, expectedErr)
assert.ErrorIs(t, repo.failedErr, expectedErr)
assert.Empty(t, repo.replacedUsers)
kratos.AssertExpectations(t)
}
func TestMapKratosIdentityToLocalUserPreservesArchivedStatus(t *testing.T) {
user := MapKratosIdentityToLocalUser(KratosIdentity{
ID: "00000000-0000-0000-0000-000000000201",
State: domain.UserStatusArchived,
Traits: map[string]any{
"email": "archived@example.com",
"name": "Archived User",
},
})
assert.Equal(t, domain.UserStatusArchived, user.Status)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,972 @@
package service
import (
"baron-sso-backend/internal/domain"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/mail"
"os"
"sort"
"strconv"
"strings"
)
const (
WorksmobileUserActionUpsert = "UPSERT"
WorksmobileUserActionSuspend = "SUSPEND"
)
type WorksmobileOrgUnitPayload struct {
DomainID int64 `json:"domainId"`
OrgUnitName string `json:"orgUnitName"`
Email string `json:"email,omitempty"`
OrgUnitExternalKey string `json:"orgUnitExternalKey"`
ParentOrgUnitID string `json:"parentOrgUnitId,omitempty"`
DisplayOrder int `json:"displayOrder"`
}
type WorksmobileUserPayload struct {
DomainID int64 `json:"domainId"`
Email string `json:"email"`
UserExternalKey string `json:"userExternalKey,omitempty"`
UserName WorksmobileUserName `json:"userName"`
CellPhone string `json:"cellPhone,omitempty"`
EmployeeNumber string `json:"employeeNumber,omitempty"`
PrivateEmail string `json:"privateEmail,omitempty"`
AliasEmails []string `json:"aliasEmails,omitempty"`
Locale string `json:"locale,omitempty"`
PasswordConfig WorksmobilePasswordConfig `json:"passwordConfig,omitempty"`
Task string `json:"task,omitempty"`
Organizations []WorksmobileUserOrganization `json:"organizations,omitempty"`
}
type WorksmobileUserName struct {
LastName string `json:"lastName,omitempty"`
}
type WorksmobilePasswordConfig struct {
PasswordCreationType string `json:"passwordCreationType"`
Password string `json:"password"`
ChangePasswordAtNextLogin *bool `json:"changePasswordAtNextLogin,omitempty"`
}
func (c WorksmobilePasswordConfig) IsZero() bool {
return strings.TrimSpace(c.PasswordCreationType) == "" &&
strings.TrimSpace(c.Password) == "" &&
c.ChangePasswordAtNextLogin == nil
}
func (p WorksmobileUserPayload) MarshalJSON() ([]byte, error) {
type payloadJSON struct {
DomainID int64 `json:"domainId"`
Email string `json:"email"`
UserExternalKey string `json:"userExternalKey,omitempty"`
UserName WorksmobileUserName `json:"userName"`
CellPhone string `json:"cellPhone,omitempty"`
EmployeeNumber string `json:"employeeNumber,omitempty"`
PrivateEmail string `json:"privateEmail,omitempty"`
AliasEmails []string `json:"aliasEmails,omitempty"`
Locale string `json:"locale,omitempty"`
PasswordConfig *WorksmobilePasswordConfig `json:"passwordConfig,omitempty"`
Task string `json:"task,omitempty"`
Organizations []WorksmobileUserOrganization `json:"organizations,omitempty"`
}
var passwordConfig *WorksmobilePasswordConfig
if !p.PasswordConfig.IsZero() {
passwordConfig = &p.PasswordConfig
}
return json.Marshal(payloadJSON{
DomainID: p.DomainID,
Email: p.Email,
UserExternalKey: p.UserExternalKey,
UserName: p.UserName,
CellPhone: p.CellPhone,
EmployeeNumber: p.EmployeeNumber,
PrivateEmail: p.PrivateEmail,
AliasEmails: p.AliasEmails,
Locale: p.Locale,
PasswordConfig: passwordConfig,
Task: p.Task,
Organizations: p.Organizations,
})
}
type WorksmobilePasswordResetPayload struct {
Email string `json:"email"`
PasswordConfig WorksmobilePasswordConfig `json:"passwordConfig"`
}
type WorksmobileUserOrganization struct {
DomainID int64 `json:"domainId,omitempty"`
Email string `json:"email,omitempty"`
Primary bool `json:"primary"`
OrgUnits []WorksmobileUserOrgUnit `json:"orgUnits"`
}
type WorksmobileUserOrgUnit struct {
OrgUnitID string `json:"orgUnitId"`
Primary bool `json:"primary"`
PositionID string `json:"positionId,omitempty"`
IsManager *bool `json:"isManager,omitempty"`
}
func BuildWorksmobileOrgUnitPayload(tenant domain.Tenant, rootConfig domain.JSONMap, displayOrder int) (WorksmobileOrgUnitPayload, error) {
return BuildWorksmobileOrgUnitPayloadForDomainTenant(tenant, tenant, rootConfig, displayOrder)
}
func BuildWorksmobileOrgUnitPayloadForDomainTenant(tenant domain.Tenant, domainTenant domain.Tenant, rootConfig domain.JSONMap, displayOrder int) (WorksmobileOrgUnitPayload, error) {
if err := ValidateWorksmobileExternalKey(tenant.ID); err != nil {
return WorksmobileOrgUnitPayload{}, err
}
if displayOrder < 1 {
displayOrder = 1
}
domainID, err := ResolveWorksmobileDomainIDFromTenant(domainTenant, rootConfig)
if err != nil {
return WorksmobileOrgUnitPayload{}, err
}
payload := WorksmobileOrgUnitPayload{
DomainID: domainID,
OrgUnitName: strings.TrimSpace(tenant.Name),
Email: buildWorksmobileOrgUnitEmail(tenant, domainTenant),
OrgUnitExternalKey: tenant.ID,
DisplayOrder: displayOrder,
}
if tenant.ParentID != nil && *tenant.ParentID != "" {
if err := ValidateWorksmobileExternalKey(*tenant.ParentID); err != nil {
return WorksmobileOrgUnitPayload{}, err
}
payload.ParentOrgUnitID = "externalKey:" + *tenant.ParentID
}
return payload, nil
}
func buildWorksmobileOrgUnitEmail(tenant domain.Tenant, domainTenant domain.Tenant) string {
slug := strings.ToLower(strings.TrimSpace(tenant.Slug))
if slug == "" {
return ""
}
if domainName := worksmobileTenantMailDomain(domainTenant); domainName != "" {
return slug + "@" + domainName
}
for _, candidate := range append([]domain.TenantDomain{}, domainTenant.Domains...) {
domainName := strings.ToLower(strings.TrimSpace(candidate.Domain))
if domainName != "" {
return slug + "@" + domainName
}
}
for _, candidate := range tenant.Domains {
domainName := strings.ToLower(strings.TrimSpace(candidate.Domain))
if domainName != "" {
return slug + "@" + domainName
}
}
return ""
}
func worksmobileTenantMailDomain(tenant domain.Tenant) string {
envKey := strings.TrimSuffix(worksmobileTenantDomainIDEnvKey(tenant), "_DOMAIN_ID")
if domainName := strings.ToLower(strings.TrimSpace(os.Getenv("WORKS_DEFAULT_DOMAIN_" + envKey))); domainName != "" {
return domainName
}
if domainName := strings.ToLower(strings.TrimSpace(os.Getenv(envKey + "_MAIL_DOMAIN"))); domainName != "" {
return domainName
}
switch envKey {
case "SAMAN":
return "samaneng.com"
case "HANMAC":
return "hanmaceng.co.kr"
case "GPDTDC":
return "baroncs.co.kr"
case "HALLA":
return "hallasanup.com"
case "BARONGROUP":
return "brsw.kr"
default:
return ""
}
}
func BuildWorksmobileUserPayload(user domain.User, tenant domain.Tenant, rootConfig domain.JSONMap) (WorksmobileUserPayload, error) {
return BuildWorksmobileUserPayloadForDomainTenant(user, tenant, tenant, rootConfig)
}
func BuildWorksmobileUserPayloadForDomainTenant(user domain.User, tenant domain.Tenant, _ domain.Tenant, rootConfig domain.JSONMap) (WorksmobileUserPayload, error) {
return BuildWorksmobileUserPayloadForDomainTenants(user, tenant, map[string]domain.Tenant{tenant.ID: tenant}, rootConfig)
}
func BuildWorksmobileUserPayloadForDomainTenants(user domain.User, tenant domain.Tenant, tenantByID map[string]domain.Tenant, rootConfig domain.JSONMap) (WorksmobileUserPayload, error) {
if err := ValidateWorksmobileExternalKey(user.ID); err != nil {
return WorksmobileUserPayload{}, err
}
if tenant.ID == "" {
return WorksmobileUserPayload{}, errors.New("tenant is required")
}
if tenantByID == nil {
tenantByID = map[string]domain.Tenant{}
}
tenantByID[tenant.ID] = tenant
domainID, err := ResolveWorksmobileAccountDomainIDFromEmail(user.Email, tenant, rootConfig)
if err != nil {
return WorksmobileUserPayload{}, err
}
employeeNumber := metadataEmployeeNumber(user.Metadata)
organizations, task, err := buildWorksmobileUserOrganizations(user, tenant, tenantByID, rootConfig)
if err != nil {
return WorksmobileUserPayload{}, err
}
if task == "" {
task = strings.TrimSpace(user.JobTitle)
}
payload := WorksmobileUserPayload{
DomainID: domainID,
Email: strings.TrimSpace(user.Email),
UserExternalKey: user.ID,
UserName: WorksmobileUserName{LastName: strings.TrimSpace(user.Name)},
CellPhone: domain.NormalizePhoneNumber(user.Phone),
EmployeeNumber: employeeNumber,
Locale: "ko_KR",
Task: task,
Organizations: organizations,
}
payload.AliasEmails = BuildWorksmobileAliasEmails(user, tenant)
return payload, nil
}
type worksmobileAppointment struct {
TenantID string
IsPrimary bool
IsManager bool
HasManager bool
JobTitle string
PositionID string
Source string
}
func buildWorksmobileUserOrganizations(user domain.User, tenant domain.Tenant, tenantByID map[string]domain.Tenant, rootConfig domain.JSONMap) ([]WorksmobileUserOrganization, string, error) {
appointments := worksmobileAppointmentsFromMetadata(user.Metadata)
if len(appointments) == 0 {
appointments = []worksmobileAppointment{{TenantID: tenant.ID, IsPrimary: true}}
} else if !worksmobileAppointmentsContainTenant(appointments, tenant.ID) && !worksmobileAppointmentsHavePrimary(appointments) {
appointments = append([]worksmobileAppointment{{
TenantID: tenant.ID,
IsPrimary: true,
JobTitle: strings.TrimSpace(user.JobTitle),
PositionID: metadataString(user.Metadata, "worksmobilePositionId", "positionId", "position_id"),
}}, appointments...)
}
accountDomainTenant := worksmobileAccountDomainTenantFromEmail(user.Email, tenant, tenantByID)
accountDomainEnvKey := worksmobileTenantDomainIDEnvKey(accountDomainTenant)
if !worksmobileAppointmentsContainDomain(appointments, tenantByID, accountDomainEnvKey) && accountDomainTenant.ID != "" {
appointments = append([]worksmobileAppointment{{
TenantID: accountDomainTenant.ID,
IsPrimary: true,
JobTitle: strings.TrimSpace(user.JobTitle),
PositionID: metadataString(user.Metadata, "worksmobilePositionId", "positionId", "position_id"),
}}, appointments...)
}
organizations := make([]WorksmobileUserOrganization, 0)
organizationIndexByDomainID := map[int64]int{}
seen := map[string]bool{}
task := ""
for _, appointment := range appointments {
if appointment.TenantID == "" || seen[appointment.TenantID] {
continue
}
appointmentTenant, ok := tenantByID[appointment.TenantID]
if !ok {
continue
}
if worksmobileShouldSkipEmailDomainRootAppointment(appointment, appointmentTenant, appointments, tenantByID) {
seen[appointment.TenantID] = true
continue
}
if isWorksmobileDomainRootTenant(appointmentTenant) {
if appointment.IsPrimary && strings.TrimSpace(appointment.JobTitle) != "" && task == "" {
task = strings.TrimSpace(appointment.JobTitle)
}
seen[appointment.TenantID] = true
continue
}
if err := ValidateWorksmobileExternalKey(appointmentTenant.ID); err != nil {
return nil, "", err
}
domainTenant := worksmobileDomainClassificationTenant(appointmentTenant, tenantByID)
domainID, err := ResolveWorksmobileDomainIDFromTenant(domainTenant, rootConfig)
if err != nil {
return nil, "", err
}
isAccountDomain := worksmobileTenantDomainIDEnvKey(domainTenant) == accountDomainEnvKey
isPrimaryOrganization := isAccountDomain && !worksmobileOrganizationsHavePrimary(organizations)
organizationIndex, organizationExists := organizationIndexByDomainID[domainID]
orgUnit := WorksmobileUserOrgUnit{
OrgUnitID: "externalKey:" + appointmentTenant.ID,
Primary: !organizationExists,
PositionID: appointment.PositionID,
}
if appointment.HasManager {
isManager := appointment.IsManager
orgUnit.IsManager = &isManager
}
if organizationExists {
if isPrimaryOrganization {
organizations[organizationIndex].Primary = true
organizations[organizationIndex].Email = worksmobileOrganizationEmail(user, domainTenant)
}
organizations[organizationIndex].OrgUnits = append(organizations[organizationIndex].OrgUnits, orgUnit)
} else {
organizationIndexByDomainID[domainID] = len(organizations)
organizations = append(organizations, WorksmobileUserOrganization{
DomainID: domainID,
Email: worksmobileOrganizationEmail(user, domainTenant),
Primary: isPrimaryOrganization,
OrgUnits: []WorksmobileUserOrgUnit{orgUnit},
})
}
if isPrimaryOrganization && strings.TrimSpace(appointment.JobTitle) != "" {
task = strings.TrimSpace(appointment.JobTitle)
}
seen[appointment.TenantID] = true
}
if len(organizations) == 0 {
return nil, task, nil
}
if !worksmobileOrganizationsHavePrimary(organizations) {
organizations[0].Primary = true
if len(organizations[0].OrgUnits) > 0 {
organizations[0].OrgUnits[0].Primary = true
}
}
sortWorksmobileOrganizations(organizations)
return organizations, task, nil
}
func worksmobileAppointmentsContainTenant(appointments []worksmobileAppointment, tenantID string) bool {
tenantID = strings.TrimSpace(tenantID)
if tenantID == "" {
return false
}
for _, appointment := range appointments {
if strings.TrimSpace(appointment.TenantID) == tenantID {
return true
}
}
return false
}
func worksmobileAppointmentsHavePrimary(appointments []worksmobileAppointment) bool {
for _, appointment := range appointments {
if appointment.IsPrimary {
return true
}
}
return false
}
func worksmobileAppointmentsContainDomain(appointments []worksmobileAppointment, tenantByID map[string]domain.Tenant, envKey string) bool {
for _, appointment := range appointments {
tenant, ok := tenantByID[appointment.TenantID]
if !ok {
continue
}
domainTenant := worksmobileDomainClassificationTenant(tenant, tenantByID)
if worksmobileTenantDomainIDEnvKey(domainTenant) == envKey {
return true
}
}
return false
}
func worksmobileShouldSkipEmailDomainRootAppointment(appointment worksmobileAppointment, tenant domain.Tenant, appointments []worksmobileAppointment, tenantByID map[string]domain.Tenant) bool {
if strings.TrimSpace(appointment.Source) != "email_domain" || !isWorksmobileDomainRootTenant(tenant) {
return false
}
envKey := worksmobileTenantDomainIDEnvKey(tenant)
for _, candidate := range appointments {
if strings.TrimSpace(candidate.TenantID) == "" || strings.TrimSpace(candidate.TenantID) == tenant.ID {
continue
}
candidateTenant, ok := tenantByID[candidate.TenantID]
if !ok || isWorksmobileDomainRootTenant(candidateTenant) {
continue
}
if worksmobileTenantDomainIDEnvKey(worksmobileDomainClassificationTenant(candidateTenant, tenantByID)) == envKey {
return true
}
}
return false
}
func worksmobileOrganizationsHavePrimary(organizations []WorksmobileUserOrganization) bool {
for _, organization := range organizations {
if organization.Primary {
return true
}
}
return false
}
func worksmobileAppointmentsFromMetadata(metadata domain.JSONMap) []worksmobileAppointment {
rawAppointments, ok := metadata["additionalAppointments"].([]any)
if !ok {
return nil
}
appointments := make([]worksmobileAppointment, 0, len(rawAppointments))
for _, raw := range rawAppointments {
item, ok := raw.(map[string]any)
if !ok {
continue
}
appointment := worksmobileAppointment{
TenantID: metadataString(domain.JSONMap(item), "tenantId", "tenant_id"),
IsPrimary: metadataBool(domain.JSONMap(item), "isPrimary", "primary"),
JobTitle: metadataString(domain.JSONMap(item), "jobTitle", "job_title", "task"),
PositionID: metadataString(domain.JSONMap(item), "worksmobilePositionId", "positionId", "position_id"),
Source: metadataString(domain.JSONMap(item), "assignmentSource", "source"),
}
if isManager, ok := metadataOptionalBool(domain.JSONMap(item), "isManager", "lead", "isLead"); ok {
appointment.IsManager = isManager
appointment.HasManager = true
}
appointments = append(appointments, appointment)
}
return appointments
}
func sortWorksmobileOrganizations(organizations []WorksmobileUserOrganization) {
sort.SliceStable(organizations, func(i, j int) bool {
if organizations[i].Primary != organizations[j].Primary {
return organizations[i].Primary
}
left := ""
right := ""
if len(organizations[i].OrgUnits) > 0 {
left = organizations[i].OrgUnits[0].OrgUnitID
}
if len(organizations[j].OrgUnits) > 0 {
right = organizations[j].OrgUnits[0].OrgUnitID
}
return left < right
})
}
func BuildWorksmobileAliasEmails(user domain.User, tenant domain.Tenant) []string {
candidates := make([]string, 0)
for _, key := range []string{
"aliasEmails",
"alias_emails",
"worksmobileAliasEmails",
"sub_email",
"secondary_email",
"secondary_emails",
"additional_email",
"additional_emails",
"naverworks_sub_email",
} {
candidates = append(candidates, metadataStringList(user.Metadata, key)...)
}
employeeNumber := metadataEmployeeNumber(user.Metadata)
if isHanmacWorksmobileTenant(tenant) && employeeNumber != "" {
candidates = append(candidates, employeeNumber+"@hanmaceng.co.kr")
}
return normalizeWorksmobileAliasEmails(user.Email, candidates)
}
func normalizeWorksmobileAliasEmails(primaryEmail string, candidates []string) []string {
result := make([]string, 0, len(candidates))
seen := map[string]bool{}
primary := strings.ToLower(strings.TrimSpace(primaryEmail))
for _, candidate := range candidates {
normalized := strings.ToLower(strings.TrimSpace(candidate))
if normalized == "" || normalized == primary || seen[normalized] {
continue
}
if _, err := mail.ParseAddress(normalized); err != nil {
continue
}
seen[normalized] = true
result = append(result, normalized)
}
return result
}
func ValidateWorksmobileAliasEmails(primaryEmail string, aliasEmails []string, existingEmails map[string]string) error {
seen := map[string]string{strings.ToLower(strings.TrimSpace(primaryEmail)): primaryEmail}
for _, aliasEmail := range aliasEmails {
normalized := strings.ToLower(strings.TrimSpace(aliasEmail))
if _, err := mail.ParseAddress(normalized); err != nil {
return err
}
if previous, ok := seen[normalized]; ok {
return fmt.Errorf("worksmobile alias email duplicates: %s and %s", previous, aliasEmail)
}
if owner, ok := existingEmails[normalized]; ok {
return fmt.Errorf("worksmobile alias email %s는 이미 사용 중입니다: %s", normalized, owner)
}
seen[normalized] = aliasEmail
}
return nil
}
func GenerateWorksmobileInitialPassword() string {
digits := "0123456789"
letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
symbols := "!@#$%"
all := digits + letters + symbols
password := []byte{
randomChar(digits),
randomChar(letters),
randomChar(symbols),
}
for len(password) < 16 {
password = append(password, randomChar(all))
}
shuffleBytes(password)
return string(password)
}
func randomChar(chars string) byte {
if chars == "" {
return 'x'
}
index, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
if err != nil {
return chars[0]
}
return chars[index.Int64()]
}
func shuffleBytes(values []byte) {
for i := len(values) - 1; i > 0; i-- {
j, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
if err != nil {
continue
}
values[i], values[j.Int64()] = values[j.Int64()], values[i]
}
}
func WorksmobileUserStatusAction(status string) string {
normalized := domain.NormalizeUserStatus(status)
if domain.IsWorksDeprovisionUserStatus(normalized) {
return domain.WorksmobileActionDelete
}
switch normalized {
case domain.UserStatusSuspended:
return WorksmobileUserActionSuspend
default:
return WorksmobileUserActionUpsert
}
}
func ValidateWorksmobileExternalKey(value string) error {
value = strings.TrimSpace(value)
if value == "" {
return errors.New("external key is required")
}
if strings.ContainsAny(value, `%\#/?`) {
return fmt.Errorf("external key contains unsupported character: %s", value)
}
return nil
}
func ResolveWorksmobileDomainIDFromTenant(tenant domain.Tenant, _ domain.JSONMap) (int64, error) {
envKey := worksmobileTenantDomainIDEnvKey(tenant)
if domainID, ok := worksmobileDomainIDFromEnv(envKey); ok {
return domainID, nil
}
return 0, fmt.Errorf("worksmobile domain id env is missing for tenant: %s", envKey)
}
func ResolveWorksmobileAccountDomainIDFromEmail(email string, fallbackTenant domain.Tenant, rootConfig domain.JSONMap) (int64, error) {
switch worksmobileEmailDomainName(email) {
case "samaneng.com":
if domainID, ok := worksmobileDomainIDFromEnv("SAMAN_DOMAIN_ID"); ok {
return domainID, nil
}
case "hanmaceng.co.kr":
if domainID, ok := worksmobileDomainIDFromEnv("HANMAC_DOMAIN_ID"); ok {
return domainID, nil
}
case "baroncs.co.kr":
if domainID, ok := worksmobileDomainIDFromEnv("GPDTDC_DOMAIN_ID"); ok {
return domainID, nil
}
case "hallasanup.com":
if domainID, ok := worksmobileDomainIDFromEnv("HALLA_DOMAIN_ID"); ok {
return domainID, nil
}
case "brsw.kr":
if domainID, ok := worksmobileDomainIDFromEnv("BARONGROUP_DOMAIN_ID"); ok {
return domainID, nil
}
}
return ResolveWorksmobileDomainIDFromTenant(fallbackTenant, rootConfig)
}
func worksmobileAccountDomainTenantFromEmail(email string, fallbackTenant domain.Tenant, tenantByID map[string]domain.Tenant) domain.Tenant {
envKey := worksmobileDomainIDEnvKeyFromEmail(email)
for _, tenant := range tenantByID {
if isWorksmobileDomainRootTenant(tenant) && worksmobileTenantDomainIDEnvKey(tenant) == envKey {
return tenant
}
}
for _, tenant := range tenantByID {
if worksmobileTenantDomainIDEnvKey(tenant) == envKey {
return worksmobileDomainClassificationTenant(tenant, tenantByID)
}
}
return worksmobileDomainClassificationTenant(fallbackTenant, tenantByID)
}
func worksmobileDomainIDEnvKeyFromEmail(email string) string {
switch worksmobileEmailDomainName(email) {
case "samaneng.com":
return "SAMAN_DOMAIN_ID"
case "hanmaceng.co.kr":
return "HANMAC_DOMAIN_ID"
case "baroncs.co.kr":
return "GPDTDC_DOMAIN_ID"
case "hallasanup.com":
return "HALLA_DOMAIN_ID"
case "brsw.kr":
return "BARONGROUP_DOMAIN_ID"
default:
return worksmobileTenantDomainIDEnvKey(domain.Tenant{})
}
}
func worksmobileEmailDomainName(email string) string {
address, err := mail.ParseAddress(strings.TrimSpace(email))
if err != nil {
return ""
}
parts := strings.Split(address.Address, "@")
if len(parts) != 2 {
return ""
}
return strings.ToLower(strings.TrimSpace(parts[1]))
}
func worksmobileOrganizationEmail(user domain.User, domainTenant domain.Tenant) string {
domainName := worksmobileTenantMailDomain(domainTenant)
if domainName == "" {
return ""
}
primaryEmail := strings.ToLower(strings.TrimSpace(user.Email))
if worksmobileEmailDomainName(primaryEmail) == domainName {
return primaryEmail
}
for _, alias := range BuildWorksmobileAliasEmails(user, domainTenant) {
if worksmobileEmailDomainName(alias) == domainName {
return alias
}
}
localPart, err := domain.ExtractNormalizedEmailLocalPart(primaryEmail)
if err != nil || localPart == "" {
return ""
}
return localPart + "@" + domainName
}
func worksmobileTenantDomainIDEnvKey(tenant domain.Tenant) string {
if tenantHasDomain(tenant, "samaneng.com") || tenantMatchesAny(tenant, "saman", "삼안") {
return "SAMAN_DOMAIN_ID"
}
if isHanmacWorksmobileTenant(tenant) {
return "HANMAC_DOMAIN_ID"
}
if tenantMatchesAny(tenant, "gpdtdc", "총괄", "기술개발센터", "기술개발") {
return "GPDTDC_DOMAIN_ID"
}
if isHallaWorksmobileTenant(tenant) {
return "HALLA_DOMAIN_ID"
}
return "BARONGROUP_DOMAIN_ID"
}
func worksmobileDomainIDFromEnv(key string) (int64, bool) {
if key == "" {
return 0, false
}
id, ok := parseDomainID(os.Getenv(key))
return id, ok
}
type worksmobileDomainEnvMapping struct {
Key string
Label string
}
func worksmobileDomainEnvMappings() []worksmobileDomainEnvMapping {
return []worksmobileDomainEnvMapping{
{Key: "SAMAN_DOMAIN_ID", Label: "삼안"},
{Key: "HANMAC_DOMAIN_ID", Label: "한맥기술"},
{Key: "GPDTDC_DOMAIN_ID", Label: "총괄기획&기술개발센터"},
{Key: "HALLA_DOMAIN_ID", Label: "한라산업개발"},
{Key: "BARONGROUP_DOMAIN_ID", Label: "바론그룹"},
}
}
func WorksmobileDomainIDsFromEnv() []int64 {
mappings := worksmobileDomainEnvMappings()
result := make([]int64, 0, len(mappings))
seen := map[int64]bool{}
for _, mapping := range mappings {
if id, ok := worksmobileDomainIDFromEnv(mapping.Key); ok && !seen[id] {
seen[id] = true
result = append(result, id)
}
}
return result
}
func WorksmobileDomainLabelForID(domainID int64) string {
for _, mapping := range worksmobileDomainEnvMappings() {
if id, ok := worksmobileDomainIDFromEnv(mapping.Key); ok && id == domainID {
return mapping.Label
}
}
return ""
}
func isHanmacWorksmobileTenant(tenant domain.Tenant) bool {
return tenantHasDomain(tenant, "hanmaceng.co.kr") || tenantMatchesAny(tenant, "hanmac", "한맥")
}
func isHallaWorksmobileTenant(tenant domain.Tenant) bool {
return tenantHasDomain(tenant, "hallasanup.com") || tenantMatchesAny(tenant, "halla", "hanlla", "한라산업개발")
}
func tenantHasDomain(tenant domain.Tenant, domainName string) bool {
domainName = strings.ToLower(strings.TrimSpace(domainName))
for _, d := range tenant.Domains {
if strings.EqualFold(strings.TrimSpace(d.Domain), domainName) {
return true
}
}
return false
}
func tenantMatchesAny(tenant domain.Tenant, needles ...string) bool {
haystack := strings.ToLower(strings.TrimSpace(tenant.Slug + " " + tenant.Name))
for _, needle := range needles {
if strings.Contains(haystack, strings.ToLower(strings.TrimSpace(needle))) {
return true
}
}
return false
}
func WorksmobileEnabled(rootConfig domain.JSONMap) bool {
rawWorksmobile, ok := rootConfig["worksmobile"].(map[string]any)
if !ok {
if raw, ok := rootConfig["worksmobile"].(domain.JSONMap); ok {
rawWorksmobile = map[string]any(raw)
} else {
return false
}
}
enabled, _ := rawWorksmobile["enabled"].(bool)
return enabled
}
func WorksmobileDomainMappings(rootConfig domain.JSONMap) map[string]int64 {
result := map[string]int64{}
rawWorksmobile, ok := rootConfig["worksmobile"].(map[string]any)
if !ok {
if raw, ok := rootConfig["worksmobile"].(domain.JSONMap); ok {
rawWorksmobile = map[string]any(raw)
} else {
return result
}
}
rawMappings, ok := rawWorksmobile["domainMappings"].(map[string]any)
if !ok {
if raw, ok := rawWorksmobile["domainMappings"].(domain.JSONMap); ok {
rawMappings = map[string]any(raw)
} else {
return result
}
}
for key, raw := range rawMappings {
if id, ok := parseDomainID(raw); ok {
result[strings.ToLower(strings.TrimSpace(key))] = id
}
}
return result
}
func parseDomainID(raw any) (int64, bool) {
switch value := raw.(type) {
case int:
return int64(value), value > 0
case int64:
return value, value > 0
case float64:
id := int64(value)
return id, id > 0
case string:
id, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
return id, err == nil && id > 0
default:
return 0, false
}
}
func metadataString(metadata domain.JSONMap, keys ...string) string {
for _, key := range keys {
if value, ok := metadata[key]; ok {
switch v := value.(type) {
case string:
return strings.TrimSpace(v)
default:
return strings.TrimSpace(fmt.Sprint(v))
}
}
}
return ""
}
func metadataEmployeeNumber(metadata domain.JSONMap) string {
for _, key := range []string{"employee_id", "employeeNumber", "employee_number"} {
value, ok := metadata[key]
if !ok {
continue
}
if normalized := normalizeMetadataEmployeeNumber(value); normalized != "" {
return normalized
}
}
return ""
}
func normalizeMetadataEmployeeNumber(value any) string {
switch v := value.(type) {
case string:
return strings.TrimSpace(v)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return strings.TrimSpace(fmt.Sprint(v))
case map[string]any:
return normalizeMetadataCharacterMap(v)
case domain.JSONMap:
return normalizeMetadataCharacterMap(map[string]any(v))
case map[string]string:
converted := make(map[string]any, len(v))
for key, value := range v {
converted[key] = value
}
return normalizeMetadataCharacterMap(converted)
default:
return ""
}
}
func normalizeMetadataCharacterMap(value map[string]any) string {
type characterEntry struct {
index int
value string
}
entries := make([]characterEntry, 0, len(value))
for key, raw := range value {
index, err := strconv.Atoi(key)
if err != nil {
return ""
}
part, ok := raw.(string)
if !ok || part == "" {
return ""
}
entries = append(entries, characterEntry{index: index, value: part})
}
if len(entries) == 0 {
return ""
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].index < entries[j].index
})
var builder strings.Builder
for _, entry := range entries {
builder.WriteString(entry.value)
}
return strings.TrimSpace(builder.String())
}
func metadataBool(metadata domain.JSONMap, keys ...string) bool {
value, _ := metadataOptionalBool(metadata, keys...)
return value
}
func metadataOptionalBool(metadata domain.JSONMap, keys ...string) (bool, bool) {
for _, key := range keys {
value, ok := metadata[key]
if !ok {
continue
}
switch v := value.(type) {
case bool:
return v, true
case string:
normalized := strings.ToLower(strings.TrimSpace(v))
if normalized == "true" || normalized == "1" || normalized == "yes" {
return true, true
}
if normalized == "false" || normalized == "0" || normalized == "no" {
return false, true
}
case int:
return v != 0, true
case float64:
return v != 0, true
}
}
return false, false
}
func metadataStringList(metadata domain.JSONMap, keys ...string) []string {
for _, key := range keys {
value, ok := metadata[key]
if !ok {
continue
}
switch v := value.(type) {
case []string:
return splitWorksmobileAliasValues(v)
case []any:
values := make([]string, 0, len(v))
for _, item := range v {
values = append(values, strings.TrimSpace(fmt.Sprint(item)))
}
return splitWorksmobileAliasValues(values)
case string:
return splitWorksmobileAliasValues([]string{v})
default:
return splitWorksmobileAliasValues([]string{fmt.Sprint(v)})
}
}
return nil
}
func splitWorksmobileAliasValues(values []string) []string {
result := make([]string, 0, len(values))
for _, value := range values {
fields := strings.FieldsFunc(value, func(r rune) bool {
return r == ',' || r == ';' || r == '\n' || r == '\r' || r == '\t'
})
for _, field := range fields {
if trimmed := strings.TrimSpace(field); trimmed != "" {
result = append(result, trimmed)
}
}
}
return result
}

View File

@@ -0,0 +1,856 @@
package service
import (
"baron-sso-backend/internal/domain"
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildWorksmobileOrgUnitPayloadUsesTenantExternalKeyAndEnvDomainClassification(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
parentID := "11111111-1111-1111-1111-111111111111"
tenant := domain.Tenant{
ID: "22222222-2222-2222-2222-222222222222",
Slug: "tech-dev-center",
Name: "Saman Engineering",
ParentID: &parentID,
Domains: []domain.TenantDomain{
{Domain: "samaneng.com"},
},
}
rootConfig := domain.JSONMap{
"worksmobile": map[string]any{
"domainMappings": map[string]any{
"samaneng.com": float64(9999),
},
},
}
payload, err := BuildWorksmobileOrgUnitPayload(tenant, rootConfig, 7)
require.NoError(t, err)
require.Equal(t, int64(1001), payload.DomainID)
require.Equal(t, "Saman Engineering", payload.OrgUnitName)
require.Equal(t, "tech-dev-center@samaneng.com", payload.Email)
require.Equal(t, tenant.ID, payload.OrgUnitExternalKey)
require.Equal(t, "externalKey:"+parentID, payload.ParentOrgUnitID)
require.Equal(t, 7, payload.DisplayOrder)
}
func TestBuildWorksmobileOrgUnitPayloadUsesWorksmobileMailDomainForBarongroup(t *testing.T) {
t.Setenv("BARONGROUP_DOMAIN_ID", "1004")
tenant := domain.Tenant{
ID: "11111111-1111-1111-1111-111111111111",
Slug: "jangheon",
Name: "(주)장헌",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "jangheon.com"}},
}
payload, err := BuildWorksmobileOrgUnitPayloadForDomainTenant(tenant, tenant, nil, 1)
require.NoError(t, err)
require.Equal(t, int64(1004), payload.DomainID)
require.Equal(t, "jangheon@brsw.kr", payload.Email)
}
func TestBuildWorksmobileOrgUnitPayloadDefaultsDisplayOrderToOne(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenant := domain.Tenant{
ID: "11111111-1111-1111-1111-111111111111",
Slug: "tech-dev-center",
Name: "기술개발센터",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileOrgUnitPayload(tenant, nil, 0)
require.NoError(t, err)
require.Equal(t, 1, payload.DisplayOrder)
}
func TestNormalizeRootChildWorksmobileOrgUnitParentClearsCrossDomainParent(t *testing.T) {
rootID := "038326b6-954a-48a7-a85f-efd83f62b82a"
payload := WorksmobileOrgUnitPayload{ParentOrgUnitID: "externalKey:" + rootID}
tenant := domain.Tenant{ParentID: &rootID}
normalized := normalizeWorksmobileOrgUnitParent(payload, tenant, nil, rootID)
require.Empty(t, normalized.ParentOrgUnitID)
}
func TestBuildWorksmobileUserPayloadMapsBaronUserAndPrimaryTenant(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
rootTenantID := "11111111-1111-1111-1111-111111111111"
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "john1@samaneng.com",
Name: "John Doe",
Phone: "+19144812222",
Position: "Manager",
JobTitle: "Sales management",
TenantID: &tenantID,
Metadata: domain.JSONMap{
"employee_id": "AB001",
},
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "sales",
Name: "Sales",
Type: domain.TenantTypeOrganization,
ParentID: &rootTenantID,
}
rootTenant := domain.Tenant{
ID: rootTenantID,
Slug: "saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
rootConfig := domain.JSONMap{
"worksmobile": map[string]any{
"domainMappings": map[string]any{
"samaneng.com": int64(9999),
},
},
}
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
user,
tenant,
map[string]domain.Tenant{
rootTenantID: rootTenant,
tenantID: tenant,
},
rootConfig,
)
require.NoError(t, err)
require.Equal(t, int64(1001), payload.DomainID)
require.Equal(t, "john1@samaneng.com", payload.Email)
require.Equal(t, user.ID, payload.UserExternalKey)
require.Equal(t, "John Doe", payload.UserName.LastName)
require.Equal(t, "+19144812222", payload.CellPhone)
require.Equal(t, "AB001", payload.EmployeeNumber)
require.Equal(t, "Sales management", payload.Task)
require.Empty(t, payload.PrivateEmail)
require.Empty(t, payload.AliasEmails)
require.Equal(t, "ko_KR", payload.Locale)
require.Empty(t, payload.PasswordConfig.PasswordCreationType)
require.Empty(t, payload.PasswordConfig.Password)
require.Len(t, payload.Organizations, 1)
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
require.True(t, payload.Organizations[0].Primary)
require.Equal(t, "externalKey:"+tenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
}
func TestBuildWorksmobileUserPayloadDeduplicatesKoreanCountryCodeInCellPhone(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "john1@samaneng.com",
Name: "John Doe",
Phone: "+82 +821091917771",
TenantID: &tenantID,
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "Saman",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, "+821091917771", payload.CellPhone)
}
func TestWorksmobileUserPayloadJSONOmitsEmptyPasswordConfig(t *testing.T) {
data, err := json.Marshal(WorksmobileUserPayload{
DomainID: 1001,
Email: "target@samaneng.com",
UserExternalKey: "user-1",
UserName: WorksmobileUserName{LastName: "Target"},
})
require.NoError(t, err)
require.NotContains(t, string(data), "passwordConfig")
}
func TestBuildWorksmobileUserPayloadOmitsOrganizationsForSamanRootTenant(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "root-user@samaneng.com",
Name: "Root User",
JobTitle: "Advisor",
TenantID: &tenantID,
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, int64(1001), payload.DomainID)
require.Equal(t, "root-user@samaneng.com", payload.Email)
require.Equal(t, "Advisor", payload.Task)
require.Empty(t, payload.Organizations)
}
func TestBuildWorksmobileUserPayloadNormalizesLegacyCharacterMapEmployeeID(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "john1@samaneng.com",
Name: "John Doe",
TenantID: &tenantID,
Metadata: domain.JSONMap{
"employee_id": map[string]any{
"0": "j",
"1": "o",
"2": "h",
"3": "n",
},
},
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "Saman",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, "john", payload.EmployeeNumber)
}
func TestBuildWorksmobileUserPayloadMapsAdditionalAppointmentsToOrgUnits(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
t.Setenv("HANMAC_DOMAIN_ID", "1002")
samanRootID := "11111111-1111-1111-1111-111111111111"
hanmacRootID := "22222222-2222-2222-2222-222222222222"
primaryTenantID := "33333333-3333-3333-3333-333333333333"
secondaryTenantID := "55555555-5555-5555-5555-555555555555"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "john1@samaneng.com",
Name: "John Doe",
Phone: "+19144812222",
TenantID: &primaryTenantID,
Metadata: domain.JSONMap{
"additionalAppointments": []any{
map[string]any{
"tenantId": secondaryTenantID,
"isPrimary": false,
"isManager": true,
"jobTitle": "PM",
"position": "팀장",
},
map[string]any{
"tenantId": primaryTenantID,
"isPrimary": true,
"isOwner": true,
"jobTitle": "Engineering",
"position": "책임",
},
},
},
}
samanRoot := domain.Tenant{
ID: samanRootID,
Slug: "saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
hanmacRoot := domain.Tenant{
ID: hanmacRootID,
Slug: "hanmac",
Name: "한맥기술",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "hanmaceng.co.kr"}},
}
primaryTenant := domain.Tenant{
ID: primaryTenantID,
Slug: "saman-sales",
Name: "Saman Sales",
Type: domain.TenantTypeOrganization,
ParentID: &samanRootID,
}
secondaryTenant := domain.Tenant{
ID: secondaryTenantID,
Slug: "hanmac-sales",
Name: "Hanmac Sales",
Type: domain.TenantTypeOrganization,
ParentID: &hanmacRootID,
}
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
user,
primaryTenant,
map[string]domain.Tenant{
samanRootID: samanRoot,
hanmacRootID: hanmacRoot,
primaryTenantID: primaryTenant,
secondaryTenantID: secondaryTenant,
},
nil,
)
require.NoError(t, err)
require.Equal(t, "Engineering", payload.Task)
require.Len(t, payload.Organizations, 2)
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
require.True(t, payload.Organizations[0].Primary)
require.Equal(t, "externalKey:"+primaryTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
require.Nil(t, payload.Organizations[0].OrgUnits[0].IsManager)
require.Equal(t, int64(1002), payload.Organizations[1].DomainID)
require.False(t, payload.Organizations[1].Primary)
require.Equal(t, "externalKey:"+secondaryTenantID, payload.Organizations[1].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[1].OrgUnits[0].Primary)
require.NotNil(t, payload.Organizations[1].OrgUnits[0].IsManager)
require.True(t, *payload.Organizations[1].OrgUnits[0].IsManager)
}
func TestBuildWorksmobileUserPayloadKeepsPrimaryTenantWhenEmailDomainAppointmentExists(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
rootTenantID := "9caf62e1-297d-4e8f-870b-61780998bbeb"
primaryTenantID := "1edc196d-020c-4519-9ec4-3d23b99076e6"
user := domain.User{
ID: "64231465-d5c0-4085-b4a2-603b90834f86",
Email: "evenlee@samaneng.com",
Name: "이용운",
JobTitle: "부사장",
TenantID: &primaryTenantID,
Metadata: domain.JSONMap{
"additionalAppointments": []any{
map[string]any{
"tenantId": rootTenantID,
"tenantSlug": "saman",
"tenantName": "삼안",
"assignmentSource": "email_domain",
"sourceDomain": "samaneng.com",
},
},
},
}
rootTenant := domain.Tenant{
ID: rootTenantID,
Slug: "saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
primaryTenant := domain.Tenant{
ID: primaryTenantID,
Slug: "asset-management",
Name: "자산관리",
Type: domain.TenantTypeOrganization,
ParentID: &rootTenantID,
}
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
user,
primaryTenant,
map[string]domain.Tenant{
rootTenantID: rootTenant,
primaryTenantID: primaryTenant,
},
nil,
)
require.NoError(t, err)
require.Len(t, payload.Organizations, 1)
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
require.True(t, payload.Organizations[0].Primary)
require.Len(t, payload.Organizations[0].OrgUnits, 1)
require.Equal(t, "externalKey:"+primaryTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
}
func TestBuildWorksmobileUserPayloadKeepsFirstAffiliationPrimaryWhenBaronRepresentativeIsGPDTDC(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
samanRootID := "11111111-1111-1111-1111-111111111111"
gpdtdcID := "5530ca6e-c5e6-4bf0-84d6-76c6a8fb70ee"
firstTenantID := "33333333-3333-3333-3333-333333333333"
secondTenantID := "55555555-5555-5555-5555-555555555555"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "gpdtdc-dual@samaneng.com",
Name: "GPDTDC Dual User",
TenantID: &gpdtdcID,
Metadata: domain.JSONMap{
"additionalAppointments": []any{
map[string]any{
"tenantId": firstTenantID,
"isPrimary": true,
"jobTitle": "First affiliation task",
},
map[string]any{
"tenantId": secondTenantID,
"isPrimary": false,
"jobTitle": "Second affiliation task",
},
},
},
}
gpdtdcTenant := domain.Tenant{
ID: gpdtdcID,
Slug: "gpdtdc",
Name: "총괄기획&기술개발센터",
}
samanRoot := domain.Tenant{
ID: samanRootID,
Slug: "saman",
Name: "삼안",
Type: domain.TenantTypeCompany,
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
firstTenant := domain.Tenant{
ID: firstTenantID,
Slug: "rnd-center",
Name: "삼안기술개발센터",
Type: domain.TenantTypeOrganization,
ParentID: &samanRootID,
}
secondTenant := domain.Tenant{
ID: secondTenantID,
Slug: "tdc",
Name: "기술개발센터",
ParentID: &gpdtdcID,
}
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
user,
gpdtdcTenant,
map[string]domain.Tenant{
samanRootID: samanRoot,
gpdtdcID: gpdtdcTenant,
firstTenantID: firstTenant,
secondTenantID: secondTenant,
},
nil,
)
require.NoError(t, err)
require.Equal(t, int64(1001), payload.DomainID)
require.Equal(t, "First affiliation task", payload.Task)
require.Len(t, payload.Organizations, 2)
require.Equal(t, int64(1001), payload.Organizations[0].DomainID)
require.True(t, payload.Organizations[0].Primary)
require.Equal(t, "externalKey:"+firstTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
require.Equal(t, int64(1003), payload.Organizations[1].DomainID)
require.False(t, payload.Organizations[1].Primary)
require.Equal(t, "externalKey:"+secondTenantID, payload.Organizations[1].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[1].OrgUnits[0].Primary)
}
func TestBuildWorksmobileUserPayloadUsesEmailDomainForAccountDomainWhenPrimaryOrgIsGPDTDC(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
samanID := "11111111-1111-1111-1111-111111111111"
gpdtdcID := "5530ca6e-c5e6-4bf0-84d6-76c6a8fb70ee"
leafTenantID := "52f06c97-9d6f-4819-971b-43303062e193"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "dhlee@samaneng.com",
Name: "GPDTDC Saman User",
TenantID: &leafTenantID,
Metadata: domain.JSONMap{
"additionalAppointments": []any{
map[string]any{
"tenantId": leafTenantID,
"isPrimary": true,
},
},
},
}
samanTenant := domain.Tenant{
ID: samanID,
Slug: "saman",
Name: "삼안",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
gpdtdcTenant := domain.Tenant{
ID: gpdtdcID,
Slug: "gpdtdc",
Name: "총괄기획&기술개발센터",
}
leafTenant := domain.Tenant{
ID: leafTenantID,
Slug: "infra-bim2",
Name: "인프라 BIM2",
ParentID: &gpdtdcID,
}
payload, err := BuildWorksmobileUserPayloadForDomainTenants(
user,
leafTenant,
map[string]domain.Tenant{
samanID: samanTenant,
gpdtdcID: gpdtdcTenant,
leafTenantID: leafTenant,
},
nil,
)
require.NoError(t, err)
require.Equal(t, int64(1001), payload.DomainID)
require.Len(t, payload.Organizations, 1)
require.Equal(t, int64(1003), payload.Organizations[0].DomainID)
require.True(t, payload.Organizations[0].Primary)
require.Equal(t, "dhlee@baroncs.co.kr", payload.Organizations[0].Email)
require.Equal(t, "externalKey:"+leafTenantID, payload.Organizations[0].OrgUnits[0].OrgUnitID)
require.True(t, payload.Organizations[0].OrgUnits[0].Primary)
}
func TestWorksmobileUserPayloadJSONIncludesFalsePrimaryFields(t *testing.T) {
payload := WorksmobileUserPayload{
Email: "user@samaneng.com",
Organizations: []WorksmobileUserOrganization{
{
DomainID: 1001,
Primary: true,
OrgUnits: []WorksmobileUserOrgUnit{
{OrgUnitID: "externalKey:primary", Primary: true},
},
},
{
DomainID: 1003,
Primary: false,
OrgUnits: []WorksmobileUserOrgUnit{
{OrgUnitID: "externalKey:secondary", Primary: false},
},
},
},
}
data, err := json.Marshal(payload)
require.NoError(t, err)
require.Contains(t, string(data), `"primary":false`)
require.Contains(t, string(data), `"orgUnitId":"externalKey:secondary","primary":false`)
}
func TestResolveWorksmobileDomainIDFromTenantIgnoresRootDomainMappings(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
rootConfig := domain.JSONMap{
"worksmobile": map[string]any{
"domainMappings": map[string]any{
"samaneng.com": int64(9999),
},
},
}
got, err := ResolveWorksmobileDomainIDFromTenant(
domain.Tenant{
Slug: "saman",
Name: "삼안",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
},
rootConfig,
)
require.NoError(t, err)
require.Equal(t, int64(1001), got)
}
func TestResolveWorksmobileDomainIDFromTenantRequiresFamilyDomainEnv(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "")
rootConfig := domain.JSONMap{
"worksmobile": map[string]any{
"domainMappings": map[string]any{
"samaneng.com": int64(9999),
},
},
}
_, err := ResolveWorksmobileDomainIDFromTenant(
domain.Tenant{
Slug: "saman",
Name: "삼안",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
},
rootConfig,
)
require.Error(t, err)
require.Contains(t, err.Error(), "SAMAN_DOMAIN_ID")
}
func TestResolveWorksmobileDomainIDUsesEnvFamilyFallbacks(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
t.Setenv("HANMAC_DOMAIN_ID", "1002")
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
t.Setenv("HALLA_DOMAIN_ID", "1005")
t.Setenv("BARONGROUP_DOMAIN_ID", "1004")
tests := []struct {
name string
tenant domain.Tenant
want int64
}{
{
name: "saman",
tenant: domain.Tenant{Slug: "saman", Domains: []domain.TenantDomain{{Domain: "samaneng.com"}}},
want: 1001,
},
{
name: "hanmac",
tenant: domain.Tenant{Slug: "hanmac", Domains: []domain.TenantDomain{{Domain: "hanmaceng.co.kr"}}},
want: 1002,
},
{
name: "gpdtdc",
tenant: domain.Tenant{Slug: "gpdtdc", Name: "총괄기획&기술개발센터"},
want: 1003,
},
{
name: "halla",
tenant: domain.Tenant{Slug: "halla", Name: "한라산업개발", Domains: []domain.TenantDomain{{Domain: "hallasanup.com"}}},
want: 1005,
},
{
name: "hanlla legacy slug",
tenant: domain.Tenant{Slug: "hanlla", Name: "한라산업개발"},
want: 1005,
},
{
name: "barongroup fallback",
tenant: domain.Tenant{Slug: "family-company", Name: "기타 가족사"},
want: 1004,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ResolveWorksmobileDomainIDFromTenant(tt.tenant, nil)
require.NoError(t, err)
require.Equal(t, tt.want, got)
})
}
}
func TestResolveWorksmobileAccountDomainIDUsesHallaEmailDomain(t *testing.T) {
t.Setenv("HALLA_DOMAIN_ID", "1005")
t.Setenv("BARONGROUP_DOMAIN_ID", "1004")
tenant := domain.Tenant{
Slug: "halla",
Name: "한라산업개발",
Domains: []domain.TenantDomain{{Domain: "hallasanup.com"}},
}
got, err := ResolveWorksmobileAccountDomainIDFromEmail("user@hallasanup.com", tenant, nil)
require.NoError(t, err)
require.Equal(t, int64(1005), got)
}
func TestWorksmobileDomainIDsFromEnvIncludesHallaBeforeFallback(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
t.Setenv("HANMAC_DOMAIN_ID", "1002")
t.Setenv("GPDTDC_DOMAIN_ID", "1003")
t.Setenv("HALLA_DOMAIN_ID", "1005")
t.Setenv("BARONGROUP_DOMAIN_ID", "1004")
got := WorksmobileDomainIDsFromEnv()
require.Equal(t, []int64{1001, 1002, 1003, 1005, 1004}, got)
require.Equal(t, "한라산업개발", WorksmobileDomainLabelForID(1005))
}
func TestBuildWorksmobileUserPayloadUsesHallaDomain(t *testing.T) {
t.Setenv("HALLA_DOMAIN_ID", "1005")
t.Setenv("WORKS_DEFAULT_DOMAIN_HALLA", "hallasanup.com")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "main@hallasanup.com",
Name: "Halla User",
TenantID: &tenantID,
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "halla",
Name: "한라산업개발",
Domains: []domain.TenantDomain{{Domain: "hallasanup.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, int64(1005), payload.DomainID)
require.Equal(t, "main@hallasanup.com", payload.Email)
}
func TestBuildWorksmobileUserPayloadAddsHanmacEmployeeAlias(t *testing.T) {
t.Setenv("HANMAC_DOMAIN_ID", "1002")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "main@hanmaceng.co.kr",
Name: "Hanmac User",
TenantID: &tenantID,
Metadata: domain.JSONMap{
"employee_id": "HM001",
"personal_email": "private@example.com",
},
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "hanmac",
Name: "한맥",
Domains: []domain.TenantDomain{{Domain: "hanmaceng.co.kr"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, int64(1002), payload.DomainID)
require.Equal(t, []string{"hm001@hanmaceng.co.kr"}, payload.AliasEmails)
require.Empty(t, payload.PrivateEmail)
require.Equal(t, "ko_KR", payload.Locale)
}
func TestBuildWorksmobileUserPayloadAddsMultipleMetadataAliases(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "main@samaneng.com",
Name: "Saman User",
TenantID: &tenantID,
Metadata: domain.JSONMap{
"aliasEmails": []any{"alias1@samaneng.com", "alias2@samaneng.com", "main@samaneng.com"},
},
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "삼안",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, []string{"alias1@samaneng.com", "alias2@samaneng.com"}, payload.AliasEmails)
}
func TestBuildWorksmobileUserPayloadAddsSubEmailMetadataAlias(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "main@samaneng.com",
Name: "Saman User",
TenantID: &tenantID,
Metadata: domain.JSONMap{
"sub_email": "alias1@hanmaceng.co.kr",
"secondary_emails": []any{"alias2@hanmaceng.co.kr"},
},
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "삼안",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, []string{"alias1@hanmaceng.co.kr", "alias2@hanmaceng.co.kr"}, payload.AliasEmails)
}
func TestBuildWorksmobileUserPayloadKeepsSubEmailAliasWithPrimaryLocalPart(t *testing.T) {
t.Setenv("SAMAN_DOMAIN_ID", "1001")
tenantID := "33333333-3333-3333-3333-333333333333"
user := domain.User{
ID: "44444444-4444-4444-4444-444444444444",
Email: "ypshim@samaneng.com",
Name: "Saman User",
TenantID: &tenantID,
Metadata: domain.JSONMap{
"sub_email": "ypshim@hanmaceng.co.kr",
},
}
tenant := domain.Tenant{
ID: tenantID,
Slug: "saman",
Name: "삼안",
Domains: []domain.TenantDomain{{Domain: "samaneng.com"}},
}
payload, err := BuildWorksmobileUserPayload(user, tenant, nil)
require.NoError(t, err)
require.Equal(t, []string{"ypshim@hanmaceng.co.kr"}, payload.AliasEmails)
}
func TestValidateWorksmobileAliasEmailsAllowsSameLocalPartOnDifferentDomains(t *testing.T) {
err := ValidateWorksmobileAliasEmails(
"main@samaneng.com",
[]string{"main@hanmaceng.co.kr"},
map[string]string{},
)
require.NoError(t, err)
err = ValidateWorksmobileAliasEmails(
"main@samaneng.com",
[]string{"main@samaneng.com"},
map[string]string{},
)
require.Error(t, err)
require.Contains(t, err.Error(), "duplicates")
err = ValidateWorksmobileAliasEmails(
"main@samaneng.com",
[]string{"alias@hanmaceng.co.kr"},
map[string]string{"alias@hanmaceng.co.kr": "existing-user"},
)
require.Error(t, err)
require.Contains(t, err.Error(), "이미 사용 중")
}
func containsAny(value string, candidates string) bool {
return strings.ContainsAny(value, candidates)
}
func TestWorksmobileUserStatusAction(t *testing.T) {
require.Equal(t, WorksmobileUserActionUpsert, WorksmobileUserStatusAction(domain.UserStatusActive))
require.Equal(t, WorksmobileUserActionUpsert, WorksmobileUserStatusAction(domain.UserStatusTemporaryLeave))
require.Equal(t, WorksmobileUserActionSuspend, WorksmobileUserStatusAction(domain.UserStatusSuspended))
require.Equal(t, domain.WorksmobileActionDelete, WorksmobileUserStatusAction(domain.UserStatusExtendedLeave))
require.Equal(t, domain.WorksmobileActionDelete, WorksmobileUserStatusAction(domain.UserStatusBaronGuest))
require.Equal(t, domain.WorksmobileActionDelete, WorksmobileUserStatusAction(domain.UserStatusArchived))
require.Equal(t, WorksmobileUserActionUpsert, WorksmobileUserStatusAction("leave_of_absence"))
require.Equal(t, domain.WorksmobileActionDelete, WorksmobileUserStatusAction("baron_only"))
}
func TestValidateWorksmobileExternalKeyRejectsUnsupportedCharacters(t *testing.T) {
require.NoError(t, ValidateWorksmobileExternalKey("44444444-4444-4444-4444-444444444444"))
require.Error(t, ValidateWorksmobileExternalKey("user/with/slash"))
require.Error(t, ValidateWorksmobileExternalKey("user#with-hash"))
}

View File

@@ -0,0 +1,79 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"time"
"github.com/go-redis/redis/v8"
)
const (
worksmobileRelayLeaderLockKey = "baron:worksmobile:relay:leader"
worksmobileRelayLeaderLockTTL = 30 * time.Second
)
const worksmobileRelayLeaderRenewScript = `
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("EXPIRE", KEYS[1], ARGV[2])
end
return 0
`
type WorksmobileRedisRelayLeaderLock struct {
client *redis.Client
key string
ttl time.Duration
ownerID string
}
func NewWorksmobileRedisRelayLeaderLock(redisService *RedisService) *WorksmobileRedisRelayLeaderLock {
if redisService == nil || redisService.Client == nil {
return nil
}
return &WorksmobileRedisRelayLeaderLock{
client: redisService.Client,
key: worksmobileRelayLeaderLockKey,
ttl: worksmobileRelayLeaderLockTTL,
ownerID: newWorksmobileRelayLeaderOwnerID(),
}
}
func (l *WorksmobileRedisRelayLeaderLock) EnsureLeadership(ctx context.Context) (bool, error) {
if l == nil || l.client == nil {
return true, nil
}
acquired, err := l.client.SetNX(ctx, l.key, l.ownerID, l.ttl).Result()
if err != nil {
return false, err
}
if acquired {
return true, nil
}
ttlSeconds := int64(l.ttl / time.Second)
if ttlSeconds <= 0 {
ttlSeconds = 30
}
result, err := l.client.Eval(ctx, worksmobileRelayLeaderRenewScript, []string{l.key}, l.ownerID, ttlSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func newWorksmobileRelayLeaderOwnerID() string {
hostname, _ := os.Hostname()
if hostname == "" {
hostname = "unknown-host"
}
randomBytes := make([]byte, 8)
if _, err := rand.Read(randomBytes); err != nil {
return fmt.Sprintf("%s:%d:%d", hostname, os.Getpid(), time.Now().UnixNano())
}
return fmt.Sprintf("%s:%d:%s", hostname, os.Getpid(), hex.EncodeToString(randomBytes))
}

View File

@@ -0,0 +1,302 @@
package service
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"sort"
"strings"
"time"
)
type WorksmobileRelayWorker struct {
repo repository.WorksmobileOutboxRepository
client WorksmobileDirectoryClient
leaderLock WorksmobileRelayLeaderLock
interval time.Duration
batchLimit int
}
type WorksmobileRelayLeaderLock interface {
EnsureLeadership(ctx context.Context) (bool, error)
}
func NewWorksmobileRelayWorker(repo repository.WorksmobileOutboxRepository, client WorksmobileDirectoryClient) *WorksmobileRelayWorker {
return &WorksmobileRelayWorker{
repo: repo,
client: client,
interval: 3 * time.Second,
batchLimit: 10,
}
}
func (w *WorksmobileRelayWorker) SetLeaderLock(lock WorksmobileRelayLeaderLock) {
w.leaderLock = lock
}
func (w *WorksmobileRelayWorker) SetBatchLimit(limit int) {
if limit <= 0 {
return
}
w.batchLimit = limit
}
func (w *WorksmobileRelayWorker) Start(ctx context.Context) {
if w.repo == nil || w.client == nil {
slog.Warn("Worksmobile relay worker disabled")
return
}
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
for {
if err := w.ProcessOnce(ctx); err != nil && !errors.Is(err, context.Canceled) {
slog.Warn("Worksmobile relay tick failed", "error", err)
}
select {
case <-ctx.Done():
return
case <-ticker.C:
}
}
}
func (w *WorksmobileRelayWorker) ProcessOnce(ctx context.Context) (err error) {
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("worksmobile relay panic: %v", recovered)
}
}()
if w.leaderLock != nil {
isLeader, err := w.leaderLock.EnsureLeadership(ctx)
if err != nil {
return err
}
if !isLeader {
return nil
}
}
jobs, err := w.repo.ListReady(ctx, w.batchLimit)
if err != nil {
return err
}
jobs = sortWorksmobileReadyJobs(jobs)
for _, job := range jobs {
if err := w.processJob(ctx, job); err != nil {
slog.Warn("Worksmobile relay job failed", "jobID", job.ID, "resourceType", job.ResourceType, "resourceID", job.ResourceID, "error", err)
}
}
return nil
}
func (w *WorksmobileRelayWorker) processJob(ctx context.Context, job domain.WorksmobileOutbox) error {
claimed, err := w.repo.MarkProcessing(ctx, job.ID)
if err != nil {
return err
}
if !claimed {
return nil
}
err = w.dispatch(ctx, job)
if err != nil {
nextAttempt := time.Now().Add(worksmobileRetryDelay(job.RetryCount))
_ = w.repo.MarkFailed(ctx, job.ID, err.Error(), nextAttempt)
return err
}
return w.repo.MarkProcessed(ctx, job.ID)
}
func (w *WorksmobileRelayWorker) dispatch(ctx context.Context, job domain.WorksmobileOutbox) error {
if job.Action == domain.WorksmobileActionDryRun {
return nil
}
switch job.ResourceType {
case domain.WorksmobileResourceOrgUnit:
if job.Action == domain.WorksmobileActionDelete {
return w.client.DeleteOrgUnit(ctx, stringValue(job.Payload["worksmobileId"]))
}
if job.Action != domain.WorksmobileActionUpsert {
return nil
}
var payload WorksmobileOrgUnitPayload
if err := decodeWorksmobileRequest(job.Payload, &payload); err != nil {
return err
}
return w.client.UpsertOrgUnit(ctx, payload, stringValue(job.Payload["matchLocalPart"]))
case domain.WorksmobileResourceUser:
switch job.Action {
case domain.WorksmobileActionUpsert:
var payload WorksmobileUserPayload
if err := decodeWorksmobileRequest(job.Payload, &payload); err != nil {
return err
}
aliasEmails := append([]string(nil), payload.AliasEmails...)
payload.AliasEmails = nil
if err := w.client.UpsertUser(ctx, payload); err != nil {
return fmt.Errorf("worksmobile user upsert failed: %w", err)
}
for _, aliasEmail := range aliasEmails {
if err := w.client.AddUserAliasEmail(ctx, payload.Email, aliasEmail); err != nil {
return fmt.Errorf("worksmobile user alias add failed: %w", err)
}
}
if stringValue(job.Payload["baronStatus"]) == domain.UserStatusActive {
if err := w.client.SetUserActive(ctx, worksmobileOutboxUserIdentifier(job), true); err != nil {
if isWorksmobileSCIMTokenNotConfiguredError(err) {
return nil
}
return fmt.Errorf("worksmobile user set active failed: %w", err)
}
}
return nil
case domain.WorksmobileActionDelete:
return w.client.DeleteUser(ctx, worksmobileOutboxUserIdentifier(job))
case domain.WorksmobileActionSuspend:
return w.client.SetUserActive(ctx, worksmobileOutboxUserIdentifier(job), false)
case domain.WorksmobileActionPasswordReset:
var payload WorksmobilePasswordResetPayload
if err := decodeWorksmobileRequest(job.Payload, &payload); err != nil {
return err
}
identifier := strings.TrimSpace(payload.Email)
if identifier == "" {
identifier = worksmobileOutboxUserIdentifier(job)
}
return w.client.ResetUserPassword(ctx, identifier, payload.PasswordConfig.Password)
default:
return nil
}
default:
return nil
}
}
func isWorksmobileSCIMTokenNotConfiguredError(err error) bool {
return err != nil && strings.Contains(err.Error(), "worksmobile scim token is not configured")
}
func sortWorksmobileReadyJobs(jobs []domain.WorksmobileOutbox) []domain.WorksmobileOutbox {
sorted := append([]domain.WorksmobileOutbox(nil), jobs...)
depthByID := worksmobileOrgUnitDepths(sorted)
sort.SliceStable(sorted, func(i, j int) bool {
leftClass := worksmobileRelayOrderClass(sorted[i])
rightClass := worksmobileRelayOrderClass(sorted[j])
if leftClass != rightClass {
return leftClass < rightClass
}
leftDepth := depthByID[sorted[i].ID]
rightDepth := depthByID[sorted[j].ID]
if leftDepth != rightDepth {
return leftDepth < rightDepth
}
return sorted[i].CreatedAt.Before(sorted[j].CreatedAt)
})
return sorted
}
func worksmobileRelayOrderClass(job domain.WorksmobileOutbox) int {
if job.ResourceType == domain.WorksmobileResourceOrgUnit && job.Action == domain.WorksmobileActionUpsert {
return 0
}
if job.ResourceType == domain.WorksmobileResourceUser {
return 1
}
return 2
}
func worksmobileOrgUnitDepths(jobs []domain.WorksmobileOutbox) map[string]int {
type orgUnitJob struct {
jobID string
parentKey string
}
byExternalKey := map[string]orgUnitJob{}
for _, job := range jobs {
externalKey, parentKey := worksmobileOrgUnitExternalKeys(job)
if externalKey == "" {
continue
}
byExternalKey[externalKey] = orgUnitJob{jobID: job.ID, parentKey: parentKey}
}
depthByExternalKey := map[string]int{}
var depth func(externalKey string, seen map[string]bool) int
depth = func(externalKey string, seen map[string]bool) int {
if value, ok := depthByExternalKey[externalKey]; ok {
return value
}
job, ok := byExternalKey[externalKey]
if !ok || job.parentKey == "" || seen[externalKey] {
depthByExternalKey[externalKey] = 0
return 0
}
seen[externalKey] = true
value := depth(job.parentKey, seen) + 1
delete(seen, externalKey)
depthByExternalKey[externalKey] = value
return value
}
depthByJobID := map[string]int{}
for externalKey, job := range byExternalKey {
depthByJobID[job.jobID] = depth(externalKey, map[string]bool{})
}
return depthByJobID
}
func worksmobileOrgUnitExternalKeys(job domain.WorksmobileOutbox) (string, string) {
if job.ResourceType != domain.WorksmobileResourceOrgUnit || job.Action != domain.WorksmobileActionUpsert {
return "", ""
}
var payload WorksmobileOrgUnitPayload
if err := decodeWorksmobileRequest(job.Payload, &payload); err != nil {
return "", ""
}
parentKey := strings.TrimSpace(payload.ParentOrgUnitID)
if strings.HasPrefix(parentKey, "externalKey:") {
parentKey = strings.TrimSpace(strings.TrimPrefix(parentKey, "externalKey:"))
} else {
parentKey = ""
}
return strings.TrimSpace(payload.OrgUnitExternalKey), parentKey
}
func worksmobileOutboxUserIdentifier(job domain.WorksmobileOutbox) string {
userID := stringValue(job.Payload["loginEmail"])
if userID == "" {
userID = stringValue(job.Payload["userExternalKey"])
}
return userID
}
func decodeWorksmobileRequest(payload domain.JSONMap, target any) error {
raw := payload["request"]
if raw == nil {
return errors.New("worksmobile request payload is missing")
}
data, err := json.Marshal(raw)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(data)))
decoder.DisallowUnknownFields()
return decoder.Decode(target)
}
func worksmobileRetryDelay(retryCount int) time.Duration {
if retryCount < 0 {
retryCount = 0
}
if retryCount > 5 {
retryCount = 5
}
return time.Duration(1<<retryCount) * time.Minute
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff