첫 커밋: 로컬 프로젝트 업로드

This commit is contained in:
2026-06-10 15:51:34 +09:00
commit 6a8dbeb2e9
1211 changed files with 312864 additions and 0 deletions

View File

@@ -0,0 +1,203 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"errors"
"fmt"
"net/mail"
"strings"
"time"
"gorm.io/gorm"
)
type SuperAdminIdentityAdmin interface {
FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error)
CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error)
UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error
}
type SuperAdminStore interface {
FindUserByEmail(ctx context.Context, email string) (*domain.User, error)
CreateUser(ctx context.Context, user *domain.User) error
UpdateUserSuperAdmin(ctx context.Context, userID string, name string) (*domain.User, error)
EnqueueSuperAdminRelation(ctx context.Context, userID string) error
}
type EnsureSuperAdminOptions struct {
Email string
Password string
Name string
Source string
UpdatePassword bool
}
type EnsureSuperAdminResult struct {
Email string
IdentityID string
LocalUserID string
IdentityCreated bool
PasswordUpdated bool
LocalUserCreated bool
LocalUserUpdated bool
KetoRelationQueued bool
}
func EnsureSuperAdmin(ctx context.Context, identityAdmin SuperAdminIdentityAdmin, store SuperAdminStore, opts EnsureSuperAdminOptions) (EnsureSuperAdminResult, error) {
email := strings.ToLower(strings.TrimSpace(opts.Email))
name := strings.TrimSpace(opts.Name)
if name == "" {
name = "System Admin"
}
source := strings.TrimSpace(opts.Source)
if source == "" {
source = "admin_cli"
}
result := EnsureSuperAdminResult{Email: email}
if _, err := mail.ParseAddress(email); err != nil {
return result, fmt.Errorf("invalid admin email: %w", err)
}
if identityAdmin == nil {
return result, errors.New("identity admin is required")
}
if store == nil {
return result, errors.New("super admin store is required")
}
identityID, err := identityAdmin.FindIdentityIDByIdentifier(ctx, email)
if err != nil {
return result, fmt.Errorf("find admin identity: %w", err)
}
if identityID == "" {
if strings.TrimSpace(opts.Password) == "" {
return result, errors.New("admin password is required to create identity")
}
identityID, err = identityAdmin.CreateUser(ctx, buildSuperAdminBrokerUser(email, name), opts.Password)
if err != nil {
return result, fmt.Errorf("create admin identity: %w", err)
}
result.IdentityCreated = true
} else if opts.UpdatePassword {
if strings.TrimSpace(opts.Password) == "" {
return result, errors.New("admin password is required to update identity password")
}
if err := identityAdmin.UpdateIdentityPassword(ctx, identityID, opts.Password); err != nil {
return result, fmt.Errorf("update admin identity password: %w", err)
}
result.PasswordUpdated = true
}
result.IdentityID = identityID
user, err := store.FindUserByEmail(ctx, email)
if err != nil {
return result, fmt.Errorf("find local admin user: %w", err)
}
if user == nil {
if identityID == "" {
return result, errors.New("identity id is required to create local admin user")
}
user = &domain.User{
ID: identityID,
Email: email,
Name: name,
Role: domain.RoleSuperAdmin,
Status: domain.UserStatusActive,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Metadata: domain.JSONMap{
"source": source,
},
}
if err := store.CreateUser(ctx, user); err != nil {
return result, fmt.Errorf("create local admin user: %w", err)
}
result.LocalUserCreated = true
} else if domain.NormalizeRole(user.Role) != domain.RoleSuperAdmin || user.Status != domain.UserStatusActive || (name != "" && user.Name != name) {
user, err = store.UpdateUserSuperAdmin(ctx, user.ID, name)
if err != nil {
return result, fmt.Errorf("update local admin user: %w", err)
}
result.LocalUserUpdated = true
}
result.LocalUserID = user.ID
if err := store.EnqueueSuperAdminRelation(ctx, user.ID); err != nil {
return result, fmt.Errorf("enqueue super admin keto relation: %w", err)
}
result.KetoRelationQueued = true
return result, nil
}
func buildSuperAdminBrokerUser(email, name string) *domain.BrokerUser {
return &domain.BrokerUser{
Email: email,
Name: name,
PhoneNumber: "",
Attributes: map[string]any{
"department": "Admin",
"affiliationType": "internal",
"grade": "",
"role": domain.RoleSuperAdmin,
},
}
}
type gormSuperAdminStore struct {
db *gorm.DB
outbox repository.KetoOutboxRepository
}
func NewGormSuperAdminStore(db *gorm.DB, outbox repository.KetoOutboxRepository) SuperAdminStore {
return &gormSuperAdminStore{db: db, outbox: outbox}
}
func (s *gormSuperAdminStore) FindUserByEmail(ctx context.Context, email string) (*domain.User, error) {
var user domain.User
if err := s.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &user, nil
}
func (s *gormSuperAdminStore) CreateUser(ctx context.Context, user *domain.User) error {
return s.db.WithContext(ctx).Create(user).Error
}
func (s *gormSuperAdminStore) UpdateUserSuperAdmin(ctx context.Context, userID string, name string) (*domain.User, error) {
updates := map[string]any{
"role": domain.RoleSuperAdmin,
"status": domain.UserStatusActive,
"updated_at": time.Now(),
}
if strings.TrimSpace(name) != "" {
updates["name"] = strings.TrimSpace(name)
}
if err := s.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", userID).Updates(updates).Error; err != nil {
return nil, err
}
var user domain.User
if err := s.db.WithContext(ctx).Where("id = ?", userID).First(&user).Error; err != nil {
return nil, err
}
return &user, nil
}
func (s *gormSuperAdminStore) EnqueueSuperAdminRelation(ctx context.Context, userID string) error {
if s.outbox == nil {
return nil
}
return s.outbox.Create(ctx, &domain.KetoOutbox{
Namespace: "System",
Object: "global",
Relation: "super_admins",
Subject: "User:" + userID,
Action: domain.KetoOutboxActionCreate,
})
}

View File

@@ -0,0 +1,157 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"context"
"errors"
"testing"
)
func TestEnsureSuperAdminCreatesIdentityLocalUserAndKetoRelation(t *testing.T) {
ctx := context.Background()
identityAdmin := &fakeSuperAdminIdentityAdmin{createdID: "identity-1"}
store := &fakeSuperAdminStore{}
result, err := EnsureSuperAdmin(ctx, identityAdmin, store, EnsureSuperAdminOptions{
Email: "new-admin@example.com",
Password: "Password!123",
Name: "New Admin",
Source: "test",
})
if err != nil {
t.Fatalf("EnsureSuperAdmin returned error: %v", err)
}
if !result.IdentityCreated {
t.Fatal("identity must be created")
}
if !result.LocalUserCreated {
t.Fatal("local user must be created")
}
if result.IdentityID != "identity-1" {
t.Fatalf("identity ID = %q, want identity-1", result.IdentityID)
}
if store.user == nil {
t.Fatal("local user was not stored")
}
if store.user.Email != "new-admin@example.com" {
t.Fatalf("local user email = %q", store.user.Email)
}
if store.user.Role != domain.RoleSuperAdmin {
t.Fatalf("local user role = %q, want %q", store.user.Role, domain.RoleSuperAdmin)
}
if len(store.ketoSubjects) != 1 || store.ketoSubjects[0] != "User:identity-1" {
t.Fatalf("keto subjects = %#v, want User:identity-1", store.ketoSubjects)
}
if identityAdmin.createdUser == nil || identityAdmin.createdUser.Attributes["role"] != domain.RoleSuperAdmin {
t.Fatalf("created identity attributes = %#v", identityAdmin.createdUser)
}
}
func TestEnsureSuperAdminPromotesExistingLocalUser(t *testing.T) {
ctx := context.Background()
identityAdmin := &fakeSuperAdminIdentityAdmin{existingID: "identity-1"}
store := &fakeSuperAdminStore{
user: &domain.User{
ID: "local-user-1",
Email: "existing@example.com",
Name: "Existing",
Role: domain.RoleUser,
Status: domain.UserStatusPreboarding,
},
}
result, err := EnsureSuperAdmin(ctx, identityAdmin, store, EnsureSuperAdminOptions{
Email: "existing@example.com",
Password: "Password!123",
Name: "Existing Admin",
Source: "test",
})
if err != nil {
t.Fatalf("EnsureSuperAdmin returned error: %v", err)
}
if result.IdentityCreated {
t.Fatal("existing identity must not be recreated")
}
if !result.LocalUserUpdated {
t.Fatal("local user must be promoted")
}
if store.user.Role != domain.RoleSuperAdmin {
t.Fatalf("local user role = %q, want %q", store.user.Role, domain.RoleSuperAdmin)
}
if store.user.Status != domain.UserStatusActive {
t.Fatalf("local user status = %q, want %q", store.user.Status, domain.UserStatusActive)
}
if len(store.ketoSubjects) != 1 || store.ketoSubjects[0] != "User:local-user-1" {
t.Fatalf("keto subjects = %#v, want User:local-user-1", store.ketoSubjects)
}
}
func TestEnsureSuperAdminRequiresPasswordForNewIdentity(t *testing.T) {
_, err := EnsureSuperAdmin(context.Background(), &fakeSuperAdminIdentityAdmin{}, &fakeSuperAdminStore{}, EnsureSuperAdminOptions{
Email: "new-admin@example.com",
Name: "New Admin",
})
if err == nil {
t.Fatal("expected error")
}
}
type fakeSuperAdminIdentityAdmin struct {
existingID string
createdID string
createdUser *domain.BrokerUser
createdSecret string
}
func (f *fakeSuperAdminIdentityAdmin) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
return f.existingID, nil
}
func (f *fakeSuperAdminIdentityAdmin) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
if f.createdID == "" {
return "", errors.New("created id is not configured")
}
f.createdUser = user
f.createdSecret = password
return f.createdID, nil
}
func (f *fakeSuperAdminIdentityAdmin) UpdateIdentityPassword(ctx context.Context, identityID string, newPassword string) error {
return nil
}
type fakeSuperAdminStore struct {
user *domain.User
ketoSubjects []string
}
func (f *fakeSuperAdminStore) FindUserByEmail(ctx context.Context, email string) (*domain.User, error) {
if f.user == nil {
return nil, nil
}
return f.user, nil
}
func (f *fakeSuperAdminStore) CreateUser(ctx context.Context, user *domain.User) error {
copied := *user
f.user = &copied
return nil
}
func (f *fakeSuperAdminStore) UpdateUserSuperAdmin(ctx context.Context, userID string, name string) (*domain.User, error) {
if f.user == nil {
return nil, errors.New("user not found")
}
f.user.Role = domain.RoleSuperAdmin
f.user.Status = domain.UserStatusActive
if name != "" {
f.user.Name = name
}
return f.user, nil
}
func (f *fakeSuperAdminStore) EnqueueSuperAdminRelation(ctx context.Context, userID string) error {
f.ketoSubjects = append(f.ketoSubjects, "User:"+userID)
return nil
}

View File

@@ -0,0 +1,116 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"fmt"
"log/slog"
"gorm.io/gorm"
)
// Run executes the application bootstrap logic (migrations, seeding, etc.)
func Run(db *gorm.DB) error {
slog.Info("[Bootstrap] Starting application bootstrap...")
// 1. Auto Migration
if err := migrateSchemas(db); err != nil {
return fmt.Errorf("migration failed: %w", err)
}
// 2. Seed Tenants
if err := SeedTenants(db); err != nil {
return fmt.Errorf("tenant seeding failed: %w", err)
}
// 3. Normalize staging seed/read-model data
if err := CanonicalizeLegacyUserStatuses(db); err != nil {
return fmt.Errorf("legacy user status canonicalization failed: %w", err)
}
if err := SanitizeLegacyUserMetadata(db); err != nil {
return fmt.Errorf("legacy user metadata sanitize failed: %w", err)
}
slog.Info("[Bootstrap] User seed skipped (Kratos is SoT)")
slog.Info("[Bootstrap] Bootstrap completed successfully.")
return nil
}
func migrateSchemas(db *gorm.DB) error {
slog.Info("[Bootstrap] Migrating database schemas...")
if err := dropLegacyTenantDomainUniqueIndex(db); err != nil {
return err
}
if err := dropLegacyUserCompanyColumns(db); err != nil {
return err
}
// Add all domain models here
return db.AutoMigrate(
&domain.Tenant{},
&domain.TenantDomain{},
&domain.User{},
&domain.UserLoginID{},
&domain.UserProjectionState{},
&domain.UserGroup{},
&domain.ApiKey{},
&domain.IdentityProviderConfig{},
&domain.ClientSecret{},
&domain.ClientConsent{},
&domain.KetoOutbox{},
&domain.RPUsageEvent{},
&domain.WorksmobileOutbox{},
&domain.WorksmobileResourceMapping{},
&domain.SharedLink{},
&domain.DeveloperRequest{},
&domain.RPUserMetadata{},
&domain.SystemSetting{},
// &domain.RelyingParty{}, // Removed: SSOT is Hydra + Keto
)
}
func CanonicalizeLegacyUserStatuses(db *gorm.DB) error {
if db == nil || !db.Migrator().HasTable(&domain.User{}) {
return nil
}
updates := map[string]string{
"inactive": domain.UserStatusPreboarding,
"leave_of_absence": domain.UserStatusTemporaryLeave,
"baron_only": domain.UserStatusBaronGuest,
}
for legacy, canonical := range updates {
if err := db.Model(&domain.User{}).
Where("status = ?", legacy).
Update("status", canonical).Error; err != nil {
return fmt.Errorf("failed to canonicalize users.status %s to %s: %w", legacy, canonical, err)
}
}
return nil
}
func dropLegacyUserCompanyColumns(db *gorm.DB) error {
if !db.Migrator().HasTable(&domain.User{}) {
return nil
}
for _, column := range []string{"company_code", "company_codes"} {
if !db.Migrator().HasColumn(&domain.User{}, column) {
continue
}
if err := db.Migrator().DropColumn(&domain.User{}, column); err != nil {
return fmt.Errorf("failed to drop legacy users.%s column: %w", column, err)
}
}
return nil
}
func dropLegacyTenantDomainUniqueIndex(db *gorm.DB) error {
if !db.Migrator().HasTable(&domain.TenantDomain{}) {
return nil
}
if !db.Migrator().HasIndex(&domain.TenantDomain{}, "idx_tenant_domains_domain") {
return nil
}
if err := db.Migrator().DropIndex(&domain.TenantDomain{}, "idx_tenant_domains_domain"); err != nil {
return fmt.Errorf("failed to drop legacy tenant domain unique index: %w", err)
}
return nil
}

View File

@@ -0,0 +1,84 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"context"
"log/slog"
"gorm.io/gorm"
)
// SyncKetoRelations synchronizes all existing DB users, tenants and RPs to Ory Keto via Outbox.
// This ensures data consistency for existing data when ReBAC is introduced.
func SyncKetoRelations(db *gorm.DB, outbox repository.KetoOutboxRepository) error {
slog.Info("🚀 Starting Keto ReBAC relation synchronization (via Outbox)...")
ctx := context.Background()
// 1. Sync All Tenants
var tenants []domain.Tenant
if err := db.Find(&tenants).Error; err != nil {
return err
}
slog.Info("Syncing tenants to Keto Outbox", "count", len(tenants))
for _, t := range tenants {
// Global Super Admin access to every tenant
_ = outbox.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: t.ID,
Relation: "admins",
Subject: "System:global#super_admins",
Action: domain.KetoOutboxActionCreate,
})
if t.ParentID != nil {
_ = outbox.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: t.ID,
Relation: "parents",
Subject: "Tenant:" + *t.ParentID,
Action: domain.KetoOutboxActionCreate,
})
}
}
// 2. Sync All RelyingParties (if needed)
// Note: We'll need a way to list them from Hydra or local DB if we had them.
// Assuming they are in a table domain.RelyingParty (though it was removed, let's see)
// Actually, the comment said SSOT is Hydra. But we might have them in a local table for metadata.
// If not, we skip for now or fetch from Hydra.
// 3. Sync All Users Roles and Tenant Memberships
var users []domain.User
if err := db.Find(&users).Error; err != nil {
return err
}
slog.Info("Syncing users to Keto Outbox", "count", len(users))
for _, u := range users {
// Tenant Membership
if u.TenantID != nil {
_ = outbox.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: *u.TenantID,
Relation: "members",
Subject: "User:" + u.ID,
Action: domain.KetoOutboxActionCreate,
})
}
// Roles
role := domain.NormalizeRole(u.Role)
if role == domain.RoleSuperAdmin {
_ = outbox.Create(ctx, &domain.KetoOutbox{
Namespace: "System",
Object: "global",
Relation: "super_admins",
Subject: "User:" + u.ID,
Action: domain.KetoOutboxActionCreate,
})
}
}
slog.Info("✅ Keto ReBAC synchronization items added to Outbox.")
return nil
}

View File

@@ -0,0 +1,75 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"log/slog"
"os"
"strings"
"time"
)
// SeedAdminIdentity creates the initial admin identity in the configured IDP.
// Returns the Kratos Identity ID and error.
func SeedAdminIdentity(idp domain.IdentityProvider) (string, error) {
if idp == nil {
return "", nil
}
adminEmail := strings.TrimSpace(os.Getenv("ADMIN_EMAIL"))
adminPassword := os.Getenv("ADMIN_PASSWORD")
if adminEmail == "" || adminPassword == "" {
slog.Warn("[Bootstrap] ADMIN_EMAIL or ADMIN_PASSWORD not set. Skipping admin identity seed.")
return "", nil
}
adminName := strings.TrimSpace(os.Getenv("ADMIN_NAME"))
if adminName == "" {
adminName = "System Admin"
}
user := &domain.BrokerUser{
Email: adminEmail,
Name: adminName,
PhoneNumber: "",
Attributes: map[string]any{
"department": "Admin",
"affiliationType": "internal",
"grade": "",
"role": "super_admin", // Explicitly set role for Kratos traits
},
}
// Retry logic for Kratos connection
maxRetries := 5
var err error
var identityID string
for i := range maxRetries {
identityID, err = idp.CreateUser(user, adminPassword)
if err == nil {
slog.Info("[Bootstrap] Admin identity created in IDP", "email", adminEmail, "idp", idp.Name(), "id", identityID)
return identityID, nil
}
if strings.Contains(err.Error(), "already exists") {
slog.Info("[Bootstrap] Admin identity already exists in IDP. Attempting to retrieve ID...", "email", adminEmail)
// Try to sign in to get the identity ID
authInfo, err := idp.SignIn(adminEmail, adminPassword)
if err == nil && authInfo != nil {
slog.Info("[Bootstrap] Retrieved existing admin identity ID", "id", authInfo.Subject)
return authInfo.Subject, nil
}
slog.Warn("[Bootstrap] Failed to retrieve existing admin identity ID via SignIn", "error", err)
return "", nil // Return nil error to avoid stopping bootstrap, but ID is missing
}
slog.Warn("[Bootstrap] Failed to seed admin identity (retrying...)",
"attempt", i+1,
"max_retries", maxRetries,
"error", err,
)
time.Sleep(2 * time.Second)
}
return "", err
}

View File

@@ -0,0 +1,77 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"log/slog"
"os"
"strings"
"time"
"gorm.io/gorm"
)
// SyncAdminRole updates the role of the admin user in the local DB.
// It ensures the admin user exists in the local DB with the correct Kratos ID.
func SyncAdminRole(db *gorm.DB, kratosID string) error {
adminEmail := strings.TrimSpace(os.Getenv("ADMIN_EMAIL"))
if adminEmail == "" {
slog.Warn("[Bootstrap] ADMIN_EMAIL not set. Skipping admin role sync.")
return nil
}
adminName := strings.TrimSpace(os.Getenv("ADMIN_NAME"))
if adminName == "" {
adminName = "System Admin"
}
// Find user by email
var user domain.User
if err := db.Where("email = ?", adminEmail).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if kratosID == "" {
slog.Warn("[Bootstrap] Admin user not found in local DB and Kratos ID is missing. Cannot create local user.", "email", adminEmail)
return nil
}
// Create new admin user in local DB
newUser := domain.User{
ID: kratosID,
Email: adminEmail,
Name: adminName,
Role: domain.RoleSuperAdmin,
Status: "active",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Metadata: domain.JSONMap{"source": "bootstrap_seed"},
}
if err := db.Create(&newUser).Error; err != nil {
return err
}
slog.Info("[Bootstrap] Created admin user in local DB", "email", adminEmail, "id", kratosID)
return nil
}
return err
}
// Update role if needed
updates := map[string]any{}
if user.Role != domain.RoleSuperAdmin {
updates["role"] = domain.RoleSuperAdmin
}
// Also ensure ID matches if it was somehow different (though changing PK is hard, at least log it)
if kratosID != "" && user.ID != kratosID {
slog.Warn("[Bootstrap] Admin user exists but ID mismatch with Kratos", "local_id", user.ID, "kratos_id", kratosID)
// We generally don't change UUID PKs, just warn.
}
if len(updates) > 0 {
if err := db.Model(&user).Updates(updates).Error; err != nil {
return err
}
slog.Info("[Bootstrap] Updated admin user role to super_admin", "email", adminEmail)
} else {
slog.Info("[Bootstrap] Admin user already has super_admin role", "email", adminEmail)
}
return nil
}

View File

@@ -0,0 +1,524 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"bytes"
"context"
"encoding/csv"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"gorm.io/gorm"
)
const seedTenantCSVPathEnv = "SEED_TENANT_CSV_PATH"
var seedTenantCSVPathCandidates = []string{
"adminfront/seed-tenant.csv",
"../adminfront/seed-tenant.csv",
"../../adminfront/seed-tenant.csv",
"../../../adminfront/seed-tenant.csv",
"/app/adminfront/seed-tenant.csv",
}
type InitialTenantConfig struct {
TenantID string
Name string
Slug string
Type string
ParentSlug string
Description string
Domains []string
Config domain.JSONMap
}
func SeedTenants(db *gorm.DB) error {
slog.Info("[Bootstrap] Checking initial tenant seed...")
configs, err := loadSeedTenantConfigs()
if err != nil {
return err
}
if len(configs) == 0 {
return errors.New("seed tenant csv has no tenant rows")
}
existingSlugs, existingIDs, err := loadExistingTenantIdentitySet(db)
if err != nil {
return err
}
missingConfigs := filterMissingSeedTenantConfigs(configs, existingSlugs, existingIDs)
if len(missingConfigs) == 0 {
slog.Info("[Bootstrap] Tenant seed skipped because all seed slugs already exist", "count", len(configs))
return nil
}
slog.Info(
"[Bootstrap] Tenant seed will create missing seed tenants",
"total", len(configs),
"missing", len(missingConfigs),
"existing", len(configs)-len(missingConfigs),
)
return seedTenantConfigs(db, missingConfigs)
}
func loadExistingTenantIdentitySet(db *gorm.DB) (map[string]bool, map[string]bool, error) {
var tenants []domain.Tenant
if err := db.Select("id", "slug").Find(&tenants).Error; err != nil {
return nil, nil, fmt.Errorf("load existing tenants before seed: %w", err)
}
slugs := make(map[string]bool, len(tenants))
ids := make(map[string]bool, len(tenants))
for _, tenant := range tenants {
slug := strings.TrimSpace(strings.ToLower(tenant.Slug))
if slug != "" {
slugs[slug] = true
}
id := strings.TrimSpace(strings.ToLower(tenant.ID))
if id != "" {
ids[id] = true
}
}
return slugs, ids, nil
}
func filterMissingSeedTenantConfigs(configs []InitialTenantConfig, existingSlugs map[string]bool, existingIDs map[string]bool) []InitialTenantConfig {
filtered := make([]InitialTenantConfig, 0, len(configs))
for _, config := range configs {
slug := strings.TrimSpace(strings.ToLower(config.Slug))
id := strings.TrimSpace(strings.ToLower(config.TenantID))
if slug == "" || existingSlugs[slug] || (id != "" && existingIDs[id]) {
continue
}
filtered = append(filtered, config)
existingSlugs[slug] = true
if id != "" {
existingIDs[id] = true
}
}
return filtered
}
func seedTenantConfigs(db *gorm.DB, configs []InitialTenantConfig) error {
slog.Info("[Bootstrap] Seeding initial tenants from CSV...", "count", len(configs))
repo := repository.NewTenantRepository(db)
userRepo := repository.NewUserRepository(db)
userGroupRepo := repository.NewUserGroupRepository(db)
outboxRepo := repository.NewKetoOutboxRepository(db)
svc := service.NewTenantService(repo, userRepo, userGroupRepo, outboxRepo)
ctx := context.Background()
for _, config := range orderSeedTenantConfigsByParentSlug(configs) {
tenantType := config.Type
if tenantType == "" {
tenantType = domain.TenantTypeCompany
}
var parentID *string
if config.ParentSlug != "" {
parent, err := repo.FindBySlug(ctx, config.ParentSlug)
if err != nil || parent == nil {
if err == nil {
err = errors.New("parent tenant not found")
}
slog.Error("Failed to resolve parent tenant for seed", "slug", config.Slug, "parentSlug", config.ParentSlug, "error", err)
return fmt.Errorf("resolve parent tenant %q for seed %q: %w", config.ParentSlug, config.Slug, err)
}
parentID = &parent.ID
}
slog.Info("[Bootstrap] Creating seed tenant", "name", config.Name, "slug", config.Slug)
var tenant *domain.Tenant
var err error
if config.TenantID != "" {
tenant, err = createSeedTenant(ctx, repo, outboxRepo, config, tenantType, parentID)
} else {
tenant, err = svc.RegisterTenant(ctx, config.Name, config.Slug, tenantType, config.Description, config.Domains, parentID, "")
}
if err != nil {
slog.Error("Failed to seed tenant", "slug", config.Slug, "error", err)
return err
}
tenant.Status = domain.TenantStatusActive
if len(config.Config) > 0 {
tenant.Config = config.Config
}
if err := db.Save(tenant).Error; err != nil {
return err
}
}
return nil
}
func loadSeedTenantConfigs() ([]InitialTenantConfig, error) {
path, err := findSeedTenantCSVPath()
if err != nil {
return nil, err
}
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open seed tenant csv %q: %w", path, err)
}
defer file.Close()
configs, err := parseSeedTenantCSV(file)
if err != nil {
return nil, fmt.Errorf("parse seed tenant csv %q: %w", path, err)
}
return configs, nil
}
func SeedTenantSlugSet() (map[string]bool, error) {
configs, err := loadSeedTenantConfigs()
if err != nil {
return nil, err
}
slugs := make(map[string]bool, len(configs))
for _, config := range configs {
slug := strings.TrimSpace(strings.ToLower(config.Slug))
if slug != "" {
slugs[slug] = true
}
}
return slugs, nil
}
func IsSeedTenantSlug(slug string) bool {
normalized := strings.TrimSpace(strings.ToLower(slug))
if normalized == "" {
return false
}
slugs, err := SeedTenantSlugSet()
if err != nil {
slog.Warn("[Bootstrap] Failed to load seed tenant slug set", "error", err)
return false
}
return slugs[normalized]
}
func findSeedTenantCSVPath() (string, error) {
if configured := strings.TrimSpace(os.Getenv(seedTenantCSVPathEnv)); configured != "" {
return configured, nil
}
for _, candidate := range seedTenantCSVPathCandidates {
cleaned := filepath.Clean(candidate)
if _, err := os.Stat(cleaned); err == nil {
return cleaned, nil
}
}
return "", fmt.Errorf("seed tenant csv not found; set %s or add adminfront/seed-tenant.csv", seedTenantCSVPathEnv)
}
func parseSeedTenantCSV(r io.Reader) ([]InitialTenantConfig, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, errors.New("failed to read csv")
}
data = bytes.TrimPrefix(data, []byte{0xEF, 0xBB, 0xBF})
reader := csv.NewReader(bytes.NewReader(data))
reader.FieldsPerRecord = -1
rows, err := reader.ReadAll()
if err != nil {
return nil, fmt.Errorf("invalid csv: %w", err)
}
if len(rows) == 0 {
return nil, errors.New("csv is empty")
}
header := seedTenantCSVHeaderIndex(rows[0])
for _, key := range []string{"name", "type", "slug"} {
if _, ok := header[key]; !ok {
return nil, fmt.Errorf("missing required column: %s", key)
}
}
configs := make([]InitialTenantConfig, 0, len(rows)-1)
for i, row := range rows[1:] {
if seedTenantCSVRowIsEmpty(row) {
continue
}
name := seedTenantCSVValue(row, header, "name")
if name == "" {
return nil, fmt.Errorf("row %d: name is required", i+2)
}
tenantType := normalizeSeedTenantType(seedTenantCSVValue(row, header, "type"))
if tenantType == "" {
return nil, fmt.Errorf("row %d: invalid tenant type", i+2)
}
slug := utils.GenerateSlug(seedTenantCSVValue(row, header, "slug"))
if slug == "" {
return nil, fmt.Errorf("row %d: slug is required", i+2)
}
config, err := seedTenantCSVRecordConfig(row, header)
if err != nil {
return nil, fmt.Errorf("row %d: %w", i+2, err)
}
configs = append(configs, InitialTenantConfig{
TenantID: seedTenantCSVValue(row, header, "tenant_id"),
Name: name,
Type: tenantType,
ParentSlug: seedTenantCSVValue(row, header, "parent_tenant_slug"),
Slug: slug,
Description: seedTenantCSVValue(row, header, "memo"),
Domains: splitSeedTenantCSVDomains(seedTenantCSVValue(row, header, "email_domain")),
Config: config,
})
}
return configs, nil
}
func seedTenantCSVHeaderIndex(header []string) map[string]int {
index := make(map[string]int, len(header))
aliases := map[string]string{
"id": "tenant_id",
"tenantid": "tenant_id",
"tenant_id": "tenant_id",
"name": "name",
"type": "type",
"parenttenantslug": "parent_tenant_slug",
"parent_tenant_slug": "parent_tenant_slug",
"parent_slug": "parent_tenant_slug",
"slug": "slug",
"memo": "memo",
"description": "memo",
"email-domain": "email_domain",
"emaildomain": "email_domain",
"email_domain": "email_domain",
"domain": "email_domain",
"domains": "email_domain",
"visibility": "visibility",
"public_setting": "visibility",
"publicsetting": "visibility",
"org_unit_type": "org_unit_type",
"orgunittype": "org_unit_type",
"organization_type": "org_unit_type",
"organizationtype": "org_unit_type",
"worksmobile": "worksmobile_sync",
"worksmobilesync": "worksmobile_sync",
"worksmobile_sync": "worksmobile_sync",
"works_sync": "worksmobile_sync",
"works": "worksmobile_sync",
}
for i, column := range header {
key := strings.ToLower(strings.TrimSpace(column))
key = strings.ReplaceAll(key, " ", "_")
if canonical, ok := aliases[key]; ok {
index[canonical] = i
}
}
return index
}
func seedTenantCSVValue(row []string, header map[string]int, key string) string {
idx, ok := header[key]
if !ok || idx >= len(row) {
return ""
}
return strings.TrimSpace(row[idx])
}
func seedTenantCSVRecordConfig(row []string, header map[string]int) (domain.JSONMap, error) {
config := domain.JSONMap{}
visibility := strings.TrimSpace(seedTenantCSVValue(row, header, "visibility"))
if visibility != "" {
normalizedVisibility, err := normalizeSeedTenantVisibility(visibility)
if err != nil {
return nil, err
}
config["visibility"] = normalizedVisibility
}
orgUnitType := strings.TrimSpace(seedTenantCSVValue(row, header, "org_unit_type"))
if orgUnitType != "" {
if !isAllowedSeedTenantOrgUnitType(orgUnitType) {
return nil, errors.New("orgUnitType must be one of 실, 팀, TF, TF팀, 센터, 디비전, 셀, 본부, 지역본부, 부, 임원직속")
}
config["orgUnitType"] = orgUnitType
}
if worksmobileSync := strings.TrimSpace(seedTenantCSVValue(row, header, "worksmobile_sync")); worksmobileSync != "" {
excluded, err := normalizeSeedTenantWorksmobileExcluded(worksmobileSync)
if err != nil {
return nil, err
}
config["worksmobileExcluded"] = excluded
}
if len(config) == 0 {
return nil, nil
}
return config, nil
}
func normalizeSeedTenantWorksmobileExcluded(value string) (bool, error) {
switch strings.ToLower(strings.TrimSpace(value)) {
case "", "yes", "y", "true", "1", "on", "sync", "linked", "연동":
return false, nil
case "no", "n", "false", "0", "off", "none", "excluded", "exclude", "not_sync", "not-synced", "미연동", "연동안함", "제외":
return true, nil
default:
return false, errors.New("worksmobile_sync must be yes or no")
}
}
func normalizeSeedTenantVisibility(value string) (string, error) {
visibility := strings.ToLower(strings.TrimSpace(value))
if visibility == "" || visibility == "public" {
return "public", nil
}
if visibility != "internal" && visibility != "private" {
return "", errors.New("visibility must be public, internal, or private")
}
return visibility, nil
}
func isAllowedSeedTenantOrgUnitType(value string) bool {
switch strings.TrimSpace(value) {
case "실", "팀", "TF", "TF팀", "센터", "디비전", "셀", "본부", "지역본부", "부", "임원직속":
return true
default:
return false
}
}
func seedTenantCSVRowIsEmpty(row []string) bool {
for _, value := range row {
if strings.TrimSpace(value) != "" {
return false
}
}
return true
}
func normalizeSeedTenantType(value string) string {
switch strings.ToUpper(strings.TrimSpace(value)) {
case domain.TenantTypePersonal:
return domain.TenantTypePersonal
case domain.TenantTypeCompany:
return domain.TenantTypeCompany
case domain.TenantTypeCompanyGroup:
return domain.TenantTypeCompanyGroup
case domain.TenantTypeOrganization:
return domain.TenantTypeOrganization
case domain.TenantTypeUserGroup:
return domain.TenantTypeUserGroup
default:
return ""
}
}
func splitSeedTenantCSVDomains(value string) []string {
value = strings.ReplaceAll(value, "\n", ";")
value = strings.ReplaceAll(value, ",", ";")
parts := strings.Split(value, ";")
domains := make([]string, 0, len(parts))
seen := make(map[string]bool, len(parts))
for _, part := range parts {
domainName := strings.ToLower(strings.TrimSpace(part))
if domainName == "" || seen[domainName] {
continue
}
seen[domainName] = true
domains = append(domains, domainName)
}
return domains
}
func orderSeedTenantConfigsByParentSlug(configs []InitialTenantConfig) []InitialTenantConfig {
bySlug := make(map[string]InitialTenantConfig, len(configs))
for _, config := range configs {
bySlug[strings.ToLower(config.Slug)] = config
}
ordered := make([]InitialTenantConfig, 0, len(configs))
visited := make(map[string]bool, len(configs))
var visit func(config InitialTenantConfig)
visit = func(config InitialTenantConfig) {
key := strings.ToLower(config.Slug)
if visited[key] {
return
}
if config.ParentSlug != "" {
if parent, ok := bySlug[strings.ToLower(config.ParentSlug)]; ok {
visit(parent)
}
}
visited[key] = true
ordered = append(ordered, config)
}
for _, config := range configs {
visit(config)
}
return ordered
}
func createSeedTenant(
ctx context.Context,
repo repository.TenantRepository,
outboxRepo repository.KetoOutboxRepository,
config InitialTenantConfig,
tenantType string,
parentID *string,
) (*domain.Tenant, error) {
tenant := &domain.Tenant{
ID: config.TenantID,
Type: tenantType,
Name: config.Name,
Slug: config.Slug,
Description: config.Description,
Status: domain.TenantStatusActive,
ParentID: parentID,
Config: config.Config,
}
if err := repo.Create(ctx, tenant); err != nil {
return nil, err
}
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "admins",
Subject: "System:global#super_admins",
Action: domain.KetoOutboxActionCreate,
}); err != nil {
return nil, err
}
if tenant.ParentID != nil {
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
Namespace: "Tenant",
Object: tenant.ID,
Relation: "parents",
Subject: "Tenant:" + *tenant.ParentID,
Action: domain.KetoOutboxActionCreate,
}); err != nil {
return nil, err
}
}
for _, domainName := range config.Domains {
if err := repo.AddDomain(ctx, tenant.ID, domainName, true); err != nil {
slog.Error("Failed to add domain to seeded tenant", "tenant", config.Slug, "domain", domainName, "error", err)
}
}
return repo.FindBySlug(ctx, config.Slug)
}

View File

@@ -0,0 +1,388 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/testsupport"
"context"
"log"
"os"
"path/filepath"
"testing"
"time"
"github.com/testcontainers/testcontainers-go"
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
gorm_postgres "gorm.io/driver/postgres"
"gorm.io/gorm"
)
func TestSeedTenantCSVDefinesWorksmobileDomainClassTenants(t *testing.T) {
configs, err := loadSeedTenantConfigs()
if err != nil {
t.Fatalf("loadSeedTenantConfigs returned error: %v", err)
}
expected := []struct {
name string
slug string
tenantType string
parentSlug string
domains []string
}{
{
name: "한맥가족",
slug: "hanmac-family",
tenantType: domain.TenantTypeCompanyGroup,
},
{
name: "삼안",
slug: "saman",
tenantType: domain.TenantTypeCompany,
parentSlug: "hanmac-family",
domains: []string{"samaneng.com"},
},
{
name: "한맥기술",
slug: "hanmac",
tenantType: domain.TenantTypeCompany,
parentSlug: "hanmac-family",
domains: []string{"hanmaceng.co.kr"},
},
{
name: "총괄기획&기술개발센터",
slug: "gpdtdc",
tenantType: domain.TenantTypeCompany,
parentSlug: "hanmac-family",
domains: []string{"baroncs.co.kr"},
},
{
name: "바론그룹",
slug: "baron-group",
tenantType: domain.TenantTypeCompanyGroup,
parentSlug: "hanmac-family",
domains: []string{"brsw.kr"},
},
{
name: "(주)장헌",
slug: "jangheon",
tenantType: domain.TenantTypeCompany,
parentSlug: "baron-group",
domains: []string{"jangheon.com"},
},
{
name: "장헌산업",
slug: "jangheon-sanup",
tenantType: domain.TenantTypeCompany,
parentSlug: "baron-group",
domains: []string{"jangheon.co.kr"},
},
{
name: "한라산업개발",
slug: "halla",
tenantType: domain.TenantTypeCompany,
parentSlug: "hanmac-family",
domains: []string{"hallasanup.com"},
},
{
name: "(주)피티씨",
slug: "ptc",
tenantType: domain.TenantTypeCompany,
parentSlug: "baron-group",
domains: []string{"pre-cast.co.kr"},
},
{
name: "Personal",
slug: "personal",
tenantType: domain.TenantTypePersonal,
},
}
if len(configs) < len(expected) {
t.Fatalf("expected at least %d seed tenants, got %d", len(expected), len(configs))
}
wantFamilyChildOrder := []string{
"gpdtdc",
"saman",
"hanmac",
"baron-group",
"halla",
}
policyFamilyChildSlugs := map[string]bool{}
for _, slug := range wantFamilyChildOrder {
policyFamilyChildSlugs[slug] = true
}
gotFamilyChildOrder := make([]string, 0, len(wantFamilyChildOrder))
for _, config := range configs {
if config.ParentSlug == "hanmac-family" && policyFamilyChildSlugs[config.Slug] {
gotFamilyChildOrder = append(gotFamilyChildOrder, config.Slug)
}
}
if len(gotFamilyChildOrder) != len(wantFamilyChildOrder) {
t.Fatalf("hanmac-family child order = %#v, want %#v", gotFamilyChildOrder, wantFamilyChildOrder)
}
for i, wantSlug := range wantFamilyChildOrder {
if gotFamilyChildOrder[i] != wantSlug {
t.Fatalf("hanmac-family child order[%d] = %q, want %q", i, gotFamilyChildOrder[i], wantSlug)
}
}
configBySlug := make(map[string]InitialTenantConfig, len(configs))
for _, config := range configs {
configBySlug[config.Slug] = config
}
for _, want := range expected {
got, ok := configBySlug[want.slug]
if !ok {
t.Fatalf("tenant slug %q not found in seed configs", want.slug)
}
if got.Name != want.name {
t.Fatalf("tenant[%s] name = %q, want %q", want.slug, got.Name, want.name)
}
if got.Slug != want.slug {
t.Fatalf("tenant[%s] slug = %q, want %q", want.slug, got.Slug, want.slug)
}
if got.Type != want.tenantType {
t.Fatalf("tenant[%s] type = %q, want %q", want.slug, got.Type, want.tenantType)
}
if got.ParentSlug != want.parentSlug {
t.Fatalf("tenant[%s] parent slug = %q, want %q", want.slug, got.ParentSlug, want.parentSlug)
}
if len(got.Domains) != len(want.domains) {
t.Fatalf("tenant[%s] domains = %#v, want %#v", want.slug, got.Domains, want.domains)
}
for j, wantDomain := range want.domains {
if got.Domains[j] != wantDomain {
t.Fatalf("tenant[%s] domain[%d] = %q, want %q", want.slug, j, got.Domains[j], wantDomain)
}
}
}
}
func TestNormalizeSeedTenantTypeAllowsOrganization(t *testing.T) {
if got := normalizeSeedTenantType("organization"); got != domain.TenantTypeOrganization {
t.Fatalf("normalizeSeedTenantType(organization) = %q, want %q", got, domain.TenantTypeOrganization)
}
}
func TestLoadSeedTenantConfigsUsesConfiguredCSVPath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "seed-tenant.csv")
csv := "name,type,parent_tenant_slug,slug,memo,email_domain,visibility,org_unit_type,worksmobile_sync\n" +
"Root,COMPANY_GROUP,,root,Root memo,,,,\n" +
"Child,USER_GROUP,root,child,Child memo,child.example.com,private,팀,no\n"
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
t.Fatalf("failed to write seed csv: %v", err)
}
t.Setenv(seedTenantCSVPathEnv, path)
configs, err := loadSeedTenantConfigs()
if err != nil {
t.Fatalf("loadSeedTenantConfigs returned error: %v", err)
}
if len(configs) != 2 {
t.Fatalf("expected 2 configs, got %d", len(configs))
}
if configs[1].ParentSlug != "root" {
t.Fatalf("child parent slug = %q, want root", configs[1].ParentSlug)
}
if len(configs[1].Domains) != 1 || configs[1].Domains[0] != "child.example.com" {
t.Fatalf("child domains = %#v, want child.example.com", configs[1].Domains)
}
if configs[1].Config["visibility"] != "private" {
t.Fatalf("child visibility = %#v, want private", configs[1].Config["visibility"])
}
if configs[1].Config["orgUnitType"] != "팀" {
t.Fatalf("child orgUnitType = %#v, want 팀", configs[1].Config["orgUnitType"])
}
if configs[1].Config["worksmobileExcluded"] != true {
t.Fatalf("child worksmobileExcluded = %#v, want true", configs[1].Config["worksmobileExcluded"])
}
}
func TestSeedTenantCSVDefinesMHDAsPrivateUserGroup(t *testing.T) {
configs, err := loadSeedTenantConfigs()
if err != nil {
t.Fatalf("loadSeedTenantConfigs returned error: %v", err)
}
configBySlug := make(map[string]InitialTenantConfig, len(configs))
for _, config := range configs {
configBySlug[config.Slug] = config
}
mhd, ok := configBySlug["mhd"]
if !ok {
t.Fatal("mhd seed tenant not found")
}
if mhd.Type != domain.TenantTypeUserGroup {
t.Fatalf("mhd type = %q, want %q", mhd.Type, domain.TenantTypeUserGroup)
}
if mhd.Config["visibility"] != "private" {
t.Fatalf("mhd visibility = %#v, want private", mhd.Config["visibility"])
}
if mhd.Config["worksmobileExcluded"] != true {
t.Fatalf("mhd worksmobileExcluded = %#v, want true", mhd.Config["worksmobileExcluded"])
}
}
func TestIsSeedTenantSlugUsesConfiguredCSVPath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "seed-tenant.csv")
csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" +
"Root,COMPANY_GROUP,,protected-root,Root memo,\n"
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
t.Fatalf("failed to write seed csv: %v", err)
}
t.Setenv(seedTenantCSVPathEnv, path)
if !IsSeedTenantSlug("protected-root") {
t.Fatal("protected-root must be detected as seed tenant")
}
if IsSeedTenantSlug("normal-tenant") {
t.Fatal("normal-tenant must not be detected as seed tenant")
}
}
func TestFilterMissingSeedTenantConfigsSkipsExistingSlugs(t *testing.T) {
configs := []InitialTenantConfig{
{TenantID: "existing-root-id", Name: "Existing Root", Slug: "existing-root"},
{Name: "Missing Child", Slug: "missing-child", ParentSlug: "existing-root"},
{TenantID: "existing-child-id", Name: "Existing Child", Slug: "existing-child", ParentSlug: "existing-root"},
{TenantID: "existing-other-id", Name: "Conflicting ID", Slug: "new-slug"},
}
existingSlugs := map[string]bool{
"existing-root": true,
"existing-child": true,
}
existingIDs := map[string]bool{
"existing-root-id": true,
"existing-child-id": true,
"existing-other-id": true,
}
filtered := filterMissingSeedTenantConfigs(configs, existingSlugs, existingIDs)
if len(filtered) != 1 {
t.Fatalf("filtered count = %d, want 1: %#v", len(filtered), filtered)
}
if filtered[0].Slug != "missing-child" {
t.Fatalf("filtered slug = %q, want missing-child", filtered[0].Slug)
}
}
func TestSeedTenantsCreatesMissingSeedRowsWithoutTouchingExistingSlugs(t *testing.T) {
if !testsupport.DockerAvailable() {
t.Skip("Docker provider is unavailable in this environment")
}
ctx := context.Background()
postgresContainer, err := postgres_module.Run(ctx,
"postgres:16-alpine",
postgres_module.WithDatabase("testdb"),
postgres_module.WithUsername("user"),
postgres_module.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(30*time.Second),
),
)
if err != nil {
t.Fatalf("failed to start postgres container: %v", err)
}
t.Cleanup(func() {
if err := postgresContainer.Terminate(ctx); err != nil {
log.Printf("failed to terminate postgres container: %v", err)
}
})
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("failed to get postgres connection string: %v", err)
}
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open postgres connection: %v", err)
}
if err := db.AutoMigrate(&domain.Tenant{}, &domain.TenantDomain{}, &domain.KetoOutbox{}); err != nil {
t.Fatalf("failed to migrate seed test tables: %v", err)
}
existingRoot := domain.Tenant{
ID: "00000000-0000-0000-0000-000000000001",
Name: "Existing Root Name",
Slug: "existing-root",
Type: domain.TenantTypeCompanyGroup,
Description: "manual tenant must not be overwritten",
Status: domain.TenantStatusActive,
}
nonSeedTenant := domain.Tenant{
ID: "00000000-0000-0000-0000-000000000002",
Name: "Manual Tenant",
Slug: "manual-tenant",
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
if err := db.Create(&existingRoot).Error; err != nil {
t.Fatalf("failed to create existing root tenant: %v", err)
}
if err := db.Create(&nonSeedTenant).Error; err != nil {
t.Fatalf("failed to create non-seed tenant: %v", err)
}
dir := t.TempDir()
path := filepath.Join(dir, "seed-tenant.csv")
csv := "id,name,type,parent_tenant_slug,slug,memo,email_domain\n" +
"10000000-0000-0000-0000-000000000001,Seed Root Name,COMPANY_GROUP,,existing-root,seed must be skipped,\n" +
"00000000-0000-0000-0000-000000000002,Conflicting ID,COMPANY,existing-root,conflicting-id,seed id must be skipped,\n" +
"10000000-0000-0000-0000-000000000002,Missing Child,COMPANY,existing-root,missing-child,created from seed,child.example.com\n"
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
t.Fatalf("failed to write seed csv: %v", err)
}
t.Setenv(seedTenantCSVPathEnv, path)
if err := SeedTenants(db); err != nil {
t.Fatalf("SeedTenants returned error: %v", err)
}
var root domain.Tenant
if err := db.First(&root, "slug = ?", "existing-root").Error; err != nil {
t.Fatalf("failed to load existing root after seed: %v", err)
}
if root.ID != existingRoot.ID {
t.Fatalf("existing root ID = %q, want %q", root.ID, existingRoot.ID)
}
if root.Name != existingRoot.Name {
t.Fatalf("existing root name = %q, want untouched %q", root.Name, existingRoot.Name)
}
var child domain.Tenant
if err := db.Preload("Domains").First(&child, "slug = ?", "missing-child").Error; err != nil {
t.Fatalf("missing seed child was not created: %v", err)
}
if child.ParentID == nil || *child.ParentID != existingRoot.ID {
t.Fatalf("child parent ID = %v, want %q", child.ParentID, existingRoot.ID)
}
if len(child.Domains) != 1 || child.Domains[0].Domain != "child.example.com" {
t.Fatalf("child domains = %#v, want child.example.com", child.Domains)
}
var rootCount int64
if err := db.Model(&domain.Tenant{}).Where("slug = ?", "existing-root").Count(&rootCount).Error; err != nil {
t.Fatalf("failed to count existing root rows: %v", err)
}
if rootCount != 1 {
t.Fatalf("existing-root row count = %d, want 1", rootCount)
}
var conflictingIDCount int64
if err := db.Model(&domain.Tenant{}).Where("slug = ?", "conflicting-id").Count(&conflictingIDCount).Error; err != nil {
t.Fatalf("failed to count conflicting-id rows: %v", err)
}
if conflictingIDCount != 0 {
t.Fatalf("conflicting-id row count = %d, want 0", conflictingIDCount)
}
}

View File

@@ -0,0 +1,34 @@
package bootstrap
import (
"fmt"
"log/slog"
"gorm.io/gorm"
)
const sanitizeLegacyUserMetadataSQL = `
update users
set metadata = metadata - 'hanmacFamily' - 'userType',
updated_at = now()
where metadata ? 'hanmacFamily'
or metadata ? 'userType'
`
// SanitizeLegacyUserMetadata removes legacy UI classification flags from Baron user metadata.
func SanitizeLegacyUserMetadata(db *gorm.DB) error {
if db == nil {
return fmt.Errorf("database is not configured")
}
if !db.Migrator().HasTable("users") {
slog.Info("[Bootstrap] Legacy user metadata sanitize skipped because users table does not exist")
return nil
}
result := db.Exec(sanitizeLegacyUserMetadataSQL)
if result.Error != nil {
return fmt.Errorf("sanitize legacy user metadata: %w", result.Error)
}
slog.Info("[Bootstrap] Legacy user metadata sanitized", "rowsAffected", result.RowsAffected)
return nil
}

View File

@@ -0,0 +1,203 @@
package bootstrap
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/testsupport"
"context"
"log"
"os"
"path/filepath"
"testing"
"time"
"github.com/testcontainers/testcontainers-go"
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
gorm_postgres "gorm.io/driver/postgres"
"gorm.io/gorm"
)
func TestSanitizeLegacyUserMetadataRemovesClassificationFlags(t *testing.T) {
db := openBootstrapPostgresTestDB(t)
if err := db.AutoMigrate(&domain.User{}); err != nil {
t.Fatalf("failed to migrate users table: %v", err)
}
user := domain.User{
ID: "10000000-0000-0000-0000-000000000001",
Email: "legacy@example.com",
Name: "Legacy User",
Role: domain.RoleUser,
Status: domain.UserStatusActive,
Metadata: domain.JSONMap{
"hanmacFamily": true,
"userType": "hanmac",
"employeeId": "E001",
"nested": map[string]any{
"userType": "must stay nested",
},
},
}
if err := db.Create(&user).Error; err != nil {
t.Fatalf("failed to create user: %v", err)
}
if err := SanitizeLegacyUserMetadata(db); err != nil {
t.Fatalf("SanitizeLegacyUserMetadata returned error: %v", err)
}
if err := SanitizeLegacyUserMetadata(db); err != nil {
t.Fatalf("SanitizeLegacyUserMetadata must be idempotent: %v", err)
}
var got domain.User
if err := db.First(&got, "id = ?", user.ID).Error; err != nil {
t.Fatalf("failed to load sanitized user: %v", err)
}
if _, ok := got.Metadata["hanmacFamily"]; ok {
t.Fatalf("hanmacFamily must be removed from metadata: %#v", got.Metadata)
}
if _, ok := got.Metadata["userType"]; ok {
t.Fatalf("userType must be removed from metadata: %#v", got.Metadata)
}
if got.Metadata["employeeId"] != "E001" {
t.Fatalf("employeeId = %#v, want E001", got.Metadata["employeeId"])
}
nested, ok := got.Metadata["nested"].(map[string]any)
if !ok || nested["userType"] != "must stay nested" {
t.Fatalf("nested metadata must be preserved: %#v", got.Metadata["nested"])
}
}
func TestCanonicalizeLegacyUserStatuses(t *testing.T) {
db := openBootstrapPostgresTestDB(t)
if err := db.AutoMigrate(&domain.User{}); err != nil {
t.Fatalf("failed to migrate users table: %v", err)
}
users := []domain.User{
{ID: "11000000-0000-0000-0000-000000000001", Email: "inactive@example.com", Name: "Inactive", Role: domain.RoleUser, Status: "inactive"},
{ID: "11000000-0000-0000-0000-000000000002", Email: "leave@example.com", Name: "Leave", Role: domain.RoleUser, Status: "leave_of_absence"},
{ID: "11000000-0000-0000-0000-000000000003", Email: "baron-only@example.com", Name: "Baron Only", Role: domain.RoleUser, Status: "baron_only"},
{ID: "11000000-0000-0000-0000-000000000004", Email: "active@example.com", Name: "Active", Role: domain.RoleUser, Status: domain.UserStatusActive},
}
if err := db.Create(&users).Error; err != nil {
t.Fatalf("failed to create users: %v", err)
}
if err := CanonicalizeLegacyUserStatuses(db); err != nil {
t.Fatalf("CanonicalizeLegacyUserStatuses returned error: %v", err)
}
if err := CanonicalizeLegacyUserStatuses(db); err != nil {
t.Fatalf("CanonicalizeLegacyUserStatuses must be idempotent: %v", err)
}
got := map[string]string{}
var loaded []domain.User
if err := db.Find(&loaded).Error; err != nil {
t.Fatalf("failed to load users: %v", err)
}
for _, user := range loaded {
got[user.Email] = user.Status
}
if got["inactive@example.com"] != domain.UserStatusPreboarding {
t.Fatalf("inactive status = %q, want %q", got["inactive@example.com"], domain.UserStatusPreboarding)
}
if got["leave@example.com"] != domain.UserStatusTemporaryLeave {
t.Fatalf("leave status = %q, want %q", got["leave@example.com"], domain.UserStatusTemporaryLeave)
}
if got["baron-only@example.com"] != domain.UserStatusBaronGuest {
t.Fatalf("baron_only status = %q, want %q", got["baron-only@example.com"], domain.UserStatusBaronGuest)
}
if got["active@example.com"] != domain.UserStatusActive {
t.Fatalf("active status = %q, want %q", got["active@example.com"], domain.UserStatusActive)
}
}
func TestRunSanitizesLegacyUserMetadata(t *testing.T) {
db := openBootstrapPostgresTestDB(t)
if err := db.AutoMigrate(&domain.User{}); err != nil {
t.Fatalf("failed to migrate users table: %v", err)
}
user := domain.User{
ID: "20000000-0000-0000-0000-000000000001",
Email: "run-legacy@example.com",
Name: "Run Legacy User",
Role: domain.RoleUser,
Status: domain.UserStatusActive,
Metadata: domain.JSONMap{
"hanmacFamily": true,
"userType": "external",
"employeeId": "E002",
},
}
if err := db.Create(&user).Error; err != nil {
t.Fatalf("failed to create user: %v", err)
}
dir := t.TempDir()
path := filepath.Join(dir, "seed-tenant.csv")
csv := "id,name,type,parent_tenant_slug,slug,memo,email_domain\n" +
"30000000-0000-0000-0000-000000000001,Seed Root,COMPANY_GROUP,,seed-root,seed root,\n"
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
t.Fatalf("failed to write seed csv: %v", err)
}
t.Setenv(seedTenantCSVPathEnv, path)
if err := Run(db); err != nil {
t.Fatalf("Run returned error: %v", err)
}
var got domain.User
if err := db.First(&got, "id = ?", user.ID).Error; err != nil {
t.Fatalf("failed to load sanitized user: %v", err)
}
if _, ok := got.Metadata["hanmacFamily"]; ok {
t.Fatalf("Run must remove hanmacFamily from metadata: %#v", got.Metadata)
}
if _, ok := got.Metadata["userType"]; ok {
t.Fatalf("Run must remove userType from metadata: %#v", got.Metadata)
}
if got.Metadata["employeeId"] != "E002" {
t.Fatalf("employeeId = %#v, want E002", got.Metadata["employeeId"])
}
}
func openBootstrapPostgresTestDB(t *testing.T) *gorm.DB {
t.Helper()
if !testsupport.DockerAvailable() {
t.Skip("Docker provider is unavailable in this environment")
}
ctx := context.Background()
postgresContainer, err := postgres_module.Run(ctx,
"postgres:16-alpine",
postgres_module.WithDatabase("testdb"),
postgres_module.WithUsername("user"),
postgres_module.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(30*time.Second),
),
)
if err != nil {
t.Fatalf("failed to start postgres container: %v", err)
}
t.Cleanup(func() {
if err := postgresContainer.Terminate(ctx); err != nil {
log.Printf("failed to terminate postgres container: %v", err)
}
})
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("failed to get postgres connection string: %v", err)
}
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open postgres connection: %v", err)
}
return db
}

View File

@@ -0,0 +1,30 @@
package domain
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// ApiKey represents an internal API key for Machine-to-Machine communication.
type ApiKey struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
Name string `gorm:"not null" json:"name"`
ClientID string `gorm:"uniqueIndex;not null" json:"clientId"`
ClientSecretHash string `gorm:"not null" json:"-"`
Scopes string `json:"scopes"` // Space or comma separated
Status string `gorm:"default:'active'" json:"status"`
LastUsedAt *time.Time `json:"lastUsedAt"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// BeforeCreate hook to generate UUID if not present.
func (k *ApiKey) BeforeCreate(tx *gorm.DB) (err error) {
if k.ID == "" {
k.ID = uuid.NewString()
}
return
}

View File

@@ -0,0 +1,123 @@
package domain
type EnchantedLinkInitRequest struct {
LoginID string `json:"loginId"`
URI string `json:"uri,omitempty"` // Redirect URI (optional for polling flow)
Method string `json:"method,omitempty"` // "email" or "sms"
CodeOnly bool `json:"codeOnly,omitempty"`
DryRun bool `json:"dryRun,omitempty"`
DrySend bool `json:"drySend,omitempty"`
}
type EnchantedLinkInitResponse struct {
LinkID string `json:"linkId"`
PendingRef string `json:"pendingRef"`
MaskedEmail string `json:"maskedEmail"`
}
type EnchantedLinkPollRequest struct {
PendingRef string `json:"pendingRef"`
}
type EnchantedLinkPollResponse struct {
SessionToken string `json:"sessionToken"` // JWT
RefreshToken string `json:"refreshToken"`
UserID string `json:"userId,omitempty"`
}
type MagicLinkVerifyRequest struct {
Token string `json:"token"`
VerifyOnly bool `json:"verifyOnly,omitempty"`
}
type QRInitResponse struct {
QRCode string `json:"qrCode"` // Base64 or URL
PendingRef string `json:"pendingRef"`
ExpiresIn int `json:"expiresIn"`
}
// Signup Flow Models
type CheckEmailRequest struct {
Email string `json:"email"`
}
type SendSignupCodeRequest struct {
Target string `json:"target"` // Email or Phone
Type string `json:"type"` // "email" or "phone"
}
type VerifySignupCodeRequest struct {
Target string `json:"target"` // Email or Phone
Type string `json:"type"` // "email" or "phone"
Code string `json:"code"`
}
type SignupRequest struct {
Email string `json:"email"`
LoginID string `json:"loginId,omitempty"`
Password string `json:"password"`
Name string `json:"name"`
Phone string `json:"phone"`
AffiliationType string `json:"affiliationType"` // "AFFILIATE" or "GENERAL"
TenantSlug string `json:"tenantSlug,omitempty"`
CompanyCode string `json:"companyCode,omitempty"`
Department string `json:"department"`
Metadata JSONMap `json:"metadata,omitempty"`
TermsAccepted bool `json:"termsAccepted"`
}
// User Profile Models
type UserProfileResponse struct {
ID string `json:"id"`
Email string `json:"email"`
LoginID string `json:"loginId,omitempty"`
Name string `json:"name"`
Phone string `json:"phone"`
Role string `json:"role"` // 추가
SessionAuthenticatedAt string `json:"sessionAuthenticatedAt,omitempty"`
Department string `json:"department"`
AffiliationType string `json:"affiliationType"`
CompanyCode string `json:"companyCode,omitempty"`
TenantID *string `json:"tenantId,omitempty"` // 추가
SessionTenantID *string `json:"sessionTenantId,omitempty"` // [New] 로그인에 사용된 식별자 기반 테넌트
RelyingPartyID *string `json:"relyingPartyId,omitempty"` // 추가
Metadata map[string]any `json:"metadata,omitempty"`
Tenant *Tenant `json:"tenant,omitempty"`
ManageableTenants []Tenant `json:"manageableTenants,omitempty"` // 추가: 관리 가능한 테넌트 목록
JoinedTenants []Tenant `json:"joinedTenants,omitempty"` // [New] 다중 소속 테넌트 목록
}
type UpdateUserRequest struct {
Name string `json:"name"`
Phone string `json:"phone"`
Department string `json:"department"`
VerificationCode string `json:"verificationCode,omitempty"` // For phone change
Metadata map[string]any `json:"metadata,omitempty"`
}
// PasswordResetInitiateRequest is the request body for initiating a password reset.
type PasswordResetInitiateRequest struct {
LoginID string `json:"loginId"`
DryRun bool `json:"dryRun,omitempty"`
DrySend bool `json:"drySend,omitempty"`
}
// PasswordResetCompleteRequest is the request body for completing a password reset.
type PasswordResetCompleteRequest struct {
LoginID string `json:"loginId"`
NewPassword string `json:"newPassword"`
}
// PasswordChangeRequest는 로그인 상태에서 비밀번호 변경 요청을 표현합니다.
type PasswordChangeRequest struct {
CurrentPassword string `json:"currentPassword"`
NewPassword string `json:"newPassword"`
}
type CheckLoginIDRequest struct {
LoginID string `json:"loginId"`
TenantSlug string `json:"tenantSlug,omitempty"`
CompanyCode string `json:"companyCode,omitempty"`
}

View File

@@ -0,0 +1,33 @@
package domain
import (
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"gorm.io/gorm"
)
type ClientConsent struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
ClientID string `gorm:"index;uniqueIndex:idx_client_subject;not null" json:"clientId"`
Subject string `gorm:"index;uniqueIndex:idx_client_subject;not null" json:"subject"` // User UUID
GrantedScopes pq.StringArray `gorm:"type:text[];not null" json:"grantedScopes"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// ClientConsentWithTenantInfo is a struct to hold joined data for API responses
type ClientConsentWithTenantInfo struct {
ClientConsent
TenantID string `gorm:"column:tenant_id" json:"tenantId"`
TenantName string `gorm:"column:tenant_name" json:"tenantName"`
}
func (c *ClientConsent) BeforeCreate(tx *gorm.DB) (err error) {
if c.ID == "" {
c.ID = uuid.New().String()
}
return
}

View File

@@ -0,0 +1,21 @@
package domain
import (
"context"
"time"
)
// ClientSecret represents the stored client secret for OIDC clients.
// Since Hydra only returns the secret once during creation, we store it here.
type ClientSecret struct {
ClientID string `gorm:"primaryKey;column:client_id"`
ClientSecret string `gorm:"column:client_secret;not null"`
CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time `gorm:"column:updated_at"`
}
type ClientSecretRepository interface {
Upsert(ctx context.Context, clientID, secret string) error
GetByID(ctx context.Context, clientID string) (string, error)
Delete(ctx context.Context, clientID string) error
}

View File

@@ -0,0 +1,60 @@
package domain
import "time"
type DataIntegrityStatus string
const (
DataIntegrityStatusPass DataIntegrityStatus = "pass"
DataIntegrityStatusWarning DataIntegrityStatus = "warning"
DataIntegrityStatusFail DataIntegrityStatus = "fail"
)
type DataIntegrityReport struct {
Status DataIntegrityStatus `json:"status"`
CheckedAt time.Time `json:"checkedAt"`
Summary DataIntegritySummary `json:"summary"`
Sections []DataIntegritySection `json:"sections"`
}
type DataIntegritySummary struct {
TotalChecks int `json:"totalChecks"`
Passed int `json:"passed"`
Warnings int `json:"warnings"`
Failures int64 `json:"failures"`
}
type DataIntegritySection struct {
Key string `json:"key"`
Label string `json:"label"`
Status DataIntegrityStatus `json:"status"`
Checks []DataIntegrityCheck `json:"checks"`
}
type DataIntegrityCheck struct {
Key string `json:"key"`
Label string `json:"label"`
Description string `json:"description"`
Status DataIntegrityStatus `json:"status"`
Severity string `json:"severity"`
Count int64 `json:"count"`
}
type OrphanUserLoginID struct {
ID string `json:"id"`
UserID string `json:"userId"`
UserEmail string `json:"userEmail,omitempty"`
UserDeletedAt *time.Time `json:"userDeletedAt,omitempty"`
TenantID string `json:"tenantId"`
TenantSlug string `json:"tenantSlug,omitempty"`
TenantDeletedAt *time.Time `json:"tenantDeletedAt,omitempty"`
FieldKey string `json:"fieldKey"`
LoginID string `json:"loginId"`
Reasons []string `json:"reasons"`
}
type DeleteOrphanUserLoginIDsResult struct {
DeletedCount int64 `json:"deletedCount"`
Deleted []OrphanUserLoginID `json:"deleted"`
SkippedIDs []string `json:"skippedIds"`
}

View File

@@ -0,0 +1,29 @@
package domain
import (
"time"
)
const (
DeveloperRequestStatusPending = "pending"
DeveloperRequestStatusApproved = "approved"
DeveloperRequestStatusRejected = "rejected"
DeveloperRequestStatusCancelled = "cancelled"
)
// DeveloperRequest represents a user's application to become a developer.
type DeveloperRequest struct {
ID uint `gorm:"primaryKey" json:"id"`
UserID string `gorm:"index;not null" json:"userId"` // Kratos User ID
TenantID string `gorm:"index;not null" json:"tenantId"`
Name string `gorm:"not null" json:"name"`
Organization string `json:"organization"`
Email string `json:"email"`
Phone string `json:"phone"`
Role string `json:"role"`
Reason string `json:"reason"`
Status string `gorm:"default:'pending';not null" json:"status"` // pending, approved, rejected, cancelled
AdminNotes string `json:"adminNotes"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}

View File

@@ -0,0 +1,6 @@
package domain
// EmailService defines the interface for sending emails.
type EmailService interface {
SendEmail(to, subject, body string) error
}

View File

@@ -0,0 +1,50 @@
package domain
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// ProviderType defines the type of the identity provider.
type ProviderType string
const (
ProviderTypeOIDC ProviderType = "oidc"
ProviderTypeSAML ProviderType = "saml"
)
// IdentityProviderConfig stores the configuration for an external Identity Provider.
type IdentityProviderConfig struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
ClientID string `gorm:"type:uuid;not null;index" json:"client_id"` // Replaces TenantID
ProviderType ProviderType `gorm:"type:varchar(10);not null" json:"provider_type"`
DisplayName string `gorm:"not null" json:"display_name"`
Status string `gorm:"default:'active'" json:"status"`
// OIDC Specific Fields
IssuerURL *string `gorm:"null" json:"issuer_url,omitempty"`
OIDCClientID *string `gorm:"null" json:"oidc_client_id,omitempty"` // Renamed from ClientID
OIDCClientSecret *string `gorm:"null" json:"oidc_client_secret,omitempty"` // Renamed from ClientSecret
// Scopes are space-separated
Scopes *string `gorm:"null" json:"scopes,omitempty"`
// SAML Specific Fields
MetadataURL *string `gorm:"null" json:"metadata_url,omitempty"`
MetadataXML *string `gorm:"type:text;null" json:"metadata_xml,omitempty"`
EntityID *string `gorm:"null" json:"entity_id,omitempty"`
AcsURL *string `gorm:"null" json:"acs_url,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// BeforeCreate hook to generate UUID if not present.
func (idc *IdentityProviderConfig) BeforeCreate(tx *gorm.DB) (err error) {
if idc.ID == "" {
idc.ID = uuid.NewString()
}
return
}

View File

@@ -0,0 +1,196 @@
package domain
import (
"errors"
"strings"
"unicode"
)
var hanmacSurnameRomanization = map[rune]string{
'한': "han",
'김': "kim",
'이': "lee",
'박': "park",
'최': "choi",
'정': "jung",
'조': "cho",
'강': "kang",
'윤': "yoon",
'장': "jang",
'임': "lim",
'림': "lim",
'신': "shin",
'오': "oh",
'서': "seo",
'권': "kwon",
'황': "hwang",
'안': "ahn",
'송': "song",
'전': "jeon",
'홍': "hong",
'유': "yoo",
'고': "ko",
'문': "moon",
'양': "yang",
'손': "son",
'배': "bae",
'백': "baek",
'허': "heo",
'남': "nam",
'심': "sim",
'노': "noh",
'하': "ha",
'곽': "kwak",
'성': "sung",
'차': "cha",
'주': "joo",
'우': "woo",
'구': "koo",
'민': "min",
'류': "ryu",
'나': "na",
'진': "jin",
'지': "ji",
'엄': "um",
'채': "chae",
'원': "won",
'천': "cheon",
'방': "bang",
'공': "gong",
'현': "hyun",
'함': "ham",
'여': "yeo",
'추': "choo",
'도': "do",
'소': "so",
'석': "seok",
'선': "sun",
'설': "seol",
'마': "ma",
'길': "gil",
'연': "yeon",
'위': "wi",
'표': "pyo",
'명': "myung",
'기': "ki",
'반': "ban",
'라': "ra",
'왕': "wang",
'금': "geum",
'옥': "ok",
'육': "yook",
'인': "in",
'맹': "maeng",
'제': "je",
'모': "mo",
'탁': "tak",
'국': "guk",
'어': "eo",
'은': "eun",
'편': "pyeon",
'용': "yong",
}
var hanmacInitialRomanization = []string{
"g", "g", "n", "d", "d", "r", "m", "b", "b", "s",
"s", "y", "j", "j", "c", "k", "t", "p", "h",
}
func SplitEmailDomain(email string) (string, string, error) {
normalized := strings.ToLower(strings.TrimSpace(email))
before, after, ok := strings.Cut(normalized, "@")
if !ok {
return "", "", errors.New("email must contain @")
}
if strings.Count(normalized, "@") != 1 {
return "", "", errors.New("email must contain one @")
}
localPart := strings.TrimSpace(before)
domainPart := strings.TrimSpace(after)
if domainPart == "" || !strings.Contains(domainPart, ".") {
return "", "", errors.New("email domain is invalid")
}
return localPart, domainPart, nil
}
func ExtractNormalizedEmailLocalPart(email string) (string, error) {
localPart, _, err := SplitEmailDomain(email)
if err != nil {
return "", err
}
if localPart == "" {
return "", errors.New("email local-part is empty")
}
return localPart, nil
}
func BuildKoreanNameEmailBase(name string) (string, bool, error) {
runes := compactNameRunes(name)
if len(runes) < 2 {
return "", true, nil
}
surname, ok := hanmacSurnameRomanization[runes[0]]
if !ok {
return "", true, nil
}
var builder strings.Builder
for _, r := range runes[1:] {
initial, ok := romanizedHangulInitial(r)
if !ok {
return "", true, nil
}
builder.WriteString(initial)
}
builder.WriteString(surname)
return builder.String(), false, nil
}
func MatchesSuggestedNameRule(localPart string, base string) bool {
localPart = strings.ToLower(strings.TrimSpace(localPart))
base = strings.ToLower(strings.TrimSpace(base))
if localPart == "" || base == "" {
return false
}
if localPart == base {
return true
}
if !strings.HasPrefix(localPart, base) {
return false
}
suffix := localPart[len(base):]
if suffix == "" {
return false
}
for _, r := range suffix {
if r < '0' || r > '9' {
return false
}
}
return true
}
func compactNameRunes(name string) []rune {
var runes []rune
for _, r := range strings.TrimSpace(name) {
if unicode.IsSpace(r) {
continue
}
runes = append(runes, r)
}
return runes
}
func romanizedHangulInitial(r rune) (string, bool) {
const hangulBase = 0xAC00
const hangulEnd = 0xD7A3
if r < hangulBase || r > hangulEnd {
return "", false
}
index := int(r-hangulBase) / 588
if index < 0 || index >= len(hanmacInitialRomanization) {
return "", false
}
return hanmacInitialRomanization[index], true
}

View File

@@ -0,0 +1,76 @@
package domain
import "testing"
func TestSplitEmailDomainAllowsDomainOnlyImportInput(t *testing.T) {
tests := []struct {
name string
email string
wantLocal string
wantDomain string
}{
{name: "full address", email: " Han@SamanEng.com ", wantLocal: "han", wantDomain: "samaneng.com"},
{name: "domain only", email: "@hanmaceng.co.kr", wantLocal: "", wantDomain: "hanmaceng.co.kr"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
local, domain, err := SplitEmailDomain(tt.email)
if err != nil {
t.Fatalf("SplitEmailDomain() error = %v", err)
}
if local != tt.wantLocal || domain != tt.wantDomain {
t.Fatalf("SplitEmailDomain() = (%q, %q), want (%q, %q)", local, domain, tt.wantLocal, tt.wantDomain)
}
})
}
}
func TestBuildKoreanNameEmailBase(t *testing.T) {
base, needsReview, err := BuildKoreanNameEmailBase("한치영")
if err != nil {
t.Fatalf("BuildKoreanNameEmailBase() error = %v", err)
}
if needsReview {
t.Fatalf("BuildKoreanNameEmailBase() needsReview = true")
}
if base != "cyhan" {
t.Fatalf("BuildKoreanNameEmailBase() = %q, want %q", base, "cyhan")
}
}
func TestBuildKoreanNameEmailBaseNeedsReviewForUnknownName(t *testing.T) {
base, needsReview, err := BuildKoreanNameEmailBase("A치영")
if err != nil {
t.Fatalf("BuildKoreanNameEmailBase() error = %v", err)
}
if base != "" {
t.Fatalf("BuildKoreanNameEmailBase() base = %q, want empty", base)
}
if !needsReview {
t.Fatalf("BuildKoreanNameEmailBase() needsReview = false")
}
}
func TestMatchesSuggestedNameRule(t *testing.T) {
tests := []struct {
localPart string
base string
want bool
}{
{localPart: "cyhan", base: "cyhan", want: true},
{localPart: "cyhan1", base: "cyhan", want: true},
{localPart: "cyhan20", base: "cyhan", want: true},
{localPart: "hcy", base: "cyhan", want: false},
{localPart: "han.cy", base: "cyhan", want: false},
{localPart: "cyhan-a", base: "cyhan", want: false},
}
for _, tt := range tests {
t.Run(tt.localPart, func(t *testing.T) {
if got := MatchesSuggestedNameRule(tt.localPart, tt.base); got != tt.want {
t.Fatalf("MatchesSuggestedNameRule(%q, %q) = %v, want %v", tt.localPart, tt.base, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,30 @@
package domain
import "time"
type HeadlessJWKSParsedKey struct {
Kid string `json:"kid,omitempty"`
Kty string `json:"kty,omitempty"`
Use string `json:"use,omitempty"`
Alg string `json:"alg,omitempty"`
N string `json:"n,omitempty"`
}
// HeadlessJWKSCacheState는 headless login용 JWKS 캐시 상태와 최근 동기화 결과를 나타냅니다.
type HeadlessJWKSCacheState struct {
ClientID string `json:"clientId"`
JWKSURI string `json:"jwksUri"`
CachedAt *time.Time `json:"cachedAt,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
NextRetryAt *time.Time `json:"nextRetryAt,omitempty"`
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
LastError string `json:"lastError,omitempty"`
ConsecutiveFailures int `json:"consecutiveFailures,omitempty"`
CachedKids []string `json:"cachedKids,omitempty"`
ParsedKeys []HeadlessJWKSParsedKey `json:"parsedKeys,omitempty"`
ETag string `json:"etag,omitempty"`
LastModified string `json:"lastModified,omitempty"`
RawJWKS string `json:"-"`
}

View File

@@ -0,0 +1,145 @@
package domain
import (
"strings"
"time"
)
const (
MetadataHeadlessLoginEnabled = "headless_login_enabled"
MetadataHeadlessTokenEndpointAuthMethod = "headless_token_endpoint_auth_method"
MetadataHeadlessJWKSURI = "headless_jwks_uri"
MetadataHeadlessJWKS = "headless_jwks"
MetadataRequestObjectSigningAlg = "request_object_signing_alg"
MetadataIDTokenClaims = "id_token_claims"
MetadataBackChannelLogoutURI = "backchannel_logout_uri"
MetadataBackChannelLogoutSessionRequired = "backchannel_logout_session_required"
MetadataAutoLoginSupported = "auto_login_supported"
MetadataAutoLoginURL = "auto_login_url"
)
type HydraClient struct {
ClientID string `json:"client_id"`
ClientName string `json:"client_name,omitempty"`
ClientSecret string `json:"client_secret,omitempty"` // Added
ClientURI string `json:"client_uri,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
Scope string `json:"scope,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
SkipConsent *bool `json:"skip_consent,omitempty"`
JWKSUri string `json:"jwks_uri,omitempty"`
JWKS any `json:"jwks,omitempty"`
BackChannelLogoutURI string `json:"backchannel_logout_uri,omitempty"`
BackChannelLogoutSessionRequired *bool `json:"backchannel_logout_session_required,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
func (c *HydraClient) SupportsHeadlessLogin() bool {
// Headless login now supports jwksUri only.
hasPublicKey := c.HeadlessJWKSURI() != ""
isPrivateKeyJwt := c.HeadlessTokenEndpointAuthMethod() == "private_key_jwt"
return hasPublicKey && isPrivateKeyJwt
}
func (c *HydraClient) HeadlessTokenEndpointAuthMethod() string {
if c.Metadata != nil {
if raw, ok := c.Metadata[MetadataHeadlessTokenEndpointAuthMethod].(string); ok {
if value := strings.TrimSpace(raw); value != "" {
return value
}
}
}
return strings.TrimSpace(c.TokenEndpointAuthMethod)
}
func (c *HydraClient) HeadlessJWKSURI() string {
if c.Metadata != nil {
if raw, ok := c.Metadata[MetadataHeadlessJWKSURI].(string); ok {
if value := strings.TrimSpace(raw); value != "" {
return value
}
}
}
return strings.TrimSpace(c.JWKSUri)
}
func (c *HydraClient) HeadlessJWKS() any {
if c.Metadata != nil {
if value, ok := c.Metadata[MetadataHeadlessJWKS]; ok && value != nil {
return value
}
}
return c.JWKS
}
func (c *HydraClient) IsHeadlessLoginEnabled() bool {
if !c.SupportsHeadlessLogin() {
return false
}
if c.Metadata == nil {
return false
}
val, ok := c.Metadata[MetadataHeadlessLoginEnabled]
if !ok {
return false
}
if b, ok := val.(bool); ok {
return b
}
return false
}
func (c *HydraClient) BackchannelLogoutURI() string {
if c.Metadata != nil {
if raw, ok := c.Metadata[MetadataBackChannelLogoutURI].(string); ok {
if value := strings.TrimSpace(raw); value != "" {
return value
}
}
}
return strings.TrimSpace(c.BackChannelLogoutURI)
}
func (c *HydraClient) BackchannelLogoutSessionRequiredValue() bool {
if c.Metadata != nil {
if raw, ok := c.Metadata[MetadataBackChannelLogoutSessionRequired].(bool); ok {
return raw
}
}
if c.BackChannelLogoutSessionRequired != nil {
return *c.BackChannelLogoutSessionRequired
}
return false
}
type HydraConsentRequest struct {
Challenge string `json:"challenge"`
RequestedScope []string `json:"requested_scope"`
RequestedAudience []string `json:"requested_access_token_audience"`
Skip bool `json:"skip"`
Subject string `json:"subject"`
Client HydraClient `json:"client"`
}
type HydraLoginRequest struct {
Challenge string `json:"challenge"`
Subject string `json:"subject"`
Skip bool `json:"skip"`
Client HydraClient `json:"client"`
}
type HydraConsentSession struct {
ConsentRequestID string `json:"consent_request_id,omitempty"`
Subject string `json:"subject,omitempty"`
GrantedScope []string `json:"grant_scope,omitempty"`
GrantedAudience []string `json:"grant_access_token_audience,omitempty"`
Remember bool `json:"remember"`
RememberFor int `json:"remember_for,omitempty"`
AuthenticatedAt *time.Time `json:"authenticated_at,omitempty"`
RequestedAt *time.Time `json:"requested_at,omitempty"`
HandledAt *time.Time `json:"handled_at,omitempty"`
Client HydraClient `json:"client"`
ConsentRequest *HydraConsentRequest `json:"consent_request,omitempty"`
}

View File

@@ -0,0 +1,182 @@
package domain
import (
"reflect"
"testing"
)
func TestHydraClient_HeadlessLoginFlags(t *testing.T) {
t.Run("metadata-backed headless login client is supported", func(t *testing.T) {
client := HydraClient{
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
}
if !client.SupportsHeadlessLogin() {
t.Fatalf("expected metadata-backed headless login client")
}
if !client.IsHeadlessLoginEnabled() {
t.Fatalf("expected metadata-backed headless login enabled")
}
})
t.Run("inline jwks without jwks uri does not support headless login", func(t *testing.T) {
client := HydraClient{
TokenEndpointAuthMethod: "private_key_jwt",
JWKS: map[string]any{
"keys": []map[string]any{{
"kty": "RSA",
}},
},
Metadata: map[string]any{
"headless_login_enabled": true,
},
}
if client.SupportsHeadlessLogin() {
t.Fatalf("expected headless login prerequisites to be missing")
}
if client.IsHeadlessLoginEnabled() {
t.Fatalf("expected headless login disabled without jwks uri")
}
})
t.Run("jwks uri without private_key_jwt does not support headless login", func(t *testing.T) {
client := HydraClient{
TokenEndpointAuthMethod: "none",
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
Metadata: map[string]any{
"headless_login_enabled": true,
},
}
if client.SupportsHeadlessLogin() {
t.Fatalf("expected headless login prerequisites to be missing")
}
if client.IsHeadlessLoginEnabled() {
t.Fatalf("expected headless login disabled when prerequisites are missing")
}
})
t.Run("headless login client without boolean metadata flag is not enabled", func(t *testing.T) {
client := HydraClient{
TokenEndpointAuthMethod: "private_key_jwt",
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
Metadata: map[string]any{
"headless_login_enabled": "true",
},
}
if !client.SupportsHeadlessLogin() {
t.Fatalf("expected headless login client")
}
if client.IsHeadlessLoginEnabled() {
t.Fatalf("expected headless login disabled for non-bool metadata")
}
})
}
func TestHydraClientHeadlessMetadataAccessors(t *testing.T) {
t.Run("metadata values override inline values", func(t *testing.T) {
metadataJWKS := map[string]any{"keys": []any{"metadata-key"}}
client := HydraClient{
TokenEndpointAuthMethod: "client_secret_post",
JWKSUri: "https://inline.example.com/jwks.json",
JWKS: map[string]any{"keys": []any{"inline-key"}},
Metadata: map[string]any{
MetadataHeadlessTokenEndpointAuthMethod: " private_key_jwt ",
MetadataHeadlessJWKSURI: " https://metadata.example.com/jwks.json ",
MetadataHeadlessJWKS: metadataJWKS,
},
}
if got := client.HeadlessTokenEndpointAuthMethod(); got != "private_key_jwt" {
t.Fatalf("unexpected auth method: %q", got)
}
if got := client.HeadlessJWKSURI(); got != "https://metadata.example.com/jwks.json" {
t.Fatalf("unexpected jwks uri: %q", got)
}
if got := client.HeadlessJWKS(); !reflect.DeepEqual(got, metadataJWKS) {
t.Fatalf("unexpected jwks value: %#v", got)
}
})
t.Run("blank or missing metadata values fall back to inline values", func(t *testing.T) {
inlineJWKS := map[string]any{"keys": []any{"inline-key"}}
client := HydraClient{
TokenEndpointAuthMethod: " private_key_jwt ",
JWKSUri: " https://inline.example.com/jwks.json ",
JWKS: inlineJWKS,
Metadata: map[string]any{
MetadataHeadlessTokenEndpointAuthMethod: " ",
MetadataHeadlessJWKSURI: " ",
MetadataHeadlessJWKS: nil,
},
}
if got := client.HeadlessTokenEndpointAuthMethod(); got != "private_key_jwt" {
t.Fatalf("unexpected auth method: %q", got)
}
if got := client.HeadlessJWKSURI(); got != "https://inline.example.com/jwks.json" {
t.Fatalf("unexpected jwks uri: %q", got)
}
if got := client.HeadlessJWKS(); !reflect.DeepEqual(got, inlineJWKS) {
t.Fatalf("unexpected jwks value: %#v", got)
}
})
}
func TestHydraClientBackchannelLogoutAccessors(t *testing.T) {
t.Run("metadata values override inline values", func(t *testing.T) {
inlineRequired := false
client := HydraClient{
BackChannelLogoutURI: "https://inline.example.com/logout",
BackChannelLogoutSessionRequired: &inlineRequired,
Metadata: map[string]any{
MetadataBackChannelLogoutURI: " https://metadata.example.com/logout ",
MetadataBackChannelLogoutSessionRequired: true,
},
}
if got := client.BackchannelLogoutURI(); got != "https://metadata.example.com/logout" {
t.Fatalf("unexpected logout uri: %q", got)
}
if !client.BackchannelLogoutSessionRequiredValue() {
t.Fatalf("expected metadata session_required value")
}
})
t.Run("blank or missing metadata values fall back to inline values", func(t *testing.T) {
inlineRequired := true
client := HydraClient{
BackChannelLogoutURI: " https://inline.example.com/logout ",
BackChannelLogoutSessionRequired: &inlineRequired,
Metadata: map[string]any{
MetadataBackChannelLogoutURI: " ",
MetadataBackChannelLogoutSessionRequired: "true",
},
}
if got := client.BackchannelLogoutURI(); got != "https://inline.example.com/logout" {
t.Fatalf("unexpected logout uri: %q", got)
}
if !client.BackchannelLogoutSessionRequiredValue() {
t.Fatalf("expected inline session_required value")
}
})
t.Run("missing session required defaults to false", func(t *testing.T) {
client := HydraClient{}
if got := client.BackchannelLogoutURI(); got != "" {
t.Fatalf("unexpected logout uri: %q", got)
}
if client.BackchannelLogoutSessionRequiredValue() {
t.Fatalf("expected default session_required false")
}
})
}

View File

@@ -0,0 +1,19 @@
package domain
import "time"
type IdentityCacheStatus struct {
Status string `json:"status"`
RedisReady bool `json:"redisReady"`
ObservedCount int64 `json:"observedCount"`
KeyCount int64 `json:"keyCount"`
LastRefreshedAt *time.Time `json:"lastRefreshedAt,omitempty"`
LastError string `json:"lastError,omitempty"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
}
type IdentityCacheFlushResult struct {
Status string `json:"status"`
FlushedKeys int64 `json:"flushedKeys"`
UpdatedAt time.Time `json:"updatedAt"`
}

View File

@@ -0,0 +1,92 @@
package domain
import (
"errors"
"net/http"
"time"
)
// ErrNotSupported는 IDP가 특정 인증 흐름을 지원하지 않을 때 반환합니다.
var ErrNotSupported = errors.New("idp: not supported")
// BrokerUser is the standard user model used within Baron SSO business logic.
// It defines the canonical set of fields that must be supported by any underlying IDP.
type BrokerUser struct {
ID string `json:"id" required:"true"`
Email string `json:"email" required:"true"`
LoginID string `json:"login_id"`
CustomLoginIDs []string `json:"custom_login_ids"` // [New] 다중 로그인 ID
Name string `json:"name"`
PhoneNumber string `json:"phone_number"`
// Attributes stores custom user attributes.
// The "required_keys" tag specifies which keys MUST be present in the IDP's schema support.
Attributes map[string]any `json:"attributes" required_keys:"grade,department"`
}
// IDPMetadata represents the schema capabilities of an Identity Provider.
type IDPMetadata struct {
// SupportedFields lists the BrokerUser fields (json tag names) that the IDP supports.
// For custom attributes, use the key name directly (e.g., "grade").
SupportedFields []string
}
// PasswordPolicy는 비밀번호 정책 정보를 표현합니다.
type PasswordPolicy struct {
MinLength int
Lowercase bool
Uppercase bool
Number bool
NonAlphanumeric bool
MinCharacterTypes int
}
// Token represents a session or refresh token.
type Token struct {
JWT string
Expiration time.Time
SessionID string
}
// AuthInfo contains authentication information after a successful login.
type AuthInfo struct {
SessionToken *Token
RefreshToken *Token
// Subject는 IDP 세션이 대표하는 주체(예: Kratos identity.id)를 나타냅니다.
Subject string
SetCookies []*http.Cookie
}
// LinkLoginInit는 링크 로그인 초기화 결과입니다.
type LinkLoginInit struct {
FlowID string
ExpiresAt time.Time
// Mode는 링크 로그인 완료 후 세션 처리 방식입니다. (예: "cookie")
Mode string
// LoginID는 IDP에 실제 전달된 식별자입니다.
LoginID string
}
// IdentityProvider is the interface that all IDP adapters must implement.
type IdentityProvider interface {
Name() string
// GetMetadata returns the schema support information for this IDP.
// This is used for startup-time validation.
GetMetadata() (*IDPMetadata, error)
// CreateUser는 BrokerUser 스키마를 기반으로 신규 사용자를 생성하고 주체 ID(예: identity.id)를 반환합니다.
CreateUser(user *BrokerUser, password string) (string, error)
// SignIn은 로그인 ID/비밀번호로 인증해 세션 정보를 반환합니다.
SignIn(loginID, password string) (*AuthInfo, error)
// UserExists는 loginID 기준으로 사용자 존재 여부를 확인합니다.
UserExists(loginID string) (bool, error)
// IssueSession은 비밀번호 없이 세션을 발급해야 하는 흐름에서 사용합니다.
IssueSession(loginID string) (*AuthInfo, error)
// InitiateLinkLogin은 링크 기반 로그인 요청을 IDP에 전달합니다.
InitiateLinkLogin(loginID, returnTo string) (*LinkLoginInit, error)
// VerifyLoginCode는 링크/코드 기반 로그인에서 코드를 제출해 세션을 발급합니다.
VerifyLoginCode(loginID, flowID, code string) (*AuthInfo, error)
// GetPasswordPolicy는 IDP가 제공하는 비밀번호 정책을 반환합니다.
GetPasswordPolicy() (*PasswordPolicy, error)
InitiatePasswordReset(loginID, redirectUrl string) error
VerifyPasswordResetToken(token string) (*AuthInfo, error)
UpdateUserPassword(loginID, newPassword string, r *http.Request) error
}

View File

@@ -0,0 +1,42 @@
package domain
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
)
// JSONMap is a custom type for handling map[string]any with PostgreSQL JSONB
type JSONMap map[string]any
// Value implements the driver.Valuer interface
func (m JSONMap) Value() (driver.Value, error) {
if m == nil {
return nil, nil
}
ba, err := json.Marshal(m)
return string(ba), err
}
// Scan implements the sql.Scanner interface
func (m *JSONMap) Scan(value any) error {
if value == nil {
*m = make(JSONMap)
return nil
}
var bytes []byte
switch v := value.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return errors.New(fmt.Sprintf("failed to scan JSONMap: %v", value))
}
result := make(JSONMap)
err := json.Unmarshal(bytes, &result)
*m = result
return err
}

View File

@@ -0,0 +1,93 @@
package domain
import (
"encoding/json"
"testing"
)
func TestJSONMapValue(t *testing.T) {
t.Run("nil map returns nil database value", func(t *testing.T) {
var payload JSONMap
value, err := payload.Value()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if value != nil {
t.Fatalf("expected nil value, got %v", value)
}
})
t.Run("map marshals to JSON string", func(t *testing.T) {
payload := JSONMap{"enabled": true, "name": "baron"}
value, err := payload.Value()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
raw, ok := value.(string)
if !ok {
t.Fatalf("expected string value, got %T", value)
}
var decoded map[string]any
if err := json.Unmarshal([]byte(raw), &decoded); err != nil {
t.Fatalf("value should be valid json: %v", err)
}
if decoded["enabled"] != true || decoded["name"] != "baron" {
t.Fatalf("unexpected decoded value: %#v", decoded)
}
})
}
func TestJSONMapScan(t *testing.T) {
t.Run("nil value becomes empty map", func(t *testing.T) {
var payload JSONMap
if err := payload.Scan(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if payload == nil || len(payload) != 0 {
t.Fatalf("expected empty map, got %#v", payload)
}
})
t.Run("byte slice value decodes JSON", func(t *testing.T) {
var payload JSONMap
if err := payload.Scan([]byte(`{"count":2,"name":"baron"}`)); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if payload["count"] != float64(2) || payload["name"] != "baron" {
t.Fatalf("unexpected payload: %#v", payload)
}
})
t.Run("string value decodes JSON", func(t *testing.T) {
var payload JSONMap
if err := payload.Scan(`{"active":true}`); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if payload["active"] != true {
t.Fatalf("unexpected payload: %#v", payload)
}
})
t.Run("unsupported value type returns error", func(t *testing.T) {
var payload JSONMap
if err := payload.Scan(42); err == nil {
t.Fatalf("expected unsupported type error")
}
})
t.Run("invalid JSON returns error", func(t *testing.T) {
var payload JSONMap
if err := payload.Scan(`{invalid`); err == nil {
t.Fatalf("expected invalid JSON error")
}
})
}

View File

@@ -0,0 +1,48 @@
package domain
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// KetoOutbox status
const (
KetoOutboxStatusPending = "pending"
KetoOutboxStatusProcessed = "processed"
KetoOutboxStatusFailed = "failed"
)
// KetoOutbox action
const (
KetoOutboxActionCreate = "CREATE"
KetoOutboxActionDelete = "DELETE"
)
// KetoOutbox represents a Keto relationship tuple update event.
type KetoOutbox struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
Namespace string `gorm:"not null" json:"namespace"`
Object string `gorm:"not null" json:"object"`
Relation string `gorm:"not null" json:"relation"`
Subject string `gorm:"not null" json:"subject"` // format: "User:ID" or "Tenant:ID#members"
Action string `gorm:"not null" json:"action"` // CREATE, DELETE
Status string `gorm:"default:'pending';index" json:"status"`
RetryCount int `gorm:"default:0" json:"retryCount"`
LastError string `json:"lastError,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
ProcessedAt *time.Time `json:"processedAt,omitempty"`
}
func (ko *KetoOutbox) TableName() string {
return "keto_outbox"
}
func (ko *KetoOutbox) BeforeCreate(tx *gorm.DB) (err error) {
if ko.ID == "" {
ko.ID = uuid.NewString()
}
return
}

View File

@@ -0,0 +1,357 @@
package domain
import (
"testing"
"time"
"github.com/google/uuid"
)
func requireGeneratedUUID(t *testing.T, value string) {
t.Helper()
if value == "" {
t.Fatalf("expected generated uuid")
}
if _, err := uuid.Parse(value); err != nil {
t.Fatalf("expected valid uuid, got %q: %v", value, err)
}
}
func TestBeforeCreateGeneratesMissingIDs(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{
name: "api key",
run: func(t *testing.T) {
model := ApiKey{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
{
name: "client consent",
run: func(t *testing.T) {
model := ClientConsent{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
{
name: "identity provider config",
run: func(t *testing.T) {
model := IdentityProviderConfig{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
{
name: "keto outbox",
run: func(t *testing.T) {
model := KetoOutbox{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
{
name: "tenant",
run: func(t *testing.T) {
model := Tenant{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
{
name: "tenant domain",
run: func(t *testing.T) {
model := TenantDomain{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
{
name: "user",
run: func(t *testing.T) {
model := User{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
{
name: "user group",
run: func(t *testing.T) {
model := UserGroup{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
{
name: "worksmobile resource mapping",
run: func(t *testing.T) {
model := WorksmobileResourceMapping{}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, model.ID)
},
},
}
for _, tc := range tests {
t.Run(tc.name, tc.run)
}
}
func TestBeforeCreatePreservesExistingIDs(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T)
}{
{
name: "api key",
run: func(t *testing.T) {
model := ApiKey{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
{
name: "client consent",
run: func(t *testing.T) {
model := ClientConsent{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
{
name: "identity provider config",
run: func(t *testing.T) {
model := IdentityProviderConfig{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
{
name: "keto outbox",
run: func(t *testing.T) {
model := KetoOutbox{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
{
name: "tenant",
run: func(t *testing.T) {
model := Tenant{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
{
name: "tenant domain",
run: func(t *testing.T) {
model := TenantDomain{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
{
name: "user",
run: func(t *testing.T) {
model := User{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
{
name: "user group",
run: func(t *testing.T) {
model := UserGroup{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
{
name: "worksmobile resource mapping",
run: func(t *testing.T) {
model := WorksmobileResourceMapping{ID: "existing-id"}
if err := model.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if model.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
},
},
}
for _, tc := range tests {
t.Run(tc.name, tc.run)
}
}
func TestTableNames(t *testing.T) {
tests := []struct {
name string
got string
expected string
}{
{name: "keto outbox", got: (&KetoOutbox{}).TableName(), expected: "keto_outbox"},
{name: "rp usage event", got: (&RPUsageEvent{}).TableName(), expected: "rp_usage_outbox"},
{name: "rp user metadata", got: (RPUserMetadata{}).TableName(), expected: "rp_user_metadata"},
{name: "user group", got: (&UserGroup{}).TableName(), expected: "user_groups"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.got != tc.expected {
t.Fatalf("unexpected table name: got=%s expected=%s", tc.got, tc.expected)
}
})
}
}
func TestTenantIsActive(t *testing.T) {
tests := []struct {
name string
status string
expected bool
}{
{name: "active", status: TenantStatusActive, expected: true},
{name: "pending", status: TenantStatusPending, expected: false},
{name: "suspended", status: TenantStatusSuspended, expected: false},
{name: "deleted", status: TenantStatusDeleted, expected: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tenant := Tenant{Status: tc.status}
if got := tenant.IsActive(); got != tc.expected {
t.Fatalf("unexpected active state: got=%v expected=%v", got, tc.expected)
}
})
}
}
func TestRPUsageEventBeforeCreateDefaults(t *testing.T) {
event := RPUsageEvent{}
if err := event.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, event.ID)
if event.Status != RPUsageOutboxStatusPending {
t.Fatalf("unexpected status: %s", event.Status)
}
if event.OccurredAt.IsZero() {
t.Fatalf("expected occurred_at default")
}
if event.Payload == nil {
t.Fatalf("expected empty payload default")
}
}
func TestRPUsageEventBeforeCreatePreservesExplicitValues(t *testing.T) {
occurredAt := time.Date(2026, 5, 29, 1, 2, 3, 0, time.UTC)
event := RPUsageEvent{
ID: "existing-id",
Status: RPUsageOutboxStatusProcessing,
OccurredAt: occurredAt,
Payload: JSONMap{"source": "test"},
}
if err := event.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if event.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
if event.Status != RPUsageOutboxStatusProcessing {
t.Fatalf("expected status to be preserved")
}
if !event.OccurredAt.Equal(occurredAt) {
t.Fatalf("expected occurred_at to be preserved")
}
if event.Payload["source"] != "test" {
t.Fatalf("expected payload to be preserved")
}
}
func TestWorksmobileOutboxBeforeCreateDefaults(t *testing.T) {
outbox := WorksmobileOutbox{}
if err := outbox.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
requireGeneratedUUID(t, outbox.ID)
if outbox.Status != WorksmobileOutboxStatusPending {
t.Fatalf("unexpected status: %s", outbox.Status)
}
}
func TestWorksmobileOutboxBeforeCreatePreservesExplicitValues(t *testing.T) {
outbox := WorksmobileOutbox{
ID: "existing-id",
Status: WorksmobileOutboxStatusProcessing,
}
if err := outbox.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if outbox.ID != "existing-id" {
t.Fatalf("expected existing id to be preserved")
}
if outbox.Status != WorksmobileOutboxStatusProcessing {
t.Fatalf("expected status to be preserved")
}
}

View File

@@ -0,0 +1,48 @@
package domain
import (
"context"
"time"
)
// AuditLog represents a single audit event
type AuditLog struct {
EventID string `json:"event_id"`
Timestamp time.Time `json:"timestamp"`
UserID string `json:"user_id"`
TenantID string `json:"tenant_id,omitempty"`
SessionID string `json:"session_id,omitempty"`
EventType string `json:"event_type"` // e.g., "login_success", "login_failed", "otp_sent"
Status string `json:"status"` // e.g., "success", "failure"
AuthMethod string `json:"auth_method,omitempty"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
DeviceID string `json:"device_id,omitempty"`
Details string `json:"details,omitempty"` // JSON string or simple text
}
// AuditRepository defines interface for storing logs
type AuditRepository interface {
Create(log *AuditLog) error
FindPage(ctx context.Context, limit int, cursor *AuditCursor, tenantID string) ([]AuditLog, error)
FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]AuditLog, error)
CountEventsSince(ctx context.Context, since time.Time) (int64, error)
CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
Ping(ctx context.Context) error
}
type AuditCursor struct {
Timestamp time.Time
EventID string
}
// RedisRepository defines interface for KV storage (Redis)
type RedisRepository interface {
Set(key string, value string, expiration time.Duration) error
Get(key string) (string, error)
Delete(key string) error
StoreVerificationCode(phone, code string) error
GetVerificationCode(phone string) (string, error)
DeleteVerificationCode(phone string) error
}

View File

@@ -0,0 +1,31 @@
package domain
import (
"context"
"time"
)
type OathkeeperAccessLog struct {
Timestamp time.Time
RequestID string
Method string
Path string
Status int
LatencyMs int
ClientID string
RP string
Action string
Target string
Subject string
ClientIP string
UserAgent string
Decision string
TraceID string
SpanID string
Raw string
}
type OathkeeperLogRepository interface {
FindPageBySubject(ctx context.Context, subject string, limit int, cursor *AuditCursor) ([]OathkeeperAccessLog, error)
Ping(ctx context.Context) error
}

View File

@@ -0,0 +1,19 @@
package domain
import (
"time"
)
// RelyingParty represents an OAuth2 Client owner by a Tenant.
// It maps 1:1 to a Hydra Client.
type RelyingParty struct {
ClientID string `json:"clientId"` // Maps to Hydra Client ID
TenantID string `json:"tenantId"`
Name string `json:"name"` // Display name (can be same as Hydra Client Name)
Description string `json:"description"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
// DeletedAt removed as it's not a DB model anymore
}
// TableName removed

View File

@@ -0,0 +1,101 @@
package domain
import (
"context"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"gorm.io/gorm"
)
const (
RPUsageOutboxStatusPending = "pending"
RPUsageOutboxStatusProcessing = "processing"
RPUsageOutboxStatusProcessed = "processed"
RPUsageOutboxStatusFailed = "failed"
)
const (
RPUsageEventTypeAuthorizationGranted = "rp_usage.authorization_granted"
RPUsageEventTypeAuthorizationRevoked = "rp_usage.authorization_revoked"
)
const (
RPUsageTenantTypeCompany = TenantTypeCompany
RPUsageTenantTypeOrganization = TenantTypeOrganization
)
type RPUsageEvent struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
EventType string `gorm:"not null;index:idx_rp_usage_outbox_event" json:"eventType"`
Subject string `gorm:"not null;index:idx_rp_usage_outbox_subject" json:"subject"`
TenantID string `gorm:"index:idx_rp_usage_outbox_tenant" json:"tenantId,omitempty"`
TenantType string `gorm:"index:idx_rp_usage_outbox_tenant" json:"tenantType,omitempty"`
ClientID string `gorm:"not null;index:idx_rp_usage_outbox_client" json:"clientId"`
ClientName string `json:"clientName,omitempty"`
SessionID string `gorm:"index" json:"sessionId,omitempty"`
Scopes pq.StringArray `gorm:"type:text[]" json:"scopes,omitempty"`
Source string `gorm:"not null;index" json:"source"`
CorrelationID string `gorm:"index" json:"correlationId,omitempty"`
Payload JSONMap `gorm:"type:jsonb" json:"payload,omitempty"`
DedupeKey string `gorm:"uniqueIndex" json:"dedupeKey"`
Status string `gorm:"default:'pending';index" json:"status"`
RetryCount int `gorm:"default:0" json:"retryCount"`
LastError string `json:"lastError,omitempty"`
NextAttemptAt *time.Time `json:"nextAttemptAt,omitempty"`
OccurredAt time.Time `gorm:"not null;index" json:"occurredAt"`
ProcessedAt *time.Time `json:"processedAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
func (e *RPUsageEvent) TableName() string {
return "rp_usage_outbox"
}
func (e *RPUsageEvent) BeforeCreate(tx *gorm.DB) error {
if e.ID == "" {
e.ID = uuid.NewString()
}
if e.Status == "" {
e.Status = RPUsageOutboxStatusPending
}
if e.OccurredAt.IsZero() {
e.OccurredAt = time.Now()
}
if e.Payload == nil {
e.Payload = JSONMap{}
}
return nil
}
type RPUsageEventSink interface {
EmitRPUsageEvent(ctx context.Context, event RPUsageEvent) error
}
type RPUsageProjectionRepository interface {
CreateRPUsageEvent(ctx context.Context, event RPUsageEvent) error
}
type RPUsageDailyMetric struct {
Date string `json:"date"`
TenantID string `json:"tenantId"`
TenantType string `json:"tenantType"`
TenantName string `json:"tenantName,omitempty"`
ClientID string `json:"clientId"`
ClientName string `json:"clientName"`
LoginRequests uint64 `json:"loginRequests"`
OtherRequests uint64 `json:"otherRequests"`
UniqueSubjects uint64 `json:"uniqueSubjects"`
}
type RPUsageQuery struct {
Days int
Period string
TenantID string
}
type RPUsageQueryRepository interface {
FindRPUsage(ctx context.Context, query RPUsageQuery) ([]RPUsageDailyMetric, error)
}

View File

@@ -0,0 +1,16 @@
package domain
import "time"
type RPUserMetadata struct {
ClientID string `gorm:"column:client_id;primaryKey" json:"clientId"`
UserID string `gorm:"column:user_id;type:uuid;primaryKey" json:"userId"`
User *User `gorm:"foreignKey:UserID" json:"-"`
Metadata JSONMap `gorm:"column:metadata;type:jsonb" json:"metadata"`
CreatedAt time.Time `gorm:"column:created_at" json:"createdAt"`
UpdatedAt time.Time `gorm:"column:updated_at" json:"updatedAt"`
}
func (RPUserMetadata) TableName() string {
return "rp_user_metadata"
}

View File

@@ -0,0 +1,53 @@
package domain
import (
"crypto/rand"
"encoding/hex"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type SharedLink struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
TenantID string `gorm:"type:uuid;not null;index" json:"tenantId"`
Token string `gorm:"uniqueIndex;not null" json:"token"`
Name string `gorm:"not null" json:"name"` // 링크 식별을 위한 이름 (예: "24년 상반기 채용공고용")
Description string `json:"description"`
AccessLevel string `gorm:"default:'READ_ONLY'" json:"accessLevel"`
IsActive bool `gorm:"default:true" json:"isActive"`
ExpiresAt *time.Time `json:"expiresAt"`
Password string `json:"-"` // 필요 시 비밀번호 (선택 사항)
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// Relation
Tenant Tenant `gorm:"foreignKey:TenantID" json:"-"`
}
func (s *SharedLink) BeforeCreate(tx *gorm.DB) (err error) {
if s.ID == "" {
s.ID = uuid.NewString()
}
if s.Token == "" {
// 32바이트(64자)의 강력한 난수 토큰 생성
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return err
}
s.Token = hex.EncodeToString(b)
}
return
}
func (s *SharedLink) IsValid() bool {
if !s.IsActive {
return false
}
if s.ExpiresAt != nil && s.ExpiresAt.Before(time.Now()) {
return false
}
return true
}

View File

@@ -0,0 +1,80 @@
package domain
import (
"encoding/hex"
"testing"
"time"
)
func TestSharedLinkBeforeCreate(t *testing.T) {
t.Run("generates id and token when missing", func(t *testing.T) {
link := SharedLink{}
if err := link.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if link.ID == "" {
t.Fatalf("expected generated id")
}
if len(link.Token) != 64 {
t.Fatalf("expected 64-character token, got %q", link.Token)
}
if _, err := hex.DecodeString(link.Token); err != nil {
t.Fatalf("expected hex token: %v", err)
}
})
t.Run("preserves existing id and token", func(t *testing.T) {
link := SharedLink{
ID: "existing-id",
Token: "existing-token",
}
if err := link.BeforeCreate(nil); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if link.ID != "existing-id" || link.Token != "existing-token" {
t.Fatalf("expected existing fields to be preserved: %#v", link)
}
})
}
func TestSharedLinkIsValid(t *testing.T) {
future := time.Now().Add(time.Hour)
past := time.Now().Add(-time.Hour)
tests := []struct {
name string
link SharedLink
expected bool
}{
{
name: "active link without expiration is valid",
link: SharedLink{IsActive: true},
expected: true,
},
{
name: "active link with future expiration is valid",
link: SharedLink{IsActive: true, ExpiresAt: &future},
expected: true,
},
{
name: "inactive link is invalid",
link: SharedLink{IsActive: false, ExpiresAt: &future},
expected: false,
},
{
name: "expired link is invalid",
link: SharedLink{IsActive: true, ExpiresAt: &past},
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := tc.link.IsValid(); got != tc.expected {
t.Fatalf("unexpected validity: got=%v expected=%v", got, tc.expected)
}
})
}
}

View File

@@ -0,0 +1,42 @@
package domain
// SmsService defines the interface for sending SMS messages.
type SmsService interface {
SendSms(to, content string) error
}
// NaverSmsRequest represents the request body for the Naver Cloud SMS API.
type NaverSmsRequest struct {
Type string `json:"type"`
ContentType string `json:"contentType"`
CountryCode string `json:"countryCode"`
From string `json:"from"`
Subject string `json:"subject,omitempty"`
Content string `json:"content"`
Messages []SmsMessage `json:"messages"`
}
// SmsMessage represents a single message to be sent.
type SmsMessage struct {
To string `json:"to"`
Content string `json:"content,omitempty"`
}
// NaverSmsResponse represents the response from the Naver Cloud SMS API.
type NaverSmsResponse struct {
RequestID string `json:"requestId"`
RequestTime string `json:"requestTime"`
StatusCode string `json:"statusCode"`
StatusName string `json:"statusName"`
}
// SmsRequest represents the request body for sending an SMS.
type SmsRequest struct {
PhoneNumber string `json:"phoneNumber"`
}
// SmsVerifyRequest represents the request body for verifying an SMS code.
type SmsVerifyRequest struct {
PhoneNumber string `json:"phoneNumber"`
Code string `json:"code"`
}

View File

@@ -0,0 +1,11 @@
package domain
import "time"
// SystemSetting stores small global configuration documents.
type SystemSetting struct {
Key string `gorm:"primaryKey;size:128" json:"key"`
Value JSONMap `gorm:"type:jsonb" json:"value"`
CreatedAt time.Time
UpdatedAt time.Time
}

View File

@@ -0,0 +1,53 @@
package domain
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// Tenant statuses
const (
TenantStatusPending = "pending"
TenantStatusActive = "active"
TenantStatusSuspended = "suspended"
TenantStatusDeleted = "deleted"
)
// Tenant types
const (
TenantTypePersonal = "PERSONAL"
TenantTypeCompany = "COMPANY"
TenantTypeCompanyGroup = "COMPANY_GROUP"
TenantTypeOrganization = "ORGANIZATION"
TenantTypeUserGroup = "USER_GROUP"
)
// Tenant represents a tenant model stored in PostgreSQL.
type Tenant struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
Type string `gorm:"not null;default:'PERSONAL'" json:"type"` // PERSONAL, COMPANY, COMPANY_GROUP, ORGANIZATION, USER_GROUP
ParentID *string `gorm:"type:uuid;index" json:"parentId,omitempty"` // 부모 테넌트 ID
Name string `gorm:"not null" json:"name"`
Slug string `gorm:"uniqueIndex;not null" json:"slug"`
Description string `json:"description"`
Status string `gorm:"default:'pending'" json:"status"`
Domains []TenantDomain `gorm:"foreignKey:TenantID" json:"domains,omitempty"`
Config JSONMap `gorm:"type:jsonb" json:"config,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (t *Tenant) IsActive() bool {
return t.Status == TenantStatusActive
}
// BeforeCreate hook to generate UUID if not present.
func (t *Tenant) BeforeCreate(tx *gorm.DB) (err error) {
if t.ID == "" {
t.ID = uuid.NewString()
}
return
}

View File

@@ -0,0 +1,27 @@
package domain
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// TenantDomain represents a domain associated with a tenant for auto-assignment.
type TenantDomain struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
TenantID string `gorm:"type:uuid;not null;uniqueIndex:idx_tenant_domains_tenant_domain" json:"tenantId"`
Domain string `gorm:"not null;uniqueIndex:idx_tenant_domains_tenant_domain" json:"domain"` // e.g. "example.com"
Verified bool `gorm:"default:false" json:"verified"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// BeforeCreate hook to generate UUID if not present.
func (td *TenantDomain) BeforeCreate(tx *gorm.DB) (err error) {
if td.ID == "" {
td.ID = uuid.NewString()
}
return
}

View File

@@ -0,0 +1,247 @@
package domain
import (
"fmt"
"slices"
"strings"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"gorm.io/gorm"
)
// User roles
const (
RoleSuperAdmin = "super_admin" // 시스템 전역 관리자
RoleUser = "user" // 일반 사용자
)
// User statuses
const (
UserStatusActive = "active"
UserStatusSuspended = "suspended"
UserStatusTemporaryLeave = "temporary_leave"
UserStatusPreboarding = "preboarding"
UserStatusBaronGuest = "baron_guest"
UserStatusExtendedLeave = "extended_leave"
UserStatusArchived = "archived"
)
func NormalizeUserStatus(status string) string {
switch strings.ToLower(strings.TrimSpace(status)) {
case "", UserStatusActive:
return UserStatusActive
case "blocked", UserStatusSuspended:
return UserStatusSuspended
case "inactive", UserStatusPreboarding:
return UserStatusPreboarding
case "leave_of_absence", UserStatusTemporaryLeave:
return UserStatusTemporaryLeave
case "baron_only", UserStatusBaronGuest:
return UserStatusBaronGuest
case UserStatusExtendedLeave:
return UserStatusExtendedLeave
case UserStatusArchived:
return UserStatusArchived
default:
return strings.ToLower(strings.TrimSpace(status))
}
}
func IsBaronActivityAllowedStatus(status string) bool {
switch NormalizeUserStatus(status) {
case UserStatusActive, UserStatusTemporaryLeave, UserStatusBaronGuest:
return true
default:
return false
}
}
func IsOrgVisibleUserStatus(status string) bool {
switch NormalizeUserStatus(status) {
case UserStatusActive, UserStatusTemporaryLeave, UserStatusSuspended:
return true
default:
return false
}
}
func IsWorksProvisionedUserStatus(status string) bool {
switch NormalizeUserStatus(status) {
case UserStatusActive, UserStatusTemporaryLeave, UserStatusSuspended:
return true
default:
return false
}
}
func IsWorksDeprovisionUserStatus(status string) bool {
switch NormalizeUserStatus(status) {
case UserStatusBaronGuest, UserStatusExtendedLeave, UserStatusArchived:
return true
default:
return false
}
}
// NormalizeRole maps legacy/synonym role values to canonical role keys.
func NormalizeRole(role string) string {
if normalized, ok := NormalizeRoleAlias(role); ok {
return normalized
}
return RoleUser
}
func NormalizeRoleAlias(role string) (string, bool) {
normalized := strings.ToLower(strings.TrimSpace(role))
switch normalized {
case RoleSuperAdmin, RoleUser:
return normalized, true
case "tenant_admin", "rp_admin", "tenant_member", "member", "admin", "tenantadmin", "tenant-admin":
return RoleUser, true
case "superadmin", "super-admin":
return RoleSuperAdmin, true
default:
return "", false
}
}
// User represents the user model stored in PostgreSQL
type User struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
Email string `gorm:"uniqueIndex;not null" json:"email"`
PasswordHash *string `gorm:"column:password_hash" json:"-"`
Name string `gorm:"column:name;not null" json:"name"`
Phone string `gorm:"column:phone" json:"phone"`
Role string `gorm:"column:role;default:'user';not null" json:"role"` // super_admin, user
AffiliationType string `gorm:"column:affiliation_type" json:"affiliationType"`
CompanyCode string `gorm:"-" json:"companyCode,omitempty"`
CompanyCodes pq.StringArray `gorm:"-" json:"companyCodes,omitempty"`
TenantID *string `gorm:"column:tenant_id;type:uuid;index" json:"tenantId,omitempty"`
Tenant *Tenant `gorm:"foreignKey:TenantID" json:"tenant,omitempty"`
RelyingPartyID *string `gorm:"column:relying_party_id;type:uuid;index" json:"relyingPartyId,omitempty"` // RP Admin용
Department string `gorm:"column:department" json:"department"`
Grade string `gorm:"column:grade" json:"grade"` // 직급 (예: 수석, 책임, 선임)
Position string `gorm:"column:position" json:"position"` // 직책 (예: 팀장, 센터장)
JobTitle string `gorm:"column:job_title" json:"jobTitle"` // 직무 (예: 프론트엔드 개발, 기획)
Metadata JSONMap `gorm:"column:metadata;type:jsonb" json:"metadata,omitempty"`
Status string `gorm:"column:status;default:'active'" json:"status"`
CreatedAt time.Time `gorm:"column:created_at" json:"createdAt"`
UpdatedAt time.Time `gorm:"column:updated_at" json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index" json:"-"`
// Multiple identifiers support
UserLoginIDs []UserLoginID `gorm:"foreignKey:UserID" json:"userLoginIds,omitempty"`
}
// UserLoginID represents multiple custom identifiers for a user
type UserLoginID struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
UserID string `gorm:"type:uuid;not null;index" json:"userId"`
TenantID string `gorm:"type:uuid;not null;index" json:"tenantId"` // 발급 테넌트
FieldKey string `gorm:"not null" json:"fieldKey"` // 스키마 필드 키 (예: emp_id)
LoginID string `gorm:"uniqueIndex;not null" json:"loginId"` // 실제 값 (예: EMP001)
}
// BeforeCreate hook to generate UUID if not present
func (u *User) BeforeCreate(tx *gorm.DB) (err error) {
if u.ID == "" {
u.ID = uuid.New().String()
}
return
}
// ValidateLoginID checks if the loginID violates any collision, length, or security rules.
func ValidateLoginID(loginID string, emails []string, phone string) error {
loginID = strings.TrimSpace(loginID)
if loginID == "" {
return nil
}
if len(loginID) < 4 || len(loginID) > 30 {
return fmt.Errorf("ID must be between 4 and 30 characters")
}
if strings.Contains(loginID, "@") {
return fmt.Errorf("ID cannot be an email format")
}
for _, email := range emails {
if email != "" && strings.EqualFold(loginID, email) {
return fmt.Errorf("ID cannot be the same as the email address")
}
}
if phone != "" {
normalizedPhone := NormalizePhoneNumber(phone)
if loginID == phone || loginID == normalizedPhone {
return fmt.Errorf("ID cannot be the same as the phone number")
}
}
isPureNumber := true
loginIDDigits := strings.ReplaceAll(loginID, "-", "")
loginIDDigits = strings.ReplaceAll(loginIDDigits, " ", "")
for _, c := range loginIDDigits {
if (c < '0' || c > '9') && c != '+' {
isPureNumber = false
break
}
}
if isPureNumber && len(loginIDDigits) >= 10 && len(loginIDDigits) <= 12 {
if strings.HasPrefix(loginIDDigits, "010") || strings.HasPrefix(loginIDDigits, "82") || strings.HasPrefix(loginIDDigits, "+82") {
return fmt.Errorf("ID cannot be a phone number format")
}
}
reserved := []string{"admin", "system", "root", "master", "superuser", "guest", "operator"}
lowerID := strings.ToLower(loginID)
if slices.Contains(reserved, lowerID) {
return fmt.Errorf("reserved ID cannot be used")
}
return nil
}
func NormalizePhoneNumber(phone string) string {
trimmed := strings.TrimSpace(phone)
if trimmed == "" {
return ""
}
hasLeadingPlus := false
digits := strings.Builder{}
for _, r := range trimmed {
switch {
case r >= '0' && r <= '9':
digits.WriteRune(r)
case r == '+' && digits.Len() == 0 && !hasLeadingPlus:
hasLeadingPlus = true
}
}
number := digits.String()
if number == "" {
return ""
}
if strings.HasPrefix(number, "010") {
return "+82" + number[1:]
}
if strings.HasPrefix(number, "82") {
rest := number[2:]
for strings.HasPrefix(rest, "82") {
rest = rest[2:]
}
if strings.HasPrefix(rest, "0") {
rest = rest[1:]
}
return "+82" + rest
}
if hasLeadingPlus {
return "+" + number
}
return number
}

View File

@@ -0,0 +1,50 @@
package domain
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// UserGroup represents a collection of users within a tenant.
type UserGroup struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
TenantID string `gorm:"type:uuid;index;not null" json:"tenantId"`
ParentID *string `gorm:"type:uuid;index" json:"parentId,omitempty"` // 상위 조직 ID
Name string `gorm:"not null" json:"name"`
Slug string `gorm:"index" json:"slug"` // 추가
Description string `json:"description"`
UnitType string `json:"unitType"` // 부, 국, 팀, 셀 등
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// Relationships
Parent *UserGroup `gorm:"foreignKey:ParentID" json:"parent,omitempty"`
Members []User `gorm:"-" json:"members,omitempty"`
}
type GroupCreateRequest struct {
Name string `json:"name"`
ParentID *string `json:"parentId"`
Description string `json:"description"`
UnitType string `json:"unitType"`
}
type GroupRole struct {
TenantID string `json:"tenantId"`
TenantName string `json:"tenantName"`
Relation string `json:"relation"`
}
func (ug *UserGroup) TableName() string {
return "user_groups"
}
func (ug *UserGroup) BeforeCreate(tx *gorm.DB) (err error) {
if ug.ID == "" {
ug.ID = uuid.NewString()
}
return
}

View File

@@ -0,0 +1,29 @@
package domain
import "time"
const (
UserProjectionNameKratos = "kratos_users"
UserProjectionStatusSyncing = "syncing"
UserProjectionStatusReady = "ready"
UserProjectionStatusFailed = "failed"
)
type UserProjectionState struct {
Name string `gorm:"primaryKey;column:name" json:"name"`
Status string `gorm:"column:status;not null" json:"status"`
LastSyncedAt *time.Time `gorm:"column:last_synced_at" json:"lastSyncedAt,omitempty"`
LastError string `gorm:"column:last_error;type:text" json:"lastError,omitempty"`
UpdatedAt time.Time `gorm:"column:updated_at" json:"updatedAt"`
}
type UserProjectionStatus struct {
Name string `json:"name"`
Status string `json:"status"`
Ready bool `json:"ready"`
LastSyncedAt *time.Time `json:"lastSyncedAt,omitempty"`
LastError string `json:"lastError,omitempty"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
ProjectedUsers int64 `json:"projectedUsers"`
}

View File

@@ -0,0 +1,73 @@
package domain
import "testing"
func TestNormalizeRole(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{name: "super admin unchanged", in: "super_admin", want: RoleSuperAdmin},
{name: "tenant admin mapped to user", in: "tenant_admin", want: RoleUser},
{name: "rp admin mapped to user", in: "rp_admin", want: RoleUser},
{name: "user unchanged", in: "user", want: RoleUser},
{name: "super admin hyphen alias", in: "super-admin", want: RoleSuperAdmin},
{name: "super admin compact alias", in: "superadmin", want: RoleSuperAdmin},
{name: "legacy admin mapped to user", in: "admin", want: RoleUser},
{name: "legacy tenant member", in: "tenant_member", want: RoleUser},
{name: "trim and lower", in: " ADMIN ", want: RoleUser},
{name: "unknown role mapped to user", in: "custom_role", want: RoleUser},
{name: "empty string mapped to user", in: " ", want: RoleUser},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := NormalizeRole(tc.in); got != tc.want {
t.Fatalf("NormalizeRole(%q)=%q, want %q", tc.in, got, tc.want)
}
})
}
}
func TestUserStatusPolicy(t *testing.T) {
tests := []struct {
status string
normalized string
baronAllowed bool
orgVisible bool
worksProvisioned bool
worksDeprovisioned bool
}{
{status: UserStatusActive, normalized: UserStatusActive, baronAllowed: true, orgVisible: true, worksProvisioned: true},
{status: UserStatusTemporaryLeave, normalized: UserStatusTemporaryLeave, baronAllowed: true, orgVisible: true, worksProvisioned: true},
{status: UserStatusSuspended, normalized: UserStatusSuspended, orgVisible: true, worksProvisioned: true},
{status: UserStatusPreboarding, normalized: UserStatusPreboarding},
{status: UserStatusBaronGuest, normalized: UserStatusBaronGuest, baronAllowed: true, worksDeprovisioned: true},
{status: UserStatusExtendedLeave, normalized: UserStatusExtendedLeave, worksDeprovisioned: true},
{status: UserStatusArchived, normalized: UserStatusArchived, worksDeprovisioned: true},
{status: "inactive", normalized: UserStatusPreboarding},
{status: "leave_of_absence", normalized: UserStatusTemporaryLeave, baronAllowed: true, orgVisible: true, worksProvisioned: true},
{status: "BARON_ONLY", normalized: UserStatusBaronGuest, baronAllowed: true, worksDeprovisioned: true},
}
for _, tc := range tests {
t.Run(tc.status, func(t *testing.T) {
if got := NormalizeUserStatus(tc.status); got != tc.normalized {
t.Fatalf("NormalizeUserStatus(%q)=%q, want %q", tc.status, got, tc.normalized)
}
if got := IsBaronActivityAllowedStatus(tc.status); got != tc.baronAllowed {
t.Fatalf("IsBaronActivityAllowedStatus(%q)=%v, want %v", tc.status, got, tc.baronAllowed)
}
if got := IsOrgVisibleUserStatus(tc.status); got != tc.orgVisible {
t.Fatalf("IsOrgVisibleUserStatus(%q)=%v, want %v", tc.status, got, tc.orgVisible)
}
if got := IsWorksProvisionedUserStatus(tc.status); got != tc.worksProvisioned {
t.Fatalf("IsWorksProvisionedUserStatus(%q)=%v, want %v", tc.status, got, tc.worksProvisioned)
}
if got := IsWorksDeprovisionUserStatus(tc.status); got != tc.worksDeprovisioned {
t.Fatalf("IsWorksDeprovisionUserStatus(%q)=%v, want %v", tc.status, got, tc.worksDeprovisioned)
}
})
}
}

View File

@@ -0,0 +1,64 @@
package domain
import (
"testing"
)
func TestValidateLoginID(t *testing.T) {
tests := []struct {
name string
loginID string
emails []string
phone string
wantErr bool
}{
{"Empty", "", []string{"test@email.com"}, "01012345678", false},
{"Valid alphanumeric", "user123", []string{"test@email.com"}, "01012345678", false},
{"Too short", "us", []string{"test@email.com"}, "01012345678", true},
{"Too long", "thisisaverylongloginidthatiswayoverthirtycharacters", []string{"test@email.com"}, "01012345678", true},
{"Email format", "user@domain.com", []string{"test@email.com"}, "01012345678", true},
{"Exact email match", "Test@Email.Com", []string{"test@email.com"}, "01012345678", true},
{"Secondary email match", "sub@test.com", []string{"test@email.com", "sub@test.com"}, "01012345678", true},
{"Phone number match", "010-1234-5678", []string{"test@email.com"}, "01012345678", true},
{"Phone number match +82", "+821012345678", []string{"test@email.com"}, "01012345678", true},
{"Phone number match digits", "01012345678", []string{"test@email.com"}, "01012345678", true},
{"Phone format (11 digits)", "01098765432", []string{"test@email.com"}, "01012345678", true},
{"Valid pure digits (employee ID)", "20230001", []string{"test@email.com"}, "01012345678", false},
{"Valid pure digits long", "123456789", []string{"test@email.com"}, "01012345678", false},
{"Valid pure digits 10 chars", "1234567890", []string{"test@email.com"}, "01012345678", false},
{"Reserved word admin", "ADMIN", []string{"test@email.com"}, "01012345678", true},
{"Reserved word root", "root", []string{"test@email.com"}, "01012345678", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateLoginID(tt.loginID, tt.emails, tt.phone)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateLoginID() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestNormalizePhoneNumberDeduplicatesKoreanCountryCode(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"Local mobile", "010-9191-7771", "+821091917771"},
{"Korean country code", "+82 10-9191-7771", "+821091917771"},
{"Duplicate plus Korean country code", "+82 +821091917771", "+821091917771"},
{"Duplicate compact Korean country code", "+82821091917771", "+821091917771"},
{"Duplicate spaced Korean country code", "+82 8210 9191 7771", "+821091917771"},
{"Non Korean international phone preserved", "+1 914 481 2222", "+19144812222"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NormalizePhoneNumber(tt.input); got != tt.want {
t.Fatalf("NormalizePhoneNumber(%q)=%q, want %q", tt.input, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,73 @@
package domain
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
const (
WorksmobileOutboxStatusPending = "pending"
WorksmobileOutboxStatusProcessing = "processing"
WorksmobileOutboxStatusProcessed = "processed"
WorksmobileOutboxStatusFailed = "failed"
)
const (
WorksmobileResourceOrgUnit = "ORGUNIT"
WorksmobileResourceUser = "USER"
)
const (
WorksmobileActionUpsert = "UPSERT"
WorksmobileActionDelete = "DELETE"
WorksmobileActionDryRun = "DRY_RUN"
WorksmobileActionSuspend = "SUSPEND"
WorksmobileActionPasswordReset = "PASSWORD_RESET"
)
type WorksmobileOutbox struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
ResourceType string `gorm:"not null;index:idx_worksmobile_outbox_resource" json:"resourceType"`
ResourceID string `gorm:"not null;index:idx_worksmobile_outbox_resource" json:"resourceId"`
Action string `gorm:"not null" json:"action"`
Payload JSONMap `gorm:"type:jsonb" json:"payload,omitempty"`
DedupeKey string `gorm:"uniqueIndex" json:"dedupeKey"`
Status string `gorm:"default:'pending';index" json:"status"`
RetryCount int `gorm:"default:0" json:"retryCount"`
LastError string `json:"lastError,omitempty"`
NextAttemptAt *time.Time `json:"nextAttemptAt,omitempty"`
ProcessedAt *time.Time `json:"processedAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
func (w *WorksmobileOutbox) BeforeCreate(tx *gorm.DB) error {
if w.ID == "" {
w.ID = uuid.NewString()
}
if w.Status == "" {
w.Status = WorksmobileOutboxStatusPending
}
return nil
}
type WorksmobileResourceMapping struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
BaronResourceType string `gorm:"not null;uniqueIndex:idx_worksmobile_mapping_baron" json:"baronResourceType"`
BaronResourceID string `gorm:"not null;uniqueIndex:idx_worksmobile_mapping_baron" json:"baronResourceId"`
ExternalKey string `gorm:"not null;uniqueIndex" json:"externalKey"`
WorksmobileResourceID string `json:"worksmobileResourceId,omitempty"`
DomainID int64 `json:"domainId"`
LastSyncedAt *time.Time `json:"lastSyncedAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
func (w *WorksmobileResourceMapping) BeforeCreate(tx *gorm.DB) error {
if w.ID == "" {
w.ID = uuid.NewString()
}
return nil
}

View File

@@ -0,0 +1,491 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"context"
"runtime"
"strconv"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
type adminHydraClientLister interface {
ListClients(ctx context.Context, limit, offset int) ([]domain.HydraClient, error)
}
type identityCacheAdmin interface {
GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error)
FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error)
}
type AdminHandler struct {
DB *gorm.DB
Keto service.KetoService
KetoOutbox repository.KetoOutboxRepository
RPUsageQueries domain.RPUsageQueryRepository
TenantRepo repository.TenantRepository
Hydra adminHydraClientLister
AuditRepo domain.AuditRepository
UserProjectionRepo repository.UserProjectionRepository
IdentityCache identityCacheAdmin
IntegrityChecker repository.DataIntegrityChecker
}
const globalCustomClaimsSettingKey = "global_custom_claim_definitions"
type globalCustomClaimDefinition struct {
Key string `json:"key"`
Label string `json:"label"`
ValueType string `json:"valueType"`
ReadPermission string `json:"readPermission"`
WritePermission string `json:"writePermission"`
Description string `json:"description,omitempty"`
}
type globalCustomClaimDefinitionsResponse struct {
Items []globalCustomClaimDefinition `json:"items"`
}
func NewAdminHandler(keto service.KetoService, ketoOutbox repository.KetoOutboxRepository) *AdminHandler {
return &AdminHandler{
Keto: keto,
KetoOutbox: ketoOutbox,
}
}
func (h *AdminHandler) GetRPUsageDaily(c *fiber.Ctx) error {
if h == nil || h.RPUsageQueries == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "rp usage query service unavailable",
})
}
days := 14
if raw := c.Query("days"); raw != "" {
if parsed, err := strconv.Atoi(raw); err == nil {
days = parsed
}
}
period := normalizeRPUsagePeriod(c.Query("period"))
tenantID, allowed := h.authorizedRPUsageTenantID(c, strings.TrimSpace(c.Query("tenantId")))
if !allowed {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "forbidden: tenant rp usage stats permission denied",
})
}
items, err := h.RPUsageQueries.FindRPUsage(c.Context(), domain.RPUsageQuery{
Days: days,
Period: period,
TenantID: tenantID,
})
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"error": err.Error(),
})
}
return c.JSON(fiber.Map{
"items": items,
"days": days,
"period": period,
"tenantId": tenantID,
})
}
func normalizeRPUsagePeriod(period string) string {
switch strings.ToLower(strings.TrimSpace(period)) {
case "week":
return "week"
case "month":
return "month"
default:
return "day"
}
}
func (h *AdminHandler) authorizedRPUsageTenantID(c *fiber.Ctx, requestedTenantID string) (string, bool) {
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
if profile != nil && domain.NormalizeRole(profile.Role) == domain.RoleSuperAdmin {
return requestedTenantID, true
}
tenantID := requestedTenantID
if tenantID == "" && profile != nil && profile.TenantID != nil {
tenantID = strings.TrimSpace(*profile.TenantID)
}
if tenantID == "" {
return "", false
}
if h == nil || h.Keto == nil || profile == nil || strings.TrimSpace(profile.ID) == "" {
return "", false
}
allowed, err := h.Keto.CheckPermission(c.Context(), "User:"+profile.ID, "Tenant", tenantID, "view_rp_usage_stats")
if err != nil || !allowed {
return "", false
}
return tenantID, true
}
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
}
func (h *AdminHandler) GetGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
if h == nil || h.DB == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "settings store unavailable",
})
}
var setting domain.SystemSetting
if err := h.DB.WithContext(c.Context()).First(&setting, "key = ?", globalCustomClaimsSettingKey).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(globalCustomClaimDefinitionsResponse{Items: []globalCustomClaimDefinition{}})
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(globalCustomClaimDefinitionsResponse{
Items: normalizeGlobalCustomClaimDefinitions(setting.Value["items"]),
})
}
func (h *AdminHandler) UpdateGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
if h == nil || h.DB == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "settings store unavailable",
})
}
var req globalCustomClaimDefinitionsResponse
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
}
items, err := validateGlobalCustomClaimDefinitions(req.Items)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
setting := domain.SystemSetting{
Key: globalCustomClaimsSettingKey,
Value: domain.JSONMap{"items": globalCustomClaimDefinitionsToJSON(items)},
}
if err := h.DB.WithContext(c.Context()).Save(&setting).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(globalCustomClaimDefinitionsResponse{Items: items})
}
func normalizeGlobalCustomClaimDefinitions(value any) []globalCustomClaimDefinition {
rawItems, ok := value.([]any)
if !ok {
return []globalCustomClaimDefinition{}
}
items := make([]globalCustomClaimDefinition, 0, len(rawItems))
for _, item := range rawItems {
raw, ok := item.(map[string]any)
if !ok {
continue
}
def := globalCustomClaimDefinition{
Key: strings.TrimSpace(stringValue(raw["key"])),
Label: strings.TrimSpace(stringValue(raw["label"])),
ValueType: normalizeGlobalCustomClaimType(stringValue(raw["valueType"])),
ReadPermission: adminNormalizeCustomClaimPermission(stringValue(raw["readPermission"])),
WritePermission: adminNormalizeCustomClaimPermission(stringValue(raw["writePermission"])),
Description: strings.TrimSpace(stringValue(raw["description"])),
}
if def.Key != "" {
items = append(items, def)
}
}
return items
}
func validateGlobalCustomClaimDefinitions(items []globalCustomClaimDefinition) ([]globalCustomClaimDefinition, error) {
seen := map[string]struct{}{}
normalized := make([]globalCustomClaimDefinition, 0, len(items))
for _, item := range items {
key := strings.TrimSpace(item.Key)
if key == "" {
continue
}
if !isValidCustomClaimKey(key) {
return nil, fiber.NewError(fiber.StatusBadRequest, "claim key must use letters, numbers, underscore, dot, or hyphen")
}
if _, exists := seen[key]; exists {
return nil, fiber.NewError(fiber.StatusBadRequest, "duplicate claim key: "+key)
}
seen[key] = struct{}{}
normalized = append(normalized, globalCustomClaimDefinition{
Key: key,
Label: strings.TrimSpace(item.Label),
ValueType: normalizeGlobalCustomClaimType(item.ValueType),
ReadPermission: adminNormalizeCustomClaimPermission(item.ReadPermission),
WritePermission: adminNormalizeCustomClaimPermission(item.WritePermission),
Description: strings.TrimSpace(item.Description),
})
}
return normalized, nil
}
func globalCustomClaimDefinitionsToJSON(items []globalCustomClaimDefinition) []any {
values := make([]any, 0, len(items))
for _, item := range items {
values = append(values, map[string]any{
"key": item.Key,
"label": item.Label,
"valueType": item.ValueType,
"readPermission": item.ReadPermission,
"writePermission": item.WritePermission,
"description": item.Description,
})
}
return values
}
func normalizeGlobalCustomClaimType(value string) string {
switch strings.ToLower(strings.TrimSpace(value)) {
case "number", "boolean", "array", "object", "date", "datetime":
return strings.ToLower(strings.TrimSpace(value))
default:
return "text"
}
}
func adminNormalizeCustomClaimPermission(value string) string {
if strings.TrimSpace(value) == "user_and_admin" {
return "user_and_admin"
}
return "admin_only"
}
func isValidCustomClaimKey(value string) bool {
for _, r := range value {
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '_' || r == '-' || r == '.' {
continue
}
return false
}
return true
}
func stringValue(value any) string {
if text, ok := value.(string); ok {
return text
}
return ""
}
func requireSuperAdminProfile(c *fiber.Ctx) bool {
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
if profile == nil || domain.NormalizeRole(profile.Role) != domain.RoleSuperAdmin {
_ = c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: super_admin required"})
return false
}
return true
}
func (h *AdminHandler) GetUserProjectionStatus(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.UserProjectionRepo == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
}
status, err := h.UserProjectionRepo.GetStatus(c.Context())
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(status)
}
func (h *AdminHandler) GetOrySSOTSystemStatus(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.UserProjectionRepo == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
}
projectionStatus, err := h.UserProjectionRepo.GetStatus(c.Context())
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
cacheStatus := domain.IdentityCacheStatus{
Status: "unavailable",
RedisReady: false,
LastError: "identity cache service unavailable",
}
if h.IdentityCache != nil {
cacheStatus, err = h.IdentityCache.GetIdentityCacheStatus(c.Context())
if err != nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
}
}
return c.JSON(fiber.Map{
"userProjection": projectionStatus,
"identityCache": cacheStatus,
})
}
func (h *AdminHandler) FlushIdentityCache(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IdentityCache == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "identity cache service unavailable"})
}
result, err := h.IdentityCache.FlushIdentityCache(c.Context())
if err != nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(result)
}
func (h *AdminHandler) GetDataIntegrity(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IntegrityChecker == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
}
report, err := h.IntegrityChecker.CheckDataIntegrity(c.Context())
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(report)
}
func (h *AdminHandler) ListOrphanUserLoginIDs(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IntegrityChecker == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
}
items, err := h.IntegrityChecker.ListOrphanUserLoginIDs(c.Context())
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{
"items": items,
"total": len(items),
})
}
func (h *AdminHandler) DeleteOrphanUserLoginIDs(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IntegrityChecker == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
}
var req struct {
IDs []string `json:"ids"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
result, err := h.IntegrityChecker.DeleteOrphanUserLoginIDs(c.Context(), req.IDs)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(result)
}
// GetSystemStats returns runtime statistics for monitoring
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
var m runtime.MemStats
runtime.ReadMemStats(&m)
ctx := c.Context()
stats := fiber.Map{
"totalTenants": h.countTenants(ctx),
"totalUsers": h.countUsers(ctx),
"oidcClients": h.countOIDCClients(ctx),
"auditEvents24h": h.countAuditEventsSince(ctx, time.Now().UTC().Add(-24*time.Hour)),
"goroutines": runtime.NumGoroutine(),
"cpus": runtime.NumCPU(),
"memory": fiber.Map{
"alloc": m.Alloc,
"totalAlign": m.TotalAlloc,
"sys": m.Sys,
"numGC": m.NumGC,
},
"timestamp": time.Now(),
}
return c.Status(fiber.StatusOK).JSON(stats)
}
func (h *AdminHandler) countTenants(ctx context.Context) int64 {
if h == nil || h.TenantRepo == nil {
return 0
}
_, total, err := h.TenantRepo.List(ctx, 1, 0, "", "")
if err != nil {
return 0
}
return total
}
func (h *AdminHandler) countUsers(ctx context.Context) int64 {
if h == nil || h.UserProjectionRepo == nil {
return 0
}
status, err := h.UserProjectionRepo.GetStatus(ctx)
if err != nil {
return 0
}
return status.ProjectedUsers
}
func (h *AdminHandler) countOIDCClients(ctx context.Context) int64 {
if h == nil || h.Hydra == nil {
return 0
}
const pageSize = 500
var total int64
for offset := 0; ; offset += pageSize {
clients, err := h.Hydra.ListClients(ctx, pageSize, offset)
if err != nil {
return total
}
for _, client := range clients {
if isHiddenSystemClient(client) {
continue
}
total++
}
if len(clients) < pageSize {
break
}
}
return total
}
func (h *AdminHandler) countAuditEventsSince(ctx context.Context, since time.Time) int64 {
if h == nil || h.AuditRepo == nil {
return 0
}
count, err := h.AuditRepo.CountEventsSince(ctx, since)
if err == nil && count > 0 {
return count
}
logs, pageErr := h.AuditRepo.FindPage(ctx, 10000, nil, "")
if pageErr != nil {
return count
}
var fallbackCount int64
for _, log := range logs {
if !log.Timestamp.Before(since) {
fallbackCount++
}
}
return fallbackCount
}

View File

@@ -0,0 +1,343 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
type fakeRPUsageQueryRepo struct {
query domain.RPUsageQuery
items []domain.RPUsageDailyMetric
}
func (f *fakeRPUsageQueryRepo) FindRPUsage(ctx context.Context, query domain.RPUsageQuery) ([]domain.RPUsageDailyMetric, error) {
f.query = query
return f.items, nil
}
type fakeAdminKeto struct {
allowed bool
subject string
object string
relation string
}
func (f *fakeAdminKeto) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
f.subject = subject
f.object = object
f.relation = relation
return f.allowed, nil
}
func (f *fakeAdminKeto) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
return nil
}
func (f *fakeAdminKeto) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
return nil
}
func (f *fakeAdminKeto) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
return nil, nil
}
func (f *fakeAdminKeto) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
return nil, nil
}
type fakeOverviewAuditRepo struct {
mockAuditRepo
since time.Time
count int64
}
func (f *fakeOverviewAuditRepo) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
f.since = since
return f.count, nil
}
type fakeAdminUserProjectionRepo struct {
status domain.UserProjectionStatus
}
func (f *fakeAdminUserProjectionRepo) IsReady(ctx context.Context) (bool, error) {
return f.status.Ready, nil
}
func (f *fakeAdminUserProjectionRepo) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
return nil, nil
}
func (f *fakeAdminUserProjectionRepo) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
return nil, nil
}
func (f *fakeAdminUserProjectionRepo) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
return nil
}
func (f *fakeAdminUserProjectionRepo) MarkFailed(ctx context.Context, syncErr error) error {
return nil
}
func (f *fakeAdminUserProjectionRepo) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
return f.status, nil
}
type fakeIdentityCacheAdmin struct {
status domain.IdentityCacheStatus
flush domain.IdentityCacheFlushResult
err error
statusHit int
flushCalls int
}
func (f *fakeIdentityCacheAdmin) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
f.statusHit++
return f.status, f.err
}
func (f *fakeIdentityCacheAdmin) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
f.flushCalls++
return f.flush, f.err
}
func TestAdminHandler_GetRPUsageDaily(t *testing.T) {
repo := &fakeRPUsageQueryRepo{
items: []domain.RPUsageDailyMetric{
{
Date: "2026-05-06",
TenantID: "tenant-1",
TenantType: domain.TenantTypeCompany,
ClientID: "orgfront",
ClientName: "OrgFront",
LoginRequests: 12,
OtherRequests: 4,
UniqueSubjects: 8,
},
},
}
h := &AdminHandler{RPUsageQueries: repo}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/rp-usage/daily", h.GetRPUsageDaily)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/rp-usage/daily?days=7&period=week&tenantId=tenant-1", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 7, repo.query.Days)
require.Equal(t, "week", repo.query.Period)
require.Equal(t, "tenant-1", repo.query.TenantID)
var body struct {
Items []domain.RPUsageDailyMetric `json:"items"`
Days int `json:"days"`
Period string `json:"period"`
TenantID string `json:"tenantId"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 7, body.Days)
require.Equal(t, "week", body.Period)
require.Equal(t, "tenant-1", body.TenantID)
require.Len(t, body.Items, 1)
require.Equal(t, "orgfront", body.Items[0].ClientID)
require.Equal(t, uint64(12), body.Items[0].LoginRequests)
}
func TestAdminHandler_UserProjectionStatusRequiresSuperAdmin(t *testing.T) {
h := &AdminHandler{
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{Name: domain.UserProjectionNameKratos, Status: domain.UserProjectionStatusReady, Ready: true},
},
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
return c.Next()
})
app.Get("/api/v1/admin/projections/users", h.GetUserProjectionStatus)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/projections/users", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
}
func TestAdminHandler_UserProjectionStatusReturnsProjectionStateForSuperAdmin(t *testing.T) {
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
h := &AdminHandler{
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusReady,
Ready: true,
LastSyncedAt: &syncedAt,
ProjectedUsers: 152,
},
},
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/projections/users", h.GetUserProjectionStatus)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/projections/users", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body domain.UserProjectionStatus
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, domain.UserProjectionNameKratos, body.Name)
require.Equal(t, domain.UserProjectionStatusReady, body.Status)
require.True(t, body.Ready)
require.Equal(t, int64(152), body.ProjectedUsers)
}
func TestAdminHandler_GetOrySSOTSystemStatusReturnsProjectionAndIdentityCache(t *testing.T) {
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
cache := &fakeIdentityCacheAdmin{
status: domain.IdentityCacheStatus{
Status: "ready",
RedisReady: true,
ObservedCount: 151,
KeyCount: 153,
LastRefreshedAt: &syncedAt,
UpdatedAt: &syncedAt,
},
}
h := &AdminHandler{
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusReady,
Ready: true,
LastSyncedAt: &syncedAt,
ProjectedUsers: 152,
},
},
IdentityCache: cache,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/ory/ssot", h.GetOrySSOTSystemStatus)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/ory/ssot", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
UserProjection domain.UserProjectionStatus `json:"userProjection"`
IdentityCache domain.IdentityCacheStatus `json:"identityCache"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, int64(152), body.UserProjection.ProjectedUsers)
require.True(t, body.IdentityCache.RedisReady)
require.Equal(t, int64(151), body.IdentityCache.ObservedCount)
require.Equal(t, int64(153), body.IdentityCache.KeyCount)
require.Equal(t, 1, cache.statusHit)
}
func TestAdminHandler_FlushIdentityCacheRequiresSuperAdminAndFlushesCacheOnly(t *testing.T) {
cache := &fakeIdentityCacheAdmin{
flush: domain.IdentityCacheFlushResult{
Status: "success",
FlushedKeys: 7,
UpdatedAt: time.Date(2026, 5, 11, 3, 2, 0, 0, time.UTC),
},
}
h := &AdminHandler{
IdentityCache: cache,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/admin/ory/ssot/identity-cache/flush", h.FlushIdentityCache)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/ory/ssot/identity-cache/flush", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body domain.IdentityCacheFlushResult
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, int64(7), body.FlushedKeys)
require.Equal(t, 1, cache.flushCalls)
}
func TestAdminHandler_GetRPUsageDailyChecksTenantPermission(t *testing.T) {
repo := &fakeRPUsageQueryRepo{}
keto := &fakeAdminKeto{allowed: true}
h := &AdminHandler{RPUsageQueries: repo, Keto: keto}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-1",
Role: "tenant_admin",
})
return c.Next()
})
app.Get("/api/v1/admin/rp-usage/daily", h.GetRPUsageDaily)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/rp-usage/daily?tenantId=tenant-allowed", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "User:user-1", keto.subject)
require.Equal(t, "tenant-allowed", keto.object)
require.Equal(t, "view_rp_usage_stats", keto.relation)
require.Equal(t, "tenant-allowed", repo.query.TenantID)
}
func TestAdminHandler_GetSystemStatsIncludesOverviewMetrics(t *testing.T) {
auditRepo := &fakeOverviewAuditRepo{count: 22}
h := &AdminHandler{
AuditRepo: auditRepo,
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusReady,
Ready: true,
ProjectedUsers: 152,
},
},
}
app := fiber.New()
app.Get("/api/v1/admin/stats", h.GetSystemStats)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/stats", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Contains(t, body, "totalTenants")
require.Contains(t, body, "totalUsers")
require.Contains(t, body, "oidcClients")
require.Contains(t, body, "auditEvents24h")
require.Equal(t, float64(152), body["totalUsers"])
require.Equal(t, float64(22), body["auditEvents24h"])
require.Equal(t, time.UTC, auditRepo.since.Location())
}

View File

@@ -0,0 +1,196 @@
package handler
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
type fakeDataIntegrityChecker struct {
calls int
listCalls int
deleteCalls int
deletedIDs []string
report domain.DataIntegrityReport
orphans []domain.OrphanUserLoginID
deleteResult domain.DeleteOrphanUserLoginIDsResult
err error
}
func (f *fakeDataIntegrityChecker) CheckDataIntegrity(ctx context.Context) (domain.DataIntegrityReport, error) {
f.calls++
return f.report, f.err
}
func (f *fakeDataIntegrityChecker) ListOrphanUserLoginIDs(ctx context.Context) ([]domain.OrphanUserLoginID, error) {
f.listCalls++
return f.orphans, f.err
}
func (f *fakeDataIntegrityChecker) DeleteOrphanUserLoginIDs(ctx context.Context, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error) {
f.deleteCalls++
f.deletedIDs = append([]string(nil), ids...)
return f.deleteResult, f.err
}
func TestAdminHandler_GetDataIntegrityRequiresSuperAdmin(t *testing.T) {
checker := &fakeDataIntegrityChecker{}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
return c.Next()
})
app.Get("/api/v1/admin/integrity", h.GetDataIntegrity)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
require.Equal(t, 0, checker.calls)
}
func TestAdminHandler_GetDataIntegrityReturnsReportForSuperAdmin(t *testing.T) {
checkedAt := time.Date(2026, 5, 14, 0, 0, 0, 0, time.UTC)
checker := &fakeDataIntegrityChecker{
report: domain.DataIntegrityReport{
Status: domain.DataIntegrityStatusFail,
CheckedAt: checkedAt,
Summary: domain.DataIntegritySummary{
TotalChecks: 1,
Failures: 1,
},
Sections: []domain.DataIntegritySection{
{
Key: "tenant_integrity",
Label: "테넌트 정합성",
Status: domain.DataIntegrityStatusFail,
Checks: []domain.DataIntegrityCheck{
{
Key: "duplicate_tenant_slugs",
Label: "중복 테넌트 slug",
Status: domain.DataIntegrityStatusFail,
Count: 1,
Severity: "error",
},
},
},
},
},
}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/integrity", h.GetDataIntegrity)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 1, checker.calls)
var body domain.DataIntegrityReport
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, domain.DataIntegrityStatusFail, body.Status)
require.Equal(t, int64(1), body.Summary.Failures)
require.Len(t, body.Sections, 1)
require.Equal(t, "tenant_integrity", body.Sections[0].Key)
}
func TestAdminHandler_ListOrphanUserLoginIDsReturnsTargetsForSuperAdmin(t *testing.T) {
checker := &fakeDataIntegrityChecker{
orphans: []domain.OrphanUserLoginID{
{
ID: "login-id-1",
UserID: "user-1",
TenantID: "tenant-1",
FieldKey: "emp_id",
LoginID: "EMP001",
Reasons: []string{"missing_tenant"},
},
},
}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/integrity/orphan-user-login-ids", h.ListOrphanUserLoginIDs)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity/orphan-user-login-ids", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 1, checker.listCalls)
var body struct {
Items []domain.OrphanUserLoginID `json:"items"`
Total int `json:"total"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 1, body.Total)
require.Equal(t, "login-id-1", body.Items[0].ID)
require.Equal(t, []string{"missing_tenant"}, body.Items[0].Reasons)
}
func TestAdminHandler_DeleteOrphanUserLoginIDsRequiresSuperAdminAndDeletesSelectedTargets(t *testing.T) {
checker := &fakeDataIntegrityChecker{
deleteResult: domain.DeleteOrphanUserLoginIDsResult{
DeletedCount: 1,
Deleted: []domain.OrphanUserLoginID{
{ID: "login-id-1", LoginID: "EMP001", Reasons: []string{"missing_user"}},
},
SkippedIDs: []string{"valid-login-id"},
},
}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Delete("/api/v1/admin/integrity/orphan-user-login-ids", h.DeleteOrphanUserLoginIDs)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/admin/integrity/orphan-user-login-ids", strings.NewReader(`{"ids":["login-id-1","valid-login-id"]}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 1, checker.deleteCalls)
require.Equal(t, []string{"login-id-1", "valid-login-id"}, checker.deletedIDs)
var body domain.DeleteOrphanUserLoginIDsResult
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, int64(1), body.DeletedCount)
require.Equal(t, []string{"valid-login-id"}, body.SkippedIDs)
}
func TestAdminHandler_DeleteOrphanUserLoginIDsRejectsTenantAdmin(t *testing.T) {
checker := &fakeDataIntegrityChecker{}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
return c.Next()
})
app.Delete("/api/v1/admin/integrity/orphan-user-login-ids", h.DeleteOrphanUserLoginIDs)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/admin/integrity/orphan-user-login-ids", strings.NewReader(`{"ids":["login-id-1"]}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
require.Equal(t, 0, checker.deleteCalls)
}

View File

@@ -0,0 +1,288 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/pagination"
"errors"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type ApiKeyHandler struct {
DB *gorm.DB
}
func NewApiKeyHandler(db *gorm.DB) *ApiKeyHandler {
return &ApiKeyHandler{DB: db}
}
type apiKeySummary struct {
ID string `json:"id"`
Name string `json:"name"`
ClientID string `json:"client_id"`
Scopes []string `json:"scopes"`
Status string `json:"status"`
LastUsedAt *string `json:"lastUsedAt"`
CreatedAt time.Time `json:"createdAt"`
}
type apiKeyListResponse struct {
Items []apiKeySummary `json:"items"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Cursor string `json:"cursor,omitempty"`
NextCursor string `json:"nextCursor,omitempty"`
}
func apiKeyToSummary(k domain.ApiKey) apiKeySummary {
lastUsed := ""
if k.LastUsedAt != nil {
lastUsed = k.LastUsedAt.Format(time.RFC3339)
}
return apiKeySummary{
ID: k.ID,
Name: k.Name,
ClientID: k.ClientID,
Scopes: strings.Fields(strings.ReplaceAll(k.Scopes, ",", " ")),
Status: k.Status,
LastUsedAt: &lastUsed,
CreatedAt: k.CreatedAt,
}
}
func apiKeyWithUpdatedScopes(k domain.ApiKey, scopes []string) domain.ApiKey {
k.Scopes = strings.Join(normalizeApiKeyScopes(scopes), " ")
return k
}
func apiKeyWithRotatedSecretHash(k domain.ApiKey, hashedSecret string) domain.ApiKey {
k.ClientSecretHash = hashedSecret
return k
}
func normalizeApiKeyScopes(scopes []string) []string {
seen := make(map[string]struct{}, len(scopes))
normalized := make([]string, 0, len(scopes))
for _, scope := range scopes {
scope = strings.TrimSpace(scope)
if scope == "" {
continue
}
if _, exists := seen[scope]; exists {
continue
}
seen[scope] = struct{}{}
normalized = append(normalized, scope)
}
return normalized
}
func (h *ApiKeyHandler) ListApiKeys(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0)
cursorRaw := strings.TrimSpace(c.Query("cursor"))
if limit <= 0 {
limit = 50
}
if offset < 0 {
offset = 0
}
var total int64
if err := h.DB.Model(&domain.ApiKey{}).Count(&total).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
var keys []domain.ApiKey
query := h.DB.Order("created_at desc, id desc").Limit(limit + 1)
if cursorRaw != "" {
cursor, err := pagination.Decode(cursorRaw)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid cursor")
}
query = pagination.ApplyCreatedAtIDCursor(query, cursor, "created_at", "id")
offset = 0
} else {
query = query.Offset(offset)
}
if err := query.Find(&keys).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
nextCursor := ""
hasMore := len(keys) > limit
if len(keys) > limit {
keys = keys[:limit]
}
if cursorRaw == "" && total > int64(offset+len(keys)) {
hasMore = true
}
if hasMore && len(keys) > 0 {
last := keys[len(keys)-1]
nextCursor = pagination.Encode(last.CreatedAt, last.ID)
}
items := make([]apiKeySummary, 0, len(keys))
for _, k := range keys {
items = append(items, apiKeyToSummary(k))
}
return c.JSON(apiKeyListResponse{
Items: items,
Total: total,
Limit: limit,
Offset: offset,
Cursor: cursorRaw,
NextCursor: nextCursor,
})
}
func (h *ApiKeyHandler) CreateApiKey(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
var req struct {
Name string `json:"name"`
Scopes []string `json:"scopes"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
if strings.TrimSpace(req.Name) == "" {
return errorJSON(c, fiber.StatusBadRequest, "name is required")
}
req.Scopes = normalizeApiKeyScopes(req.Scopes)
if len(req.Scopes) == 0 {
return errorJSON(c, fiber.StatusBadRequest, "at least one scope is required")
}
// Generate Client ID (16 chars hex)
clientID := GenerateSecureToken(8)
// Generate plain secret (16 chars hex)
plainSecret := GenerateSecureToken(8)
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(plainSecret), bcrypt.DefaultCost)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to hash secret")
}
apiKey := domain.ApiKey{
Name: req.Name,
ClientID: clientID,
ClientSecretHash: string(hashedSecret),
Scopes: strings.Join(req.Scopes, " "),
Status: "active",
}
if err := h.DB.Create(&apiKey).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
// Return summary + PLAIN SECRET (only this time)
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
"apiKey": apiKeyToSummary(apiKey),
"clientSecret": plainSecret, // VERY IMPORTANT: user must save this now
})
}
func (h *ApiKeyHandler) UpdateApiKey(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
id := c.Params("id")
if id == "" {
return errorJSON(c, fiber.StatusBadRequest, "id is required")
}
var req struct {
Scopes []string `json:"scopes"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
req.Scopes = normalizeApiKeyScopes(req.Scopes)
if len(req.Scopes) == 0 {
return errorJSON(c, fiber.StatusBadRequest, "at least one scope is required")
}
var apiKey domain.ApiKey
if err := h.DB.First(&apiKey, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errorJSON(c, fiber.StatusNotFound, "api key not found")
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
apiKey = apiKeyWithUpdatedScopes(apiKey, req.Scopes)
if err := h.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("scopes", apiKey.Scopes).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(apiKeyToSummary(apiKey))
}
func (h *ApiKeyHandler) RotateApiKeySecret(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
id := c.Params("id")
if id == "" {
return errorJSON(c, fiber.StatusBadRequest, "id is required")
}
var apiKey domain.ApiKey
if err := h.DB.First(&apiKey, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errorJSON(c, fiber.StatusNotFound, "api key not found")
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
plainSecret := GenerateSecureToken(8)
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(plainSecret), bcrypt.DefaultCost)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to hash secret")
}
apiKey = apiKeyWithRotatedSecretHash(apiKey, string(hashedSecret))
if err := h.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("client_secret_hash", apiKey.ClientSecretHash).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(fiber.Map{
"apiKey": apiKeyToSummary(apiKey),
"clientSecret": plainSecret,
})
}
func (h *ApiKeyHandler) DeleteApiKey(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
id := c.Params("id")
if id == "" {
return errorJSON(c, fiber.StatusBadRequest, "id is required")
}
if err := h.DB.Delete(&domain.ApiKey{}, "id = ?", id).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusNoContent)
}

View File

@@ -0,0 +1,133 @@
package handler
import (
"baron-sso-backend/internal/domain"
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
// Mock DB for ApiKey tests using a real GORM instance but with a hijacked connection
// or just a simple mock if we only check nil.
// For ApiKeyHandler, it uses DB for Create/List/Delete.
func TestApiKeyHandler_CreateApiKey(t *testing.T) {
app := fiber.New()
// ApiKeyHandler requires a valid DB connection to perform h.DB.Create
// Since we don't have a real DB here, we'll check if it fails gracefully
// or we can use sqlite in-memory for a more realistic test.
h := &ApiKeyHandler{DB: nil} // Testing ServiceUnavailable
app.Post("/api-keys", h.CreateApiKey)
input := map[string]any{
"name": "M2M Test",
"scopes": []string{"read", "write"},
}
body, _ := json.Marshal(input)
req := httptest.NewRequest("POST", "/api-keys", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
}
func TestApiKeyHandler_Validation(t *testing.T) {
app := fiber.New()
// Using a dummy DB pointer to pass the nil check
h := &ApiKeyHandler{DB: &gorm.DB{}}
app.Post("/api-keys", h.CreateApiKey)
// Missing name
input := map[string]any{
"scopes": []string{"read"},
}
body, _ := json.Marshal(input)
req := httptest.NewRequest("POST", "/api-keys", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}
func TestApiKeyHandler_UpdateApiKeyScopesRequiresDatabase(t *testing.T) {
app := fiber.New()
h := &ApiKeyHandler{DB: nil}
app.Patch("/api-keys/:id", h.UpdateApiKey)
body, _ := json.Marshal(map[string]any{
"scopes": []string{"org-context:read"},
})
req := httptest.NewRequest("PATCH", "/api-keys/api-key-id", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
}
func TestApiKeyHandler_RotateApiKeySecretRequiresDatabase(t *testing.T) {
app := fiber.New()
h := &ApiKeyHandler{DB: nil}
app.Post("/api-keys/:id/secret/rotate", h.RotateApiKeySecret)
req := httptest.NewRequest("POST", "/api-keys/api-key-id/secret/rotate", nil)
resp, _ := app.Test(req)
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
}
func TestApiKeyWithUpdatedScopesPreservesClientID(t *testing.T) {
key := domain.ApiKey{
ID: "api-key-id",
Name: "M2M Test",
ClientID: "client-id-stable",
ClientSecretHash: "old-secret-hash",
Scopes: "audit:read",
Status: "active",
}
updated := apiKeyWithUpdatedScopes(key, []string{"audit:read", "org-context:read"})
assert.Equal(t, "client-id-stable", updated.ClientID)
assert.Equal(t, "old-secret-hash", updated.ClientSecretHash)
assert.Equal(t, "audit:read org-context:read", updated.Scopes)
}
func TestApiKeyWithRotatedSecretHashPreservesClientIDAndScopes(t *testing.T) {
key := domain.ApiKey{
ID: "api-key-id",
Name: "M2M Test",
ClientID: "client-id-stable",
ClientSecretHash: "old-secret-hash",
Scopes: "audit:read org-context:read",
Status: "active",
}
updated := apiKeyWithRotatedSecretHash(key, "new-secret-hash")
assert.Equal(t, "client-id-stable", updated.ClientID)
assert.Equal(t, "audit:read org-context:read", updated.Scopes)
assert.Equal(t, "new-secret-hash", updated.ClientSecretHash)
}
func TestNormalizeApiKeyScopesTrimsAndDeduplicates(t *testing.T) {
scopes := normalizeApiKeyScopes([]string{
" audit:read ",
"",
"org-context:read",
"audit:read",
})
assert.Equal(t, []string{"audit:read", "org-context:read"}, scopes)
}

View File

@@ -0,0 +1,139 @@
package handler
import (
"baron-sso-backend/internal/domain"
"encoding/base64"
"errors"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
)
type AuditHandler struct {
repo domain.AuditRepository
}
func NewAuditHandler(repo domain.AuditRepository) *AuditHandler {
return &AuditHandler{repo: repo}
}
// CreateLog handles POST /api/v1/audit
func (h *AuditHandler) CreateLog(c *fiber.Ctx) error {
var req domain.AuditLog
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "Cannot parse JSON")
}
// Auto-fill metadata if missing
if req.IPAddress == "" {
req.IPAddress = c.IP()
}
if req.UserAgent == "" {
req.UserAgent = c.Get("User-Agent")
}
if req.Timestamp.IsZero() {
req.Timestamp = time.Now()
}
if req.EventID == "" {
req.EventID = ensureRequestID(c)
}
if h.repo == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
}
if err := h.repo.Create(&req); err != nil {
// Log internal error but don't expose details
return errorJSON(c, fiber.StatusInternalServerError, "Failed to save audit log")
}
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
"message": "Audit log saved",
})
}
// ListLogs handles GET /api/v1/audit
func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
limit := c.QueryInt("limit", 50)
cursorRaw := c.Query("cursor")
requestedTenantID := c.Query("tenantId")
cursor, err := parseAuditCursor(cursorRaw)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor")
}
if h.repo == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
}
// [New] Role-based Filtering
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
var filterTenantID string
if profile != nil {
if profile.Role == domain.RoleSuperAdmin {
// Super Admin can see everything or filter by a specific tenant if requested
filterTenantID = requestedTenantID
} else {
return errorJSON(c, fiber.StatusForbidden, "forbidden")
}
}
logs, err := h.repo.FindPage(c.Context(), limit+1, cursor, filterTenantID)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
}
nextCursor := ""
if len(logs) > limit {
last := logs[limit-1]
nextCursor = encodeAuditCursor(last)
logs = logs[:limit]
}
return c.JSON(fiber.Map{
"items": logs,
"limit": limit,
"cursor": cursorRaw,
"next_cursor": nextCursor,
})
}
func ensureRequestID(c *fiber.Ctx) string {
reqID := c.Get("X-Request-Id")
if reqID == "" {
reqID = uuid.New().String()
c.Set("X-Request-Id", reqID)
}
return reqID
}
func parseAuditCursor(raw string) (*domain.AuditCursor, error) {
if raw == "" {
return nil, nil
}
decoded, err := base64.RawURLEncoding.DecodeString(raw)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(decoded), "|", 2)
if len(parts) != 2 {
return nil, errors.New("invalid cursor")
}
ts, err := time.Parse(time.RFC3339Nano, parts[0])
if err != nil {
return nil, err
}
return &domain.AuditCursor{
Timestamp: ts,
EventID: parts[1],
}, nil
}
func encodeAuditCursor(log domain.AuditLog) string {
payload := log.Timestamp.UTC().Format(time.RFC3339Nano) + "|" + log.EventID
return base64.RawURLEncoding.EncodeToString([]byte(payload))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,371 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)
// --- Async Test Mocks ---
type AsyncMockIdpProvider struct {
mock.Mock
}
func (m *AsyncMockIdpProvider) Name() string { return "mock-idp" }
func (m *AsyncMockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) {
return &domain.IDPMetadata{}, nil
}
func (m *AsyncMockIdpProvider) UserExists(loginID string) (bool, error) {
args := m.Called(loginID)
return args.Bool(0), args.Error(1)
}
func (m *AsyncMockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
args := m.Called(user, password)
return args.String(0), args.Error(1)
}
func (m *AsyncMockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return &domain.PasswordPolicy{MinLength: 12}, nil
}
func (m *AsyncMockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
func (m *AsyncMockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
return nil
}
type AsyncMockUserRepo struct {
mock.Mock
createCalled chan bool
}
func (m *AsyncMockUserRepo) Create(ctx context.Context, user *domain.User) error {
// Simulate DB latency
time.Sleep(50 * time.Millisecond)
args := m.Called(ctx, user)
if m.createCalled != nil {
m.createCalled <- true
}
return args.Error(0)
}
func (m *AsyncMockUserRepo) Update(ctx context.Context, user *domain.User) error {
args := m.Called(ctx, user)
if m.createCalled != nil {
m.createCalled <- true
}
return args.Error(0)
}
func (m *AsyncMockUserRepo) Delete(ctx context.Context, id string) error { return nil }
func (m *AsyncMockUserRepo) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByID(ctx context.Context, id string) (*domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) {
return nil, 0, "", nil
}
func (m *AsyncMockUserRepo) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
return 0, nil
}
func (m *AsyncMockUserRepo) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
args := m.Called(ctx, tenantIDs)
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *AsyncMockUserRepo) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
args := m.Called(ctx, codes)
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *AsyncMockUserRepo) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) DB() *gorm.DB {
return nil
}
func (m *AsyncMockUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
return nil
}
func (m *AsyncMockUserRepo) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
return false, nil
}
func (m *AsyncMockUserRepo) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
return "", nil
}
type AsyncMockRedisRepo struct {
mock.Mock
}
func (m *AsyncMockRedisRepo) Set(key string, value string, expiration time.Duration) error {
args := m.Called(key, value, expiration)
return args.Error(0)
}
func (m *AsyncMockRedisRepo) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *AsyncMockRedisRepo) Delete(key string) error {
args := m.Called(key)
return args.Error(0)
}
func (m *AsyncMockRedisRepo) StoreVerificationCode(phone, code string) error { return nil }
func (m *AsyncMockRedisRepo) GetVerificationCode(phone string) (string, error) { return "", nil }
func (m *AsyncMockRedisRepo) DeleteVerificationCode(phone string) error { return nil }
type AsyncMockTenantService struct {
mock.Mock
}
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *AsyncMockTenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
args := m.Called(ctx, emailDomain)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *AsyncMockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
return nil, 0, nil
}
func (m *AsyncMockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
return false, nil
}
func (m *AsyncMockTenantService) ApproveTenant(ctx context.Context, id string) error { return nil }
func (m *AsyncMockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) SetKetoService(keto service.KetoService) {}
func (m *AsyncMockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
func (m *AsyncMockTenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
func (m *AsyncMockTenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
return nil, nil
}
func (m *AsyncMockTenantService) DeleteTenantsBulk(ctx context.Context, ids []string) error {
return nil
}
func (m *AsyncMockTenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
if args.Get(0) != nil {
return args.Get(0).([]domain.Tenant), args.Error(1)
}
return nil, args.Error(1)
}
type AsyncMockKetoService struct {
mock.Mock
}
func (m *AsyncMockKetoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
args := m.Called(ctx, namespace, object, relation, subject)
return args.Error(0)
}
func (m *AsyncMockKetoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
return nil
}
func (m *AsyncMockKetoService) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) {
return false, nil
}
func (m *AsyncMockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
return nil, nil
}
func (m *AsyncMockKetoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
return nil, nil
}
// --- Tests ---
func TestSignup_AsyncDB_Isolation(t *testing.T) {
mockIdp := new(AsyncMockIdpProvider)
mockUserRepo := new(AsyncMockUserRepo)
mockRedis := new(AsyncMockRedisRepo)
mockTenant := new(AsyncMockTenantService)
mockKeto := new(AsyncMockKetoService)
h := &AuthHandler{
IdpProvider: mockIdp,
UserRepo: mockUserRepo,
RedisService: mockRedis,
TenantService: mockTenant,
KetoService: mockKeto,
}
app := fiber.New()
app.Post("/signup", h.Signup)
t.Run("SoT_DB_Failure_Ignored_And_Async", func(t *testing.T) {
email := "test@example.com"
phone := "010-1234-5678"
emailKey := "signup:email:" + email
phoneKey := "signup:phone:" + "01012345678"
// Redis Mocks
mockRedis.On("Get", emailKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
mockRedis.On("Get", phoneKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
mockRedis.On("Delete", emailKey).Return(nil)
mockRedis.On("Delete", phoneKey).Return(nil)
// Tenant Mocks
personalTenant := &domain.Tenant{ID: "personal-t1", Slug: "personal-test", Type: domain.TenantTypePersonal, Status: domain.TenantStatusActive}
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(nil, nil)
mockTenant.On(
"RegisterTenant",
mock.Anything,
"Personal - test@example.com",
mock.MatchedBy(func(slug string) bool { return strings.HasPrefix(slug, "personal-") }),
domain.TenantTypePersonal,
"Automatically provisioned personal tenant",
[]string(nil),
(*string)(nil),
"",
).Return(personalTenant, nil)
mockTenant.On("GetTenant", mock.Anything, "personal-t1").Return(personalTenant, nil)
// Kratos Mocks (Success)
mockIdp.On("CreateUser", mock.Anything, "Password123!").Return("new-user-uuid", nil)
// UserRepo Mocks (Async & Failure)
mockUserRepo.createCalled = make(chan bool, 1)
mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
return u.Email == email
})).Return(errors.New("db connection error"))
// Keto Mocks (Optional, since it's also async)
// We won't block on this either
body, _ := json.Marshal(domain.SignupRequest{
Email: email,
Password: "Password123!",
Name: "Test User",
Phone: phone,
TermsAccepted: true,
})
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
start := time.Now()
resp, err := app.Test(req, 5000)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
assert.Equal(t, 200, resp.StatusCode)
// Ensure API responded faster than DB latency (50ms)
assert.Less(t, int64(elapsed), int64(60*time.Millisecond), "API should return before DB timeout")
// Wait for async execution
select {
case <-mockUserRepo.createCalled:
// Pass
case <-time.After(2 * time.Second):
t.Fatal("UserRepo.Create was not called asynchronously")
}
mockRedis.AssertExpectations(t)
mockIdp.AssertExpectations(t)
mockUserRepo.AssertExpectations(t)
})
}

View File

@@ -0,0 +1,200 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
func TestRevokeLinkedRp_Success(t *testing.T) {
// Mock Hydra transport for revocation
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
// 1. Kratos whoami
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{"id": "user-123"},
}), nil
}
// 2. Hydra Revoke
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
assert.Equal(t, "user-123", r.URL.Query().Get("subject"))
assert.Equal(t, "app-1", r.URL.Query().Get("client"))
return httpResponse(r, http.StatusNoContent, ""), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
auditRepo := &mockAuditRepo{}
rpUsageSink := &mockRPUsageEventSink{}
consentRepo := &mockConsentRepo{
consents: []domain.ClientConsent{
{
ClientID: "app-1",
Subject: "user-123",
GrantedScopes: []string{"openid", "profile"},
},
},
}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
AuditRepo: auditRepo,
ConsentRepo: consentRepo,
RPUsageSink: rpUsageSink,
}
app := fiber.New()
app.Delete("/api/v1/user/rp/linked/:id", h.RevokeLinkedRp)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/rp/linked/app-1", nil)
req.Header.Set("Cookie", "valid")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 1, len(auditRepo.logs))
assert.Equal(t, "consent.revoked", auditRepo.logs[0].EventType)
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
assert.Equal(t, "success", auditRepo.logs[0].Status)
auditDetails, err := utils.ParseAuditDetails(auditRepo.logs[0].Details)
assert.NoError(t, err)
assert.Equal(t, "app-1", auditDetails["client_id"])
assert.Equal(t, 1, len(rpUsageSink.events))
assert.Equal(t, domain.RPUsageEventTypeAuthorizationRevoked, rpUsageSink.events[0].EventType)
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
assert.Equal(t, "app-1", rpUsageSink.events[0].ClientID)
remaining, err := consentRepo.Find(req.Context(), "app-1", "user-123")
assert.NoError(t, err)
assert.Nil(t, remaining)
}
func TestRevokeLinkedRp_SendsBackchannelLogoutTokenWhenConfigured(t *testing.T) {
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
var receivedBody string
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{"id": "user-123"},
}), nil
}
if r.URL.Host == "hydra.test" && r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
return httpResponse(r, http.StatusNoContent, ""), nil
}
if r.URL.Host == "hydra.test" && r.Method == http.MethodGet && r.URL.Path == "/clients/app-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "app-1",
"backchannel_logout_uri": "https://rp.example.com/backchannel-logout",
}), nil
}
if r.URL.Host == "rp.example.com" && r.Method == http.MethodPost && r.URL.Path == "/backchannel-logout" {
raw, _ := io.ReadAll(r.Body)
receivedBody = string(raw)
return httpResponse(r, http.StatusNoContent, ""), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
backchannelLogout, err := service.NewBackchannelLogoutService()
assert.NoError(t, err)
backchannelLogout.HTTPClient = client
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
BackchannelLogout: backchannelLogout,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Delete("/api/v1/user/rp/linked/:id", h.RevokeLinkedRp)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/rp/linked/app-1", nil)
req.Header.Set("Cookie", "valid")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.True(t, strings.Contains(receivedBody, "logout_token="))
values, err := url.ParseQuery(receivedBody)
assert.NoError(t, err)
assert.NotEmpty(t, values.Get("logout_token"))
assert.Len(t, auditRepo.logs, 2)
assert.Equal(t, "backchannel_logout.sent", auditRepo.logs[1].EventType)
}
func TestListRpHistory_Aggregation(t *testing.T) {
now := time.Now()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "consent.revoked", // Newest
Timestamp: now,
Details: `{"client_id":"app-1"}`,
},
{
UserID: "user-123",
EventType: "consent.granted", // Oldest
Timestamp: now.Add(-1 * time.Hour),
Details: `{"client_id":"app-1", "client_name":"App One"}`,
},
},
}
h := &AuthHandler{
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/rp/history", h.ListRpHistory)
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{"id": "user-123"},
}), nil
})
http.DefaultClient = &http.Client{Transport: transport}
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/history", nil)
req.Header.Set("Cookie", "valid")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []struct {
ClientID string `json:"client_id"`
Status string `json:"status"`
} `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
assert.Equal(t, 1, len(res.Items))
assert.Equal(t, "app-1", res.Items[0].ClientID)
// Newest event (revoked) should win
assert.Equal(t, "revoked", res.Items[0].Status)
}

View File

@@ -0,0 +1,515 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Mocks ---
type MockKratosAdminServiceForConsent struct {
mock.Mock
}
func (m *MockKratosAdminServiceForConsent) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
args := m.Called(ctx, identifier)
return args.String(0), args.Error(1)
}
func (m *MockKratosAdminServiceForConsent) GetIdentity(ctx context.Context, id string) (*service.KratosIdentity, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*service.KratosIdentity), args.Error(1)
}
func (m *MockKratosAdminServiceForConsent) ListIdentities(ctx context.Context) ([]service.KratosIdentity, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) CreateIdentity(ctx context.Context, traits map[string]any) (*service.KratosIdentity, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) DeleteIdentity(ctx context.Context, identityID string) error {
return nil
}
func (m *MockKratosAdminServiceForConsent) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
return nil
}
func (m *MockKratosAdminServiceForConsent) ListIdentitySessions(ctx context.Context, identityID string) ([]service.KratosSession, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) GetSession(ctx context.Context, sessionID string) (*service.KratosSession, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) DeleteSession(ctx context.Context, sessionID string) error {
return nil
}
func (m *MockKratosAdminServiceForConsent) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
return "", nil
}
type MockTenantServiceForConsent struct {
mock.Mock
}
func (m *MockTenantServiceForConsent) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForConsent) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) GetTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
return nil, 0, nil
}
func (m *MockTenantServiceForConsent) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) ApproveTenant(ctx context.Context, id string) error {
return nil
}
func (m *MockTenantServiceForConsent) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForConsent) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForConsent) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
return false, nil
}
func (m *MockTenantServiceForConsent) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) SetKetoService(keto service.KetoService) {}
func (m *MockTenantServiceForConsent) DeleteTenantsBulk(ctx context.Context, ids []string) error {
return nil
}
// --- Test Helpers ---
func newConsentTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
return app
}
// --- Tests ---
func TestGetConsentRequest_Normal(t *testing.T) {
// Mock Hydra transport
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-123",
"requested_scope": []string{"openid", "profile"},
"skip": false,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"client_name": "Test App",
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := newConsentTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-123", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]any
json.NewDecoder(resp.Body).Decode(&body)
assert.Equal(t, "challenge-123", body["challenge"])
assert.Equal(t, false, body["skip"])
}
func TestGetConsentRequest_AddsMandatoryTenantScope(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-scope" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-tenant-scope",
"requested_scope": []string{"openid", "profile"},
"skip": false,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"client_name": "Test App",
"metadata": map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-allow"},
"structured_scopes": []map[string]any{
{"name": "openid", "mandatory": true},
{"name": "tenant", "mandatory": true, "locked": true},
{"name": "profile", "mandatory": false},
},
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
mockTenantSvc := &MockTenantServiceForConsent{}
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
// Mock profile resolution to allow tenant access
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@example.com",
},
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, "tenant-allow").Return(&domain.Tenant{
ID: "tenant-allow",
Slug: "tenant-allow",
Name: "Allowed Tenant",
}, nil)
// Mock hydration calls
mockTenantSvc.On("ListJoinedTenants", mock.Anything, mock.Anything).Return([]domain.Tenant{
{ID: "tenant-allow", Slug: "tenant-allow", Name: "Allowed Tenant"},
}, nil)
mockTenantSvc.On("ListManageableTenants", mock.Anything, mock.Anything).Return([]domain.Tenant{}, nil)
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
TenantService: mockTenantSvc,
KratosAdmin: mockKratosAdmin,
}
app := newConsentTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant-scope", nil)
req.Header.Set("X-Mock-Role", "user")
req.Header.Set("X-Tenant-ID", "tenant-allow")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]any
json.NewDecoder(resp.Body).Decode(&body)
assert.Equal(t, []any{"openid", "tenant", "profile"}, body["requested_scope"])
scopeDetails := body["scope_details"].(map[string]any)
tenantDetail := scopeDetails["tenant"].(map[string]any)
assert.Equal(t, true, tenantDetail["mandatory"])
}
func TestGetConsentRequest_Skip_AutoAccept(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
// Hydra: Get Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-skip",
"requested_scope": []string{"openid"},
"skip": true,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
},
}), nil
}
// Kratos: Get Identity
if r.URL.Path == "/admin/identities/user-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
},
}), nil
}
// Hydra: Accept Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-skip" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
consentRepo := &mockConsentRepo{}
rpUsageSink := &mockRPUsageEventSink{}
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: mockKratosAdmin,
ConsentRepo: consentRepo,
RPUsageSink: rpUsageSink,
}
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
},
}, nil)
app := newConsentTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]any
json.NewDecoder(resp.Body).Decode(&body)
assert.Equal(t, "http://rp/cb", body["redirectTo"])
assert.Equal(t, 1, len(rpUsageSink.events))
assert.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, rpUsageSink.events[0].EventType)
assert.Equal(t, "client-app", rpUsageSink.events[0].ClientID)
assert.Equal(t, "challenge-skip", rpUsageSink.events[0].CorrelationID)
assert.Equal(t, true, rpUsageSink.events[0].Payload["auto_accepted"])
}
func TestAcceptConsentRequest_Normal(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-accept" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-accept",
"requested_scope": []string{"openid", "profile"},
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"client_name": "Test App",
},
}), nil
}
if r.URL.Path == "/admin/identities/user-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-accept" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
auditRepo := &mockAuditRepo{}
consentRepo := &mockConsentRepo{}
rpUsageSink := &mockRPUsageEventSink{}
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: mockKratosAdmin,
AuditRepo: auditRepo,
ConsentRepo: consentRepo,
RPUsageSink: rpUsageSink,
}
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
},
}, nil)
app := newConsentTestApp(h)
body, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-accept",
"grant_scope": []string{"openid"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 1, len(auditRepo.logs))
assert.Equal(t, "consent.granted", auditRepo.logs[0].EventType)
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
assert.Equal(t, "success", auditRepo.logs[0].Status)
auditDetails, err := utils.ParseAuditDetails(auditRepo.logs[0].Details)
assert.NoError(t, err)
assert.Equal(t, "client-app", auditDetails["client_id"])
assert.Equal(t, "Test App", auditDetails["client_name"])
assert.Equal(t, []any{"openid"}, auditDetails["scopes"])
assert.Equal(t, 1, len(rpUsageSink.events))
assert.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, rpUsageSink.events[0].EventType)
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
assert.Equal(t, "client-app", rpUsageSink.events[0].ClientID)
assert.Equal(t, "Test App", rpUsageSink.events[0].ClientName)
assert.Equal(t, []string{"openid"}, []string(rpUsageSink.events[0].Scopes))
assert.Equal(t, "hydra_consent", rpUsageSink.events[0].Source)
}
func TestAcceptConsentRequest_EnforcesMandatoryTenantScope(t *testing.T) {
t.Setenv("APP_ENV", "dev")
var capturedGrantScopes []string
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-accept" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-tenant-accept",
"requested_scope": []string{"openid", "profile"},
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"metadata": map[string]any{
"tenant_id": "tenant-abc",
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-abc"},
"structured_scopes": []map[string]any{
{"name": "openid", "mandatory": true},
{"name": "tenant", "mandatory": true, "locked": true},
{"name": "profile", "mandatory": false},
},
},
},
}), nil
}
if r.URL.Path == "/admin/identities/user-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-accept" {
var payload map[string]any
assert.NoError(t, json.NewDecoder(r.Body).Decode(&payload))
for _, scope := range payload["grant_scope"].([]any) {
capturedGrantScopes = append(capturedGrantScopes, scope.(string))
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: mockKratosAdmin,
}
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
},
}, nil)
app := newConsentTestApp(h)
body, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-tenant-accept",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Mock-Role", "user")
req.Header.Set("X-Tenant-ID", "tenant-abc")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, []string{"openid", "tenant", "profile"}, capturedGrantScopes)
}

View File

@@ -0,0 +1,829 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestBuildOidcClaimsFromTraits_DynamicClaims(t *testing.T) {
traits := map[string]any{
"email": "user@baron.com",
"name": "홍길동",
"tenant_id": "primary-tenant-999", // Added primary tenant
"tenant-1": map[string]any{
"department": "개발팀",
"grade": "선임",
},
"tenant-2": map[string]any{
"department": "재무팀",
"grade": "팀장",
},
}
scopes := []string{"openid", "profile"}
t.Run("No tenantID", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "primary-tenant-999", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999") // Should contain primary
})
t.Run("With tenant-1", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-1")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "tenant-1", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
t.Run("With tenant-2", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-2")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "tenant-2", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
t.Run("With non-existent tenant", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-3")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "tenant-3", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
t.Run("Tenant scope includes detailed tenant metadata", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, []string{"openid", "profile", "tenant"}, "tenant-1")
assert.Equal(t, "tenant-1", claims["tenant_id"])
assert.Equal(t, "개발팀", claims["department"])
assert.Equal(t, "선임", claims["grade"])
assert.NotNil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
}
func TestRepresentativeTenantIDFromTraits(t *testing.T) {
t.Run("explicit tenant_id wins", func(t *testing.T) {
traits := map[string]any{
"tenant_id": "01970f0a-5c28-74d8-a73a-f6e9e9a7b210",
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", "isPrimary": true},
},
}
assert.Equal(t, "01970f0a-5c28-74d8-a73a-f6e9e9a7b210", representativeTenantIDFromTraits(traits))
})
t.Run("primary appointment wins when tenant_id is absent", func(t *testing.T) {
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630", "representative": true},
},
}
assert.Equal(t, "01970f0c-8c44-7069-9f20-7d28c0b8e630", representativeTenantIDFromTraits(traits))
})
t.Run("first appointment is fallback", func(t *testing.T) {
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630"},
},
}
assert.Equal(t, "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", representativeTenantIDFromTraits(traits))
})
}
func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
var capturedClaims map[string]any
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
// Hydra: Get Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-dynamic",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"metadata": map[string]any{
"tenant_id": "tenant-abc",
},
},
}), nil
}
// Kratos: Get Identity
if r.URL.Path == "/admin/identities/user-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
"name": "Test User",
"tenant-abc": map[string]any{
"department": "Innovation",
"position": "Architect",
},
},
}), nil
}
// Hydra: Accept Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
// Capture the claims sent to Hydra
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
"tenant-abc": map[string]any{
"department": "Innovation",
"position": "Architect",
},
},
}, nil)
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-dynamic",
"grant_scope": []string{"openid", "profile", "tenant"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify captured claims
assert.NotNil(t, capturedClaims)
assert.Equal(t, "user@test.com", capturedClaims["email"])
assert.Equal(t, "tenant-abc", capturedClaims["tenant_id"])
assert.Equal(t, "Innovation", capturedClaims["department"])
assert.Equal(t, "Architect", capturedClaims["position"])
}
func TestAcceptConsentRequest_UsesRepresentativeTenantIDInsteadOfClientTenantContext(t *testing.T) {
var capturedClaims map[string]any
representativeTenantID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
rpContextTenantID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-representative-tenant",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-representative",
"client": map[string]any{
"client_id": "client-app",
"metadata": map[string]any{
"tenant_id": rpContextTenantID,
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-representative").Return(&service.KratosIdentity{
ID: "user-representative",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
"additionalAppointments": []any{
map[string]any{"tenantId": representativeTenantID, "isPrimary": true},
map[string]any{"tenantId": rpContextTenantID},
},
},
}, nil)
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-representative-tenant",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
assert.Equal(t, representativeTenantID, capturedClaims["tenant_id"])
assert.Contains(t, capturedClaims["joined_tenants"], representativeTenantID)
assert.Contains(t, capturedClaims["joined_tenants"], rpContextTenantID)
assert.Nil(t, capturedClaims["tenants"])
}
func TestAcceptConsentRequest_IncludesHanmacFamilyTenantClaimDetails(t *testing.T) {
var capturedClaims map[string]any
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-hanmac-tenant-claim",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-hanmac",
"client": map[string]any{
"client_id": "hanmac-rp",
"metadata": map[string]any{
"tenant_id": deptID,
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-hanmac").Return(&service.KratosIdentity{
ID: "user-hanmac",
Traits: map[string]any{
"email": "hanmac-user@example.com",
"name": "한맥 사용자",
"additionalAppointments": []any{
map[string]any{
"tenantId": deptID,
"isPrimary": true,
"isOwner": true,
"grade": "책임",
"jobTitle": "기술기획",
"position": "팀장",
},
map[string]any{
"tenantId": secondDeptID,
"isPrimary": false,
"isOwner": false,
"grade": "선임",
"jobTitle": "품질관리",
"position": "파트원",
},
},
},
}, nil)
mockTenantSvc := new(MockTenantService)
mockTenantSvc.On("ListJoinedTenants", mock.Anything, "user-hanmac").Return([]domain.Tenant{}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
ID: deptID,
Slug: "tech-planning",
Name: "기술기획팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
ID: secondDeptID,
Slug: "quality",
Name: "품질관리팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
Name: "한맥기술",
Type: domain.TenantTypeCompany,
ParentID: &rootID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
ID: rootID,
Slug: "hanmac-family",
Name: "한맥가족",
Type: domain.TenantTypeCompanyGroup,
}, nil)
h.TenantService = mockTenantSvc
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-hanmac-tenant-claim",
"grant_scope": []string{"openid", "profile", "tenant"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
assert.Equal(t, []any{deptID}, capturedClaims["lead_tenants"])
assert.ElementsMatch(t, []any{deptID, secondDeptID}, capturedClaims["joined_tenants"])
tenants := capturedClaims["tenants"].(map[string]any)
dept := tenants[deptID].(map[string]any)
assert.Equal(t, true, dept["lead"])
assert.Equal(t, true, dept["representative"])
assert.Equal(t, "책임", dept["grade"])
assert.Equal(t, "기술기획", dept["jobTitle"])
assert.Equal(t, "팀장", dept["position"])
assert.Equal(t, companyID, dept["parentTenantId"])
assert.NotContains(t, dept, "parentTenant")
ancestors := dept["ancestors"].([]any)
assert.Len(t, ancestors, 2)
companyAncestor := ancestors[0].(map[string]any)
assert.Equal(t, companyID, companyAncestor["id"])
assert.Equal(t, "hanmac", companyAncestor["slug"])
assert.Equal(t, rootID, companyAncestor["parentTenantId"])
assert.NotContains(t, companyAncestor, "parentTenant")
rootAncestor := ancestors[1].(map[string]any)
assert.Equal(t, rootID, rootAncestor["id"])
assert.Equal(t, "hanmac-family", rootAncestor["slug"])
assert.Contains(t, rootAncestor, "parentTenantId")
assert.Nil(t, rootAncestor["parentTenantId"])
assert.NotContains(t, rootAncestor, "parentTenant")
secondDept := tenants[secondDeptID].(map[string]any)
assert.Equal(t, false, secondDept["lead"])
assert.Equal(t, false, secondDept["representative"])
assert.Equal(t, "선임", secondDept["grade"])
assert.Equal(t, "품질관리", secondDept["jobTitle"])
assert.Equal(t, "파트원", secondDept["position"])
assert.Equal(t, companyID, secondDept["parentTenantId"])
}
func TestWithHanmacFamilyTenantClaims_DefaultClaimsOnlyWithoutTenantScope(t *testing.T) {
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
mockTenantSvc := new(MockTenantService)
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
ID: deptID,
Slug: "tech-planning",
Name: "기술기획팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
ID: secondDeptID,
Slug: "quality",
Name: "품질관리팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
Name: "한맥기술",
Type: domain.TenantTypeCompany,
ParentID: &rootID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
ID: rootID,
Slug: "hanmac-family",
Name: "한맥가족",
Type: domain.TenantTypeCompanyGroup,
}, nil)
h := &AuthHandler{TenantService: mockTenantSvc}
claims := map[string]any{"tenant_id": deptID}
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{
"tenantId": deptID,
"isPrimary": true,
"isOwner": true,
"grade": "책임",
},
map[string]any{
"tenantId": secondDeptID,
"grade": "선임",
},
},
}
claims = h.withHanmacFamilyTenantClaims(context.Background(), claims, traits, []string{"openid", "profile"})
assert.Equal(t, deptID, claims["tenant_id"])
assert.ElementsMatch(t, []string{deptID, secondDeptID}, claims["joined_tenants"])
assert.NotContains(t, claims, "tenants")
assert.NotContains(t, claims, "lead_tenants")
}
func TestAcceptConsentRequest_IncludesRPProfileClaims(t *testing.T) {
var capturedClaims map[string]any
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-rp-profile",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"metadata": map[string]any{
"customUserSchema": []map[string]any{
{
"key": "approvalLevel",
"label": "승인 등급",
"type": "text",
"claimEnabled": true,
},
{
"key": "internalMemo",
"label": "내부 메모",
"type": "text",
"claimEnabled": false,
},
},
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
},
}, nil)
repo := new(devMockRPUserMetadataRepo)
repo.On("Get", mock.Anything, "client-app", "user-123").Return(&domain.RPUserMetadata{
ClientID: "client-app",
UserID: "user-123",
Metadata: domain.JSONMap{
"approvalLevel": "A",
"internalMemo": "관리자 전용",
},
}, nil).Once()
h.RPUserMetadataRepo = repo
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-rp-profile",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
rpProfiles, ok := capturedClaims["rp_profiles"].([]any)
assert.True(t, ok)
assert.Len(t, rpProfiles, 1)
profile := rpProfiles[0].(map[string]any)
assert.Equal(t, "client-app", profile["client_id"])
fields := profile["fields"].(map[string]any)
assert.Equal(t, "A", fields["approvalLevel"])
assert.NotContains(t, fields, "internalMemo")
repo.AssertExpectations(t)
}
func TestGetConsentRequest_Skip_DynamicClaims(t *testing.T) {
var capturedClaims map[string]any
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
// Hydra: Get Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-skip-dynamic",
"requested_scope": []string{"openid", "profile", "tenant"},
"skip": true,
"subject": "user-456",
"client": map[string]any{
"client_id": "skip-app",
"metadata": map[string]any{
"tenant_id": "tenant-xyz",
},
},
}), nil
}
// Kratos: Get Identity
if r.URL.Path == "/admin/identities/user-456" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-456",
"traits": map[string]any{
"email": "skip@test.com",
"tenant-xyz": map[string]any{
"department": "Security",
"position": "Officer",
},
},
}), nil
}
// Hydra: Accept Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
// Capture the claims sent to Hydra
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-456").Return(&service.KratosIdentity{
ID: "user-456",
Traits: map[string]any{
"email": "skip@test.com",
"tenant-xyz": map[string]any{
"department": "Security",
"position": "Officer",
},
},
}, nil)
app := fiber.New()
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip-dynamic", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify captured claims
assert.NotNil(t, capturedClaims)
assert.Equal(t, "skip@test.com", capturedClaims["email"])
assert.Equal(t, "tenant-xyz", capturedClaims["tenant_id"])
assert.Equal(t, "Security", capturedClaims["department"])
assert.Equal(t, "Officer", capturedClaims["position"])
}
func TestBuildOidcClaimsFromTraits_IncludesGlobalCustomClaims(t *testing.T) {
claims := buildOidcClaimsFromTraits(map[string]any{
"email": "user@test.com",
"name": "Test User",
"global_custom_claims": map[string]any{
"contract_date": "2026-06-09",
"approved_at": "2026-06-09T09:30:00+09:00",
"email": "override@test.com",
"rp_claims": "reserved",
},
"global_custom_claim_permissions": map[string]any{
"contract_date": map[string]any{
"readPermission": "user_and_admin",
"writePermission": "admin_only",
},
},
}, []string{"openid", "profile", "email"}, "")
assert.Equal(t, "2026-06-09", claims["contract_date"])
assert.Equal(t, "2026-06-09T09:30:00+09:00", claims["approved_at"])
assert.Equal(t, "user@test.com", claims["email"])
assert.NotEqual(t, "reserved", claims["rp_claims"])
assert.NotContains(t, claims, "global_custom_claim_permissions")
}
func TestAcceptConsentRequest_AppliesConfiguredIDTokenClaims(t *testing.T) {
var capturedClaims map[string]any
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-configured-claims" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-configured-claims",
"requested_scope": []string{"openid", "profile"},
"subject": "user-789",
"client": map[string]any{
"client_id": "client-configured-claims",
"metadata": map[string]any{
"tenant_id": "tenant-claims",
"id_token_claims": []map[string]any{
{
"namespace": "top_level",
"key": "locale",
"value": "ko-KR",
"valueType": "text",
},
{
"namespace": "top_level",
"key": "email",
"value": "should-not-override@example.com",
"valueType": "text",
},
{
"namespace": "rp_claims",
"key": "tier",
"value": "2",
"valueType": "number",
},
{
"namespace": "rp_claims",
"key": "features",
"value": "[\"sso\",\"claims\"]",
"valueType": "array",
},
},
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-configured-claims" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-789").Return(&service.KratosIdentity{
ID: "user-789",
Traits: map[string]any{
"email": "real-user@example.com",
"name": "Configured User",
"tenant-claims": map[string]any{
"department": "Platform",
},
},
}, nil)
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-configured-claims",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
assert.Equal(t, "real-user@example.com", capturedClaims["email"])
assert.Equal(t, "ko-KR", capturedClaims["locale"])
assert.Equal(t, "tenant-claims", capturedClaims["tenant_id"])
rpClaims, ok := capturedClaims["rp_claims"].(map[string]any)
if assert.True(t, ok) {
assert.Equal(t, float64(2), rpClaims["tier"])
assert.Equal(t, []any{"sso", "claims"}, rpClaims["features"])
}
}

View File

@@ -0,0 +1,904 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/testsupport"
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// Mock services
type mockEmailService struct {
lastTo string
lastSubject string
lastBody string
}
func (m *mockEmailService) SendEmail(to, subject, body string) error {
m.lastTo = to
m.lastSubject = subject
m.lastBody = body
return nil
}
type mockSmsService struct {
lastTo string
lastContent string
}
func (m *mockSmsService) SendSms(to, content string) error {
m.lastTo = to
m.lastContent = content
return nil
}
func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit)
app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll)
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
return app
}
func newKratosWhoamiTestServer(t *testing.T, identityID string) *httptest.Server {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/sessions/whoami" {
http.NotFound(w, r)
return
}
if r.Header.Get("Cookie") == "" && r.Header.Get("X-Session-Token") == "" {
http.Error(w, "missing session", http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "session-123",
"authenticated_at": "2026-05-21T00:00:00Z",
"identity": map[string]any{
"id": identityID,
"traits": map[string]any{
"email": "user@example.com",
},
},
})
}))
origDefaultClient := http.DefaultClient
http.DefaultClient = server.Client()
t.Cleanup(func() {
http.DefaultClient = origDefaultClient
})
t.Cleanup(server.Close)
return server
}
func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
EmailService: &mockEmailService{},
SmsService: &mockSmsService{},
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/init", h.InitEnchantedLink)
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
t.Setenv("USERFRONT_URL", "http://userfront.test")
// 1. Init Enchanted Link (Email)
body, _ := json.Marshal(map[string]string{
"loginId": "user@example.com",
"method": "email",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
// Find the token key "enchanted_token:..." in mock redis
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
// 2. Verify Magic Link
verifyBody, _ := json.Marshal(map[string]any{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// 3. Poll (Success)
pollBody, _ := json.Marshal(map[string]string{"pendingRef": pendingRef})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var pollResp map[string]any
json.NewDecoder(resp.Body).Decode(&pollResp)
assert.Equal(t, "ok", pollResp["status"])
assert.Equal(t, "valid-jwt", pollResp["sessionJwt"])
}
func TestEnchantedLinkFlow_Sms_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/init", h.InitEnchantedLink)
// 1. Init Enchanted Link (SMS)
body, _ := json.Marshal(map[string]string{
"loginId": "010-1234-5678",
"method": "sms",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
json.NewDecoder(resp.Body).Decode(&initResp)
assert.NotEmpty(t, initResp["userCode"])
}
func TestVerifyMagicLink_VerifyOnlyWithoutSharedBrowserSessionApprovesOnly(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
}}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
body, _ := json.Marshal(map[string]any{
"token": "token-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Nil(t, got["sessionJwt"])
assert.Nil(t, got["token"])
}
func TestVerifyMagicLink_VerifyOnlySharedBrowserSameSubjectApprovesOnly(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-user-1")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
body, _ := json.Marshal(map[string]any{
"token": "token-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Nil(t, got["sessionJwt"])
assert.Nil(t, got["token"])
}
func TestVerifyMagicLink_VerifyOnlySharedBrowserDifferentSubjectApprovesOnly(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
body, _ := json.Marshal(map[string]any{
"token": "token-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Nil(t, got["sessionJwt"])
assert.Nil(t, got["token"])
assert.Contains(t, redis.data[prefixSession+"pending-123"], "approved")
}
func TestResolveUserfrontURL_DevLocalhostUsesConfiguredPort(t *testing.T) {
t.Setenv("APP_ENV", "dev")
t.Setenv("USERFRONT_URL", "http://localhost:5000")
h := &AuthHandler{}
app := fiber.New()
app.Get("/probe", func(c *fiber.Ctx) error {
return c.SendString(h.resolveUserfrontURL(c))
})
req := httptest.NewRequest(http.MethodGet, "http://localhost/probe", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "http://localhost:5000", string(body))
}
func TestVerifyLoginCode_VerifyOnlySharedBrowserDifferentSubjectApprovesOnly(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixLoginCode + "user@example.com": "flow-123",
prefixLoginCodePending + "user@example.com": "pending-123",
prefixLoginCodeValue + "pending-123": "569765",
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
body, _ := json.Marshal(map[string]any{
"loginId": "user@example.com",
"code": "569765",
"pendingRef": "pending-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Nil(t, got["sessionJwt"])
assert.Nil(t, got["token"])
assert.Contains(t, redis.data[prefixSession+"pending-123"], "approved")
}
func TestVerifyLoginCode_MapsSmsPhoneBeforeFlowLookup(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixLoginCode + "su-@samaneng.com": "flow-123",
prefixLoginCodePending + "su-@samaneng.com": "pending-123",
prefixLoginCodeSmsLookup + "+821041585840": "su-@samaneng.com",
prefixLoginCodeSmsTarget + "su-@samaneng.com": "+821041585840",
prefixLoginCodeValue + "pending-123": "569765",
}}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
body, _ := json.Marshal(map[string]any{
"loginId": "01041585840",
"code": "569765",
"pendingRef": "pending-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Equal(t, "pending-123", got["pendingRef"])
}
func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
body, _ := json.Marshal(map[string]string{
"pendingRef": "missing-ref",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "expired_token", got["error"])
assert.Equal(t, "expired_token", got["code"])
}
func TestPollEnchantedLink_SharedBrowserSameSubjectIssuesSession(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixSession + "pending-123": `{"status":"approved","loginId":"user@example.com"}`,
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-user-1")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{
issueSession: &domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "new-session-id"},
Subject: "kratos-user-1",
},
},
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
body, _ := json.Marshal(map[string]string{"pendingRef": "pending-123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "ok", got["status"])
assert.Equal(t, "valid-jwt", got["sessionJwt"])
}
func TestPollEnchantedLink_SharedBrowserDifferentSubjectConflicts(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixSession + "pending-123": `{"status":"approved","loginId":"user@example.com"}`,
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{
issueSession: &domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "new-session-id"},
Subject: "kratos-user-1",
},
},
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
body, _ := json.Marshal(map[string]string{"pendingRef": "pending-123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusConflict, resp.StatusCode)
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "session_subject_conflict", got["code"])
assert.NotContains(t, redis.data[prefixSession+"pending-123"], "valid-jwt")
}
func TestHeadlessLinkInit_HeadlessLoginClientSuccess(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
if !testsupport.PortBindingAvailable() {
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
}
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
defer jwksServer.Close()
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
},
},
})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
body, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.NotEmpty(t, got["pendingRef"])
_, hasUserCode := got["userCode"]
assert.False(t, hasUserCode)
}
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
if !testsupport.PortBindingAvailable() {
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
}
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
ClientName: "local-demo-rp",
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil)
auditRepo := &mockAuditRepo{}
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
initBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
_ = json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
verifyBody, _ := json.Marshal(map[string]any{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
pollBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
"pendingRef": pendingRef,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var pollResp map[string]any
_ = json.NewDecoder(resp.Body).Decode(&pollResp)
assert.Equal(t, "http://rp/cb", pollResp["redirectTo"])
assert.Equal(t, "ok", pollResp["status"])
assert.Nil(t, pollResp["sessionJwt"])
assert.Nil(t, pollResp["token"])
assert.Empty(t, resp.Cookies())
if assert.Len(t, auditRepo.logs, 1) {
assert.Contains(t, auditRepo.logs[0].EventType, "/api/v1/auth/")
details, err := parseAuditDetails(auditRepo.logs[0].Details)
if err != nil {
t.Fatalf("failed to parse audit details: %v", err)
}
assert.Equal(t, "headless-login-client", details["client_id"])
assert.Equal(t, "local-demo-rp", details["client_name"])
assert.Equal(t, "challenge-123", details["login_challenge"])
}
}
func TestHeadlessLinkPoll_ApproverSubjectConflictBlocksMixedRP(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
if !testsupport.PortBindingAvailable() {
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
}
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
acceptCalled := false
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
ClientName: "local-demo-rp",
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
acceptCalled = true
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-target-b", nil)
auditRepo := &mockAuditRepo{}
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
},
SmsService: &mockSmsService{},
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
initBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
_ = json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
kratosPublic := newKratosWhoamiTestServer(t, "kratos-userfront-a")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
verifyBody, _ := json.Marshal(map[string]any{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=userfront-a-session")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
pollBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
"pendingRef": pendingRef,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusConflict, resp.StatusCode)
assert.False(t, acceptCalled)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "oidc_subject_conflict", got["code"])
assert.Equal(t, "redirect_to_userfront_login", got["recommendedAction"])
assert.Equal(t, "kratos-userfront-a", got["currentSubject"])
assert.Equal(t, "kratos-target-b", got["targetSubject"])
assert.Empty(t, auditRepo.logs)
}
func TestHeadlessLinkPoll_RequestCookieSubjectConflictBlocksMixedRP(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
if !testsupport.PortBindingAvailable() {
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
}
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
acceptCalled := false
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
acceptCalled = true
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-target-b", nil)
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
},
SmsService: &mockSmsService{},
KratosAdmin: mockKratos,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
initBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
_ = json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
verifyBody, _ := json.Marshal(map[string]any{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
kratosPublic := newKratosWhoamiTestServer(t, "kratos-userfront-a")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
pollBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
"pendingRef": pendingRef,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=userfront-a-session")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusConflict, resp.StatusCode)
assert.False(t, acceptCalled)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "oidc_subject_conflict", got["code"])
assert.Equal(t, "kratos-userfront-a", got["currentSubject"])
assert.Equal(t, "kratos-target-b", got["targetSubject"])
}

View File

@@ -0,0 +1,287 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
// --- Helper ---
func newLinkedRpTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Get("/api/v1/user/rp/linked", h.ListLinkedRps)
return app
}
// --- Tests ---
func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
if r.Header.Get("X-Session-Token") == "" && r.Header.Get("Cookie") == "" {
return httpResponse(r, http.StatusUnauthorized, "unauthorized"), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
},
},
}), nil
}
case "hydra.test":
if r.URL.Path == "/oauth2/auth/sessions/consent" {
return httpJSONAny(r, http.StatusOK, []map[string]any{
{
"client": map[string]any{
"client_id": "devfront",
"client_name": "DevFront",
"redirect_uris": []string{
"https://active.example.com/callback",
},
},
"grant_scope": []string{"openid", "profile"},
"handled_at": time.Now().Format(time.RFC3339),
},
{
"client": map[string]any{
"client_id": "orgfront",
"client_name": "OrgFront",
"metadata": map[string]any{
"auto_login_supported": true,
"auto_login_url": "http://localhost:5175/login",
},
"redirect_uris": []string{
"http://localhost:5175/auth/callback",
},
},
"grant_scope": []string{"openid", "profile"},
"handled_at": time.Now().Format(time.RFC3339),
},
}), nil
}
if r.URL.Path == "/admin/clients/client-audit" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-audit",
"client_name": "Audit App",
}), nil
}
if r.URL.Path == "/admin/clients/client-consent" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-consent",
"client_name": "Consent App",
}), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() {
http.DefaultClient = origDefault
}()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "consent.granted",
Timestamp: time.Now().Add(-10 * time.Hour),
Details: `{"client_id":"client-audit", "scopes":["audit_scope"]}`,
},
},
}
consentRepo := &mockConsentRepo{
consents: []domain.ClientConsent{
{
Subject: "user-123",
ClientID: "client-consent",
GrantedScopes: []string{"consent_scope"},
UpdatedAt: time.Now().Add(-2 * time.Hour),
},
},
}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
AuditRepo: auditRepo,
ConsentRepo: consentRepo,
KratosAdmin: new(MockKratosAdminService),
}
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
t.Setenv("HYDRA_PUBLIC_URL", "https://sso.example.com/oidc")
t.Setenv("DEVFRONT_URL", "http://localhost:5174")
app := newLinkedRpTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/linked", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []struct {
ID string `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
Scopes []string `json:"scopes"`
InitURL string `json:"init_url"`
AutoLoginSupported bool `json:"auto_login_supported"`
AutoLoginURL string `json:"auto_login_url"`
} `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
assert.Equal(t, 4, len(res.Items))
statusMap := make(map[string]string)
for _, item := range res.Items {
statusMap[item.ID] = item.Status
}
assert.Equal(t, "active", statusMap["devfront"])
assert.Equal(t, "active", statusMap["orgfront"])
assert.Equal(t, "inactive", statusMap["client-consent"])
assert.Equal(t, "inactive", statusMap["client-audit"])
var activeInitURL string
for _, item := range res.Items {
if item.ID == "devfront" {
activeInitURL = item.InitURL
break
}
}
parsedInitURL, err := url.Parse(activeInitURL)
assert.NoError(t, err)
assert.Equal(t, "http", parsedInitURL.Scheme)
assert.Equal(t, "localhost:5174", parsedInitURL.Host)
assert.Equal(t, "/login", parsedInitURL.Path)
assert.Equal(t, "1", parsedInitURL.Query().Get("auto"))
assert.Equal(t, "/clients", parsedInitURL.Query().Get("returnTo"))
var orgfrontItem struct {
InitURL string
AutoLoginSupported bool
AutoLoginURL string
}
for _, item := range res.Items {
if item.ID == "orgfront" {
orgfrontItem.InitURL = item.InitURL
orgfrontItem.AutoLoginSupported = item.AutoLoginSupported
orgfrontItem.AutoLoginURL = item.AutoLoginURL
break
}
}
assert.True(t, orgfrontItem.AutoLoginSupported)
assert.Equal(t, "http://localhost:5175/login?auto=1", orgfrontItem.AutoLoginURL)
assert.Equal(t, orgfrontItem.AutoLoginURL, orgfrontItem.InitURL)
}
func TestListLinkedRps_EnrichesLogoFromHydraClientWhenConsentSessionOmitsMetadata(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-123",
},
}), nil
}
case "hydra.test":
if r.URL.Path == "/oauth2/auth/sessions/consent" {
return httpJSONAny(r, http.StatusOK, []map[string]any{
{
"client": map[string]any{
"client_id": "gitea-client",
"client_name": "Gitea",
"redirect_uris": []string{
"https://gitea.example.com/callback",
},
},
"grant_scope": []string{"openid", "profile"},
"handled_at": time.Now().Format(time.RFC3339),
},
}), nil
}
if r.URL.Path == "/clients/gitea-client" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "gitea-client",
"client_name": "Gitea",
"redirect_uris": []string{
"https://gitea.example.com/callback",
},
"metadata": map[string]any{
"logo_url": "https://cdn.example.com/gitea.svg",
},
}), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() {
http.DefaultClient = origDefault
}()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
t.Setenv("HYDRA_PUBLIC_URL", "https://sso.example.com/oidc")
app := newLinkedRpTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/linked", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []struct {
ID string `json:"id"`
Logo string `json:"logo"`
} `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
assert.Len(t, res.Items, 1)
assert.Equal(t, "gitea-client", res.Items[0].ID)
assert.Equal(t, "https://cdn.example.com/gitea.svg", res.Items[0].Logo)
}

View File

@@ -0,0 +1,206 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
)
func newVerifyLoginCodeTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
app.Post("/api/v1/auth/login/code/verify-short", h.VerifyLoginShortCode)
return app
}
func decodeJSONBody(t *testing.T, resp *http.Response) map[string]any {
t.Helper()
var got map[string]any
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response body: %v", err)
}
return got
}
func TestVerifyLoginCode_InvalidBody_ReturnsExplicitCode(t *testing.T) {
h := &AuthHandler{}
app := newVerifyLoginCodeTestApp(h)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewBufferString("{"))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "bad_request" {
t.Fatalf("expected code=bad_request, got %v", got["code"])
}
}
func TestVerifyLoginCode_IdpUnavailable_ReturnsExplicitCode(t *testing.T) {
h := &AuthHandler{}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"loginId": "user@example.com",
"code": "AA-111111",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("expected 503, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "service_unavailable" {
t.Fatalf("expected code=service_unavailable, got %v", got["code"])
}
}
func TestVerifyLoginCode_VerifyOnlyInvalidCode_ReturnsExplicitCode(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
redis.data[prefixLoginCode+"user@example.com"] = "flow-1"
redis.data[prefixLoginCodePending+"user@example.com"] = "pending-1"
redis.data[prefixLoginCodeValue+"pending-1"] = "AB-123"
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"loginId": "user@example.com",
"code": "ZZ-999",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "invalid_code" {
t.Fatalf("expected code=invalid_code, got %v", got["code"])
}
}
func TestVerifyLoginShortCode_MissingShortCode_ReturnsExplicitCode(t *testing.T) {
h := &AuthHandler{
RedisService: &mockRedisRepo{data: make(map[string]string)},
}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"shortCode": "",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "bad_request" {
t.Fatalf("expected code=bad_request, got %v", got["code"])
}
}
func TestVerifyLoginShortCode_InvalidOrExpired_ReturnsExplicitCode(t *testing.T) {
h := &AuthHandler{
RedisService: &mockRedisRepo{data: make(map[string]string)},
}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"shortCode": "AB-123456",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "invalid_or_expired_code" {
t.Fatalf("expected code=invalid_or_expired_code, got %v", got["code"])
}
}
func TestVerifyLoginShortCode_VerifyOnlyMissingPendingRef_ReturnsExplicitCode(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
payload, _ := json.Marshal(shortLoginCodePayload{
LoginID: "user@example.com",
Code: "AB-123",
})
redis.data[prefixLoginCodeShort+"AB-123456"] = string(payload)
h := &AuthHandler{
RedisService: redis,
}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"shortCode": "AB-123456",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "invalid_session_reference" {
t.Fatalf("expected code=invalid_session_reference, got %v", got["code"])
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,178 @@
package handler
import (
"baron-sso-backend/internal/service"
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
)
func newOidcLoginTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/oidc/login/accept", h.AcceptOidcLoginRequest)
return app
}
func TestAcceptOidcLoginRequest_CookieOnly(t *testing.T) {
var gotSubject string
var gotChallenge string
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path != "/sessions/whoami" {
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
if r.Header.Get("X-Session-Token") != "" {
return httpResponse(r, http.StatusUnauthorized, "invalid token"), nil
}
if r.Header.Get("Cookie") == "" {
return httpResponse(r, http.StatusUnauthorized, "missing cookie"), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "kratos-123",
"traits": map[string]any{},
},
}), nil
case "hydra.test":
if r.URL.Path != "/oauth2/auth/requests/login/accept" {
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
gotChallenge = r.URL.Query().Get("login_challenge")
body, _ := io.ReadAll(r.Body)
var payload map[string]any
_ = json.Unmarshal(body, &payload)
if subject, ok := payload["subject"].(string); ok {
gotSubject = subject
}
return httpResponse(r, http.StatusOK, `{"redirect_to":"http://rp/cb"}`), nil
default:
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() {
http.DefaultClient = origDefault
}()
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := newOidcLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=abc123")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
var got map[string]string
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["redirectTo"] != "http://rp/cb" {
t.Fatalf("unexpected redirectTo: %v", got["redirectTo"])
}
if gotSubject != "kratos-123" {
t.Fatalf("unexpected subject: %v", gotSubject)
}
if gotChallenge != "challenge-123" {
t.Fatalf("unexpected login_challenge: %v", gotChallenge)
}
}
func TestAcceptOidcLoginRequest_TokenFallbackToCookie(t *testing.T) {
var gotSubject string
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path != "/sessions/whoami" {
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
if r.Header.Get("X-Session-Token") != "" {
return httpResponse(r, http.StatusUnauthorized, "invalid token"), nil
}
if r.Header.Get("Cookie") == "" {
return httpResponse(r, http.StatusUnauthorized, "missing cookie"), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "kratos-456",
"traits": map[string]any{},
},
}), nil
case "hydra.test":
if r.URL.Path != "/oauth2/auth/requests/login/accept" {
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
body, _ := io.ReadAll(r.Body)
var payload map[string]any
_ = json.Unmarshal(body, &payload)
if subject, ok := payload["subject"].(string); ok {
gotSubject = subject
}
return httpResponse(r, http.StatusOK, `{"redirect_to":"http://rp/cb"}`), nil
default:
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() {
http.DefaultClient = origDefault
}()
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := newOidcLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"login_challenge": "challenge-456",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer invalid-token")
req.Header.Set("Cookie", "ory_kratos_session=def456")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if gotSubject != "kratos-456" {
t.Fatalf("unexpected subject: %v", gotSubject)
}
}

View File

@@ -0,0 +1,110 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
func TestHandleKratosCourierRelay_Email(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
emailSvc := &mockEmailService{}
h := &AuthHandler{
RedisService: redis,
EmailService: emailSvc,
}
app := fiber.New()
app.Post("/api/v1/auth/kratos/courier", h.HandleKratosCourierRelay)
// Simulate Kratos Courier Request for Email
reqBody := map[string]any{
"recipient": "user@example.com",
"template_type": "verification_code",
"template_data": map[string]any{
"verification_code": "123456",
},
"subject": "Verify your email",
"body": "Your code is 123456",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/kratos/courier", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestVerifySignupCode_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Post("/api/v1/auth/signup/verify", h.VerifySignupCode)
// Mock stored code in redis
// signup:email:user@test.com -> {"code":"654321", "verified":false, "expires_at":...}
state := map[string]any{
"code": "654321",
"verified": false,
"expires_at": 9999999999, // far future
}
stateJSON, _ := json.Marshal(state)
redis.data["signup:email:user@test.com"] = string(stateJSON)
// Verify Code
verifyBody := map[string]string{
"type": "email",
"target": "user@test.com",
"code": "654321",
}
body, _ := json.Marshal(verifyBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/signup/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res map[string]any
json.NewDecoder(resp.Body).Decode(&res)
assert.True(t, res["success"].(bool))
// Check redis state updated to verified
val, _ := redis.Get("signup:email:user@test.com")
var updatedState map[string]any
json.Unmarshal([]byte(val), &updatedState)
assert.True(t, updatedState["verified"].(bool))
}
func TestVerifySignupCode_Invalid(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Post("/api/v1/auth/signup/verify", h.VerifySignupCode)
stateJSON, _ := json.Marshal(map[string]any{
"code": "111111",
"expires_at": 9999999999,
})
redis.data["signup:email:user@test.com"] = string(stateJSON)
verifyBody := map[string]string{
"type": "email",
"target": "user@test.com",
"code": "222222", // wrong code
}
body, _ := json.Marshal(verifyBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/signup/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}

View File

@@ -0,0 +1,244 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"maps"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
type recordingUpdateMeUserRepo struct {
MockUserRepoForHandler
updated *domain.User
loginIDs []domain.UserLoginID
}
func (r *recordingUpdateMeUserRepo) Update(ctx context.Context, user *domain.User) error {
copied := *user
r.updated = &copied
return nil
}
func (r *recordingUpdateMeUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
r.loginIDs = append([]domain.UserLoginID(nil), loginIDs...)
return nil
}
type recordingUpdateMeKratosAdmin struct {
MockKratosAdminService
updatedIdentityID string
updatedTraits map[string]any
updatedState string
storedTraits map[string]any
}
func (r *recordingUpdateMeKratosAdmin) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
r.updatedIdentityID = identityID
r.updatedTraits = maps.Clone(traits)
r.updatedState = state
if r.storedTraits != nil {
maps.Copy(r.storedTraits, traits)
}
return &service.KratosIdentity{
ID: identityID,
Traits: traits,
State: state,
}, nil
}
func TestUpdateMe_InvalidatesProfileCacheForTokenSession(t *testing.T) {
token := "token-abc"
identityID := "user-1"
traits := map[string]any{
"email": "qa@example.com",
"name": "QA User",
"phone_number": "+821012345678",
"department": "Old Dept",
"affiliationType": "employee",
"companyCode": "",
"role": domain.RoleUser,
}
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case r.URL.Host == "kratos.test" &&
r.URL.Path == "/sessions/whoami" &&
r.Method == http.MethodGet:
if r.Header.Get("X-Session-Token") != token {
return httpResponse(r, http.StatusUnauthorized, `{"error":"invalid token"}`), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": identityID,
"traits": traits,
},
}), nil
case r.URL.Host == "kratos.test" &&
r.URL.Path == "/admin/identities/"+identityID &&
r.Method == http.MethodPut:
var payload struct {
Traits map[string]any `json:"traits"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
return httpResponse(r, http.StatusBadRequest, `{"error":"invalid body"}`), nil
}
maps.Copy(traits, payload.Traits)
return httpResponse(r, http.StatusOK, `{"ok":true}`), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
setDefaultHTTPClientForTest(t, transport)
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
redis := &mockRedisRepo{data: make(map[string]string)}
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
h := &AuthHandler{
RedisService: redis,
KratosAdmin: kratosAdmin,
}
app := fiber.New()
app.Get("/api/v1/user/me", h.GetMe)
app.Put("/api/v1/user/me", h.UpdateMe)
// 1) 첫 조회로 Old Dept가 캐시에 저장됨
getReq1 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
getReq1.Header.Set("Authorization", "Bearer "+token)
getResp1, err := app.Test(getReq1, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, getResp1.StatusCode)
var profile1 map[string]any
require.NoError(t, json.NewDecoder(getResp1.Body).Decode(&profile1))
require.Equal(t, "Old Dept", profile1["department"])
// 2) 소속을 New Dept로 변경
updateBody, _ := json.Marshal(map[string]string{
"name": "QA User",
"phone": "01012345678",
"department": "New Dept",
})
updateReq := httptest.NewRequest(
http.MethodPut,
"/api/v1/user/me",
bytes.NewReader(updateBody),
)
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+token)
updateResp, err := app.Test(updateReq, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, updateResp.StatusCode)
require.Equal(t, "New Dept", traits["department"])
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
require.Equal(t, "New Dept", kratosAdmin.updatedTraits["department"])
// 3) 새로고침 재조회 시 New Dept가 보여야 함(캐시 무효화 회귀 방지)
getReq2 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
getReq2.Header.Set("Authorization", "Bearer "+token)
getResp2, err := app.Test(getReq2, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, getResp2.StatusCode)
var profile2 map[string]any
require.NoError(t, json.NewDecoder(getResp2.Body).Decode(&profile2))
require.Equal(t, "New Dept", profile2["department"])
}
func TestUpdateMe_SyncsLocalReadModelFields(t *testing.T) {
token := "token-sync"
identityID := "user-sync"
traits := map[string]any{
"email": "sync@example.com",
"name": "Old Name",
"phone_number": "+821012345678",
"department": "Old Dept",
"affiliationType": "employee",
"companyCode": "saman",
"tenant_id": "11111111-1111-1111-1111-111111111111",
"role": domain.RoleUser,
}
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case r.URL.Host == "kratos.test" &&
r.URL.Path == "/sessions/whoami" &&
r.Method == http.MethodGet:
if r.Header.Get("X-Session-Token") != token {
return httpResponse(r, http.StatusUnauthorized, `{"error":"invalid token"}`), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": identityID,
"traits": traits,
},
}), nil
case r.URL.Host == "kratos.test" &&
r.URL.Path == "/admin/identities/"+identityID &&
r.Method == http.MethodPut:
var payload struct {
Traits map[string]any `json:"traits"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
return httpResponse(r, http.StatusBadRequest, `{"error":"invalid body"}`), nil
}
maps.Copy(traits, payload.Traits)
return httpResponse(r, http.StatusOK, `{"ok":true}`), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
setDefaultHTTPClientForTest(t, transport)
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
redis := &mockRedisRepo{data: map[string]string{
"verify_update_phone:" + identityID + ":+821087654321": "verified",
}}
userRepo := &recordingUpdateMeUserRepo{}
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
h := &AuthHandler{
RedisService: redis,
UserRepo: userRepo,
KratosAdmin: kratosAdmin,
}
app := fiber.New()
app.Put("/api/v1/user/me", h.UpdateMe)
updateBody, _ := json.Marshal(map[string]any{
"name": "New Name",
"phone": "01087654321",
"department": "New Dept",
})
updateReq := httptest.NewRequest(
http.MethodPut,
"/api/v1/user/me",
bytes.NewReader(updateBody),
)
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+token)
updateResp, err := app.Test(updateReq, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, updateResp.StatusCode)
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
require.Equal(t, "New Name", kratosAdmin.updatedTraits["name"])
require.Equal(t, "+821087654321", kratosAdmin.updatedTraits["phone_number"])
require.NotNil(t, userRepo.updated)
require.Equal(t, identityID, userRepo.updated.ID)
require.Equal(t, "sync@example.com", userRepo.updated.Email)
require.Equal(t, "New Name", userRepo.updated.Name)
require.Equal(t, "+821087654321", userRepo.updated.Phone)
require.Equal(t, "New Dept", userRepo.updated.Department)
require.Empty(t, userRepo.updated.CompanyCode)
require.NotNil(t, userRepo.updated.TenantID)
require.Equal(t, "11111111-1111-1111-1111-111111111111", *userRepo.updated.TenantID)
}

View File

@@ -0,0 +1,206 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
// --- Mock Redis ---
type mockRedisRepo struct {
data map[string]string
}
func (m *mockRedisRepo) Set(key, value string, ttl time.Duration) error {
if m.data == nil {
m.data = make(map[string]string)
}
m.data[key] = value
return nil
}
func (m *mockRedisRepo) Get(key string) (string, error) {
// Bypass rate limiting for tests
if strings.HasPrefix(key, "poll_meta:") {
return "", nil
}
return m.data[key], nil
}
func (m *mockRedisRepo) Delete(key string) error {
delete(m.data, key)
return nil
}
func (m *mockRedisRepo) StoreVerificationCode(phone, code string) error {
return m.Set("sms:"+phone, code, time.Minute)
}
func (m *mockRedisRepo) GetVerificationCode(phone string) (string, error) {
return m.Get("sms:" + phone)
}
func (m *mockRedisRepo) DeleteVerificationCode(phone string) error {
return m.Delete("sms:" + phone)
}
// --- Tests ---
func TestQRLoginFlow_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Post("/api/v1/auth/qr/init", h.InitQRLogin)
app.Post("/api/v1/auth/qr/poll", h.PollQRLogin)
// 1. Init QR Login
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/init", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
// 2. Poll (Pending)
body, _ := json.Marshal(map[string]string{"pendingRef": pendingRef})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
// Expect authorization_pending (400)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var pollResp map[string]any
json.NewDecoder(resp.Body).Decode(&pollResp)
assert.Equal(t, "authorization_pending", pollResp["error"])
assert.Equal(t, "authorization_pending", pollResp["code"])
// 3. Mock Approval
sessionData, _ := json.Marshal(map[string]string{
"status": "success",
"jwt": "mock-session-jwt",
})
redis.data["enchanted_session:"+pendingRef] = string(sessionData)
// 4. Poll (Success)
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var successResp map[string]any
json.NewDecoder(resp.Body).Decode(&successResp)
assert.Equal(t, "ok", successResp["status"])
assert.Equal(t, "mock-session-jwt", successResp["sessionJwt"])
}
func TestScanQRLogin_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
idp := &mockIdpProvider{userExists: true}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
}
app := fiber.New()
app.Post("/api/v1/auth/qr/approve", h.ScanQRLogin)
pendingRef := "test-ref"
redis.data["enchanted_session:"+pendingRef] = `{"status":"pending"}`
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
origDefault := http.DefaultClient
http.DefaultClient = &http.Client{Transport: transport}
defer func() { http.DefaultClient = origDefault }()
body, _ := json.Marshal(map[string]string{
"pendingRef": pendingRef,
"token": "valid-token",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/approve", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestResolveConsentSubjects_TokenAndCookie(t *testing.T) {
h := &AuthHandler{}
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Header.Get("X-Session-Token") == "token-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-token",
"traits": map[string]any{
"email": "token@test.com",
},
},
}), nil
}
if r.Header.Get("Cookie") == "ory_kratos_session=cookie-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-cookie",
"traits": map[string]any{
"email": "cookie@test.com",
"phone": "010-1234-5678",
},
},
}), nil
}
return httpResponse(r, http.StatusUnauthorized, "unauthorized"), nil
})
origDefault := http.DefaultClient
http.DefaultClient = &http.Client{Transport: transport}
defer func() { http.DefaultClient = origDefault }()
app := fiber.New()
// Token case
app.Get("/test-token", func(c *fiber.Ctx) error {
subjects, err := h.resolveConsentSubjects(c)
assert.NoError(t, err)
assert.Contains(t, subjects, "user-token")
return c.SendStatus(200)
})
req := httptest.NewRequest("GET", "/test-token", nil)
req.Header.Set("Authorization", "Bearer token-123")
app.Test(req, -1)
// Cookie case
app.Get("/test-cookie", func(c *fiber.Ctx) error {
subjects, err := h.resolveConsentSubjects(c)
assert.NoError(t, err)
assert.Contains(t, subjects, "user-cookie")
return c.SendStatus(200)
})
req = httptest.NewRequest("GET", "/test-cookie", nil)
req.Header.Set("Cookie", "ory_kratos_session=cookie-123")
app.Test(req, -1)
}

View File

@@ -0,0 +1,105 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
func TestGetMe_IncludesSessionAuthenticatedAtFromKratosSession(t *testing.T) {
const (
token = "token-session"
identityID = "user-session"
sessionAuthenticated = "2026-03-23T15:30:00Z"
)
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "kratos.test" &&
r.URL.Path == "/sessions/whoami" &&
r.Method == http.MethodGet {
require.Equal(t, token, r.Header.Get("X-Session-Token"))
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "kratos-session-1",
"authenticated_at": sessionAuthenticated,
"identity": map[string]any{
"id": identityID,
"traits": map[string]any{
"email": "qa@example.com",
"name": "QA User",
"department": "Platform",
"affiliationType": "GENERAL",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
setDefaultHTTPClientForTest(t, transport)
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{}
app := fiber.New()
app.Get("/api/v1/user/me", h.GetMe)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var profile map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&profile))
require.Equal(t, sessionAuthenticated, profile["sessionAuthenticatedAt"])
}
func TestGetMe_IncludesSessionAuthenticatedAtForCookieSession(t *testing.T) {
const (
cookieHeader = "ory_kratos_session=session-cookie"
identityID = "user-cookie"
sessionAuthenticated = "2026-03-24T01:20:00Z"
)
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "kratos.test" &&
r.URL.Path == "/sessions/whoami" &&
r.Method == http.MethodGet {
require.Equal(t, cookieHeader, r.Header.Get("Cookie"))
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "kratos-session-cookie",
"authenticated_at": sessionAuthenticated,
"identity": map[string]any{
"id": identityID,
"traits": map[string]any{
"email": "cookie@example.com",
"name": "Cookie User",
"department": "Platform",
"affiliationType": "GENERAL",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
setDefaultHTTPClientForTest(t, transport)
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{}
app := fiber.New()
app.Get("/api/v1/user/me", h.GetMe)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
req.Header.Set("Cookie", cookieHeader)
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var profile map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&profile))
require.Equal(t, sessionAuthenticated, profile["sessionAuthenticatedAt"])
}

View File

@@ -0,0 +1,944 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestListMySessions_Success(t *testing.T) {
now := time.Date(2026, 4, 2, 1, 2, 3, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
{
ID: "other-sid",
Active: true,
AuthenticatedAt: now.Add(-2 * time.Hour),
ExpiresAt: now.Add(22 * time.Hour),
},
}, nil).Once()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "login_success",
SessionID: "other-sid",
Timestamp: now.Add(-30 * time.Minute),
IPAddress: "203.0.113.10",
UserAgent: "Mozilla/5.0",
},
},
}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
IsCurrent bool `json:"is_current"`
IsActive bool `json:"is_active"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 2) {
assert.Equal(t, "current-sid", body.Items[0].SessionID)
assert.True(t, body.Items[0].IsCurrent)
assert.Equal(t, "other-sid", body.Items[1].SessionID)
assert.True(t, body.Items[1].IsActive)
assert.Equal(t, "203.0.113.10", body.Items[1].IPAddress)
assert.Equal(t, "Mozilla/5.0", body.Items[1].UserAgent)
}
mockKratos.AssertExpectations(t)
}
func TestListMySessions_UsesConsentGrantForAppName(t *testing.T) {
now := time.Date(2026, 4, 2, 4, 40, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
{
ID: "c7c721ea-session",
Active: true,
AuthenticatedAt: now.Add(-5 * time.Minute),
ExpiresAt: now.Add(23*time.Hour + 55*time.Minute),
},
}, nil).Once()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "consent.granted",
SessionID: "c7c721ea-session",
Timestamp: now,
Details: `{"client_id":"devfront","client_name":"DevFront","session_id":"c7c721ea-session","approved_session_id":"c7c721ea-session"}`,
},
},
}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
AppName string `json:"app_name"`
ClientID string `json:"client_id"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 2) {
assert.Equal(t, "c7c721ea-session", body.Items[1].SessionID)
assert.Equal(t, "DevFront", body.Items[1].AppName)
assert.Equal(t, "devfront", body.Items[1].ClientID)
}
mockKratos.AssertExpectations(t)
}
func TestListMySessions_PreservesAppNameFromOlderConsentGrant(t *testing.T) {
now := time.Date(2026, 4, 2, 4, 40, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
{
ID: "c7c721ea-session",
Active: true,
AuthenticatedAt: now.Add(-5 * time.Minute),
ExpiresAt: now.Add(23*time.Hour + 55*time.Minute),
},
}, nil).Once()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "consent.granted",
SessionID: "c7c721ea-session",
Timestamp: now.Add(-30 * time.Second),
IPAddress: "203.0.113.10",
Details: `{"client_id":"devfront","client_name":"DevFront","session_id":"c7c721ea-session"}`,
},
{
UserID: "user-123",
EventType: "login_success",
SessionID: "c7c721ea-session",
Timestamp: now,
IPAddress: "10.0.0.12",
UserAgent: "Mozilla/5.0",
},
},
}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
AppName string `json:"app_name"`
ClientID string `json:"client_id"`
IPAddress string `json:"ip_address"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 2) {
assert.Equal(t, "c7c721ea-session", body.Items[1].SessionID)
assert.Equal(t, "DevFront", body.Items[1].AppName)
assert.Equal(t, "devfront", body.Items[1].ClientID)
assert.Equal(t, "203.0.113.10", body.Items[1].IPAddress)
}
mockKratos.AssertExpectations(t)
}
func TestListMySessions_CurrentSessionFallsBackToRequestMetadata(t *testing.T) {
now := time.Date(2026, 4, 6, 1, 2, 3, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
}, nil).Once()
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: &mockAuditRepo{},
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) Chrome/146.0.0.0 Safari/537.36")
req.Header.Set("X-Forwarded-For", "100.100.100.1, 203.0.113.25")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
IsCurrent bool `json:"is_current"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
ClientID string `json:"client_id"`
AppName string `json:"app_name"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 1) {
assert.Equal(t, "current-sid", body.Items[0].SessionID)
assert.True(t, body.Items[0].IsCurrent)
assert.Equal(t, "203.0.113.25", body.Items[0].IPAddress)
assert.Contains(t, body.Items[0].UserAgent, "Mozilla/5.0")
assert.Equal(t, "userfront", body.Items[0].ClientID)
assert.Equal(t, "UserFront", body.Items[0].AppName)
}
mockKratos.AssertExpectations(t)
}
func TestDeleteMySession_Success(t *testing.T) {
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
var hydraRevokeCalls int
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
case "hydra.test":
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
if r.URL.Query().Get("subject") != "user-123" {
t.Fatalf("unexpected revoke subject: %s", r.URL.Query().Get("subject"))
}
if r.URL.Query().Get("client") != "devfront" {
t.Fatalf("unexpected revoke client: %s", r.URL.Query().Get("client"))
}
hydraRevokeCalls++
return httpResponse(r, http.StatusNoContent, ""), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
setDefaultHTTPClientForTest(t, client.Transport)
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{ID: "target-sid", Active: true},
}, nil).Once()
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: true,
}, nil).Once()
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
UserID: "user-123",
EventType: "POST /api/v1/auth/oidc/login/accept",
SessionID: "target-sid",
Details: `{"client_id":"devfront","client_name":"Devfront"}`,
})
app := fiber.New()
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "session-test-agent")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
if assert.Len(t, auditRepo.logs, 2) {
assert.Equal(t, "session.revoked", auditRepo.logs[len(auditRepo.logs)-1].EventType)
assert.Equal(t, "user-123", auditRepo.logs[len(auditRepo.logs)-1].UserID)
assert.Equal(t, "current-sid", auditRepo.logs[len(auditRepo.logs)-1].SessionID)
assert.Contains(t, auditRepo.logs[len(auditRepo.logs)-1].Details, "target-sid")
}
assert.Equal(t, 1, hydraRevokeCalls)
mockKratos.AssertExpectations(t)
}
func TestDeleteMySession_DoesNotRevokeAllHydraSessionsWhenClientBindingMissing(t *testing.T) {
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
var hydraRevokeCalls int
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
case "hydra.test":
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
hydraRevokeCalls++
return httpResponse(r, http.StatusNoContent, ""), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
setDefaultHTTPClientForTest(t, client.Transport)
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{ID: "target-sid", Active: true},
}, nil).Once()
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: true,
}, nil).Once()
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := fiber.New()
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "session-test-agent")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 0, hydraRevokeCalls)
if assert.Len(t, auditRepo.logs, 1) {
assert.Equal(t, "session.revoked", auditRepo.logs[0].EventType)
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
assert.Contains(t, auditRepo.logs[0].Details, "target-sid")
}
mockKratos.AssertExpectations(t)
}
func TestDeleteMySession_SendsBackchannelLogoutTokenWhenClientConfigured(t *testing.T) {
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
var receivedBody string
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
case "hydra.test":
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
return httpResponse(r, http.StatusNoContent, ""), nil
}
if r.Method == http.MethodGet && r.URL.Path == "/clients/devfront" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "devfront",
"backchannel_logout_uri": "https://rp.example.com/backchannel-logout",
}), nil
}
case "rp.example.com":
if r.Method == http.MethodPost && r.URL.Path == "/backchannel-logout" {
raw, _ := io.ReadAll(r.Body)
receivedBody = string(raw)
return httpResponse(r, http.StatusNoContent, ""), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
setDefaultHTTPClientForTest(t, client.Transport)
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{ID: "target-sid", Active: true},
}, nil).Once()
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: true,
}, nil).Once()
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
backchannelLogout, err := service.NewBackchannelLogoutService()
assert.NoError(t, err)
backchannelLogout.HTTPClient = client
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
BackchannelLogout: backchannelLogout,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
UserID: "user-123",
EventType: "POST /api/v1/auth/oidc/login/accept",
SessionID: "target-sid",
Details: `{"client_id":"devfront","client_name":"Devfront"}`,
})
app := fiber.New()
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "session-test-agent")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.True(t, strings.Contains(receivedBody, "logout_token="))
values, err := url.ParseQuery(receivedBody)
assert.NoError(t, err)
assert.NotEmpty(t, values.Get("logout_token"))
foundBackchannelAudit := false
for _, log := range auditRepo.logs {
if log.EventType == "backchannel_logout.sent" {
foundBackchannelAudit = true
break
}
}
assert.True(t, foundBackchannelAudit)
mockKratos.AssertExpectations(t)
}
func TestDeleteMySession_RevokesHydraClientBoundFromPasswordLoginAudit(t *testing.T) {
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
var hydraRevokeCalls int
var revokedClient string
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
case "hydra.test":
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
revokedClient = r.URL.Query().Get("client")
hydraRevokeCalls++
return httpResponse(r, http.StatusNoContent, ""), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
setDefaultHTTPClientForTest(t, client.Transport)
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{ID: "target-sid", Active: true},
}, nil).Once()
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: true,
}, nil).Once()
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
UserID: "user-123",
EventType: "POST /api/v1/auth/password/login",
SessionID: "target-sid",
Details: `{"client_id":"adminfront","client_name":"AdminFront","session_id":"target-sid"}`,
})
app := fiber.New()
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "session-test-agent")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 1, hydraRevokeCalls)
assert.Equal(t, "adminfront", revokedClient)
mockKratos.AssertExpectations(t)
}
func TestGetHydraProfile_RejectsInactiveLinkedSession(t *testing.T) {
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "hydra.test" && r.URL.Path == "/oauth2/introspect" {
body, _ := io.ReadAll(r.Body)
if string(body) != "token=opaque-token" {
t.Fatalf("unexpected introspect body: %s", string(body))
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"active": true,
"sub": "user-123",
"client_id": "devfront",
"ext": map[string]any{
"session_id": "target-sid",
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
mockKratos := new(MockKratosAdminService)
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: false,
Identity: &service.KratosIdentity{
ID: "user-123",
},
}, nil).Once()
h := &AuthHandler{
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
profile, err := h.getHydraProfile(context.Background(), "opaque-token")
assert.Nil(t, profile)
assert.Error(t, err)
assert.Contains(t, err.Error(), "inactive")
mockKratos.AssertExpectations(t)
}
func TestGetAuthTimeline_FillsSessionIDFromOathkeeperRaw(t *testing.T) {
now := time.Date(2026, 4, 7, 4, 39, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
h := &AuthHandler{
AuditRepo: &mockAuditRepo{},
OathkeeperRepo: &mockOathkeeperRepo{
logs: []domain.OathkeeperAccessLog{
{
Timestamp: now,
RequestID: "req-1",
Method: http.MethodGet,
Path: "/api/v1/dev/sessions",
Status: http.StatusOK,
Subject: "user-123",
ClientIP: "203.0.113.7",
UserAgent: "Mozilla/5.0",
Raw: `{"request":{"url":"https://devfront.example.com/callback?client_id=devfront"},"extra":{"session_id":"target-sid"}}`,
},
},
},
}
app := fiber.New()
app.Get("/api/v1/audit/auth/timeline", h.GetAuthTimeline)
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/auth/timeline", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
ClientID string `json:"client_id"`
AppName string `json:"app_name"`
Source string `json:"source"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 1) {
assert.Equal(t, "target-sid", body.Items[0].SessionID)
assert.Equal(t, "devfront", body.Items[0].ClientID)
assert.Equal(t, "devfront", body.Items[0].AppName)
assert.Equal(t, "oathkeeper", body.Items[0].Source)
}
}
func TestGetAuthTimeline_IncludesHeadlessPasswordLogin(t *testing.T) {
now := time.Date(2026, 4, 7, 5, 10, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
h := &AuthHandler{
AuditRepo: &mockAuditRepo{
logs: []domain.AuditLog{
{
EventID: "audit-1",
Timestamp: now,
UserID: "user-123",
SessionID: "headless-session-1",
EventType: "POST /api/v1/auth/headless/password/login",
Status: "success",
IPAddress: "203.0.113.20",
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36",
Details: `{"client_id":"headless-login-client","client_name":"Headless Login Portal","session_id":"headless-session-1","login_id":"user@example.com","login_challenge":"challenge-123"}`,
},
},
},
}
app := fiber.New()
app.Get("/api/v1/audit/auth/timeline", h.GetAuthTimeline)
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/auth/timeline", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
ClientID string `json:"client_id"`
AppName string `json:"app_name"`
AuthMethod string `json:"auth_method"`
EventType string `json:"event_type"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 1) {
assert.Equal(t, "headless-session-1", body.Items[0].SessionID)
assert.Equal(t, "headless-login-client", body.Items[0].ClientID)
assert.Equal(t, "Headless Login Portal", body.Items[0].AppName)
assert.Equal(t, "비밀번호(Email)", body.Items[0].AuthMethod)
assert.Equal(t, "POST /api/v1/auth/headless/password/login", body.Items[0].EventType)
}
}
func TestListMySessions_UsesHeadlessPasswordLoginForClientBinding(t *testing.T) {
now := time.Date(2026, 4, 7, 5, 35, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
{
ID: "headless-session-1",
Active: true,
AuthenticatedAt: now.Add(-10 * time.Minute),
ExpiresAt: now.Add(23*time.Hour + 50*time.Minute),
},
}, nil).Once()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "POST /api/v1/auth/headless/password/login",
SessionID: "headless-session-1",
Timestamp: now,
IPAddress: "203.0.113.20",
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36",
Details: `{"client_id":"headless-login-client","client_name":"Headless Login Portal","session_id":"headless-session-1"}`,
},
},
}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
AppName string `json:"app_name"`
ClientID string `json:"client_id"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 2) {
assert.Equal(t, "headless-session-1", body.Items[1].SessionID)
assert.Equal(t, "Headless Login Portal", body.Items[1].AppName)
assert.Equal(t, "headless-login-client", body.Items[1].ClientID)
assert.Equal(t, "203.0.113.20", body.Items[1].IPAddress)
assert.Equal(t, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36", body.Items[1].UserAgent)
}
mockKratos.AssertExpectations(t)
}

View File

@@ -0,0 +1,144 @@
package handler
import (
"baron-sso-backend/internal/domain"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Local Mocks for Signup Test ---
type MockRedisForSignup struct {
mock.Mock
}
func (m *MockRedisForSignup) Set(key string, value string, ttl time.Duration) error {
return m.Called(key, value, ttl).Error(0)
}
func (m *MockRedisForSignup) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *MockRedisForSignup) Delete(key string) error {
return m.Called(key).Error(0)
}
func (m *MockRedisForSignup) StoreVerificationCode(phone, code string) error { return nil }
func (m *MockRedisForSignup) GetVerificationCode(phone string) (string, error) { return "", nil }
func (m *MockRedisForSignup) DeleteVerificationCode(phone string) error { return nil }
func (m *MockRedisForSignup) Ping(ctx context.Context) error { return nil }
type MockIdpForSignup struct {
mock.Mock
}
func (m *MockIdpForSignup) Name() string { return "mock-idp" }
func (m *MockIdpForSignup) GetMetadata() (*domain.IDPMetadata, error) {
return &domain.IDPMetadata{SupportedFields: []string{"email", "name", "phoneNumber", "grade", "department"}}, nil
}
func (m *MockIdpForSignup) CreateUser(user *domain.BrokerUser, password string) (string, error) {
args := m.Called(user, password)
return args.String(0), args.Error(1)
}
func (m *MockIdpForSignup) SignIn(loginID, password string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *MockIdpForSignup) UserExists(loginID string) (bool, error) { return false, nil }
func (m *MockIdpForSignup) IssueSession(loginID string) (*domain.AuthInfo, error) { return nil, nil }
func (m *MockIdpForSignup) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, nil
}
func (m *MockIdpForSignup) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *MockIdpForSignup) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return &domain.PasswordPolicy{MinLength: 12}, nil
}
func (m *MockIdpForSignup) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
func (m *MockIdpForSignup) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *MockIdpForSignup) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
return nil
}
func TestSignup_TenantSlugValidation(t *testing.T) {
app := fiber.New()
mockTenantSvc := new(MockTenantService)
mockRedis := new(MockRedisForSignup)
mockIdp := new(MockIdpForSignup)
h := &AuthHandler{
TenantService: mockTenantSvc,
RedisService: mockRedis,
IdpProvider: mockIdp,
}
app.Post("/signup", h.Signup)
// Prepare mock state (already verified email/phone)
verifiedState, _ := json.Marshal(map[string]any{
"verified": true,
"expires_at": time.Now().Add(time.Hour).Unix(),
})
mockRedis.On("Get", mock.Anything).Return(string(verifiedState), nil)
t.Run("Rejects legacy CompanyCode", func(t *testing.T) {
reqBody := domain.SignupRequest{
Email: "user@gmail.com",
Password: "StrongPass123!",
Name: "Test User",
Phone: "010-1234-5678",
TermsAccepted: true,
CompanyCode: "new-slug",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("Active Tenant Slug", func(t *testing.T) {
reqBody := domain.SignupRequest{
Email: "user@hanmaceng.co.kr",
Password: "StrongPass123!",
Name: "Test User",
Phone: "010-1234-5678",
TermsAccepted: true,
TenantSlug: "hanmac",
}
body, _ := json.Marshal(reqBody)
validTenant := &domain.Tenant{ID: "t1", Slug: "hanmac", Status: domain.TenantStatusActive}
mockTenantSvc.On("GetTenantByDomain", mock.Anything, "hanmaceng.co.kr").Return(&domain.Tenant{Slug: "hanmac"}, nil).Once()
mockTenantSvc.On("ProvisionTenantByDomain", mock.Anything, "hanmaceng.co.kr").Return(validTenant, nil).Maybe()
mockTenantSvc.On("GetTenantBySlug", mock.Anything, "hanmac").Return(validTenant, nil).Once()
mockTenantSvc.On("GetTenant", mock.Anything, "t1").Return(validTenant, nil).Once()
mockIdp.On("CreateUser", mock.Anything, mock.Anything).Return("user-id", nil).Once()
mockRedis.On("Delete", mock.Anything).Return(nil)
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
}

View File

@@ -0,0 +1,521 @@
package handler
import (
"baron-sso-backend/internal/middleware"
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
)
// helper to build a Fiber app with the handler route mounted.
func newTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/password/reset/complete", h.CompletePasswordReset)
return app
}
func newResetFlowTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/password/reset/verify", h.ProcessPasswordResetToken)
app.Post("/api/v1/auth/password/reset/complete", h.CompletePasswordReset)
return app
}
func newResetInitAppWithErrorCodeEnricher(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Use(middleware.ErrorCodeEnricher())
app.Post("/api/v1/auth/password/reset/init", h.InitiatePasswordReset)
return app
}
type testRedisRepo struct {
values map[string]string
}
func (m *testRedisRepo) Set(key string, value string, expiration time.Duration) error {
if m.values == nil {
m.values = map[string]string{}
}
m.values[key] = value
return nil
}
func (m *testRedisRepo) Get(key string) (string, error) {
if m.values == nil {
return "", nil
}
return m.values[key], nil
}
func (m *testRedisRepo) Delete(key string) error {
if m.values != nil {
delete(m.values, key)
}
return nil
}
func (m *testRedisRepo) StoreVerificationCode(phone, code string) error {
return m.Set("sms:"+phone, code, time.Minute)
}
func (m *testRedisRepo) GetVerificationCode(phone string) (string, error) {
return m.Get("sms:" + phone)
}
func (m *testRedisRepo) DeleteVerificationCode(phone string) error {
return m.Delete("sms:" + phone)
}
func TestCompletePasswordReset_MissingLoginID(t *testing.T) {
h := &AuthHandler{}
app := newTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": "Password1!",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing loginId, got %d", resp.StatusCode)
}
var got map[string]string
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "Login ID and new password are required" {
t.Fatalf("unexpected error message: %v", got["error"])
}
}
func TestCompletePasswordReset_InvalidPasswordPolicy(t *testing.T) {
h := &AuthHandler{}
app := newTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": "short", // too short + missing complexity
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete?loginId=user@example.com", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for weak password, got %d", resp.StatusCode)
}
var got map[string]string
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "비밀번호는 최소 12자 이상이어야 합니다" {
t.Fatalf("unexpected error message: %v", got["error"])
}
}
func TestCompletePasswordReset_NilIDPProvider(t *testing.T) {
h := &AuthHandler{} // IdpProvider intentionally nil to hit the configuration error branch
app := newTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": "StrongPass1!",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete?loginId=user@example.com", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("expected 500 when IDP provider is nil, got %d", resp.StatusCode)
}
var got map[string]string
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "Authentication service not configured" {
t.Fatalf("unexpected error message: %v", got["error"])
}
}
func TestCompletePasswordReset_TokenValueOverridesLoginIDQuery(t *testing.T) {
const resetToken = "tok-reset-1"
const tokenLoginID = "user@example.com"
const wrongLoginID = "wrong@example.com"
const newPassword = "StrongPass1!"
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + resetToken: tokenLoginID,
},
}
idp := &mockIdpProvider{
userExists: true,
err: nil,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
}
app := newResetFlowTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": newPassword,
})
url := fmt.Sprintf(
"/api/v1/auth/password/reset/complete?loginId=%s&token=%s",
wrongLoginID,
resetToken,
)
req := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if !idp.updateCalled {
t.Fatal("expected UpdateUserPassword to be called")
}
if idp.updatedLoginID != tokenLoginID {
t.Fatalf("expected loginId from token(%s), got %s", tokenLoginID, idp.updatedLoginID)
}
if idp.updatedPassword != newPassword {
t.Fatalf("expected newPassword propagated, got %s", idp.updatedPassword)
}
}
func TestCompletePasswordReset_InvalidTokenRejectedEvenWhenLoginIDExists(t *testing.T) {
const resetToken = "invalid-token"
redis := &testRedisRepo{
values: map[string]string{},
}
idp := &mockIdpProvider{
userExists: true,
err: nil,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
}
app := newResetFlowTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": "StrongPass1!",
})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/auth/password/reset/complete?loginId=user@example.com&token="+resetToken,
bytes.NewReader(body),
)
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401 for invalid token, got %d", resp.StatusCode)
}
if idp.updateCalled {
t.Fatal("UpdateUserPassword must not be called when token is invalid")
}
}
func TestCompletePasswordReset_DuplicateTokenSubmitIsIdempotent(t *testing.T) {
const resetToken = "dup-token"
const loginID = "user@example.com"
const newPassword = "StrongPass1!"
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + resetToken: loginID,
},
}
idp := &mockIdpProvider{
userExists: true,
err: nil,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
}
app := newResetFlowTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": newPassword,
})
url := fmt.Sprintf(
"/api/v1/auth/password/reset/complete?token=%s",
resetToken,
)
firstReq := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
firstReq.Header.Set("Content-Type", "application/json")
firstResp, err := app.Test(firstReq)
if err != nil {
t.Fatalf("first request failed: %v", err)
}
defer firstResp.Body.Close()
if firstResp.StatusCode != http.StatusOK {
t.Fatalf("expected first response to be 200, got %d", firstResp.StatusCode)
}
if idp.updateCallCount != 1 {
t.Fatalf("expected first request to update password once, got %d", idp.updateCallCount)
}
secondReq := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
secondReq.Header.Set("Content-Type", "application/json")
secondResp, err := app.Test(secondReq)
if err != nil {
t.Fatalf("second request failed: %v", err)
}
defer secondResp.Body.Close()
if secondResp.StatusCode != http.StatusOK {
t.Fatalf("expected duplicate response to be 200, got %d", secondResp.StatusCode)
}
if idp.updateCallCount != 1 {
t.Fatalf("expected duplicate request not to update password again, got %d", idp.updateCallCount)
}
}
func TestProcessPasswordResetToken_EncodesLoginIDInRedirect(t *testing.T) {
const token = "tok-enc"
const loginID = "user+alias@example.com"
t.Setenv("USERFRONT_URL", "https://sss.hmac.kr")
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + token: loginID,
},
}
h := &AuthHandler{
RedisService: redis,
}
app := newResetFlowTestApp(h)
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/auth/password/reset/verify?token="+token,
nil,
)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusFound {
t.Fatalf("expected 302, got %d", resp.StatusCode)
}
location := resp.Header.Get("Location")
if location == "" {
t.Fatal("missing redirect location")
}
redirectReq := httptest.NewRequest(http.MethodGet, location, nil)
gotLoginID := redirectReq.URL.Query().Get("loginId")
if gotLoginID != loginID {
t.Fatalf("expected encoded loginId round-trip=%s, got %s (location=%s)", loginID, gotLoginID, location)
}
}
func TestPasswordResetVerifyAlias_AcceptsShortVePath(t *testing.T) {
const token = "tok-ve"
const loginID = "user@example.com"
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + token: loginID,
},
}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Get("/api/v1/auth/password/reset/ve", h.VerifyPasswordResetPage)
app.Post("/api/v1/auth/password/reset/ve", h.ProcessPasswordResetToken)
getReq := httptest.NewRequest(
http.MethodGet,
"/api/v1/auth/password/reset/ve?token="+token,
nil,
)
getResp, err := app.Test(getReq)
if err != nil {
t.Fatalf("get request failed: %v", err)
}
defer getResp.Body.Close()
if getResp.StatusCode != http.StatusOK {
t.Fatalf("expected alias GET to return 200, got %d", getResp.StatusCode)
}
postReq := httptest.NewRequest(
http.MethodPost,
"/api/v1/auth/password/reset/ve?token="+token,
nil,
)
postResp, err := app.Test(postReq)
if err != nil {
t.Fatalf("post request failed: %v", err)
}
defer postResp.Body.Close()
if postResp.StatusCode != http.StatusFound {
t.Fatalf("expected alias POST to return 302, got %d", postResp.StatusCode)
}
}
func TestPasswordResetVerifyPathToken_AcceptsShortVPath(t *testing.T) {
const token = "tok-path"
const loginID = "user@example.com"
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + token: loginID,
},
}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Get("/api/v1/auth/password/reset/v/:token", h.VerifyPasswordResetPage)
app.Post("/api/v1/auth/password/reset/v/:token", h.ProcessPasswordResetToken)
getReq := httptest.NewRequest(
http.MethodGet,
"/api/v1/auth/password/reset/v/"+token,
nil,
)
getResp, err := app.Test(getReq)
if err != nil {
t.Fatalf("get request failed: %v", err)
}
defer getResp.Body.Close()
if getResp.StatusCode != http.StatusOK {
t.Fatalf("expected path-token GET to return 200, got %d", getResp.StatusCode)
}
postReq := httptest.NewRequest(
http.MethodPost,
"/api/v1/auth/password/reset/v/"+token,
nil,
)
postResp, err := app.Test(postReq)
if err != nil {
t.Fatalf("post request failed: %v", err)
}
defer postResp.Body.Close()
if postResp.StatusCode != http.StatusFound {
t.Fatalf("expected path-token POST to return 302, got %d", postResp.StatusCode)
}
}
func TestPasswordResetInit_LegacyErrorResponseHasCodeViaMiddleware(t *testing.T) {
h := &AuthHandler{}
app := newResetInitAppWithErrorCodeEnricher(h)
body, _ := json.Marshal(map[string]string{
"loginId": "",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
var got map[string]any
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "Login ID is required" {
t.Fatalf("unexpected error message: %v", got["error"])
}
if got["code"] != "bad_request" {
t.Fatalf("expected code=bad_request, got %v", got["code"])
}
}
func TestInitiatePasswordReset_SmsContainsVerifyLink(t *testing.T) {
t.Setenv("USERFRONT_URL", "https://sss.hmac.kr")
redis := &testRedisRepo{values: map[string]string{}}
smsSvc := &mockSmsService{}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
SmsService: smsSvc,
}
app := fiber.New()
app.Post("/api/v1/auth/password/reset/init", h.InitiatePasswordReset)
body, _ := json.Marshal(map[string]string{
"loginId": "01012345678",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if !strings.Contains(smsSvc.lastContent, "/api/v1/auth/password/reset/v/") {
t.Fatalf("expected SMS to contain short path verify link, got %q", smsSvc.lastContent)
}
if strings.Contains(smsSvc.lastContent, "/reset-password?token=") {
t.Fatalf("expected direct reset-password link to be removed, got %q", smsSvc.lastContent)
}
}

View File

@@ -0,0 +1,537 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/response"
"baron-sso-backend/internal/service"
"encoding/json"
"errors"
"sort"
"strings"
"github.com/gofiber/fiber/v2"
)
const (
clientTenantAccessRestrictedKey = "tenant_access_restricted"
clientAllowedTenantsKey = "allowed_tenants"
)
func normalizeClientTenantAccessMetadata(metadata map[string]any) (map[string]any, error) {
if metadata == nil {
metadata = map[string]any{}
}
restricted := readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey)
allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])
ownerTenantID := normalizeMetadataString(metadata["tenant_id"])
if len(allowedTenants) > 0 {
restricted = true
}
if !restricted {
delete(metadata, clientAllowedTenantsKey)
metadata[clientTenantAccessRestrictedKey] = false
return metadata, nil
}
if ownerTenantID != "" {
allowedTenants = append(allowedTenants, ownerTenantID)
}
allowedTenants = uniqueSortedStrings(allowedTenants)
if len(allowedTenants) == 0 {
return nil, errors.New("allowed_tenants is required when tenant_access_restricted is enabled")
}
metadata[clientTenantAccessRestrictedKey] = true
metadata[clientAllowedTenantsKey] = allowedTenants
return metadata, nil
}
func clientTenantAccessRestricted(metadata map[string]any) bool {
if metadata == nil {
return false
}
if readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey) {
return true
}
return len(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])) > 0
}
func clientAllowedTenants(metadata map[string]any) []string {
if metadata == nil {
return nil
}
if !clientTenantAccessRestricted(metadata) {
return nil
}
return uniqueSortedStrings(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey]))
}
func normalizeMetadataStringSlice(raw any) []string {
switch value := raw.(type) {
case []string:
return uniqueSortedStrings(value)
case []any:
items := make([]string, 0, len(value))
for _, item := range value {
if s, ok := item.(string); ok {
items = append(items, s)
}
}
return uniqueSortedStrings(items)
default:
return nil
}
}
func normalizeMetadataString(raw any) string {
s, ok := raw.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func uniqueSortedStrings(values []string) []string {
if len(values) == 0 {
return nil
}
seen := make(map[string]struct{}, len(values))
out := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
if _, ok := seen[trimmed]; ok {
continue
}
seen[trimmed] = struct{}{}
out = append(out, trimmed)
}
sort.Strings(out)
return out
}
func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
if !clientTenantAccessRestricted(client.Metadata) {
return true
}
allowed := clientAllowedTenants(client.Metadata)
if len(allowed) == 0 {
return false
}
keys := manageableTenantKeysFromProfile(profile)
if len(keys) == 0 {
return false
}
for _, tenantID := range allowed {
if _, ok := keys[strings.ToLower(strings.TrimSpace(tenantID))]; ok {
return true
}
}
return false
}
func clientTenantAccessAllowedForSubtree(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse, client domain.HydraClient) bool {
if clientTenantAccessAllowed(profile, client) {
return true
}
if tenantSvc == nil || profile == nil {
return false
}
allowedTenants := make([]domain.Tenant, 0)
for _, identifier := range clientAllowedTenants(client.Metadata) {
if tenant, ok := resolveTenantAccessTenant(c, tenantSvc, domain.Tenant{ID: identifier, Slug: identifier}); ok {
allowedTenants = append(allowedTenants, tenant)
}
}
if len(allowedTenants) == 0 {
return false
}
for _, candidate := range tenantAccessProfileTenants(profile) {
resolvedCandidate, ok := resolveTenantAccessTenant(c, tenantSvc, candidate)
if !ok {
continue
}
for _, allowed := range allowedTenants {
if tenantMatchesOrDescendsFrom(c, tenantSvc, resolvedCandidate, allowed) {
return true
}
}
}
return false
}
func tenantAccessProfileTenants(profile *domain.UserProfileResponse) []domain.Tenant {
if profile == nil {
return nil
}
seen := make(map[string]struct{})
tenants := make([]domain.Tenant, 0, len(profile.ManageableTenants)+len(profile.JoinedTenants)+2)
add := func(tenant domain.Tenant) {
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Name))
if key == "" {
return
}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
tenants = append(tenants, tenant)
}
if profile.Tenant != nil {
add(*profile.Tenant)
}
if profile.TenantID != nil {
add(domain.Tenant{ID: strings.TrimSpace(*profile.TenantID)})
}
for _, tenant := range profile.ManageableTenants {
add(tenant)
}
for _, tenant := range profile.JoinedTenants {
add(tenant)
}
return tenants
}
func resolveTenantAccessTenant(c *fiber.Ctx, tenantSvc service.TenantService, tenant domain.Tenant) (domain.Tenant, bool) {
if tenantSvc == nil {
return tenant, firstNonEmptyString(tenant.ID, tenant.Slug) != ""
}
if strings.TrimSpace(tenant.ID) != "" {
if resolved, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(tenant.ID)); err == nil && resolved != nil {
return *resolved, true
}
}
if strings.TrimSpace(tenant.Slug) != "" {
if resolved, err := tenantSvc.GetTenantBySlug(c.Context(), strings.TrimSpace(tenant.Slug)); err == nil && resolved != nil {
return *resolved, true
}
}
return tenant, firstNonEmptyString(tenant.ID, tenant.Slug) != ""
}
func tenantMatchesOrDescendsFrom(c *fiber.Ctx, tenantSvc service.TenantService, tenant domain.Tenant, ancestor domain.Tenant) bool {
if tenantAccessTenantMatches(tenant, ancestor) {
return true
}
if tenantSvc == nil {
return false
}
visited := make(map[string]struct{})
current := tenant
for current.ParentID != nil && strings.TrimSpace(*current.ParentID) != "" {
parentID := strings.TrimSpace(*current.ParentID)
if _, ok := visited[parentID]; ok {
return false
}
visited[parentID] = struct{}{}
parent, err := tenantSvc.GetTenant(c.Context(), parentID)
if err != nil || parent == nil {
return false
}
if tenantAccessTenantMatches(*parent, ancestor) {
return true
}
current = *parent
}
return false
}
func tenantAccessTenantMatches(left, right domain.Tenant) bool {
leftID := strings.ToLower(strings.TrimSpace(left.ID))
rightID := strings.ToLower(strings.TrimSpace(right.ID))
if leftID != "" && rightID != "" && leftID == rightID {
return true
}
leftSlug := strings.ToLower(strings.TrimSpace(left.Slug))
rightSlug := strings.ToLower(strings.TrimSpace(right.Slug))
return leftSlug != "" && rightSlug != "" && leftSlug == rightSlug
}
type tenantAccessDeniedDetails struct {
Account tenantAccessDeniedAccount `json:"account"`
CurrentTenant tenantAccessDeniedTenant `json:"current_tenant"`
AffiliatedTenants []tenantAccessDeniedTenant `json:"affiliated_tenants,omitempty"`
AllowedTenants []tenantAccessDeniedTenant `json:"allowed_tenants,omitempty"`
}
type tenantAccessDeniedAccount struct {
Email string `json:"email,omitempty"`
}
type tenantAccessDeniedTenant struct {
ID string `json:"id,omitempty"`
Slug string `json:"slug,omitempty"`
Name string `json:"name,omitempty"`
Identifier string `json:"identifier,omitempty"`
}
func tenantNotAllowedError(c *fiber.Ctx, details tenantAccessDeniedDetails) error {
return response.ErrorWithDetails(
c,
fiber.StatusForbidden,
"tenant_not_allowed",
"허용되지 않은 테넌트입니다.",
details,
)
}
func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
if profile == nil {
return false
}
return clientTenantAccessAllowed(profile, client)
}
func enforceClientTenantAccess(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) bool {
if !clientTenantAccessRestricted(client.Metadata) {
return false
}
details := buildTenantAccessDeniedDetails(c, tenantSvc, client, profile)
if resolveErr != nil || profile == nil {
_ = tenantNotAllowedError(c, details)
return true
}
if !clientTenantAccessAllowedForSubtree(c, tenantSvc, profile, client) {
_ = tenantNotAllowedError(c, details)
return true
}
return false
}
func buildTenantAccessDeniedDetails(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse) tenantAccessDeniedDetails {
details := tenantAccessDeniedDetails{
Account: tenantAccessDeniedAccount{Email: strings.TrimSpace(profileEmail(profile))},
CurrentTenant: resolveCurrentTenantDetails(c, tenantSvc, profile),
AffiliatedTenants: resolveAffiliatedTenantDetails(c, tenantSvc, profile),
}
for _, identifier := range clientAllowedTenants(client.Metadata) {
details.AllowedTenants = append(details.AllowedTenants, resolveAllowedTenantDetails(c, tenantSvc, identifier))
}
return details
}
func resolveAffiliatedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) []tenantAccessDeniedTenant {
if profile == nil {
return nil
}
seen := make(map[string]struct{})
out := make([]tenantAccessDeniedTenant, 0, len(profile.JoinedTenants)+1)
appendTenant := func(tenant tenantAccessDeniedTenant) {
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Identifier, tenant.Name))
if key == "" {
return
}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
out = append(out, tenant)
}
appendTenant(resolveCurrentTenantDetails(c, tenantSvc, profile))
for _, joined := range profile.JoinedTenants {
appendTenant(tenantAccessDeniedTenant{
ID: strings.TrimSpace(joined.ID),
Slug: strings.TrimSpace(joined.Slug),
Name: strings.TrimSpace(joined.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(joined.Slug), strings.TrimSpace(joined.ID)),
})
}
return out
}
func resolveCurrentTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) tenantAccessDeniedTenant {
if profile == nil {
return tenantAccessDeniedTenant{}
}
if profile.Tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(profile.Tenant.ID),
Slug: strings.TrimSpace(profile.Tenant.Slug),
Name: strings.TrimSpace(profile.Tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(profile.Tenant.Slug), strings.TrimSpace(profile.Tenant.ID)),
}
}
if tenantSvc != nil {
if profile.TenantID != nil && strings.TrimSpace(*profile.TenantID) != "" {
if tenant, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(*profile.TenantID)); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
}
}
}
}
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(pointerValue(profile.TenantID)),
Identifier: strings.TrimSpace(pointerValue(profile.TenantID)),
}
}
func resolveAllowedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, identifier string) tenantAccessDeniedTenant {
identifier = strings.TrimSpace(identifier)
if identifier == "" {
return tenantAccessDeniedTenant{}
}
if tenantSvc != nil {
if tenant, err := tenantSvc.GetTenant(c.Context(), identifier); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
}
}
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), identifier); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
}
}
}
return tenantAccessDeniedTenant{Identifier: identifier}
}
func profileEmail(profile *domain.UserProfileResponse) string {
if profile == nil {
return ""
}
return profile.Email
}
func pointerValue(value *string) string {
if value == nil {
return ""
}
return *value
}
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
type clientStructuredScope struct {
Name string `json:"name"`
Mandatory bool `json:"mandatory"`
Locked bool `json:"locked"`
}
func mergeRequestedScopesWithClientRequirements(client domain.HydraClient, requested []string) []string {
combined := make([]string, 0, len(requested)+2)
combined = append(combined, requested...)
combined = append(combined, requiredClientScopes(client)...)
return normalizeScopesInConsentOrder(combined)
}
func normalizeScopesInConsentOrder(scopes []string) []string {
combined := make([]string, 0, len(scopes))
combined = append(combined, scopes...)
seen := make(map[string]struct{}, len(combined))
out := make([]string, 0, len(combined))
appendIfPresent := func(scope string) {
scope = strings.TrimSpace(scope)
if scope == "" {
return
}
if _, ok := seen[scope]; ok {
return
}
for _, candidate := range combined {
if strings.TrimSpace(candidate) != scope {
continue
}
seen[scope] = struct{}{}
out = append(out, scope)
return
}
}
appendIfPresent("openid")
appendIfPresent("tenant")
for _, scope := range combined {
scope = strings.TrimSpace(scope)
if scope == "" {
continue
}
if _, ok := seen[scope]; ok {
continue
}
seen[scope] = struct{}{}
out = append(out, scope)
}
return out
}
func requiredClientScopes(client domain.HydraClient) []string {
required := make([]string, 0, 4)
if clientTenantAccessRestricted(client.Metadata) {
required = append(required, "tenant")
}
if client.Metadata == nil {
return normalizeScopesInConsentOrder(required)
}
rawStructuredScopes, ok := client.Metadata["structured_scopes"]
if !ok || rawStructuredScopes == nil {
return normalizeScopesInConsentOrder(required)
}
rawBytes, err := json.Marshal(rawStructuredScopes)
if err != nil {
return normalizeScopesInConsentOrder(required)
}
var scopes []clientStructuredScope
if err := json.Unmarshal(rawBytes, &scopes); err != nil {
return normalizeScopesInConsentOrder(required)
}
for _, scope := range scopes {
name := strings.TrimSpace(scope.Name)
if name == "" {
continue
}
if scope.Mandatory || scope.Locked {
required = append(required, name)
}
}
return normalizeScopesInConsentOrder(required)
}

View File

@@ -0,0 +1,458 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) {
var captured domain.HydraClient
ownerTenantID := "tenant-owner"
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.NoError(t, json.Unmarshal(body, &captured))
return httpJSONAny(r, http.StatusCreated, map[string]any{
"client_id": captured.ClientID,
"client_name": captured.ClientName,
"redirect_uris": captured.RedirectURIs,
"grant_types": captured.GrantTypes,
"response_types": captured.ResponseTypes,
"scope": captured.Scope,
"token_endpoint_auth_method": captured.TokenEndpointAuthMethod,
"metadata": captured.Metadata,
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
Keto: new(devMockKetoService),
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-1",
Role: domain.RoleSuperAdmin,
TenantID: &ownerTenantID,
})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"id": "client-tenant",
"name": "Tenant Client",
"type": "pkce",
"redirectUris": []string{"https://rp.example.com/cb"},
"metadata": map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-b", "tenant-a", "tenant-b"},
},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
assert.True(t, clientTenantAccessRestricted(captured.Metadata))
assert.Equal(t, []string{"tenant-a", "tenant-b", "tenant-owner"}, clientAllowedTenants(captured.Metadata))
}
func TestCreateClient_RejectsTenantAccessWithoutAllowedTenants(t *testing.T) {
hydraCalled := false
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
hydraCalled = true
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
Keto: new(devMockKetoService),
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "user-1", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"id": "client-tenant",
"name": "Tenant Client",
"type": "pkce",
"redirectUris": []string{"https://rp.example.com/cb"},
"metadata": map[string]any{
"tenant_access_restricted": true,
},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.False(t, hydraCalled)
}
func TestMergeRequestedScopesWithClientRequirements_AddsTenantScope(t *testing.T) {
client := domain.HydraClient{
Metadata: map[string]any{
"tenant_access_restricted": true,
"structured_scopes": []map[string]any{
{"name": "openid", "mandatory": true},
{"name": "tenant", "mandatory": true, "locked": true},
{"name": "profile", "mandatory": false},
},
},
}
merged := mergeRequestedScopesWithClientRequirements(client, []string{"openid", "profile"})
assert.Equal(t, []string{"openid", "tenant", "profile"}, merged)
}
func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant":
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-tenant",
"requested_scope": []string{"openid", "profile"},
"skip": false,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-tenant",
"metadata": map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-b"},
},
},
}), nil
case r.URL.Host == "kratos.test" && r.URL.Path == "/sessions/whoami":
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
"tenant_id": "tenant-a",
},
},
}), nil
default:
return httpJSONAny(r, http.StatusNotFound, nil), nil
}
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := fiber.New()
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
t.Setenv("APP_ENV", "dev")
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant", nil)
req.Header.Set("X-Mock-Role", "user")
req.Header.Set("X-Tenant-ID", "tenant-a")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.NotEmpty(t, account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.NotEmpty(t, currentTenant["identifier"])
}
func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) {
acceptCalled := false
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-profile-missing":
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-profile-missing",
"requested_scope": []string{"openid", "profile"},
"skip": false,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-tenant",
"metadata": map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-b"},
},
},
}), nil
case r.URL.Path == "/oauth2/auth/requests/consent/accept":
acceptCalled = true
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
default:
return httpJSONAny(r, http.StatusNotFound, nil), nil
}
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: func() service.KratosAdminService {
mockKratos := new(MockKratosAdminService)
mockKratos.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
"tenant_id": "tenant-a",
"companyCode": "tenant-a",
},
}, nil).Once()
return mockKratos
}(),
TenantService: func() service.TenantService {
tenantSvc := new(MockTenantService)
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
ID: "tenant-a",
Slug: "tenant-a",
Name: "Tenant A",
}, nil)
tenantSvc.On("GetTenant", mock.Anything, "tenant-c").Return(&domain.Tenant{
ID: "tenant-c",
Slug: "tenant-c",
Name: "Tenant C",
}, nil)
tenantSvc.On("ListJoinedTenants", mock.Anything, "user-123").Return([]domain.Tenant{
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
}, nil).Once()
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError)
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
ID: "tenant-b-id",
Slug: "tenant-b",
Name: "Tenant B",
}, nil)
return tenantSvc
}(),
ConsentRepo: &mockConsentRepo{
consents: []domain.ClientConsent{
{
ClientID: "client-tenant",
Subject: "user-123",
GrantedScopes: []string{"openid", "profile"},
},
},
},
}
app := fiber.New()
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-profile-missing", nil)
req.Header.Set("Cookie", "ory_kratos_session=invalid-session")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assert.False(t, acceptCalled)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "user@test.com", account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant A", currentTenant["name"])
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, affiliatedTenants, 2)
}
func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
app := fiber.New()
app.Get("/deny", func(c *fiber.Ctx) error {
tenantID := "tenant-a"
profile := &domain.UserProfileResponse{
ID: "user-123",
Role: domain.RoleUser,
Email: "user@test.com",
TenantID: &tenantID,
CompanyCode: "tenant-a",
JoinedTenants: []domain.Tenant{
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
},
}
client := domain.HydraClient{
ClientID: "client-tenant",
Metadata: map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-b"},
},
}
tenantSvc := new(MockTenantService)
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
ID: "tenant-a",
Slug: "tenant-a",
Name: "Tenant A",
}, nil)
tenantSvc.On("GetTenant", mock.Anything, "tenant-c").Return(&domain.Tenant{
ID: "tenant-c",
Slug: "tenant-c",
Name: "Tenant C",
}, nil)
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError)
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
ID: "tenant-b-id",
Slug: "tenant-b",
Name: "Tenant B",
}, nil)
enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
return nil
})
req := httptest.NewRequest(http.MethodGet, "/deny", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "user@test.com", account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant A", currentTenant["name"])
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, affiliatedTenants, 2)
allowedTenants, ok := details["allowed_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, allowedTenants, 1)
allowedTenant, ok := allowedTenants[0].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant B", allowedTenant["name"])
}
func TestAcceptOidcLoginRequest_AllowsRestrictedClientForHanmacFamilyDescendant(t *testing.T) {
app := fiber.New()
app.Get("/allow-descendant", func(c *fiber.Ctx) error {
hanmacFamilyID := "hanmac-family-id"
samanID := "saman-id"
profile := &domain.UserProfileResponse{
ID: "user-123",
Role: domain.RoleUser,
Email: "user@samaneng.com",
TenantID: &samanID,
Tenant: &domain.Tenant{
ID: samanID,
Slug: "saman",
Name: "삼안",
ParentID: &hanmacFamilyID,
},
JoinedTenants: []domain.Tenant{
{
ID: samanID,
Slug: "saman",
Name: "삼안",
ParentID: &hanmacFamilyID,
},
},
}
client := domain.HydraClient{
ClientID: "orgfront",
Metadata: map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"hanmac-family"},
},
}
tenantSvc := new(MockTenantService)
tenantSvc.On("GetTenant", mock.Anything, "hanmac-family").Return(nil, assert.AnError).Maybe()
tenantSvc.On("GetTenantBySlug", mock.Anything, "hanmac-family").Return(&domain.Tenant{
ID: hanmacFamilyID,
Slug: "hanmac-family",
Name: "한맥가족",
}, nil).Maybe()
tenantSvc.On("GetTenant", mock.Anything, samanID).Return(&domain.Tenant{
ID: samanID,
Slug: "saman",
Name: "삼안",
ParentID: &hanmacFamilyID,
}, nil).Maybe()
tenantSvc.On("GetTenant", mock.Anything, hanmacFamilyID).Return(&domain.Tenant{
ID: hanmacFamilyID,
Slug: "hanmac-family",
Name: "한맥가족",
}, nil).Maybe()
blocked := enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
assert.False(t, blocked)
return c.SendStatus(http.StatusNoContent)
})
req := httptest.NewRequest(http.MethodGet, "/allow-descendant", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
}

View File

@@ -0,0 +1,303 @@
package handler
import (
"baron-sso-backend/internal/domain"
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"slices"
"time"
)
// --- Mock IDP Provider ---
type mockIdpProvider struct {
userExists bool
name string
signInInfo *domain.AuthInfo
issueSession *domain.AuthInfo
verifyCodeInfo *domain.AuthInfo
err error
initiateLinkErr error
updateCalled bool
updateCallCount int
updatedLoginID string
updatedPassword string
}
func (m *mockIdpProvider) Name() string {
if m.name != "" {
return m.name
}
return "mock-idp"
}
func (m *mockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) { return nil, m.err }
func (m *mockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
return "mock-user-id", m.err
}
func (m *mockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
return m.signInInfo, m.err
}
func (m *mockIdpProvider) UserExists(loginID string) (bool, error) { return m.userExists, m.err }
func (m *mockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
if m.issueSession != nil {
return m.issueSession, m.err
}
return &domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "valid-sid"},
}, m.err
}
func (m *mockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
if m.initiateLinkErr != nil {
return nil, m.initiateLinkErr
}
return &domain.LinkLoginInit{FlowID: "mock-flow-id", Mode: "code"}, m.err
}
func (m *mockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return m.verifyCodeInfo, m.err
}
func (m *mockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) { return nil, m.err }
func (m *mockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return m.err }
func (m *mockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, m.err
}
func (m *mockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
m.updateCalled = true
m.updateCallCount++
m.updatedLoginID = loginID
m.updatedPassword = newPassword
return m.err
}
// --- Mock Audit Repository ---
type mockAuditRepo struct {
logs []domain.AuditLog
}
func (m *mockAuditRepo) Create(log *domain.AuditLog) error {
m.logs = append(m.logs, *log)
return nil
}
func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
return m.logs, nil
}
func (m *mockAuditRepo) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
var results []domain.AuditLog
for _, log := range m.logs {
if log.UserID == userID {
if slices.Contains(eventTypes, log.EventType) {
results = append(results, log)
}
}
}
return results, nil
}
func (m *mockAuditRepo) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
return 0, nil
}
func (m *mockAuditRepo) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
return 0, nil
}
func (m *mockAuditRepo) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
return 0, nil
}
func (m *mockAuditRepo) Ping(ctx context.Context) error { return nil }
type mockRPUsageEventSink struct {
events []domain.RPUsageEvent
err error
}
func (m *mockRPUsageEventSink) EmitRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
if m.err != nil {
return m.err
}
m.events = append(m.events, event)
return nil
}
type mockOathkeeperRepo struct {
logs []domain.OathkeeperAccessLog
}
func (m *mockOathkeeperRepo) FindPageBySubject(ctx context.Context, subject string, limit int, cursor *domain.AuditCursor) ([]domain.OathkeeperAccessLog, error) {
if subject == "" {
return m.logs, nil
}
results := make([]domain.OathkeeperAccessLog, 0, len(m.logs))
for _, log := range m.logs {
if log.Subject == subject {
results = append(results, log)
}
}
return results, nil
}
func (m *mockOathkeeperRepo) Ping(ctx context.Context) error { return nil }
// --- Mock Consent Repository ---
type mockConsentRepo struct {
consents []domain.ClientConsent
}
func (m *mockConsentRepo) Upsert(ctx context.Context, consent *domain.ClientConsent) error {
m.consents = append(m.consents, *consent)
return nil
}
func (m *mockConsentRepo) ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error) {
var results []domain.ClientConsent
for _, c := range m.consents {
if c.Subject == subject {
results = append(results, c)
}
}
return results, nil
}
func (m *mockConsentRepo) ListSubjectsByClient(ctx context.Context, clientID string) ([]string, error) {
seen := map[string]struct{}{}
subjects := make([]string, 0, len(m.consents))
for _, consent := range m.consents {
if consent.ClientID != clientID {
continue
}
if _, ok := seen[consent.Subject]; ok {
continue
}
seen[consent.Subject] = struct{}{}
subjects = append(subjects, consent.Subject)
}
return subjects, nil
}
func (m *mockConsentRepo) Find(ctx context.Context, clientID, subject string) (*domain.ClientConsent, error) {
for _, consent := range m.consents {
if consent.ClientID == clientID && consent.Subject == subject {
found := consent
return &found, nil
}
}
return nil, nil
}
func (m *mockConsentRepo) Delete(ctx context.Context, subject, clientID string) error {
filtered := m.consents[:0]
for _, consent := range m.consents {
if consent.Subject == subject && (clientID == "" || consent.ClientID == clientID) {
continue
}
filtered = append(filtered, consent)
}
m.consents = filtered
return nil
}
func (m *mockConsentRepo) DeleteByClient(ctx context.Context, clientID string) error {
filtered := m.consents[:0]
for _, consent := range m.consents {
if consent.ClientID != clientID {
filtered = append(filtered, consent)
}
}
m.consents = filtered
return nil
}
func (m *mockConsentRepo) List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
results := make([]domain.ClientConsentWithTenantInfo, 0, len(m.consents))
for _, consent := range m.consents {
if consent.ClientID == clientID {
results = append(results, domain.ClientConsentWithTenantInfo{ClientConsent: consent})
}
}
return results, int64(len(results)), nil
}
func (m *mockConsentRepo) ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
results := make([]domain.ClientConsentWithTenantInfo, 0, len(m.consents))
for _, consent := range m.consents {
if consent.ClientID == clientID {
results = append(results, domain.ClientConsentWithTenantInfo{
ClientConsent: consent,
TenantID: tenantID,
})
}
}
return results, int64(len(results)), nil
}
// --- Mock Secret Repository ---
type mockSecretRepo struct {
secrets map[string]string
}
func (m *mockSecretRepo) Upsert(ctx context.Context, clientID, secret string) error {
if m.secrets == nil {
m.secrets = make(map[string]string)
}
m.secrets[clientID] = secret
return nil
}
func (m *mockSecretRepo) GetByID(ctx context.Context, clientID string) (string, error) {
return m.secrets[clientID], nil
}
func (m *mockSecretRepo) Delete(ctx context.Context, clientID string) error {
delete(m.secrets, clientID)
return nil
}
// --- HTTP Mock Helpers ---
type roundTripFunc func(req *http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func setDefaultHTTPClientForTest(t interface{ Cleanup(func()) }, transport http.RoundTripper) {
origDefault := http.DefaultClient
http.DefaultClient = &http.Client{Transport: transport}
t.Cleanup(func() {
http.DefaultClient = origDefault
})
}
func httpResponse(r *http.Request, code int, body string) *http.Response {
return &http.Response{
StatusCode: code,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewBufferString(body)),
Request: r,
}
}
func httpJSONAny(r *http.Request, code int, data any) *http.Response {
body, _ := json.Marshal(data)
return &http.Response{
StatusCode: code,
Header: http.Header{
"Content-Type": []string{"application/json"},
},
Body: io.NopCloser(bytes.NewBuffer(body)),
Request: r,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,242 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestDevHandler_Isolation(t *testing.T) {
createHandler := func(mockKeto *devMockKetoService) *DevHandler {
return &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients" {
return httpJSONAny(r, http.StatusOK, []map[string]any{
{
"client_id": "client-tenant-a",
"client_name": "App Tenant A",
"token_endpoint_auth_method": "none", // PKCE
"metadata": map[string]any{"tenant_id": "tenant-a"},
},
{
"client_id": "client-tenant-b",
"client_name": "App Tenant B",
"token_endpoint_auth_method": "none", // PKCE
"metadata": map[string]any{"tenant_id": "tenant-b"},
},
}), nil
}
if (r.Method == http.MethodGet || r.Method == http.MethodPut) && strings.HasPrefix(r.URL.Path, "/clients/") {
id := strings.TrimPrefix(r.URL.Path, "/clients/")
tenantID := "tenant-a"
if id == "client-tenant-b" {
tenantID = "tenant-b"
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": id,
"client_name": "App " + id,
"token_endpoint_auth_method": "none",
"metadata": map[string]any{"tenant_id": tenantID},
}), nil
}
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
return httpJSONAny(r, http.StatusCreated, body), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
}),
},
},
Keto: mockKeto,
}
}
t.Run("Local bypass should be removed", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
app.Get("/api/v1/dev/clients", h.ListClients)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
req.Header.Set("Origin", "http://localhost:5174")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("ListClients should show all for SuperAdmin", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "super-user",
Role: domain.RoleSuperAdmin,
})
return c.Next()
})
app.Get("/api/v1/dev/clients", h.ListClients)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []clientSummary `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
// Should see both clients
assert.Equal(t, 2, len(res.Items))
})
t.Run("ListClients should filter by permit for non-SuperAdmin", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
tenantA := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-a",
Role: domain.RoleUser,
TenantID: &tenantA,
})
return c.Next()
})
app.Get("/api/v1/dev/clients", h.ListClients)
// Explicit permission for private client check bypass
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "System", "global", "manage_all").Return(true, nil).Maybe()
// Mock permit for the specific client
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(true, nil).Maybe()
// Deny for other clients
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-b", "view").Return(false, nil).Maybe()
mockKeto.On("ListRelations", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]service.RelationTuple{}, nil).Maybe()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []clientSummary `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
// Should only see client-tenant-a (tenant permit)
assert.Equal(t, 1, len(res.Items))
assert.Equal(t, "client-tenant-a", res.Items[0].ID)
})
t.Run("Tenant member should see empty list from DevFront clients if no relation", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
tenantA := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-member",
Role: domain.RoleUser,
TenantID: &tenantA,
})
return c.Next()
})
app.Get("/api/v1/dev/clients", h.ListClients)
// Deny all by default
mockKeto.On("CheckPermission", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(false, nil).Maybe()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []clientSummary `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
// Empty list because we didn't mock any specific 'view' permissions for this user
assert.Equal(t, 0, len(res.Items))
})
t.Run("GetClient should enforce isolation for non-SuperAdmin", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
tenantA := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-a",
Role: domain.RoleUser,
TenantID: &tenantA,
})
return c.Next()
})
app.Get("/api/v1/dev/clients/:id", h.GetClient)
// Case 1: Same tenant BUT no permit (Normal users need permit now)
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(false, nil).Once()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-a", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
// Case 2: Same tenant WITH permit
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(true, nil).Maybe()
mockKeto.On("ListRelations", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]service.RelationTuple{}, nil).Maybe()
req = httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-a", nil)
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Case 3: Different tenant
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-b", "view").Return(false, nil).Maybe()
req = httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-b", nil)
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
})
t.Run("CreateClient should record user_id and tenant_id", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
tenantA := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-a",
Role: domain.RoleSuperAdmin, // Bypass for creation permission
TenantID: &tenantA,
})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"client_name": "New App",
"type": "pkce",
"redirectUris": []string{"http://localhost/cb"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Tenant-ID", "tenant-a")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
var res clientDetailResponse
json.NewDecoder(resp.Body).Decode(&res)
assert.Equal(t, "tenant-a", res.Client.Metadata["tenant_id"])
assert.Equal(t, "user-a", res.Client.Metadata["user_id"])
})
}

View File

@@ -0,0 +1,278 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type devMockRPUserMetadataRepo struct {
mock.Mock
}
func (m *devMockRPUserMetadataRepo) Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error) {
args := m.Called(ctx, clientID, userID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.RPUserMetadata), args.Error(1)
}
func (m *devMockRPUserMetadataRepo) Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error {
args := m.Called(ctx, metadata)
return args.Error(0)
}
func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"tenant_id": "tenant-1",
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "approvalLevel",
"valueType": "text",
"value": "A",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
repo.On("Upsert", mock.Anything, mock.MatchedBy(func(row *domain.RPUserMetadata) bool {
return row.ClientID == "client-1" &&
row.UserID == "user-1" &&
row.Metadata["approvalLevel"] == "A" &&
row.Metadata["approvalLevel_permissions"].(map[string]any)["readPermission"] == "admin_only" &&
row.Metadata["approvalLevel_permissions"].(map[string]any)["writePermission"] == "user_and_admin"
})).Return(nil).Once()
repo.On("Get", mock.Anything, "client-1", "user-1").Return(&domain.RPUserMetadata{
ClientID: "client-1",
UserID: "user-1",
Metadata: domain.JSONMap{"approvalLevel": "A"},
}, nil).Once()
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
app.Get("/api/v1/dev/clients/:id/users/:userId/metadata", h.GetRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{
"approvalLevel": "A",
"approvalLevel_permissions": map[string]any{
"writePermission": "user_and_admin",
},
},
})
putReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
putReq.Header.Set("Content-Type", "application/json")
putResp, _ := app.Test(putReq, -1)
assert.Equal(t, http.StatusOK, putResp.StatusCode)
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-1/users/user-1/metadata", nil)
getResp, _ := app.Test(getReq, -1)
assert.Equal(t, http.StatusOK, getResp.StatusCode)
var got map[string]any
assert.NoError(t, json.NewDecoder(getResp.Body).Decode(&got))
assert.Equal(t, "client-1", got["clientId"])
assert.Equal(t, "user-1", got["userId"])
assert.Equal(t, "A", got["metadata"].(map[string]any)["approvalLevel"])
repo.AssertExpectations(t)
}
func TestDevHandler_RPUserMetadataMirrorsToKratosTraits(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"tenant_id": "tenant-1",
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "approvalLevel",
"valueType": "text",
"value": "A",
"readPermission": "user_and_admin",
"writePermission": "admin_only",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
repo.On("Upsert", mock.Anything, mock.AnythingOfType("*domain.RPUserMetadata")).Return(nil).Once()
kratos := new(MockKratosAdmin)
kratos.On("GetIdentity", mock.Anything, "user-1").Return(&service.KratosIdentity{
ID: "user-1",
State: "active",
Traits: map[string]any{
"email": "user@example.com",
"name": "User One",
},
}, nil).Once()
var capturedTraits map[string]any
kratos.On("UpdateIdentity", mock.Anything, "user-1", mock.Anything, "active").Run(func(args mock.Arguments) {
capturedTraits = args.Get(2).(map[string]any)
}).Return(&service.KratosIdentity{ID: "user-1", State: "active", Traits: map[string]any{}}, nil).Once()
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
KratosAdmin: kratos,
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"approvalLevel": "B"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
require.Equal(t, http.StatusOK, resp.StatusCode)
rpClaims := capturedTraits["rp_custom_claims"].(map[string]any)
clientClaims := rpClaims["client-1"].(domain.JSONMap)
require.Equal(t, "B", clientClaims["approvalLevel"])
require.Equal(t, map[string]any{
"readPermission": "user_and_admin",
"writePermission": "admin_only",
}, clientClaims["approvalLevel_permissions"])
repo.AssertExpectations(t)
kratos.AssertExpectations(t)
}
func TestDevHandler_RPUserMetadataRejectsUndefinedClaimKey(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "contract_date",
"valueType": "date",
"value": "2026-06-09",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"unknown_claim": "A"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
}
func TestDevHandler_RPUserMetadataRejectsInvalidTypedClaimValue(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "contract_date",
"valueType": "date",
"value": "2026-06-09",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"contract_date": "2026/06/09"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
package handler
import (
"baron-sso-backend/internal/response"
"github.com/gofiber/fiber/v2"
)
// errorJSON은 기존 error 필드를 유지하면서 기계 판독용 code를 명시적으로 추가합니다.
func errorJSON(c *fiber.Ctx, status int, message string) error {
return response.Error(c, status, response.StatusCode(status), message)
}
// errorJSONCode는 상태코드 기반 매핑만으로 부족한 경우 명시 코드를 강제할 때 사용합니다.
func errorJSONCode(c *fiber.Ctx, status int, code, message string) error {
return response.Error(c, status, code, message)
}

View File

@@ -0,0 +1,161 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"errors"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
// FederationHandler handles API requests for IdP federation.
type FederationHandler struct {
fedSvc *service.FederationService
repo repository.FederationRepository // For IdP Config CRUD
db *gorm.DB // For tenant existence checks, etc. in CRUD
}
// NewFederationHandler creates a new FederationHandler.
func NewFederationHandler(fedSvc *service.FederationService, repo repository.FederationRepository, db *gorm.DB) *FederationHandler {
return &FederationHandler{
fedSvc: fedSvc,
repo: repo,
db: db,
}
}
// InitiateOIDCLogin handles the start of the OIDC login flow.
// It expects `provider_id` and `login_challenge` as query parameters.
func (h *FederationHandler) InitiateOIDCLogin(c *fiber.Ctx) error {
providerID := c.Query("provider_id")
loginChallenge := c.Query("login_challenge")
if providerID == "" || loginChallenge == "" {
return errorJSON(c, fiber.StatusBadRequest, "provider_id and login_challenge are required")
}
redirectURL, err := h.fedSvc.InitiateOIDCLogin(c.Context(), providerID, loginChallenge)
if err != nil {
// Log the error properly in a real application
return errorJSON(c, fiber.StatusInternalServerError, "failed to initiate OIDC login")
}
return c.Redirect(redirectURL, fiber.StatusFound)
}
// HandleOIDCCallback handles the OIDC callback from the IdP.
func (h *FederationHandler) HandleOIDCCallback(c *fiber.Ctx) error {
code := c.Query("code")
state := c.Query("state")
if code == "" || state == "" {
return errorJSON(c, fiber.StatusBadRequest, "code and state are required")
}
redirectURL, err := h.fedSvc.HandleOIDCCallback(c.Context(), code, state)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to handle OIDC callback")
}
return c.Redirect(redirectURL, fiber.StatusFound)
}
// --- New Client-based IdP Config Methods ---
// ListIdpConfigsForClient handles listing all IdP configurations for a client.
func (h *FederationHandler) ListIdpConfigsForClient(c *fiber.Ctx) error {
clientID := c.Params("clientId")
if clientID == "" {
return errorJSON(c, fiber.StatusBadRequest, "clientId is required")
}
var configs []domain.IdentityProviderConfig
if err := h.db.Where("client_id = ?", clientID).Find(&configs).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(configs)
}
// CreateIdpConfigForClient handles the creation of a new IdP configuration for a client.
func (h *FederationHandler) CreateIdpConfigForClient(c *fiber.Ctx) error {
clientID := c.Params("clientId")
if clientID == "" {
return errorJSON(c, fiber.StatusBadRequest, "clientId is required in path")
}
var req domain.IdentityProviderConfig
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
// Assign clientID from path parameter
req.ClientID = clientID
// Basic validation
if req.DisplayName == "" || req.ProviderType == "" {
return errorJSON(c, fiber.StatusBadRequest, "display_name and provider_type are required")
}
// TODO: Optionally, validate if the clientID exists in Hydra
// Create in DB
if err := h.db.Create(&req).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.Status(fiber.StatusCreated).JSON(req)
}
// --- Deprecated Tenant-based IdP Config Methods ---
// ListIdpConfigsForTenant handles listing all IdP configurations for a tenant.
func (h *FederationHandler) ListIdpConfigsForTenant(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
if tenantID == "" {
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
}
// This is a temporary solution. We should create a proper method in the repository.
var configs []domain.IdentityProviderConfig
// Note: This now queries client_id, which is incorrect for tenants.
// This method is deprecated.
if err := h.db.Where("tenant_id = ?", tenantID).Find(&configs).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(configs)
}
// CreateIdpConfig handles the creation of a new IdP configuration.
func (h *FederationHandler) CreateIdpConfig(c *fiber.Ctx) error {
var req domain.IdentityProviderConfig
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
// Basic validation - This is the old validation logic
if req.ClientID == "" || req.DisplayName == "" || req.ProviderType == "" {
return errorJSON(c, fiber.StatusBadRequest, "client_id, display_name, and provider_type are required")
}
// This check is now incorrect and deprecated.
var tenant domain.Tenant
if err := h.db.First(&tenant, "id = ?", req.ClientID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errorJSON(c, fiber.StatusBadRequest, "tenant not found")
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
// Create in DB
if err := h.db.Create(&req).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.Status(fiber.StatusCreated).JSON(req)
}
// TODO: Re-implement Update, Delete handlers for IdP Configs for Clients

View File

@@ -0,0 +1,243 @@
package handler
import (
"baron-sso-backend/internal/domain"
"context"
"fmt"
"slices"
"strings"
)
const hanmacFamilyTenantSlug = "hanmac-family"
type hanmacEmailScope struct {
TenantIDs map[string]bool
Slugs map[string]bool
IDList []string
SlugList []string
}
type hanmacEmailEvaluation struct {
Email string
OriginalEmail string
SuggestedEmail string
Status string
Warnings []string
Message string
Blocking bool
LocalPart string
}
func (h *UserHandler) evaluateHanmacImportEmail(ctx context.Context, item bulkUserItem, scope *hanmacEmailScope, usedLocalParts map[string]bool) hanmacEmailEvaluation {
originalEmail := strings.TrimSpace(item.Email)
name := strings.TrimSpace(item.Name)
evaluation := hanmacEmailEvaluation{
Email: originalEmail,
OriginalEmail: originalEmail,
Status: "valid",
}
localPart, domainPart, err := domain.SplitEmailDomain(originalEmail)
if err != nil {
evaluation.Status = "blockingError"
evaluation.Message = "invalid email format"
evaluation.Blocking = true
return evaluation
}
base, needsReview, _ := domain.BuildKoreanNameEmailBase(name)
if needsReview {
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
evaluation.Status = "needsReview"
}
if localPart == "" {
if base == "" {
evaluation.Status = "blockingError"
evaluation.Message = "이름으로 이메일 ID를 제안할 수 없습니다."
evaluation.Blocking = true
return evaluation
}
nextLocalPart := nextAvailableHanmacLocalPart(base, usedLocalParts)
evaluation.Email = nextLocalPart + "@" + domainPart
evaluation.SuggestedEmail = evaluation.Email
evaluation.LocalPart = nextLocalPart
evaluation.Status = "suggested"
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "suggested")
return evaluation
}
evaluation.LocalPart = localPart
if usedLocalParts[localPart] {
evaluation.Status = "blockingError"
evaluation.Message = "한맥가족 내에서 이미 사용 중인 이메일 ID입니다."
evaluation.Blocking = true
return evaluation
}
if base != "" && !domain.MatchesSuggestedNameRule(localPart, base) {
evaluation.Status = "ruleMismatch"
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "ruleMismatch")
}
if evaluation.Status == "needsReview" && len(evaluation.Warnings) == 0 {
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
}
_ = scope
return evaluation
}
func (h *UserHandler) ensureHanmacCreateEmailAllowed(ctx context.Context, email string, tenantSlug string, tenantID string) error {
scope, err := h.resolveHanmacEmailScope(ctx)
if err != nil || scope == nil || !scope.ContainsTenant(tenantID, tenantSlug) {
return nil
}
localPart, err := domain.ExtractNormalizedEmailLocalPart(email)
if err != nil {
return err
}
usedLocalParts, err := h.loadHanmacLocalParts(ctx, scope)
if err != nil {
return err
}
if usedLocalParts[localPart] {
return fmt.Errorf("한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
}
return nil
}
func (h *UserHandler) resolveHanmacEmailScope(ctx context.Context) (*hanmacEmailScope, error) {
if h.TenantService == nil {
return nil, nil
}
tenants, _, err := h.TenantService.ListTenants(ctx, 10000, 0, "", "")
if err != nil {
return nil, err
}
var rootID string
for _, tenant := range tenants {
if strings.EqualFold(strings.TrimSpace(tenant.Slug), hanmacFamilyTenantSlug) {
rootID = tenant.ID
break
}
}
if rootID == "" {
return nil, nil
}
tenantByID := make(map[string]domain.Tenant, len(tenants))
for _, tenant := range tenants {
tenantByID[tenant.ID] = tenant
}
scope := &hanmacEmailScope{
TenantIDs: make(map[string]bool),
Slugs: make(map[string]bool),
}
for _, tenant := range tenants {
if isTenantDescendantOf(tenant, rootID, tenantByID) {
scope.TenantIDs[tenant.ID] = true
scope.Slugs[strings.ToLower(strings.TrimSpace(tenant.Slug))] = true
scope.IDList = append(scope.IDList, tenant.ID)
scope.SlugList = append(scope.SlugList, tenant.Slug)
}
}
return scope, nil
}
func (h *UserHandler) loadHanmacLocalParts(ctx context.Context, scope *hanmacEmailScope) (map[string]bool, error) {
used := make(map[string]bool)
if h.UserRepo == nil || scope == nil {
return used, nil
}
if len(scope.IDList) > 0 {
users, err := h.UserRepo.FindByTenantIDs(ctx, scope.IDList)
if err != nil {
return nil, err
}
addUserEmailLocalParts(used, users)
}
if len(scope.SlugList) > 0 {
users, err := h.UserRepo.FindByCompanyCodes(ctx, scope.SlugList)
if err != nil {
return nil, err
}
addUserEmailLocalParts(used, users)
}
return used, nil
}
func (s *hanmacEmailScope) ContainsTenant(tenantID string, slug string) bool {
if s == nil {
return false
}
if tenantID != "" && s.TenantIDs[tenantID] {
return true
}
return s.Slugs[strings.ToLower(strings.TrimSpace(slug))]
}
func isTenantDescendantOf(tenant domain.Tenant, rootID string, tenantByID map[string]domain.Tenant) bool {
if tenant.ID == rootID {
return true
}
visited := make(map[string]bool)
parentID := ""
if tenant.ParentID != nil {
parentID = *tenant.ParentID
}
for parentID != "" {
if parentID == rootID {
return true
}
if visited[parentID] {
return false
}
visited[parentID] = true
parent, ok := tenantByID[parentID]
if !ok || parent.ParentID == nil {
return false
}
parentID = *parent.ParentID
}
return false
}
func addUserEmailLocalParts(target map[string]bool, users []domain.User) {
for _, user := range users {
localPart, err := domain.ExtractNormalizedEmailLocalPart(user.Email)
if err == nil && localPart != "" {
target[localPart] = true
}
}
}
func nextAvailableHanmacLocalPart(base string, usedLocalParts map[string]bool) string {
base = strings.ToLower(strings.TrimSpace(base))
if base == "" {
return ""
}
if !usedLocalParts[base] {
return base
}
for index := 1; ; index++ {
candidate := fmt.Sprintf("%s%d", base, index)
if !usedLocalParts[candidate] {
return candidate
}
}
}
func appendUniqueString(values []string, value string) []string {
if slices.Contains(values, value) {
return values
}
return append(values, value)
}

View File

@@ -0,0 +1,98 @@
package handler
import (
"baron-sso-backend/internal/domain"
"crypto/rand"
"encoding/binary"
"testing"
"unicode"
)
// 정책을 받아 필수 요구사항을 모두 포함하는 비밀번호를 생성한다.
func generatePasswordFromPolicy(policy *domain.PasswordPolicy) string {
minLen := policy.MinLength
if minLen < 8 {
minLen = 12 // 안전한 기본값
}
pwd := make([]rune, 0, minLen)
if policy.Lowercase {
pwd = append(pwd, 'a')
}
if policy.Uppercase {
pwd = append(pwd, 'B')
}
if policy.Number {
pwd = append(pwd, '3')
}
if policy.NonAlphanumeric {
pwd = append(pwd, '!')
}
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
for len(pwd) < minLen {
pwd = append(pwd, rune(charset[randomInt(len(charset))]))
}
// 섞어서 예측 가능성을 낮춘다.
for i := range pwd {
j := randomInt(len(pwd))
pwd[i], pwd[j] = pwd[j], pwd[i]
}
return string(pwd)
}
func randomInt(n int) int {
if n <= 0 {
return 0
}
var b [8]byte
if _, err := rand.Read(b[:]); err != nil {
return 0
}
return int(binary.BigEndian.Uint64(b[:]) % uint64(n))
}
func TestGeneratePasswordUsesNonAlphanumericRequirement(t *testing.T) {
policy := &domain.PasswordPolicy{
MinLength: 8,
Lowercase: true,
Uppercase: true,
Number: true,
NonAlphanumeric: true,
}
pwd := generatePasswordFromPolicy(policy)
if len(pwd) < policy.MinLength {
t.Fatalf("비밀번호 길이가 정책 최소 길이 미만: got %d, want >= %d", len(pwd), policy.MinLength)
}
var hasLower, hasUpper, hasNumber, hasSymbol bool
for _, r := range pwd {
switch {
case unicode.IsLower(r):
hasLower = true
case unicode.IsUpper(r):
hasUpper = true
case unicode.IsNumber(r):
hasNumber = true
case !unicode.IsLetter(r) && !unicode.IsNumber(r):
hasSymbol = true
}
}
if policy.Lowercase && !hasLower {
t.Fatalf("소문자 요구사항 미충족: %q", pwd)
}
if policy.Uppercase && !hasUpper {
t.Fatalf("대문자 요구사항 미충족: %q", pwd)
}
if policy.Number && !hasNumber {
t.Fatalf("숫자 요구사항 미충족: %q", pwd)
}
if policy.NonAlphanumeric && !hasSymbol {
t.Fatalf("비영문자 요구사항 미충족: %q", pwd)
}
}

View File

@@ -0,0 +1,114 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"log/slog"
"github.com/gofiber/fiber/v2"
)
type RelyingPartyHandler struct {
Service service.RelyingPartyService
KratosAdmin service.KratosAdminService
}
func NewRelyingPartyHandler(s service.RelyingPartyService, kratos service.KratosAdminService) *RelyingPartyHandler {
return &RelyingPartyHandler{Service: s, KratosAdmin: kratos}
}
func (h *RelyingPartyHandler) Create(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
if tenantID == "" {
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
}
var req domain.HydraClient
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
rp, err := h.Service.Create(c.Context(), tenantID, req)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.Status(fiber.StatusCreated).JSON(rp)
}
func (h *RelyingPartyHandler) ListAll(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: user profile not found in context")
}
var rps []domain.RelyingParty
var err error
role := domain.NormalizeRole(profile.Role)
if role == domain.RoleSuperAdmin {
rps, err = h.Service.ListAll(c.Context())
} else if role == "tenant_admin" && profile.TenantID != nil {
rps, err = h.Service.List(c.Context(), *profile.TenantID)
} else {
slog.Warn("Forbidden access to all applications", "userID", profile.ID, "role", role)
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient role to list all applications")
}
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(rps)
}
func (h *RelyingPartyHandler) List(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
if tenantID == "" {
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
}
rps, err := h.Service.List(c.Context(), tenantID)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(rps)
}
func (h *RelyingPartyHandler) Get(c *fiber.Ctx) error {
id := c.Params("id")
rp, hydraClient, err := h.Service.Get(c.Context(), id)
if err != nil {
return errorJSON(c, fiber.StatusNotFound, "relying party not found")
}
return c.JSON(fiber.Map{
"relyingParty": rp,
"oauth2Config": hydraClient,
})
}
func (h *RelyingPartyHandler) Update(c *fiber.Ctx) error {
id := c.Params("id")
var req domain.HydraClient
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
rp, err := h.Service.Update(c.Context(), id, req)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(rp)
}
func (h *RelyingPartyHandler) Delete(c *fiber.Ctx) error {
id := c.Params("id")
if err := h.Service.Delete(c.Context(), id); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusNoContent)
}

View File

@@ -0,0 +1,255 @@
package handler
import (
"html"
"os"
"strings"
"github.com/gofiber/fiber/v2"
)
type RPManifestHandler struct{}
const rpObjectLookupMermaid = `flowchart TD
A[RP request] --> B{obj_id supplied?}
B -->|yes| C[Normalize object type and obj_id]
B -->|no| D{Route has client_id?}
D -->|yes| E[obj_id = RelyingParty:<client_id>]
D -->|no| F{Route has tenant_id?}
F -->|yes| G[obj_id = Tenant:<tenant_id>]
F -->|no| H[Reject: explicit obj_id required]
C --> I[Check Keto relation]
E --> I
G --> I
I --> J{allowed?}
J -->|yes| K[Inject trusted Baron headers]
J -->|no| L[Reject request]
K --> M[Write audit with obj_id, relation, client_id, X-Request-Id]`
const rpExternalKeyMermaid = `flowchart TD
A[User authenticates through Baron SSO] --> B[Baron resolves internal identity]
B --> C[Baron derives or loads Baron-issued alias]
C --> D[Baron injects X-Baron-External-Key]
D --> E[Baron injects X-Baron-Subject]
E --> I[RP receives trusted headers from Baron gateway]
I --> F[RP upserts local user with provider + X-Baron-External-Key]
F --> G[RP stores the full external key as opaque value]
G --> H[RP never parses or stores raw kratos_identity_id]`
func NewRPManifestHandler() *RPManifestHandler {
return &RPManifestHandler{}
}
func (h *RPManifestHandler) GetJSON(c *fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
return c.JSON(buildRPManifest(c))
}
func (h *RPManifestHandler) GetSchema(c *fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
return c.JSON(rpManifestSchema())
}
func (h *RPManifestHandler) GetHTML(c *fiber.Ctx) error {
manifest := buildRPManifest(c)
issuer, _ := manifest["issuer"].(string)
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
c.Type("html", "utf-8")
return c.SendString(`<!doctype html>
<html lang="ko">
<head>
<meta charset="utf-8">
<title>Baron RP IAM Manifest</title>
<style>
body { font-family: system-ui, sans-serif; margin: 2rem; line-height: 1.6; max-width: 920px; }
code, pre { background: #f5f5f5; border-radius: 4px; padding: .1rem .3rem; }
pre { padding: 1rem; overflow: auto; }
table { border-collapse: collapse; width: 100%; }
th, td { border: 1px solid #ddd; padding: .5rem; text-align: left; }
</style>
</head>
<body>
<h1>Baron RP IAM Manifest</h1>
<p>외부 RP가 Baron SSO/Ory Stack/Keto 기반 공용 IAM을 연동하기 위한 공개 규격입니다.</p>
<ul>
<li>Machine-readable manifest: <a href="/.well-known/baron-rp-manifest.json">/.well-known/baron-rp-manifest.json</a></li>
<li>JSON schema: <a href="/.well-known/baron-rp-manifest.schema.json">/.well-known/baron-rp-manifest.schema.json</a></li>
</ul>
<h2>Issuer</h2>
<pre>` + html.EscapeString(issuer) + `</pre>
<h2>Identity Contract</h2>
<table>
<tr><th>용도</th><th>Header</th><th>정책</th></tr>
<tr><td>Keto subject</td><td><code>X-Baron-Subject</code></td><td><code>User:&lt;baron_identity_id&gt;</code> 전체 문자열을 opaque subject로 취급합니다.</td></tr>
<tr><td>RP upsert key</td><td><code>X-Baron-External-Key</code></td><td>Baron-issued alias입니다. RP가 만들거나 제출하지 않고, Baron이 주입한 전체 문자열을 local user external key로 저장합니다.</td></tr>
<tr><td>RP client</td><td><code>X-Baron-Client-ID</code></td><td>현재 접근 중인 RP client id입니다.</td></tr>
</table>
<h2>External Key Flow</h2>
<p><code>X-Baron-External-Key</code>는 RP 입력값이 아니라 Baron이 인증된 subject에서 발급/조회해 주입하는 opaque alias입니다. RP upserts local user from the Baron-issued alias.</p>
<pre>` + "```mermaid\n" + html.EscapeString(rpExternalKeyMermaid) + "\n```" + `</pre>
<h2>Object Lookup</h2>
<pre>check(User:abc, viewers, RelyingParty:&lt;client_id&gt;)
check(User:abc, members, Tenant:&lt;tenant_id&gt;)
check(User:abc, viewers, Resource:&lt;resource_type&gt;:&lt;resource_id&gt;)</pre>
<h2>audit_contract</h2>
<p>권한과 설정을 변경하는 command는 sync audit write에 실패하면 요청도 실패해야 합니다. Read audit은 allowlist된 조회에 한해 best effort로 취급합니다.</p>
<pre>{
"mutating_command_mode": "fail_closed_sync",
"missing_audit_sink_behavior": "reject_mutation",
"correlation_header": "X-Request-Id"
}</pre>
<h2>Object Lookup Flow</h2>
<pre>` + "```mermaid\n" + html.EscapeString(rpObjectLookupMermaid) + "\n```" + `</pre>
</body>
</html>`)
}
func buildRPManifest(c *fiber.Ctx) map[string]any {
issuer := resolvePublicRequestBaseURL(c, os.Getenv("BACKEND_PUBLIC_URL"))
if issuer == "" {
issuer = strings.TrimRight(os.Getenv("USERFRONT_URL"), "/")
}
if issuer == "" {
issuer = "https://sso.hmac.kr"
}
issuer = strings.TrimRight(issuer, "/")
return map[string]any{
"version": "2026-05-11",
"issuer": issuer,
"oidc": map[string]any{
"discovery_url": issuer + "/.well-known/openid-configuration",
"jwks_url": issuer + "/.well-known/jwks.json",
"supported_flows": []string{"authorization_code_pkce"},
"required_scopes": []string{"openid", "profile", "email"},
},
"iam": map[string]any{
"authorization_engine": "ory-keto",
"subject_format": "User:<baron_identity_id>",
"target_object_patterns": []string{
"RelyingParty:<client_id>",
"Tenant:<tenant_id>",
"Resource:<resource_type>:<resource_id>",
},
"supported_relations": []string{
"admins",
"users",
"viewers",
"operators",
"members",
"owners",
"editors",
},
},
"identity_contract": map[string]any{
"subject_header": "X-Baron-Subject",
"external_key_header": "X-Baron-External-Key",
"external_key_is_opaque": true,
"external_key_issuer": "baron",
"external_key_delivery": "baron_injected_header",
"external_key_lifecycle": "issued_or_loaded_after_successful_authentication_before_rp_request",
"rp_supplied_external_key_allowed": false,
"rp_user_upsert_source": "rp_must_upsert_from_header_value",
"raw_kratos_identity_id_exposed": false,
"rp_user_upsert_key": "provider + external_key",
"email_is_stable_primary_key": false,
"initial_external_key_expression": "X-Baron-External-Key",
"fallback_to_subject_allowed": false,
},
"trusted_headers": map[string]any{
"subject": "X-Baron-Subject",
"external_key": "X-Baron-External-Key",
"email": "X-Baron-Email",
"tenant": "X-Baron-Tenant",
"relations": "X-Baron-Relations",
"client_id": "X-Baron-Client-ID",
},
"object_lookup": map[string]any{
"rp_level": map[string]any{
"object": "RelyingParty:<client_id>",
"relations": []string{"viewers", "users", "operators", "admins"},
"example": "check(User:abc, viewers, RelyingParty:mh-dashboard)",
},
"tenant_level": map[string]any{
"object": "Tenant:<tenant_id>",
"relations": []string{"members", "admins", "owners"},
"example": "check(User:abc, members, Tenant:9caf62e1-297d-4e8f-870b-61780998bbe)",
},
"resource_level": map[string]any{
"object": "Resource:<resource_type>:<resource_id>",
"relations": []string{"viewers", "editors", "owners"},
"example": "check(User:abc, viewers, Resource:dashboard:mh-monthly-2026-05)",
},
"recommended_order": []string{
"authenticated",
"rp_level",
"tenant_or_resource_level",
"trusted_header_injection",
},
},
"object_lookup_flow": map[string]any{
"format": "mermaid",
"mermaid": rpObjectLookupMermaid,
},
"external_key_flow": map[string]any{
"format": "mermaid",
"mermaid": rpExternalKeyMermaid,
},
"audit_contract": map[string]any{
"mutating_command_mode": "fail_closed_sync",
"missing_audit_sink_behavior": "reject_mutation",
"read_audit_mode": "best_effort_allowlisted",
"correlation_header": "X-Request-Id",
"rp_business_audit_required": true,
"baron_gateway_audit_required": true,
"required_detail_fields": []string{
"obj_id",
"relation",
"client_id",
"subject",
"decision",
},
"guarantee_scope": "Baron-mediated IAM mutations fail closed on audit write failure; RP-owned business events must be emitted by the RP with the same correlation header.",
},
"security_requirements": map[string]any{
"strip_external_identity_headers": true,
"backend_direct_exposure_allowed": false,
"static_snapshot_requires_auth": true,
"email_as_primary_key_allowed": false,
},
}
}
func rpManifestSchema() map[string]any {
return map[string]any{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Baron RP IAM Manifest",
"type": "object",
"required": []string{
"version",
"issuer",
"oidc",
"iam",
"trusted_headers",
"identity_contract",
"object_lookup",
"object_lookup_flow",
"external_key_flow",
"audit_contract",
"security_requirements",
},
"properties": map[string]any{
"version": map[string]any{"type": "string"},
"issuer": map[string]any{"type": "string", "format": "uri"},
"oidc": map[string]any{"type": "object"},
"iam": map[string]any{"type": "object"},
"trusted_headers": map[string]any{"type": "object"},
"identity_contract": map[string]any{"type": "object"},
"object_lookup": map[string]any{"type": "object"},
"object_lookup_flow": map[string]any{"type": "object"},
"external_key_flow": map[string]any{"type": "object"},
"audit_contract": map[string]any{"type": "object"},
"security_requirements": map[string]any{"type": "object"},
},
}
}

View File

@@ -0,0 +1,125 @@
package handler
import (
"encoding/json"
"io"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
func TestRPManifestJSONIncludesIAMAndExternalKeyContract(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
app := fiber.New()
h := NewRPManifestHandler()
app.Get("/.well-known/baron-rp-manifest.json", h.GetJSON)
req := httptest.NewRequest("GET", "/.well-known/baron-rp-manifest.json", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "sso.hmac.kr")
resp, err := app.Test(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Contains(t, resp.Header.Get("Content-Type"), "application/json")
var body map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, "https://sso.hmac.kr", body["issuer"])
oidc := body["oidc"].(map[string]any)
require.Equal(t, "https://sso.hmac.kr/.well-known/openid-configuration", oidc["discovery_url"])
require.Equal(t, "https://sso.hmac.kr/.well-known/jwks.json", oidc["jwks_url"])
iam := body["iam"].(map[string]any)
require.Equal(t, "ory-keto", iam["authorization_engine"])
require.Equal(t, "User:<baron_identity_id>", iam["subject_format"])
require.Contains(t, iam["target_object_patterns"].([]any), "RelyingParty:<client_id>")
require.Contains(t, iam["target_object_patterns"].([]any), "Tenant:<tenant_id>")
require.Contains(t, iam["target_object_patterns"].([]any), "Resource:<resource_type>:<resource_id>")
identity := body["identity_contract"].(map[string]any)
require.Equal(t, "X-Baron-External-Key", identity["external_key_header"])
require.Equal(t, true, identity["external_key_is_opaque"])
require.Equal(t, false, identity["raw_kratos_identity_id_exposed"])
require.Equal(t, "baron", identity["external_key_issuer"])
require.Equal(t, "baron_injected_header", identity["external_key_delivery"])
require.Equal(t, false, identity["rp_supplied_external_key_allowed"])
require.Equal(t, "rp_must_upsert_from_header_value", identity["rp_user_upsert_source"])
headers := body["trusted_headers"].(map[string]any)
require.Equal(t, "X-Baron-Subject", headers["subject"])
require.Equal(t, "X-Baron-External-Key", headers["external_key"])
require.Equal(t, "X-Baron-Client-ID", headers["client_id"])
security := body["security_requirements"].(map[string]any)
require.Equal(t, true, security["strip_external_identity_headers"])
require.Equal(t, false, security["backend_direct_exposure_allowed"])
audit := body["audit_contract"].(map[string]any)
require.Equal(t, "fail_closed_sync", audit["mutating_command_mode"])
require.Equal(t, "reject_mutation", audit["missing_audit_sink_behavior"])
require.Equal(t, "X-Request-Id", audit["correlation_header"])
require.Contains(t, audit["required_detail_fields"].([]any), "obj_id")
require.Contains(t, audit["required_detail_fields"].([]any), "client_id")
flow := body["object_lookup_flow"].(map[string]any)
require.Contains(t, flow["mermaid"].(string), "flowchart TD")
require.Contains(t, flow["mermaid"].(string), "obj_id")
aliasFlow := body["external_key_flow"].(map[string]any)
require.Contains(t, aliasFlow["mermaid"].(string), "Baron resolves internal identity")
require.Contains(t, aliasFlow["mermaid"].(string), "Baron injects X-Baron-External-Key")
require.Contains(t, aliasFlow["mermaid"].(string), "RP upserts local user")
require.NotContains(t, aliasFlow["mermaid"].(string), "RP creates external key")
}
func TestRPManifestSchemaRequiresLookupAndIdentityContracts(t *testing.T) {
app := fiber.New()
h := NewRPManifestHandler()
app.Get("/.well-known/baron-rp-manifest.schema.json", h.GetSchema)
resp, err := app.Test(httptest.NewRequest("GET", "/.well-known/baron-rp-manifest.schema.json", nil))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, fiber.StatusOK, resp.StatusCode)
var body map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
required := body["required"].([]any)
require.Contains(t, required, "iam")
require.Contains(t, required, "trusted_headers")
require.Contains(t, required, "identity_contract")
require.Contains(t, required, "object_lookup")
require.Contains(t, required, "audit_contract")
require.Contains(t, required, "object_lookup_flow")
require.Contains(t, required, "external_key_flow")
}
func TestRPManifestHTMLLinksMachineReadableManifest(t *testing.T) {
app := fiber.New()
h := NewRPManifestHandler()
app.Get("/.well-known/baron-rp-manifest", h.GetHTML)
resp, err := app.Test(httptest.NewRequest("GET", "/.well-known/baron-rp-manifest", nil))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Contains(t, resp.Header.Get("Content-Type"), "text/html")
raw, err := io.ReadAll(resp.Body)
require.NoError(t, err)
text := string(raw)
require.Contains(t, text, "/.well-known/baron-rp-manifest.json")
require.Contains(t, text, "X-Baron-External-Key")
require.Contains(t, text, "RelyingParty:&lt;client_id&gt;")
require.Contains(t, text, "```mermaid")
require.Contains(t, text, "audit_contract")
require.Contains(t, text, "Baron-issued alias")
require.Contains(t, text, "RP upserts local user")
}

View File

@@ -0,0 +1,154 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"context"
"fmt"
"log/slog"
"maps"
"strings"
)
const tenantAccessCleanupClientPageSize = 500
func cleanupDeletedTenantReferences(ctx context.Context, hydra *service.HydraAdminService, consentRepo repository.ClientConsentRepository, ketoOutbox repository.KetoOutboxRepository, deletedTenantIDs []string) error {
if hydra == nil {
return nil
}
deletedTenantSet := make(map[string]struct{}, len(deletedTenantIDs))
for _, tenantID := range deletedTenantIDs {
tenantID = strings.TrimSpace(tenantID)
if tenantID == "" {
continue
}
deletedTenantSet[tenantID] = struct{}{}
}
if len(deletedTenantSet) == 0 {
return nil
}
for offset := 0; ; offset += tenantAccessCleanupClientPageSize {
clients, err := hydra.ListClients(ctx, tenantAccessCleanupClientPageSize, offset)
if err != nil {
return fmt.Errorf("failed to list hydra clients for tenant cleanup: %w", err)
}
for _, client := range clients {
beforeMetadata := maps.Clone(client.Metadata)
updatedMetadata, changed, removedOwnerTenantID := pruneDeletedTenantReferences(beforeMetadata, deletedTenantSet)
if !changed {
continue
}
updatedClient := client
updatedClient.Metadata = updatedMetadata
if _, err := hydra.UpdateClient(ctx, client.ClientID, updatedClient); err != nil {
return fmt.Errorf("failed to update hydra client %s during tenant cleanup: %w", client.ClientID, err)
}
if removedOwnerTenantID != "" {
if err := enqueueDeletedTenantRelyingPartyParentCleanup(ctx, ketoOutbox, client.ClientID, removedOwnerTenantID); err != nil {
return fmt.Errorf("failed to cleanup RP parent relation for client %s during tenant cleanup: %w", client.ClientID, err)
}
}
if tenantAccessPolicyChanged(beforeMetadata, updatedMetadata) {
if err := revokeClientConsentsForPolicyChange(ctx, hydra, consentRepo, client.ClientID); err != nil {
return fmt.Errorf("failed to revoke consent sessions for client %s during tenant cleanup: %w", client.ClientID, err)
}
}
}
if len(clients) < tenantAccessCleanupClientPageSize {
return nil
}
}
}
func pruneDeletedTenantReferences(metadata map[string]any, deletedTenantSet map[string]struct{}) (map[string]any, bool, string) {
if len(deletedTenantSet) == 0 {
return metadata, false, ""
}
ownerTenantID := normalizeMetadataString(metadata["tenant_id"])
_, ownerDeleted := deletedTenantSet[ownerTenantID]
allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])
filtered := make([]string, 0, len(allowedTenants))
for _, tenantID := range allowedTenants {
if _, ok := deletedTenantSet[tenantID]; ok {
continue
}
filtered = append(filtered, tenantID)
}
allowedChanged := len(filtered) != len(allowedTenants)
if !ownerDeleted && !allowedChanged {
return metadata, false, ""
}
updated := maps.Clone(metadata)
if ownerDeleted {
delete(updated, "tenant_id")
}
if len(filtered) == 0 {
delete(updated, clientAllowedTenantsKey)
updated[clientTenantAccessRestrictedKey] = false
return updated, true, ownerTenantID
}
updated[clientAllowedTenantsKey] = uniqueSortedStrings(filtered)
updated[clientTenantAccessRestrictedKey] = true
return updated, true, ownerTenantID
}
func enqueueDeletedTenantRelyingPartyParentCleanup(ctx context.Context, ketoOutbox repository.KetoOutboxRepository, clientID, tenantID string) error {
if ketoOutbox == nil {
return nil
}
clientID = strings.TrimSpace(clientID)
tenantID = strings.TrimSpace(tenantID)
if clientID == "" || tenantID == "" {
return nil
}
return ketoOutbox.Create(ctx, &domain.KetoOutbox{
Namespace: "RelyingParty",
Object: clientID,
Relation: "parents",
Subject: "Tenant:" + tenantID,
Action: domain.KetoOutboxActionDelete,
})
}
func revokeClientConsentsForPolicyChange(ctx context.Context, hydra *service.HydraAdminService, consentRepo repository.ClientConsentRepository, clientID string) error {
if consentRepo == nil || hydra == nil {
return nil
}
subjects, err := consentRepo.ListSubjectsByClient(ctx, clientID)
if err != nil {
return err
}
for _, subject := range subjects {
subject = strings.TrimSpace(subject)
if subject == "" {
continue
}
if err := hydra.RevokeConsentSessions(ctx, subject, clientID); err != nil {
return err
}
}
return consentRepo.DeleteByClient(ctx, clientID)
}
func logTenantCleanupFailure(err error, deletedTenantIDs []string) {
if err == nil {
return
}
slog.Error("Failed to cleanup RP tenant restrictions after tenant deletion", "tenant_ids", deletedTenantIDs, "error", err)
}

View File

@@ -0,0 +1,178 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"context"
"encoding/json"
"net/http"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func TestPruneDeletedTenantReferences_PreservesOtherAllowedTenants(t *testing.T) {
metadata := map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []string{"keep-tenant", "deleted-tenant"},
"tenant_id": "deleted-tenant",
}
updated, changed, removedOwnerTenantID := pruneDeletedTenantReferences(metadata, map[string]struct{}{
"deleted-tenant": {},
})
require.True(t, changed)
assert.Equal(t, "deleted-tenant", removedOwnerTenantID)
assert.Equal(t, true, updated[clientTenantAccessRestrictedKey])
assert.Equal(t, []string{"keep-tenant"}, updated[clientAllowedTenantsKey])
_, exists := updated["tenant_id"]
assert.False(t, exists)
}
func TestPruneDeletedTenantReferences_DisablesRestrictionWhenLastTenantRemoved(t *testing.T) {
metadata := map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []string{"deleted-tenant"},
"tenant_id": "deleted-tenant",
}
updated, changed, removedOwnerTenantID := pruneDeletedTenantReferences(metadata, map[string]struct{}{
"deleted-tenant": {},
})
require.True(t, changed)
assert.Equal(t, "deleted-tenant", removedOwnerTenantID)
assert.Equal(t, false, updated[clientTenantAccessRestrictedKey])
_, exists := updated[clientAllowedTenantsKey]
assert.False(t, exists)
_, exists = updated["tenant_id"]
assert.False(t, exists)
}
func TestCleanupDeletedTenantReferences_PrunesClientsAndRevokesConsents(t *testing.T) {
var (
mu sync.Mutex
page0Called bool
updated = map[string]map[string]any{}
revokes []string
)
transport := roundTripFunc(func(req *http.Request) (*http.Response, error) {
mu.Lock()
defer mu.Unlock()
switch {
case req.Method == http.MethodGet && req.URL.Path == "/clients":
switch req.URL.Query().Get("offset") {
case "":
page0Called = true
return httpJSONAny(req, http.StatusOK, []domain.HydraClient{
{
ClientID: "client-keep",
Metadata: map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []string{"keep-tenant", "deleted-tenant"},
"tenant_id": "deleted-tenant",
},
},
{
ClientID: "client-drop",
Metadata: map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []string{"deleted-tenant"},
"tenant_id": "deleted-tenant",
},
},
}), nil
default:
return httpResponse(req, http.StatusBadRequest, "unexpected offset"), nil
}
case req.Method == http.MethodPut && strings.HasPrefix(req.URL.Path, "/clients/"):
var client domain.HydraClient
require.NoError(t, json.NewDecoder(req.Body).Decode(&client))
updated[client.ClientID] = client.Metadata
return httpJSONAny(req, http.StatusOK, client), nil
case req.Method == http.MethodDelete && req.URL.Path == "/oauth2/auth/sessions/consent":
revokes = append(revokes, req.URL.Query().Get("subject")+"|"+req.URL.Query().Get("client"))
return httpResponse(req, http.StatusNoContent, ""), nil
default:
return httpResponse(req, http.StatusNotFound, "unexpected request"), nil
}
})
hydra := &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
}
consentRepo := &mockConsentRepo{
consents: []domain.ClientConsent{
{ClientID: "client-keep", Subject: "user-a"},
{ClientID: "client-drop", Subject: "user-b"},
},
}
outbox := &tenantCleanupMockKetoOutboxRepository{}
err := cleanupDeletedTenantReferences(context.Background(), hydra, consentRepo, outbox, []string{"deleted-tenant"})
require.NoError(t, err)
assert.True(t, page0Called)
assert.Equal(t, map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []any{"keep-tenant"},
}, updated["client-keep"])
assert.Equal(t, map[string]any{
clientTenantAccessRestrictedKey: false,
}, updated["client-drop"])
assert.ElementsMatch(t, []string{"user-a|client-keep", "user-b|client-drop"}, revokes)
assert.Empty(t, consentRepo.consents)
require.Len(t, outbox.entries, 2)
assert.ElementsMatch(t, []string{"client-keep", "client-drop"}, []string{outbox.entries[0].Object, outbox.entries[1].Object})
for _, entry := range outbox.entries {
assert.Equal(t, "RelyingParty", entry.Namespace)
assert.Equal(t, "parents", entry.Relation)
assert.Equal(t, "Tenant:deleted-tenant", entry.Subject)
assert.Equal(t, domain.KetoOutboxActionDelete, entry.Action)
}
}
type tenantCleanupMockKetoOutboxRepository struct {
entries []domain.KetoOutbox
}
var _ repository.KetoOutboxRepository = (*tenantCleanupMockKetoOutboxRepository)(nil)
func (m *tenantCleanupMockKetoOutboxRepository) Create(ctx context.Context, entry *domain.KetoOutbox) error {
if entry == nil {
return nil
}
m.entries = append(m.entries, *entry)
return nil
}
func (m *tenantCleanupMockKetoOutboxRepository) CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error {
return m.Create(context.Background(), entry)
}
func (m *tenantCleanupMockKetoOutboxRepository) FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error) {
return nil, nil
}
func (m *tenantCleanupMockKetoOutboxRepository) ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error) {
return nil, nil
}
func (m *tenantCleanupMockKetoOutboxRepository) UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error {
return nil
}
func (m *tenantCleanupMockKetoOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
return nil
}

View File

@@ -0,0 +1,155 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"errors"
"fmt"
"strings"
"github.com/google/uuid"
)
func representativeTenantIDFromTraits(traits map[string]any) string {
if value := tenantClaimString(traits, "tenant_id"); value != "" {
return value
}
if value := tenantClaimString(traits, "primaryTenantId"); value != "" {
return value
}
if metadata, ok := traits["metadata"].(map[string]any); ok {
if value := tenantClaimString(metadata, "primaryTenantId"); value != "" {
return value
}
}
appointments := tenantAssignmentAppointmentsFromTraits(traits)
for _, appointment := range appointments {
if tenantAssignmentBool(appointment, "isPrimary", "primary", "representative", "isRepresentative") {
if value := tenantAssignmentTenantID(appointment); value != "" {
return value
}
}
}
for _, appointment := range appointments {
if value := tenantAssignmentTenantID(appointment); value != "" {
return value
}
}
for _, tenantID := range tenantNamespaceIDsFromTraits(traits) {
return tenantID
}
return ""
}
func joinedTenantIDsFromTraits(traits map[string]any, representativeTenantID string) []string {
values := make([]string, 0)
if representativeTenantID != "" {
values = append(values, representativeTenantID)
}
if value := tenantClaimString(traits, "tenant_id"); value != "" {
values = append(values, value)
}
for _, appointment := range tenantAssignmentAppointmentsFromTraits(traits) {
if value := tenantAssignmentTenantID(appointment); value != "" {
values = append(values, value)
}
}
values = append(values, tenantNamespaceIDsFromTraits(traits)...)
return uniqueSortedStrings(values)
}
func tenantAssignmentAppointmentsFromTraits(traits map[string]any) []map[string]any {
raw := rawAdditionalAppointments(traits)
switch values := raw.(type) {
case []any:
appointments := make([]map[string]any, 0, len(values))
for _, item := range values {
if appointment, ok := item.(map[string]any); ok {
appointments = append(appointments, appointment)
}
}
return appointments
case []map[string]any:
return values
default:
return nil
}
}
func tenantAssignmentTenantID(appointment map[string]any) string {
for _, key := range []string{"tenantId", "tenant_id"} {
if value := tenantClaimString(appointment, key); value != "" {
return value
}
}
return ""
}
func tenantAssignmentBool(values map[string]any, keys ...string) bool {
for _, key := range keys {
raw, ok := values[key]
if !ok || raw == nil {
continue
}
switch value := raw.(type) {
case bool:
if value {
return true
}
case string:
normalized := strings.ToLower(strings.TrimSpace(value))
if normalized == "true" || normalized == "1" || normalized == "yes" {
return true
}
}
}
return false
}
func tenantNamespaceIDsFromTraits(traits map[string]any) []string {
if traits == nil {
return nil
}
ids := make([]string, 0)
for key, value := range traits {
if key == "" || key == "metadata" {
continue
}
switch value.(type) {
case map[string]any:
ids = append(ids, key)
}
}
return uniqueSortedStrings(ids)
}
func createPersonalTenantForUser(ctx context.Context, tenantService service.TenantService, email string) (*domain.Tenant, error) {
if tenantService == nil {
return nil, errors.New("tenant service unavailable")
}
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
if normalizedEmail == "" {
normalizedEmail = "user"
}
slug := "personal-" + strings.ReplaceAll(uuid.NewString(), "-", "")
tenant, err := tenantService.RegisterTenant(
ctx,
fmt.Sprintf("Personal - %s", normalizedEmail),
slug,
domain.TenantTypePersonal,
"Automatically provisioned personal tenant",
nil,
nil,
"",
)
if err != nil {
return nil, err
}
if tenant == nil {
return nil, errors.New("personal tenant not created")
}
return tenant, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,159 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/testsupport"
"bytes"
"context"
"encoding/json"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/testcontainers/testcontainers-go"
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
gorm_postgres "gorm.io/driver/postgres"
"gorm.io/gorm"
)
func newTenantHandlerSeedDeleteDB(t *testing.T) *gorm.DB {
t.Helper()
if !testsupport.DockerAvailable() {
t.Skip("Docker provider is unavailable in this environment")
}
ctx := context.Background()
postgresContainer, err := postgres_module.Run(ctx,
"postgres:16-alpine",
postgres_module.WithDatabase("testdb"),
postgres_module.WithUsername("user"),
postgres_module.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(30*time.Second)),
)
if err != nil {
t.Fatalf("failed to start postgres container: %v", err)
}
t.Cleanup(func() {
if err := postgresContainer.Terminate(ctx); err != nil {
log.Printf("failed to terminate postgres container: %v", err)
}
})
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("failed to get postgres connection string: %v", err)
}
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open postgres connection: %v", err)
}
if err := db.AutoMigrate(&domain.Tenant{}); err != nil {
t.Fatalf("failed to migrate tenants: %v", err)
}
return db
}
func setSeedTenantCSVForDeleteGuard(t *testing.T, slug string) {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "seed-tenant.csv")
csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" +
"Protected,COMPANY_GROUP,," + slug + ",Protected seed,\n"
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
t.Fatalf("failed to write seed csv: %v", err)
}
t.Setenv("SEED_TENANT_CSV_PATH", path)
}
func TestTenantHandlerDeleteTenantRejectsSeedTenant(t *testing.T) {
setSeedTenantCSVForDeleteGuard(t, "protected-root")
db := newTenantHandlerSeedDeleteDB(t)
tenant := domain.Tenant{
ID: "00000000-0000-0000-0000-000000000001",
Name: "Protected",
Slug: "protected-root",
Type: domain.TenantTypeCompanyGroup,
Status: domain.TenantStatusActive,
}
if err := db.Create(&tenant).Error; err != nil {
t.Fatalf("failed to create tenant: %v", err)
}
app := fiber.New()
app.Delete("/tenants/:id", (&TenantHandler{DB: db}).DeleteTenant)
req := httptest.NewRequest(http.MethodDelete, "/tenants/"+tenant.ID, nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
if resp.StatusCode != http.StatusConflict {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
}
var count int64
if err := db.Model(&domain.Tenant{}).Where("id = ?", tenant.ID).Count(&count).Error; err != nil {
t.Fatalf("count tenant: %v", err)
}
if count != 1 {
t.Fatalf("seed tenant count = %d, want 1", count)
}
}
func TestTenantHandlerDeleteTenantsBulkRejectsSeedTenant(t *testing.T) {
setSeedTenantCSVForDeleteGuard(t, "protected-root")
db := newTenantHandlerSeedDeleteDB(t)
seed := domain.Tenant{
ID: "00000000-0000-0000-0000-000000000011",
Name: "Protected",
Slug: "protected-root",
Type: domain.TenantTypeCompanyGroup,
Status: domain.TenantStatusActive,
}
normal := domain.Tenant{
ID: "00000000-0000-0000-0000-000000000012",
Name: "Normal",
Slug: "normal",
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
if err := db.Create(&seed).Error; err != nil {
t.Fatalf("failed to create seed tenant: %v", err)
}
if err := db.Create(&normal).Error; err != nil {
t.Fatalf("failed to create normal tenant: %v", err)
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Delete("/tenants/bulk", (&TenantHandler{DB: db}).DeleteTenantsBulk)
body, _ := json.Marshal(map[string][]string{"ids": {seed.ID, normal.ID}})
req := httptest.NewRequest(http.MethodDelete, "/tenants/bulk", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
if resp.StatusCode != http.StatusConflict {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
}
var count int64
if err := db.Model(&domain.Tenant{}).Where("id IN ?", []string{seed.ID, normal.ID}).Count(&count).Error; err != nil {
t.Fatalf("count tenants: %v", err)
}
if count != 2 {
t.Fatalf("remaining tenant count = %d, want 2", count)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,133 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"github.com/gofiber/fiber/v2"
)
type UserGroupHandler struct {
Service service.UserGroupService
}
func NewUserGroupHandler(s service.UserGroupService) *UserGroupHandler {
return &UserGroupHandler{Service: s}
}
func (h *UserGroupHandler) List(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
groups, err := h.Service.List(c.Context(), tenantID)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(groups)
}
func (h *UserGroupHandler) Create(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
var req domain.GroupCreateRequest
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid body"})
}
group, err := h.Service.Create(c.Context(), tenantID, req.ParentID, req.Name, req.Description, req.UnitType)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.Status(fiber.StatusCreated).JSON(group)
}
func (h *UserGroupHandler) Get(c *fiber.Ctx) error {
id := c.Params("id")
group, err := h.Service.Get(c.Context(), id)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to get group: "+err.Error())
}
return c.JSON(group)
}
func (h *UserGroupHandler) Update(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
groupID := c.Params("id")
var req domain.GroupCreateRequest // Using create request for update fields
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid body"})
}
group, err := h.Service.Update(c.Context(), tenantID, groupID, req.Name, req.Description, req.UnitType, req.ParentID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(group)
}
func (h *UserGroupHandler) Delete(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
groupID := c.Params("id")
if err := h.Service.Delete(c.Context(), tenantID, groupID); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.SendStatus(fiber.StatusNoContent)
}
func (h *UserGroupHandler) AddMember(c *fiber.Ctx) error {
groupID := c.Params("id")
var req struct {
UserID string `json:"userId"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "userId is required")
}
if err := h.Service.AddMember(c.Context(), groupID, req.UserID); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusOK)
}
func (h *UserGroupHandler) RemoveMember(c *fiber.Ctx) error {
groupID := c.Params("id")
userID := c.Params("userId")
if err := h.Service.RemoveMember(c.Context(), groupID, userID); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusNoContent)
}
func (h *UserGroupHandler) AssignRole(c *fiber.Ctx) error {
groupID := c.Params("id")
var req struct {
TenantID string `json:"tenantId"`
Relation string `json:"relation"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid body")
}
if err := h.Service.AssignRoleToTenant(c.Context(), groupID, req.TenantID, req.Relation); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusOK)
}
func (h *UserGroupHandler) ListRoles(c *fiber.Ctx) error {
groupID := c.Params("id")
roles, err := h.Service.ListRoles(c.Context(), groupID)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(roles)
}
func (h *UserGroupHandler) RemoveRole(c *fiber.Ctx) error {
groupID := c.Params("id")
tenantID := c.Params("tenantId")
relation := c.Params("relation")
if err := h.Service.RemoveRoleFromTenant(c.Context(), groupID, tenantID, relation); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusNoContent)
}

View File

@@ -0,0 +1,155 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Mocks ---
type MockUserGroupService struct {
mock.Mock
}
func (m *MockUserGroupService) Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error) {
args := m.Called(ctx, tenantID, parentID, name, description, unitType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupService) Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error) {
args := m.Called(ctx, tenantID, groupID, name, description, unitType, parentID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupService) Delete(ctx context.Context, tenantID, groupID string) error {
return m.Called(ctx, tenantID, groupID).Error(0)
}
func (m *MockUserGroupService) Get(ctx context.Context, id string) (*domain.UserGroup, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupService) List(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
args := m.Called(ctx, tenantID)
return args.Get(0).([]domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupService) SetWorksmobileSyncer(syncer service.WorksmobileSyncer) {}
func (m *MockUserGroupService) AddMember(ctx context.Context, groupID, userID string) error {
return m.Called(ctx, groupID, userID).Error(0)
}
func (m *MockUserGroupService) RemoveMember(ctx context.Context, groupID, userID string) error {
return m.Called(ctx, groupID, userID).Error(0)
}
func (m *MockUserGroupService) ListRoles(ctx context.Context, groupID string) ([]domain.GroupRole, error) {
args := m.Called(ctx, groupID)
return args.Get(0).([]domain.GroupRole), args.Error(1)
}
func (m *MockUserGroupService) AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error {
return m.Called(ctx, groupID, tenantID, relation).Error(0)
}
func (m *MockUserGroupService) RemoveRoleFromTenant(ctx context.Context, groupID, tenantID, relation string) error {
return m.Called(ctx, groupID, tenantID, relation).Error(0)
}
// --- Tests ---
func TestUserGroupHandler_List(t *testing.T) {
mockSvc := new(MockUserGroupService)
h := NewUserGroupHandler(mockSvc)
app := fiber.New()
app.Get("/tenants/:tenantId/user-groups", h.List)
tenantID := "t1"
groups := []domain.UserGroup{{ID: "g1", Name: "Group 1"}}
mockSvc.On("List", mock.Anything, tenantID).Return(groups, nil)
req := httptest.NewRequest("GET", "/tenants/t1/user-groups", nil)
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var result []domain.UserGroup
json.NewDecoder(resp.Body).Decode(&result)
assert.Len(t, result, 1)
assert.Equal(t, "Group 1", result[0].Name)
}
func TestUserGroupHandler_Create(t *testing.T) {
mockSvc := new(MockUserGroupService)
h := NewUserGroupHandler(mockSvc)
app := fiber.New()
app.Post("/tenants/:tenantId/user-groups", h.Create)
body, _ := json.Marshal(map[string]string{"name": "New Group"})
mockSvc.On("Create", mock.Anything, "t1", mock.Anything, "New Group", mock.Anything, mock.Anything).Return(&domain.UserGroup{ID: "g1", Name: "New Group"}, nil)
req := httptest.NewRequest("POST", "/tenants/t1/user-groups", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
}
func TestUserGroupHandler_AddMember(t *testing.T) {
mockSvc := new(MockUserGroupService)
h := NewUserGroupHandler(mockSvc)
app := fiber.New()
app.Post("/user-groups/:id/members", h.AddMember)
groupID := "g1"
userID := "u1"
body, _ := json.Marshal(map[string]string{"userId": userID})
mockSvc.On("AddMember", mock.Anything, groupID, userID).Return(nil)
req := httptest.NewRequest("POST", "/user-groups/g1/members", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestUserGroupHandler_AssignRole(t *testing.T) {
mockSvc := new(MockUserGroupService)
h := NewUserGroupHandler(mockSvc)
app := fiber.New()
app.Post("/user-groups/:id/roles", h.AssignRole)
groupID := "g1"
targetTenantID := "t2"
relation := "manage"
body, _ := json.Marshal(map[string]string{"tenantId": targetTenantID, "relation": relation})
mockSvc.On("AssignRoleToTenant", mock.Anything, groupID, targetTenantID, relation).Return(nil)
req := httptest.NewRequest("POST", "/user-groups/g1/roles", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
package handler
import (
"baron-sso-backend/internal/domain"
"bytes"
"encoding/json"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestUserHandler_BulkCreateUsers_UUIDImportPolicy(t *testing.T) {
tests := []struct {
name string
field string
}{
{name: "id 필드 차단", field: "id"},
{name: "uuid 필드 차단", field: "uuid"},
{name: "userId 필드 차단", field: "userId"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockOry := new(MockOryProvider)
h := &UserHandler{
KratosAdmin: mockKratos,
OryProvider: mockOry,
}
app.Post("/users/bulk", h.BulkCreateUsers)
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
payload := map[string]any{
"users": []map[string]any{
{
"email": "uuid-import@test.com",
"name": "UUID Import User",
tt.field: "550e8400-e29b-41d4-a716-446655440000",
"metadata": map[string]any{},
},
},
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest("POST", "/users/bulk", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
var result struct {
Results []bulkUserResult `json:"results"`
}
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
assert.Len(t, result.Results, 1)
assert.False(t, result.Results[0].Success)
assert.Contains(t, result.Results[0].Message, "사용자 UUID 가져오기는 지원하지 않습니다")
mockOry.AssertExpectations(t)
mockKratos.AssertNotCalled(t, "FindIdentityIDByIdentifier", mock.Anything, mock.Anything)
})
}
}

View File

@@ -0,0 +1,16 @@
package handler
import (
"baron-sso-backend/internal/repository"
"context"
"log/slog"
)
func markUserProjectionFailed(ctx context.Context, repo repository.UserProjectionRepository, syncErr error) {
if repo == nil || syncErr == nil {
return
}
if err := repo.MarkFailed(ctx, syncErr); err != nil {
slog.Error("Failed to mark user projection as failed", "syncError", syncErr, "error", err)
}
}

View File

@@ -0,0 +1,217 @@
package handler
import (
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/csv"
"errors"
"log/slog"
"strings"
"github.com/gofiber/fiber/v2"
)
type WorksmobileHandler struct {
Service service.WorksmobileAdminService
}
func NewWorksmobileHandler(svc service.WorksmobileAdminService) *WorksmobileHandler {
return &WorksmobileHandler{Service: svc}
}
func (h *WorksmobileHandler) GetOverview(c *fiber.Ctx) error {
overview, err := h.Service.GetTenantOverview(c.Context(), strings.TrimSpace(c.Params("tenantId")))
if err != nil {
return worksmobileGuardError(c, err, "get_overview")
}
if !worksmobileOverviewAllowed(overview) {
return errorJSON(c, fiber.StatusNotFound, "worksmobile is only available for hanmac-family root tenant")
}
return c.JSON(overview)
}
func (h *WorksmobileHandler) GetComparison(c *fiber.Ctx) error {
includeMatched := strings.EqualFold(strings.TrimSpace(c.Query("includeMatched")), "true")
comparison, err := h.Service.GetComparison(c.Context(), strings.TrimSpace(c.Params("tenantId")), includeMatched)
if err != nil {
return worksmobileGuardError(c, err, "get_comparison")
}
return c.JSON(comparison)
}
func (h *WorksmobileHandler) OAuthCallback(c *fiber.Ctx) error {
return c.Type("html").SendString("<!doctype html><html><body>Worksmobile OAuth callback reachable</body></html>")
}
func (h *WorksmobileHandler) BackfillDryRun(c *fiber.Ctx) error {
result, err := h.Service.EnqueueBackfillDryRun(c.Context(), strings.TrimSpace(c.Params("tenantId")))
if err != nil {
return worksmobileGuardError(c, err, "backfill_dry_run")
}
return c.JSON(result)
}
func (h *WorksmobileHandler) SyncOrgUnit(c *fiber.Ctx) error {
orgUnitID := strings.TrimSpace(c.Params("orgUnitId"))
job, err := h.Service.EnqueueOrgUnitSync(c.Context(), strings.TrimSpace(c.Params("tenantId")), orgUnitID)
if err != nil {
return worksmobileGuardError(c, err, "sync_orgunit", "org_unit_id", orgUnitID)
}
return c.Status(fiber.StatusAccepted).JSON(job)
}
func (h *WorksmobileHandler) DeleteOrgUnit(c *fiber.Ctx) error {
orgUnitID := strings.TrimSpace(c.Params("orgUnitId"))
job, err := h.Service.EnqueueOrgUnitDelete(c.Context(), strings.TrimSpace(c.Params("tenantId")), orgUnitID)
if err != nil {
return worksmobileGuardError(c, err, "delete_orgunit", "org_unit_id", orgUnitID)
}
return c.Status(fiber.StatusAccepted).JSON(job)
}
func (h *WorksmobileHandler) SyncUser(c *fiber.Ctx) error {
userID := strings.TrimSpace(c.Params("userId"))
credentialRequest, err := parseWorksmobileCredentialRequest(c)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
job, err := h.Service.EnqueueUserSync(
c.Context(),
strings.TrimSpace(c.Params("tenantId")),
userID,
credentialRequest.CredentialBatchID,
credentialRequest.InitialPassword,
)
if err != nil {
return worksmobileGuardError(c, err, "sync_user", "user_id", userID)
}
return c.Status(fiber.StatusAccepted).JSON(job)
}
func (h *WorksmobileHandler) ResetUserPassword(c *fiber.Ctx) error {
userID := strings.TrimSpace(c.Params("userId"))
credentialBatchID, err := parseWorksmobileCredentialBatchID(c)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
job, err := h.Service.EnqueueUserPasswordReset(c.Context(), strings.TrimSpace(c.Params("tenantId")), userID, credentialBatchID)
if err != nil {
return worksmobileGuardError(c, err, "reset_user_password", "user_id", userID)
}
return c.Status(fiber.StatusAccepted).JSON(job)
}
func (h *WorksmobileHandler) RetryJob(c *fiber.Ctx) error {
jobID := strings.TrimSpace(c.Params("jobId"))
job, err := h.Service.RetryJob(c.Context(), strings.TrimSpace(c.Params("tenantId")), jobID)
if err != nil {
return worksmobileGuardError(c, err, "retry_job", "job_id", jobID)
}
return c.JSON(job)
}
func (h *WorksmobileHandler) DeletePendingJobs(c *fiber.Ctx) error {
result, err := h.Service.DeletePendingJobs(c.Context(), strings.TrimSpace(c.Params("tenantId")))
if err != nil {
return worksmobileGuardError(c, err, "delete_pending_jobs")
}
return c.JSON(result)
}
func (h *WorksmobileHandler) DownloadInitialPasswordsCSV(c *fiber.Ctx) error {
credentials, err := h.Service.ListInitialPasswordCredentials(c.Context(), strings.TrimSpace(c.Params("tenantId")), strings.TrimSpace(c.Query("batchId")))
if err != nil {
return worksmobileGuardError(c, err, "download_initial_passwords")
}
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
if err := writer.Write([]string{"email", "name", "primaryLeafOrgName", "initialPassword", "status", "lastError"}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
for _, credential := range credentials {
if err := writer.Write([]string{credential.Email, credential.Name, credential.PrimaryLeafOrgName, credential.InitialPassword, credential.Status, credential.LastError}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
}
writer.Flush()
if err := writer.Error(); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
c.Set(fiber.HeaderContentType, "text/csv; charset=utf-8")
c.Set(fiber.HeaderContentDisposition, `attachment; filename="worksmobile_initial_passwords.csv"`)
return c.Send(buf.Bytes())
}
func (h *WorksmobileHandler) ListCredentialBatches(c *fiber.Ctx) error {
batches, err := h.Service.ListCredentialBatches(c.Context(), strings.TrimSpace(c.Params("tenantId")))
if err != nil {
return worksmobileGuardError(c, err, "list_credential_batches")
}
return c.JSON(batches)
}
func (h *WorksmobileHandler) DeleteCredentialBatchPasswords(c *fiber.Ctx) error {
batchID := strings.TrimSpace(c.Params("batchId"))
batch, err := h.Service.DeleteCredentialBatchPasswords(c.Context(), strings.TrimSpace(c.Params("tenantId")), batchID)
if err != nil {
return worksmobileGuardError(c, err, "delete_credential_batch_passwords", "batch_id", batchID)
}
return c.JSON(batch)
}
type worksmobileCredentialBatchRequest struct {
CredentialBatchID string `json:"credentialBatchId"`
InitialPassword string `json:"initialPassword"`
}
func parseWorksmobileCredentialBatchID(c *fiber.Ctx) (string, error) {
req, err := parseWorksmobileCredentialRequest(c)
return req.CredentialBatchID, err
}
func parseWorksmobileCredentialRequest(c *fiber.Ctx) (worksmobileCredentialBatchRequest, error) {
batchID := strings.TrimSpace(c.Query("credentialBatchId"))
req := worksmobileCredentialBatchRequest{CredentialBatchID: batchID}
if len(bytes.TrimSpace(c.Body())) == 0 {
return req, nil
}
if err := c.BodyParser(&req); err != nil {
return worksmobileCredentialBatchRequest{}, err
}
req.InitialPassword = strings.TrimSpace(req.InitialPassword)
if bodyBatchID := strings.TrimSpace(req.CredentialBatchID); bodyBatchID != "" {
req.CredentialBatchID = bodyBatchID
return req, nil
}
req.CredentialBatchID = batchID
return req, nil
}
func worksmobileOverviewAllowed(overview service.WorksmobileTenantOverview) bool {
return overview.Tenant.Slug == service.HanmacFamilyTenantSlug && overview.Tenant.ParentID == nil
}
func worksmobileGuardError(c *fiber.Ctx, err error, operation string, attrs ...any) error {
if err == nil {
return nil
}
logAttrs := []any{
"operation", operation,
"tenant_id", strings.TrimSpace(c.Params("tenantId")),
"path", c.Path(),
"error", err,
}
logAttrs = append(logAttrs, attrs...)
if errors.Is(err, context.Canceled) {
slog.Warn("worksmobile admin operation failed", logAttrs...)
return errorJSON(c, fiber.StatusRequestTimeout, err.Error())
}
slog.Error("worksmobile admin operation failed", logAttrs...)
if strings.Contains(err.Error(), "hanmac-family root") {
return errorJSON(c, fiber.StatusNotFound, err.Error())
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}

View File

@@ -0,0 +1,267 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"errors"
"io"
"log/slog"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
func TestWorksmobileHandlerRejectsNonHanmacTenant(t *testing.T) {
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
overview: service.WorksmobileTenantOverview{
Tenant: domain.Tenant{ID: "tenant-1", Slug: "other"},
},
})
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile", h.GetOverview)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/tenant-1/worksmobile", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func TestWorksmobileHandlerReturnsOverviewForHanmacTenant(t *testing.T) {
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
overview: service.WorksmobileTenantOverview{
Tenant: domain.Tenant{ID: "hanmac-id", Slug: "hanmac-family"},
Config: service.WorksmobileConfigSummary{
Enabled: true,
},
},
})
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile", h.GetOverview)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func TestWorksmobileHandlerDownloadsInitialPasswordCSV(t *testing.T) {
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
credentials: []service.WorksmobileInitialPasswordCredential{
{
Email: "user@hanmaceng.co.kr",
Name: "홍길동",
PrimaryLeafOrgName: "인재성장",
InitialPassword: "Aa1!Aa1!Aa1!Aa1!",
Status: "processed",
},
},
})
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile/initial-passwords.csv", h.DownloadInitialPasswordsCSV)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/initial-passwords.csv", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Contains(t, resp.Header.Get("Content-Disposition"), "worksmobile_initial_passwords.csv")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "email,name,primaryLeafOrgName,initialPassword,status,lastError")
require.Contains(t, string(body), "user@hanmaceng.co.kr,홍길동,인재성장,Aa1!Aa1!Aa1!Aa1!,processed,")
}
func TestWorksmobileHandlerPassesInitialPasswordBatchID(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{
credentials: []service.WorksmobileInitialPasswordCredential{
{Email: "batch-user@hanmaceng.co.kr", InitialPassword: "BatchPass1!", Status: "pending"},
},
}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile/initial-passwords.csv", h.DownloadInitialPasswordsCSV)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/initial-passwords.csv?batchId=batch-1", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.downloadCredentialBatchID)
}
func TestWorksmobileHandlerPassesSyncUserCredentialBatchID(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", strings.NewReader(`{"credentialBatchId":"batch-1","initialPassword":"InputPass1!"}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.syncUserCredentialBatchID)
require.Equal(t, "InputPass1!", fakeService.syncUserInitialPassword)
}
func TestWorksmobileHandlerPassesPasswordResetCredentialBatchID(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Post("/tenants/:tenantId/worksmobile/users/:userId/password/reset", h.ResetUserPassword)
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/password/reset", strings.NewReader(`{"credentialBatchId":"batch-1"}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.resetPasswordCredentialBatchID)
}
func TestWorksmobileHandlerReturnsCredentialBatchHistory(t *testing.T) {
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
credentialBatches: []service.WorksmobileCredentialBatch{
{BatchID: "batch-1", UserCount: 2, HasPasswords: true},
},
})
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile/credential-batches", h.ListCredentialBatches)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/credential-batches", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), `"batchId":"batch-1"`)
require.Contains(t, string(body), `"userCount":2`)
}
func TestWorksmobileHandlerDeletesCredentialBatchPasswords(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Delete("/tenants/:tenantId/worksmobile/credential-batches/:batchId/passwords", h.DeleteCredentialBatchPasswords)
resp, err := app.Test(httptest.NewRequest("DELETE", "/tenants/hanmac-id/worksmobile/credential-batches/batch-1/passwords", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.deletedCredentialBatchID)
}
func TestWorksmobileHandlerDeletesPendingJobs(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{
pendingJobsDeleteResult: service.WorksmobilePendingJobDeleteResult{DeletedCount: 3},
}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Delete("/tenants/:tenantId/worksmobile/jobs/pending", h.DeletePendingJobs)
resp, err := app.Test(httptest.NewRequest("DELETE", "/tenants/hanmac-id/worksmobile/jobs/pending", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "hanmac-id", fakeService.deletedPendingJobsTenantID)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), `"deletedCount":3`)
}
func TestWorksmobileHandlerLogsActionFailures(t *testing.T) {
var logs bytes.Buffer
previous := slog.Default()
slog.SetDefault(slog.New(slog.NewJSONHandler(&logs, nil)))
t.Cleanup(func() {
slog.SetDefault(previous)
})
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
syncUserErr: errors.New("works user sync failed"),
})
app := fiber.New()
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
resp, err := app.Test(httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, logs.String(), "worksmobile admin operation failed")
require.Contains(t, logs.String(), "sync_user")
require.Contains(t, logs.String(), "works user sync failed")
}
type fakeWorksmobileAdminService struct {
overview service.WorksmobileTenantOverview
credentials []service.WorksmobileInitialPasswordCredential
syncUserErr error
syncUserCredentialBatchID string
syncUserInitialPassword string
resetPasswordCredentialBatchID string
downloadCredentialBatchID string
deletedCredentialBatchID string
deletedPendingJobsTenantID string
pendingJobsDeleteResult service.WorksmobilePendingJobDeleteResult
credentialBatches []service.WorksmobileCredentialBatch
}
func (f *fakeWorksmobileAdminService) GetTenantOverview(ctx context.Context, tenantID string) (service.WorksmobileTenantOverview, error) {
return f.overview, nil
}
func (f *fakeWorksmobileAdminService) GetComparison(ctx context.Context, tenantID string, includeMatched bool) (service.WorksmobileComparison, error) {
return service.WorksmobileComparison{}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueBackfillDryRun(ctx context.Context, tenantID string) (service.WorksmobileBackfillDryRun, error) {
return service.WorksmobileBackfillDryRun{}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueOrgUnitSync(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error) {
return &domain.WorksmobileOutbox{ID: "job-orgunit", ResourceID: orgUnitID}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueOrgUnitDelete(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error) {
return &domain.WorksmobileOutbox{ID: "job-orgunit-delete", ResourceID: orgUnitID, Action: domain.WorksmobileActionDelete}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID, initialPassword string) (*domain.WorksmobileOutbox, error) {
f.syncUserCredentialBatchID = credentialBatchID
f.syncUserInitialPassword = initialPassword
if f.syncUserErr != nil {
return nil, f.syncUserErr
}
return &domain.WorksmobileOutbox{ID: "job-user", ResourceID: userID}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueUserPasswordReset(ctx context.Context, tenantID, userID, credentialBatchID string) (*domain.WorksmobileOutbox, error) {
f.resetPasswordCredentialBatchID = credentialBatchID
return &domain.WorksmobileOutbox{ID: "job-user-password-reset", ResourceID: userID}, nil
}
func (f *fakeWorksmobileAdminService) RetryJob(ctx context.Context, tenantID, jobID string) (*domain.WorksmobileOutbox, error) {
return &domain.WorksmobileOutbox{ID: jobID}, nil
}
func (f *fakeWorksmobileAdminService) ListInitialPasswordCredentials(ctx context.Context, tenantID, credentialBatchID string) ([]service.WorksmobileInitialPasswordCredential, error) {
f.downloadCredentialBatchID = credentialBatchID
return f.credentials, nil
}
func (f *fakeWorksmobileAdminService) ListCredentialBatches(ctx context.Context, tenantID string) ([]service.WorksmobileCredentialBatch, error) {
return f.credentialBatches, nil
}
func (f *fakeWorksmobileAdminService) DeleteCredentialBatchPasswords(ctx context.Context, tenantID, credentialBatchID string) (service.WorksmobileCredentialBatch, error) {
f.deletedCredentialBatchID = credentialBatchID
return service.WorksmobileCredentialBatch{BatchID: credentialBatchID}, nil
}
func (f *fakeWorksmobileAdminService) DeletePendingJobs(ctx context.Context, tenantID string) (service.WorksmobilePendingJobDeleteResult, error) {
f.deletedPendingJobsTenantID = tenantID
return f.pendingJobsDeleteResult, nil
}

View File

@@ -0,0 +1,560 @@
package handlerregression
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/handler"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"maps"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v2"
)
type roundTripFunc func(req *http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
type mockSecretRepo struct {
secrets map[string]string
}
func (m *mockSecretRepo) Upsert(ctx context.Context, clientID, secret string) error {
if m.secrets == nil {
m.secrets = make(map[string]string)
}
m.secrets[clientID] = secret
return nil
}
func (m *mockSecretRepo) GetByID(ctx context.Context, clientID string) (string, error) {
return m.secrets[clientID], nil
}
func (m *mockSecretRepo) Delete(ctx context.Context, clientID string) error {
delete(m.secrets, clientID)
return nil
}
type mockRedisRepo struct {
data map[string]string
}
func (m *mockRedisRepo) Set(key, value string, exp time.Duration) error {
if m.data == nil {
m.data = make(map[string]string)
}
m.data[key] = value
return nil
}
func (m *mockRedisRepo) Get(key string) (string, error) {
v, ok := m.data[key]
if !ok {
return "", fmt.Errorf("not found")
}
return v, nil
}
func (m *mockRedisRepo) Delete(key string) error {
delete(m.data, key)
return nil
}
func (m *mockRedisRepo) StoreVerificationCode(p, c string) error { return nil }
func (m *mockRedisRepo) GetVerificationCode(p string) (string, error) { return "", nil }
func (m *mockRedisRepo) DeleteVerificationCode(p string) error { return nil }
func httpJSONAny(r *http.Request, code int, payload any) *http.Response {
body, _ := json.Marshal(payload)
return &http.Response{
StatusCode: code,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewReader(body)),
Request: r,
}
}
func newDevHandlerApp(h *handler.DevHandler) *fiber.App {
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "test-user", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
app.Get("/api/v1/dev/clients/:id", h.GetClient)
app.Put("/api/v1/dev/clients/:id", h.UpdateClient)
app.Post("/api/v1/dev/clients/:id/secret/rotate", h.RotateClientSecret)
return app
}
func decodeClientSecret(t *testing.T, resp *http.Response) string {
t.Helper()
var payload struct {
Client struct {
ClientSecret string `json:"clientSecret"`
} `json:"client"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
return payload.Client.ClientSecret
}
func TestGetClient_ClientSecretFallbackPaths(t *testing.T) {
tests := []struct {
name string
hydraClient map[string]any
initialRedis map[string]string
initialSecrets map[string]string
expectedSecret string
expectedRedisAfter string
expectRedisAfterSet bool
}{
{
name: "uses hydra client_secret directly",
hydraClient: map[string]any{
"client_id": "client-direct",
"client_name": "Direct App",
"client_secret": "hydra-secret",
"redirect_uris": []string{"https://direct.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "client_secret_basic",
"metadata": map[string]any{"status": "active"},
},
expectedSecret: "hydra-secret",
},
{
name: "falls back to metadata client_secret",
hydraClient: map[string]any{
"client_id": "client-metadata",
"client_name": "Metadata App",
"redirect_uris": []string{"https://metadata.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "client_secret_basic",
"metadata": map[string]any{
"status": "active",
"client_secret": "metadata-secret",
},
},
expectedSecret: "metadata-secret",
},
{
name: "falls back to redis cache",
hydraClient: map[string]any{
"client_id": "client-redis",
"client_name": "Redis App",
"redirect_uris": []string{"https://redis.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "client_secret_basic",
"metadata": map[string]any{"status": "active"},
},
initialRedis: map[string]string{"client_secret:client-redis": "redis-secret"},
expectedSecret: "redis-secret",
},
{
name: "falls back to postgres and warms redis",
hydraClient: map[string]any{
"client_id": "client-postgres",
"client_name": "Postgres App",
"redirect_uris": []string{"https://postgres.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "client_secret_basic",
"metadata": map[string]any{"status": "active"},
},
initialSecrets: map[string]string{"client-postgres": "postgres-secret"},
expectedSecret: "postgres-secret",
expectedRedisAfter: "postgres-secret",
expectRedisAfterSet: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients/"+tt.hydraClient["client_id"].(string) {
return httpJSONAny(r, http.StatusOK, tt.hydraClient), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
secretRepo := &mockSecretRepo{secrets: map[string]string{}}
maps.Copy(secretRepo.secrets, tt.initialSecrets)
redisRepo := &mockRedisRepo{data: map[string]string{}}
maps.Copy(redisRepo.data, tt.initialRedis)
h := &handler.DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
PublicURL: "http://hydra.public",
HTTPClient: &http.Client{Transport: transport},
},
SecretRepo: secretRepo,
Redis: redisRepo,
}
app := newDevHandlerApp(h)
clientID := tt.hydraClient["client_id"].(string)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/"+clientID, nil)
resp, err := app.Test(req, -1)
if err != nil {
t.Fatalf("get request failed: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected get 200, got %d", resp.StatusCode)
}
if secret := decodeClientSecret(t, resp); secret != tt.expectedSecret {
t.Fatalf("expected secret %q, got %q", tt.expectedSecret, secret)
}
if tt.expectRedisAfterSet {
redisSecret, err := redisRepo.Get("client_secret:" + clientID)
if err != nil {
t.Fatalf("expected warmed redis secret, got error: %v", err)
}
if redisSecret != tt.expectedRedisAfter {
t.Fatalf("expected warmed redis secret %q, got %q", tt.expectedRedisAfter, redisSecret)
}
}
})
}
}
func TestCreateClient_PersistsSecretForLaterDetailFetch(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
return httpJSONAny(r, http.StatusCreated, map[string]any{
"client_id": "client-created",
"client_name": "Created App",
"client_secret": "created-secret",
"redirect_uris": []string{"https://created.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "client_secret_basic",
"metadata": map[string]any{"status": "active"},
}), nil
}
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-created" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-created",
"client_name": "Created App",
"redirect_uris": []string{"https://created.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "client_secret_basic",
"metadata": map[string]any{"status": "active"},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
secretRepo := &mockSecretRepo{secrets: make(map[string]string)}
redisRepo := &mockRedisRepo{data: make(map[string]string)}
h := &handler.DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
PublicURL: "http://hydra.public",
HTTPClient: &http.Client{Transport: transport},
},
SecretRepo: secretRepo,
Redis: redisRepo,
}
app := newDevHandlerApp(h)
createBody, _ := json.Marshal(map[string]any{
"name": "Created App",
"type": "private",
"redirectUris": []string{"https://created.example.com/callback"},
"scopes": []string{"openid", "profile"},
})
createReq := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(createBody))
createReq.Header.Set("Content-Type", "application/json")
createResp, err := app.Test(createReq, -1)
if err != nil {
t.Fatalf("create request failed: %v", err)
}
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("expected create 201, got %d", createResp.StatusCode)
}
if secret := decodeClientSecret(t, createResp); secret != "created-secret" {
t.Fatalf("expected create secret created-secret, got %q", secret)
}
storedSecret, _ := secretRepo.GetByID(context.Background(), "client-created")
if storedSecret != "created-secret" {
t.Fatalf("expected postgres secret created-secret, got %q", storedSecret)
}
redisSecret, err := redisRepo.Get("client_secret:client-created")
if err != nil {
t.Fatalf("expected redis secret after create, got error: %v", err)
}
if redisSecret != "created-secret" {
t.Fatalf("expected redis secret created-secret, got %q", redisSecret)
}
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-created", nil)
getResp, err := app.Test(getReq, -1)
if err != nil {
t.Fatalf("get request failed: %v", err)
}
if getResp.StatusCode != http.StatusOK {
t.Fatalf("expected get 200, got %d", getResp.StatusCode)
}
if secret := decodeClientSecret(t, getResp); secret != "created-secret" {
t.Fatalf("expected detail secret created-secret, got %q", secret)
}
}
func TestRotateClientSecret_PersistsForLaterDetailFetch(t *testing.T) {
getCount := 0
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-rotate" {
getCount++
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-rotate",
"client_name": "Rotate App",
"redirect_uris": []string{"https://rotate.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "client_secret_basic",
"metadata": map[string]any{"status": "active"},
}), nil
}
if r.Method == http.MethodPut && r.URL.Path == "/clients/client-rotate" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-rotate",
"client_name": "Rotate App",
"redirect_uris": []string{"https://rotate.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "client_secret_basic",
"metadata": map[string]any{"status": "active"},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
secretRepo := &mockSecretRepo{secrets: make(map[string]string)}
redisRepo := &mockRedisRepo{data: make(map[string]string)}
h := &handler.DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
PublicURL: "http://hydra.public",
HTTPClient: &http.Client{Transport: transport},
},
SecretRepo: secretRepo,
Redis: redisRepo,
}
app := newDevHandlerApp(h)
rotateReq := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients/client-rotate/secret/rotate", nil)
rotateResp, err := app.Test(rotateReq, -1)
if err != nil {
t.Fatalf("rotate request failed: %v", err)
}
if rotateResp.StatusCode != http.StatusOK {
t.Fatalf("expected rotate 200, got %d", rotateResp.StatusCode)
}
rotatedSecret := decodeClientSecret(t, rotateResp)
if rotatedSecret == "" {
t.Fatalf("expected rotated secret to be present")
}
storedSecret, _ := secretRepo.GetByID(context.Background(), "client-rotate")
if storedSecret != rotatedSecret {
t.Fatalf("expected postgres secret %q, got %q", rotatedSecret, storedSecret)
}
redisSecret, err := redisRepo.Get("client_secret:client-rotate")
if err != nil {
t.Fatalf("expected redis secret after rotate, got error: %v", err)
}
if redisSecret != rotatedSecret {
t.Fatalf("expected redis secret %q, got %q", rotatedSecret, redisSecret)
}
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-rotate", nil)
getResp, err := app.Test(getReq, -1)
if err != nil {
t.Fatalf("get request failed: %v", err)
}
if getResp.StatusCode != http.StatusOK {
t.Fatalf("expected get 200, got %d", getResp.StatusCode)
}
if secret := decodeClientSecret(t, getResp); secret != rotatedSecret {
t.Fatalf("expected detail secret %q, got %q", rotatedSecret, secret)
}
}
func TestUpdateClient_HeadlessLoginSecretPersistsForLaterDetailFetch(t *testing.T) {
getCount := 0
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-headless-login" {
getCount++
if getCount == 1 {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-headless-login",
"client_name": "Headless Login Before",
"redirect_uris": []string{"https://before.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "none",
"metadata": map[string]any{
"status": "active",
},
}), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-headless-login",
"client_name": "Headless Login After",
"redirect_uris": []string{"https://headless.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "private_key_jwt",
"jwks_uri": "https://headless.example.com/jwks.json",
"metadata": map[string]any{
"status": "active",
"headless_login_enabled": true,
"request_object_signing_alg": "RS256",
},
}), nil
}
if r.Method == http.MethodPut && r.URL.Path == "/clients/client-headless-login" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-headless-login",
"client_name": "Headless Login After",
"client_secret": "headless-secret",
"redirect_uris": []string{"https://headless.example.com/callback"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"scope": "openid profile",
"token_endpoint_auth_method": "private_key_jwt",
"jwks_uri": "https://headless.example.com/jwks.json",
"metadata": map[string]any{
"status": "active",
"headless_login_enabled": true,
"request_object_signing_alg": "RS256",
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
secretRepo := &mockSecretRepo{secrets: make(map[string]string)}
redisRepo := &mockRedisRepo{data: make(map[string]string)}
h := &handler.DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
PublicURL: "http://hydra.public",
HTTPClient: &http.Client{Transport: transport},
},
SecretRepo: secretRepo,
Redis: redisRepo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "test-user", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id", h.UpdateClient)
app.Get("/api/v1/dev/clients/:id", h.GetClient)
updateBody, _ := json.Marshal(map[string]any{
"name": "Headless Login After",
"redirectUris": []string{"https://headless.example.com/callback"},
"tokenEndpointAuthMethod": "private_key_jwt",
"jwksUri": "https://headless.example.com/jwks.json",
"metadata": map[string]any{
"headless_login_enabled": true,
"request_object_signing_alg": "RS256",
},
})
updateReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-headless-login", bytes.NewReader(updateBody))
updateReq.Header.Set("Content-Type", "application/json")
updateResp, err := app.Test(updateReq, -1)
if err != nil {
t.Fatalf("update request failed: %v", err)
}
if updateResp.StatusCode != http.StatusOK {
t.Fatalf("expected update 200, got %d", updateResp.StatusCode)
}
storedSecret, _ := secretRepo.GetByID(context.Background(), "client-headless-login")
if storedSecret != "headless-secret" {
t.Fatalf("expected postgres secret headless-secret, got %q", storedSecret)
}
redisSecret, err := redisRepo.Get("client_secret:client-headless-login")
if err != nil {
t.Fatalf("expected redis secret, got error: %v", err)
}
if redisSecret != "headless-secret" {
t.Fatalf("expected redis secret headless-secret, got %q", redisSecret)
}
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-headless-login", nil)
getResp, err := app.Test(getReq, -1)
if err != nil {
t.Fatalf("get request failed: %v", err)
}
if getResp.StatusCode != http.StatusOK {
t.Fatalf("expected get 200, got %d", getResp.StatusCode)
}
var payload struct {
Client struct {
ClientSecret string `json:"clientSecret"`
} `json:"client"`
}
if err := json.NewDecoder(getResp.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload.Client.ClientSecret != "headless-secret" {
t.Fatalf("expected detail secret headless-secret, got %q", payload.Client.ClientSecret)
}
}

View File

@@ -0,0 +1,305 @@
package idp
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"strings"
)
// Ory 계열(kratos/hydra) 공급자 문자열을 정규화하기 위한 매핑.
var providerAliases = map[string]string{
"ory": "ory",
"hydra": "ory",
"kratos": "ory",
"ory-kratos": "ory",
"ory_hydra": "ory",
"ory_kratos": "ory",
}
// getEnv는 환경 변수를 읽거나 대체 값을 반환하는 헬퍼 함수입니다.
func getEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}
// InitializeProvider는 환경 설정을 기반으로 IDP 공급자를 생성하고 반환합니다.
// 이것은 IdentityProvider 인터페이스의 팩토리 역할을 합니다.
func InitializeProvider() (domain.IdentityProvider, error) {
rawProviders := getEnv("IDP_PROVIDER", "ory")
providers := strings.Split(rawProviders, ",")
slog.Info("Initializing IDP chain", "providers", rawProviders)
var initialized []domain.IdentityProvider
for _, p := range providers {
providerName := strings.TrimSpace(strings.ToLower(p))
if canonical, ok := providerAliases[providerName]; ok {
providerName = canonical
}
switch providerName {
case "ory":
// Kratos/Hydra 주 공급자
oryProvider := service.NewOryProvider()
initialized = append(initialized, oryProvider)
default:
// 알 수 없는 공급자는 건너뛰고 다음 후보를 시도
slog.Warn("Skipping unsupported IDP provider entry", "provider", providerName)
}
}
if len(initialized) == 0 {
return nil, fmt.Errorf("no valid IDP_PROVIDER entries configured from: %s", rawProviders)
}
if len(initialized) == 1 {
slog.Info("Initialized IDP provider", "provider", initialized[0].Name())
return initialized[0], nil
}
chain := newChainedProvider(initialized)
slog.Info("Initialized IDP provider chain", "providers", chain.Name())
return chain, nil
}
// newChainedProvider는 우선순위 순으로 IDP를 시도하는 체인을 생성합니다.
func newChainedProvider(providers []domain.IdentityProvider) domain.IdentityProvider {
names := make([]string, len(providers))
for i, p := range providers {
names[i] = p.Name()
}
return &chainedProvider{
providers: providers,
names: names,
}
}
// chainedProvider는 다중 IDP를 우선순위대로 호출하며 실패 시 폴백합니다.
type chainedProvider struct {
providers []domain.IdentityProvider
names []string
}
func (c *chainedProvider) Name() string {
return strings.Join(c.names, " > ")
}
func (c *chainedProvider) GetMetadata() (*domain.IDPMetadata, error) {
supported := make([]string, 0)
seen := make(map[string]bool)
for _, p := range c.providers {
meta, err := p.GetMetadata()
if err != nil {
return nil, fmt.Errorf("failed to fetch metadata from %s: %w", p.Name(), err)
}
for _, field := range meta.SupportedFields {
if !seen[field] {
seen[field] = true
supported = append(supported, field)
}
}
}
return &domain.IDPMetadata{SupportedFields: supported}, nil
}
func (c *chainedProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
for _, p := range c.providers {
id, err := p.CreateUser(user, password)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "CreateUser", "error", err)
return "", err
}
return id, nil
}
return "", domain.ErrNotSupported
}
func (c *chainedProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
for _, p := range c.providers {
info, err := p.SignIn(loginID, password)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "SignIn", "error", err)
return nil, err
}
return info, nil
}
return nil, domain.ErrNotSupported
}
func (c *chainedProvider) UserExists(loginID string) (bool, error) {
var errs []error
for _, p := range c.providers {
exists, err := p.UserExists(loginID)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
continue
}
if exists {
return true, nil
}
}
if len(errs) == 0 {
return false, nil
}
return false, fmt.Errorf("all IDP providers failed for UserExists: %w", errors.Join(errs...))
}
func (c *chainedProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
var errs []error
for idx, p := range c.providers {
info, err := p.IssueSession(loginID)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "IssueSession", "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", "IssueSession", "provider", p.Name())
}
return info, nil
}
if len(errs) == 0 {
return nil, domain.ErrNotSupported
}
return nil, fmt.Errorf("all IDP providers failed for IssueSession: %w", errors.Join(errs...))
}
func (c *chainedProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
var errs []error
for idx, p := range c.providers {
info, err := p.InitiateLinkLogin(loginID, returnTo)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "InitiateLinkLogin", "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", "InitiateLinkLogin", "provider", p.Name())
}
return info, nil
}
if len(errs) == 0 {
return nil, domain.ErrNotSupported
}
return nil, fmt.Errorf("all IDP providers failed for InitiateLinkLogin: %w", errors.Join(errs...))
}
func (c *chainedProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
var errs []error
for idx, p := range c.providers {
info, err := p.VerifyLoginCode(loginID, flowID, code)
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "VerifyLoginCode", "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", "VerifyLoginCode", "provider", p.Name())
}
return info, nil
}
if len(errs) == 0 {
return nil, domain.ErrNotSupported
}
return nil, fmt.Errorf("all IDP providers failed for VerifyLoginCode: %w", errors.Join(errs...))
}
func (c *chainedProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
var errs []error
for _, p := range c.providers {
policy, err := p.GetPasswordPolicy()
if err != nil {
if errors.Is(err, domain.ErrNotSupported) {
continue
}
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
continue
}
if policy != nil {
return policy, nil
}
}
if len(errs) == 0 {
return nil, domain.ErrNotSupported
}
return nil, fmt.Errorf("all IDP providers failed for GetPasswordPolicy: %w", errors.Join(errs...))
}
func (c *chainedProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
return c.tryProviders("InitiatePasswordReset", func(p domain.IdentityProvider) error {
return p.InitiatePasswordReset(loginID, redirectUrl)
})
}
func (c *chainedProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
var errs []error
for idx, p := range c.providers {
info, err := p.VerifyPasswordResetToken(token)
if err != nil {
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
slog.Warn("IDP VerifyPasswordResetToken failed", "provider", p.Name(), "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", "VerifyPasswordResetToken", "provider", p.Name())
}
return info, nil
}
if len(errs) == 0 {
return nil, fmt.Errorf("no IDP providers available for VerifyPasswordResetToken")
}
return nil, fmt.Errorf("all IDP providers failed for VerifyPasswordResetToken: %w", errors.Join(errs...))
}
func (c *chainedProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
return c.tryProviders("UpdateUserPassword", func(p domain.IdentityProvider) error {
return p.UpdateUserPassword(loginID, newPassword, r)
})
}
func (c *chainedProvider) tryProviders(operation string, fn func(domain.IdentityProvider) error) error {
var errs []error
for idx, p := range c.providers {
if err := fn(p); err != nil {
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", operation, "error", err)
continue
}
if idx > 0 {
slog.Info("IDP fallback succeeded", "operation", operation, "provider", p.Name())
}
return nil
}
if len(errs) == 0 {
return fmt.Errorf("no IDP providers available for %s", operation)
}
return fmt.Errorf("all IDP providers failed for %s: %w", operation, errors.Join(errs...))
}

View File

@@ -0,0 +1,167 @@
package idp
import (
"baron-sso-backend/internal/domain"
"errors"
"net/http"
"reflect"
"strings"
"testing"
)
type stubProvider struct {
name string
metadata []string
createErr error
initiateErr error
verifyErr error
updateErr error
signInErr error
userExistsErr error
issueErr error
linkInitErr error
verifyCodeErr error
policyErr error
initiateCalls int
verifyCalls int
updateCalls int
signInCalls int
createCalls int
userExistsCalls int
issueCalls int
linkInitCalls int
verifyCodeCalls int
policyCalls int
verifyResponse *domain.AuthInfo
userExists bool
policy *domain.PasswordPolicy
}
func (s *stubProvider) Name() string { return s.name }
func (s *stubProvider) GetMetadata() (*domain.IDPMetadata, error) {
return &domain.IDPMetadata{SupportedFields: s.metadata}, nil
}
func (s *stubProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
s.createCalls++
if s.createErr != nil {
return "", s.createErr
}
return "created-id", nil
}
func (s *stubProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
s.signInCalls++
if s.signInErr != nil {
return nil, s.signInErr
}
return &domain.AuthInfo{Subject: "subject-123"}, nil
}
func (s *stubProvider) UserExists(loginID string) (bool, error) {
s.userExistsCalls++
if s.userExistsErr != nil {
return false, s.userExistsErr
}
return s.userExists, nil
}
func (s *stubProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
s.issueCalls++
if s.issueErr != nil {
return nil, s.issueErr
}
return &domain.AuthInfo{Subject: "issue-subject"}, nil
}
func (s *stubProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
s.linkInitCalls++
if s.linkInitErr != nil {
return nil, s.linkInitErr
}
return &domain.LinkLoginInit{FlowID: "flow-123", Mode: "cookie"}, nil
}
func (s *stubProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
s.verifyCodeCalls++
if s.verifyCodeErr != nil {
return nil, s.verifyCodeErr
}
return &domain.AuthInfo{Subject: "verify-code-subject"}, nil
}
func (s *stubProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
s.policyCalls++
if s.policyErr != nil {
return nil, s.policyErr
}
return s.policy, nil
}
func (s *stubProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
s.initiateCalls++
return s.initiateErr
}
func (s *stubProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
s.verifyCalls++
if s.verifyErr != nil {
return nil, s.verifyErr
}
if s.verifyResponse != nil {
return s.verifyResponse, nil
}
return &domain.AuthInfo{}, nil
}
func (s *stubProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
s.updateCalls++
return s.updateErr
}
func TestChainedProviderMetadataUnion(t *testing.T) {
p1 := &stubProvider{name: "primary", metadata: []string{"id", "email"}}
p2 := &stubProvider{name: "backup", metadata: []string{"email", "phone_number", "grade"}}
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
meta, err := chain.GetMetadata()
if err != nil {
t.Fatalf("GetMetadata returned error: %v", err)
}
expected := []string{"id", "email", "phone_number", "grade"}
if !reflect.DeepEqual(meta.SupportedFields, expected) {
t.Fatalf("metadata mismatch: got %v, want %v", meta.SupportedFields, expected)
}
}
func TestChainedProviderUpdateUserPasswordFallback(t *testing.T) {
p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("boom")}
p2 := &stubProvider{name: "backup", metadata: []string{"id"}}
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
if err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil); err != nil {
t.Fatalf("expected fallback to succeed, got error: %v", err)
}
if p1.updateCalls != 1 || p2.updateCalls != 1 {
t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls)
}
}
func TestChainedProviderUpdateUserPasswordAllFail(t *testing.T) {
p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("fail1")}
p2 := &stubProvider{name: "backup", metadata: []string{"id"}, updateErr: errors.New("fail2")}
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil)
if err == nil {
t.Fatalf("expected error when all providers fail")
}
if !strings.Contains(err.Error(), "all IDP providers failed") {
t.Fatalf("unexpected error: %v", err)
}
if p1.updateCalls != 1 || p2.updateCalls != 1 {
t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls)
}
}

Some files were not shown because too many files have changed in this diff Show More