첫 커밋: 로컬 프로젝트 업로드
This commit is contained in:
203
baron-sso/backend/internal/bootstrap/admin_account.go
Normal file
203
baron-sso/backend/internal/bootstrap/admin_account.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SuperAdminIdentityAdmin interface {
|
||||
FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error)
|
||||
CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error)
|
||||
UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error
|
||||
}
|
||||
|
||||
type SuperAdminStore interface {
|
||||
FindUserByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||
CreateUser(ctx context.Context, user *domain.User) error
|
||||
UpdateUserSuperAdmin(ctx context.Context, userID string, name string) (*domain.User, error)
|
||||
EnqueueSuperAdminRelation(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
type EnsureSuperAdminOptions struct {
|
||||
Email string
|
||||
Password string
|
||||
Name string
|
||||
Source string
|
||||
UpdatePassword bool
|
||||
}
|
||||
|
||||
type EnsureSuperAdminResult struct {
|
||||
Email string
|
||||
IdentityID string
|
||||
LocalUserID string
|
||||
IdentityCreated bool
|
||||
PasswordUpdated bool
|
||||
LocalUserCreated bool
|
||||
LocalUserUpdated bool
|
||||
KetoRelationQueued bool
|
||||
}
|
||||
|
||||
func EnsureSuperAdmin(ctx context.Context, identityAdmin SuperAdminIdentityAdmin, store SuperAdminStore, opts EnsureSuperAdminOptions) (EnsureSuperAdminResult, error) {
|
||||
email := strings.ToLower(strings.TrimSpace(opts.Email))
|
||||
name := strings.TrimSpace(opts.Name)
|
||||
if name == "" {
|
||||
name = "System Admin"
|
||||
}
|
||||
source := strings.TrimSpace(opts.Source)
|
||||
if source == "" {
|
||||
source = "admin_cli"
|
||||
}
|
||||
result := EnsureSuperAdminResult{Email: email}
|
||||
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return result, fmt.Errorf("invalid admin email: %w", err)
|
||||
}
|
||||
if identityAdmin == nil {
|
||||
return result, errors.New("identity admin is required")
|
||||
}
|
||||
if store == nil {
|
||||
return result, errors.New("super admin store is required")
|
||||
}
|
||||
|
||||
identityID, err := identityAdmin.FindIdentityIDByIdentifier(ctx, email)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("find admin identity: %w", err)
|
||||
}
|
||||
if identityID == "" {
|
||||
if strings.TrimSpace(opts.Password) == "" {
|
||||
return result, errors.New("admin password is required to create identity")
|
||||
}
|
||||
identityID, err = identityAdmin.CreateUser(ctx, buildSuperAdminBrokerUser(email, name), opts.Password)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("create admin identity: %w", err)
|
||||
}
|
||||
result.IdentityCreated = true
|
||||
} else if opts.UpdatePassword {
|
||||
if strings.TrimSpace(opts.Password) == "" {
|
||||
return result, errors.New("admin password is required to update identity password")
|
||||
}
|
||||
if err := identityAdmin.UpdateIdentityPassword(ctx, identityID, opts.Password); err != nil {
|
||||
return result, fmt.Errorf("update admin identity password: %w", err)
|
||||
}
|
||||
result.PasswordUpdated = true
|
||||
}
|
||||
result.IdentityID = identityID
|
||||
|
||||
user, err := store.FindUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("find local admin user: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
if identityID == "" {
|
||||
return result, errors.New("identity id is required to create local admin user")
|
||||
}
|
||||
user = &domain.User{
|
||||
ID: identityID,
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: domain.RoleSuperAdmin,
|
||||
Status: domain.UserStatusActive,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Metadata: domain.JSONMap{
|
||||
"source": source,
|
||||
},
|
||||
}
|
||||
if err := store.CreateUser(ctx, user); err != nil {
|
||||
return result, fmt.Errorf("create local admin user: %w", err)
|
||||
}
|
||||
result.LocalUserCreated = true
|
||||
} else if domain.NormalizeRole(user.Role) != domain.RoleSuperAdmin || user.Status != domain.UserStatusActive || (name != "" && user.Name != name) {
|
||||
user, err = store.UpdateUserSuperAdmin(ctx, user.ID, name)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("update local admin user: %w", err)
|
||||
}
|
||||
result.LocalUserUpdated = true
|
||||
}
|
||||
result.LocalUserID = user.ID
|
||||
|
||||
if err := store.EnqueueSuperAdminRelation(ctx, user.ID); err != nil {
|
||||
return result, fmt.Errorf("enqueue super admin keto relation: %w", err)
|
||||
}
|
||||
result.KetoRelationQueued = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func buildSuperAdminBrokerUser(email, name string) *domain.BrokerUser {
|
||||
return &domain.BrokerUser{
|
||||
Email: email,
|
||||
Name: name,
|
||||
PhoneNumber: "",
|
||||
Attributes: map[string]any{
|
||||
"department": "Admin",
|
||||
"affiliationType": "internal",
|
||||
"grade": "",
|
||||
"role": domain.RoleSuperAdmin,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type gormSuperAdminStore struct {
|
||||
db *gorm.DB
|
||||
outbox repository.KetoOutboxRepository
|
||||
}
|
||||
|
||||
func NewGormSuperAdminStore(db *gorm.DB, outbox repository.KetoOutboxRepository) SuperAdminStore {
|
||||
return &gormSuperAdminStore{db: db, outbox: outbox}
|
||||
}
|
||||
|
||||
func (s *gormSuperAdminStore) FindUserByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
if err := s.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *gormSuperAdminStore) CreateUser(ctx context.Context, user *domain.User) error {
|
||||
return s.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
func (s *gormSuperAdminStore) UpdateUserSuperAdmin(ctx context.Context, userID string, name string) (*domain.User, error) {
|
||||
updates := map[string]any{
|
||||
"role": domain.RoleSuperAdmin,
|
||||
"status": domain.UserStatusActive,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
if strings.TrimSpace(name) != "" {
|
||||
updates["name"] = strings.TrimSpace(name)
|
||||
}
|
||||
if err := s.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", userID).Updates(updates).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user domain.User
|
||||
if err := s.db.WithContext(ctx).Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *gormSuperAdminStore) EnqueueSuperAdminRelation(ctx context.Context, userID string) error {
|
||||
if s.outbox == nil {
|
||||
return nil
|
||||
}
|
||||
return s.outbox.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "System",
|
||||
Object: "global",
|
||||
Relation: "super_admins",
|
||||
Subject: "User:" + userID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
157
baron-sso/backend/internal/bootstrap/admin_account_test.go
Normal file
157
baron-sso/backend/internal/bootstrap/admin_account_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnsureSuperAdminCreatesIdentityLocalUserAndKetoRelation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
identityAdmin := &fakeSuperAdminIdentityAdmin{createdID: "identity-1"}
|
||||
store := &fakeSuperAdminStore{}
|
||||
|
||||
result, err := EnsureSuperAdmin(ctx, identityAdmin, store, EnsureSuperAdminOptions{
|
||||
Email: "new-admin@example.com",
|
||||
Password: "Password!123",
|
||||
Name: "New Admin",
|
||||
Source: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureSuperAdmin returned error: %v", err)
|
||||
}
|
||||
if !result.IdentityCreated {
|
||||
t.Fatal("identity must be created")
|
||||
}
|
||||
if !result.LocalUserCreated {
|
||||
t.Fatal("local user must be created")
|
||||
}
|
||||
if result.IdentityID != "identity-1" {
|
||||
t.Fatalf("identity ID = %q, want identity-1", result.IdentityID)
|
||||
}
|
||||
if store.user == nil {
|
||||
t.Fatal("local user was not stored")
|
||||
}
|
||||
if store.user.Email != "new-admin@example.com" {
|
||||
t.Fatalf("local user email = %q", store.user.Email)
|
||||
}
|
||||
if store.user.Role != domain.RoleSuperAdmin {
|
||||
t.Fatalf("local user role = %q, want %q", store.user.Role, domain.RoleSuperAdmin)
|
||||
}
|
||||
if len(store.ketoSubjects) != 1 || store.ketoSubjects[0] != "User:identity-1" {
|
||||
t.Fatalf("keto subjects = %#v, want User:identity-1", store.ketoSubjects)
|
||||
}
|
||||
if identityAdmin.createdUser == nil || identityAdmin.createdUser.Attributes["role"] != domain.RoleSuperAdmin {
|
||||
t.Fatalf("created identity attributes = %#v", identityAdmin.createdUser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureSuperAdminPromotesExistingLocalUser(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
identityAdmin := &fakeSuperAdminIdentityAdmin{existingID: "identity-1"}
|
||||
store := &fakeSuperAdminStore{
|
||||
user: &domain.User{
|
||||
ID: "local-user-1",
|
||||
Email: "existing@example.com",
|
||||
Name: "Existing",
|
||||
Role: domain.RoleUser,
|
||||
Status: domain.UserStatusPreboarding,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EnsureSuperAdmin(ctx, identityAdmin, store, EnsureSuperAdminOptions{
|
||||
Email: "existing@example.com",
|
||||
Password: "Password!123",
|
||||
Name: "Existing Admin",
|
||||
Source: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureSuperAdmin returned error: %v", err)
|
||||
}
|
||||
if result.IdentityCreated {
|
||||
t.Fatal("existing identity must not be recreated")
|
||||
}
|
||||
if !result.LocalUserUpdated {
|
||||
t.Fatal("local user must be promoted")
|
||||
}
|
||||
if store.user.Role != domain.RoleSuperAdmin {
|
||||
t.Fatalf("local user role = %q, want %q", store.user.Role, domain.RoleSuperAdmin)
|
||||
}
|
||||
if store.user.Status != domain.UserStatusActive {
|
||||
t.Fatalf("local user status = %q, want %q", store.user.Status, domain.UserStatusActive)
|
||||
}
|
||||
if len(store.ketoSubjects) != 1 || store.ketoSubjects[0] != "User:local-user-1" {
|
||||
t.Fatalf("keto subjects = %#v, want User:local-user-1", store.ketoSubjects)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureSuperAdminRequiresPasswordForNewIdentity(t *testing.T) {
|
||||
_, err := EnsureSuperAdmin(context.Background(), &fakeSuperAdminIdentityAdmin{}, &fakeSuperAdminStore{}, EnsureSuperAdminOptions{
|
||||
Email: "new-admin@example.com",
|
||||
Name: "New Admin",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeSuperAdminIdentityAdmin struct {
|
||||
existingID string
|
||||
createdID string
|
||||
createdUser *domain.BrokerUser
|
||||
createdSecret string
|
||||
}
|
||||
|
||||
func (f *fakeSuperAdminIdentityAdmin) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
|
||||
return f.existingID, nil
|
||||
}
|
||||
|
||||
func (f *fakeSuperAdminIdentityAdmin) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
|
||||
if f.createdID == "" {
|
||||
return "", errors.New("created id is not configured")
|
||||
}
|
||||
f.createdUser = user
|
||||
f.createdSecret = password
|
||||
return f.createdID, nil
|
||||
}
|
||||
|
||||
func (f *fakeSuperAdminIdentityAdmin) UpdateIdentityPassword(ctx context.Context, identityID string, newPassword string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeSuperAdminStore struct {
|
||||
user *domain.User
|
||||
ketoSubjects []string
|
||||
}
|
||||
|
||||
func (f *fakeSuperAdminStore) FindUserByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
if f.user == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return f.user, nil
|
||||
}
|
||||
|
||||
func (f *fakeSuperAdminStore) CreateUser(ctx context.Context, user *domain.User) error {
|
||||
copied := *user
|
||||
f.user = &copied
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeSuperAdminStore) UpdateUserSuperAdmin(ctx context.Context, userID string, name string) (*domain.User, error) {
|
||||
if f.user == nil {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
f.user.Role = domain.RoleSuperAdmin
|
||||
f.user.Status = domain.UserStatusActive
|
||||
if name != "" {
|
||||
f.user.Name = name
|
||||
}
|
||||
return f.user, nil
|
||||
}
|
||||
|
||||
func (f *fakeSuperAdminStore) EnqueueSuperAdminRelation(ctx context.Context, userID string) error {
|
||||
f.ketoSubjects = append(f.ketoSubjects, "User:"+userID)
|
||||
return nil
|
||||
}
|
||||
116
baron-sso/backend/internal/bootstrap/bootstrap.go
Normal file
116
baron-sso/backend/internal/bootstrap/bootstrap.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Run executes the application bootstrap logic (migrations, seeding, etc.)
|
||||
func Run(db *gorm.DB) error {
|
||||
slog.Info("[Bootstrap] Starting application bootstrap...")
|
||||
|
||||
// 1. Auto Migration
|
||||
if err := migrateSchemas(db); err != nil {
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
// 2. Seed Tenants
|
||||
if err := SeedTenants(db); err != nil {
|
||||
return fmt.Errorf("tenant seeding failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Normalize staging seed/read-model data
|
||||
if err := CanonicalizeLegacyUserStatuses(db); err != nil {
|
||||
return fmt.Errorf("legacy user status canonicalization failed: %w", err)
|
||||
}
|
||||
if err := SanitizeLegacyUserMetadata(db); err != nil {
|
||||
return fmt.Errorf("legacy user metadata sanitize failed: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("[Bootstrap] User seed skipped (Kratos is SoT)")
|
||||
slog.Info("[Bootstrap] Bootstrap completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateSchemas(db *gorm.DB) error {
|
||||
slog.Info("[Bootstrap] Migrating database schemas...")
|
||||
if err := dropLegacyTenantDomainUniqueIndex(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := dropLegacyUserCompanyColumns(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add all domain models here
|
||||
return db.AutoMigrate(
|
||||
&domain.Tenant{},
|
||||
&domain.TenantDomain{},
|
||||
&domain.User{},
|
||||
&domain.UserLoginID{},
|
||||
&domain.UserProjectionState{},
|
||||
&domain.UserGroup{},
|
||||
&domain.ApiKey{},
|
||||
&domain.IdentityProviderConfig{},
|
||||
&domain.ClientSecret{},
|
||||
&domain.ClientConsent{},
|
||||
&domain.KetoOutbox{},
|
||||
&domain.RPUsageEvent{},
|
||||
&domain.WorksmobileOutbox{},
|
||||
&domain.WorksmobileResourceMapping{},
|
||||
&domain.SharedLink{},
|
||||
&domain.DeveloperRequest{},
|
||||
&domain.RPUserMetadata{},
|
||||
&domain.SystemSetting{},
|
||||
// &domain.RelyingParty{}, // Removed: SSOT is Hydra + Keto
|
||||
)
|
||||
}
|
||||
|
||||
func CanonicalizeLegacyUserStatuses(db *gorm.DB) error {
|
||||
if db == nil || !db.Migrator().HasTable(&domain.User{}) {
|
||||
return nil
|
||||
}
|
||||
updates := map[string]string{
|
||||
"inactive": domain.UserStatusPreboarding,
|
||||
"leave_of_absence": domain.UserStatusTemporaryLeave,
|
||||
"baron_only": domain.UserStatusBaronGuest,
|
||||
}
|
||||
for legacy, canonical := range updates {
|
||||
if err := db.Model(&domain.User{}).
|
||||
Where("status = ?", legacy).
|
||||
Update("status", canonical).Error; err != nil {
|
||||
return fmt.Errorf("failed to canonicalize users.status %s to %s: %w", legacy, canonical, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dropLegacyUserCompanyColumns(db *gorm.DB) error {
|
||||
if !db.Migrator().HasTable(&domain.User{}) {
|
||||
return nil
|
||||
}
|
||||
for _, column := range []string{"company_code", "company_codes"} {
|
||||
if !db.Migrator().HasColumn(&domain.User{}, column) {
|
||||
continue
|
||||
}
|
||||
if err := db.Migrator().DropColumn(&domain.User{}, column); err != nil {
|
||||
return fmt.Errorf("failed to drop legacy users.%s column: %w", column, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dropLegacyTenantDomainUniqueIndex(db *gorm.DB) error {
|
||||
if !db.Migrator().HasTable(&domain.TenantDomain{}) {
|
||||
return nil
|
||||
}
|
||||
if !db.Migrator().HasIndex(&domain.TenantDomain{}, "idx_tenant_domains_domain") {
|
||||
return nil
|
||||
}
|
||||
if err := db.Migrator().DropIndex(&domain.TenantDomain{}, "idx_tenant_domains_domain"); err != nil {
|
||||
return fmt.Errorf("failed to drop legacy tenant domain unique index: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
84
baron-sso/backend/internal/bootstrap/keto_sync.go
Normal file
84
baron-sso/backend/internal/bootstrap/keto_sync.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SyncKetoRelations synchronizes all existing DB users, tenants and RPs to Ory Keto via Outbox.
|
||||
// This ensures data consistency for existing data when ReBAC is introduced.
|
||||
func SyncKetoRelations(db *gorm.DB, outbox repository.KetoOutboxRepository) error {
|
||||
slog.Info("🚀 Starting Keto ReBAC relation synchronization (via Outbox)...")
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. Sync All Tenants
|
||||
var tenants []domain.Tenant
|
||||
if err := db.Find(&tenants).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("Syncing tenants to Keto Outbox", "count", len(tenants))
|
||||
for _, t := range tenants {
|
||||
// Global Super Admin access to every tenant
|
||||
_ = outbox.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: t.ID,
|
||||
Relation: "admins",
|
||||
Subject: "System:global#super_admins",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
|
||||
if t.ParentID != nil {
|
||||
_ = outbox.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: t.ID,
|
||||
Relation: "parents",
|
||||
Subject: "Tenant:" + *t.ParentID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Sync All RelyingParties (if needed)
|
||||
// Note: We'll need a way to list them from Hydra or local DB if we had them.
|
||||
// Assuming they are in a table domain.RelyingParty (though it was removed, let's see)
|
||||
// Actually, the comment said SSOT is Hydra. But we might have them in a local table for metadata.
|
||||
// If not, we skip for now or fetch from Hydra.
|
||||
|
||||
// 3. Sync All Users Roles and Tenant Memberships
|
||||
var users []domain.User
|
||||
if err := db.Find(&users).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("Syncing users to Keto Outbox", "count", len(users))
|
||||
for _, u := range users {
|
||||
// Tenant Membership
|
||||
if u.TenantID != nil {
|
||||
_ = outbox.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: *u.TenantID,
|
||||
Relation: "members",
|
||||
Subject: "User:" + u.ID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
|
||||
// Roles
|
||||
role := domain.NormalizeRole(u.Role)
|
||||
if role == domain.RoleSuperAdmin {
|
||||
_ = outbox.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "System",
|
||||
Object: "global",
|
||||
Relation: "super_admins",
|
||||
Subject: "User:" + u.ID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("✅ Keto ReBAC synchronization items added to Outbox.")
|
||||
return nil
|
||||
}
|
||||
75
baron-sso/backend/internal/bootstrap/kratos_seed.go
Normal file
75
baron-sso/backend/internal/bootstrap/kratos_seed.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SeedAdminIdentity creates the initial admin identity in the configured IDP.
|
||||
// Returns the Kratos Identity ID and error.
|
||||
func SeedAdminIdentity(idp domain.IdentityProvider) (string, error) {
|
||||
if idp == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
adminEmail := strings.TrimSpace(os.Getenv("ADMIN_EMAIL"))
|
||||
adminPassword := os.Getenv("ADMIN_PASSWORD")
|
||||
if adminEmail == "" || adminPassword == "" {
|
||||
slog.Warn("[Bootstrap] ADMIN_EMAIL or ADMIN_PASSWORD not set. Skipping admin identity seed.")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
adminName := strings.TrimSpace(os.Getenv("ADMIN_NAME"))
|
||||
if adminName == "" {
|
||||
adminName = "System Admin"
|
||||
}
|
||||
|
||||
user := &domain.BrokerUser{
|
||||
Email: adminEmail,
|
||||
Name: adminName,
|
||||
PhoneNumber: "",
|
||||
Attributes: map[string]any{
|
||||
"department": "Admin",
|
||||
"affiliationType": "internal",
|
||||
"grade": "",
|
||||
"role": "super_admin", // Explicitly set role for Kratos traits
|
||||
},
|
||||
}
|
||||
|
||||
// Retry logic for Kratos connection
|
||||
maxRetries := 5
|
||||
var err error
|
||||
var identityID string
|
||||
|
||||
for i := range maxRetries {
|
||||
identityID, err = idp.CreateUser(user, adminPassword)
|
||||
if err == nil {
|
||||
slog.Info("[Bootstrap] Admin identity created in IDP", "email", adminEmail, "idp", idp.Name(), "id", identityID)
|
||||
return identityID, nil
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "already exists") {
|
||||
slog.Info("[Bootstrap] Admin identity already exists in IDP. Attempting to retrieve ID...", "email", adminEmail)
|
||||
// Try to sign in to get the identity ID
|
||||
authInfo, err := idp.SignIn(adminEmail, adminPassword)
|
||||
if err == nil && authInfo != nil {
|
||||
slog.Info("[Bootstrap] Retrieved existing admin identity ID", "id", authInfo.Subject)
|
||||
return authInfo.Subject, nil
|
||||
}
|
||||
slog.Warn("[Bootstrap] Failed to retrieve existing admin identity ID via SignIn", "error", err)
|
||||
return "", nil // Return nil error to avoid stopping bootstrap, but ID is missing
|
||||
}
|
||||
|
||||
slog.Warn("[Bootstrap] Failed to seed admin identity (retrying...)",
|
||||
"attempt", i+1,
|
||||
"max_retries", maxRetries,
|
||||
"error", err,
|
||||
)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
77
baron-sso/backend/internal/bootstrap/sync_admin.go
Normal file
77
baron-sso/backend/internal/bootstrap/sync_admin.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SyncAdminRole updates the role of the admin user in the local DB.
|
||||
// It ensures the admin user exists in the local DB with the correct Kratos ID.
|
||||
func SyncAdminRole(db *gorm.DB, kratosID string) error {
|
||||
adminEmail := strings.TrimSpace(os.Getenv("ADMIN_EMAIL"))
|
||||
if adminEmail == "" {
|
||||
slog.Warn("[Bootstrap] ADMIN_EMAIL not set. Skipping admin role sync.")
|
||||
return nil
|
||||
}
|
||||
|
||||
adminName := strings.TrimSpace(os.Getenv("ADMIN_NAME"))
|
||||
if adminName == "" {
|
||||
adminName = "System Admin"
|
||||
}
|
||||
|
||||
// Find user by email
|
||||
var user domain.User
|
||||
if err := db.Where("email = ?", adminEmail).First(&user).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if kratosID == "" {
|
||||
slog.Warn("[Bootstrap] Admin user not found in local DB and Kratos ID is missing. Cannot create local user.", "email", adminEmail)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new admin user in local DB
|
||||
newUser := domain.User{
|
||||
ID: kratosID,
|
||||
Email: adminEmail,
|
||||
Name: adminName,
|
||||
Role: domain.RoleSuperAdmin,
|
||||
Status: "active",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Metadata: domain.JSONMap{"source": "bootstrap_seed"},
|
||||
}
|
||||
if err := db.Create(&newUser).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("[Bootstrap] Created admin user in local DB", "email", adminEmail, "id", kratosID)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Update role if needed
|
||||
updates := map[string]any{}
|
||||
if user.Role != domain.RoleSuperAdmin {
|
||||
updates["role"] = domain.RoleSuperAdmin
|
||||
}
|
||||
// Also ensure ID matches if it was somehow different (though changing PK is hard, at least log it)
|
||||
if kratosID != "" && user.ID != kratosID {
|
||||
slog.Warn("[Bootstrap] Admin user exists but ID mismatch with Kratos", "local_id", user.ID, "kratos_id", kratosID)
|
||||
// We generally don't change UUID PKs, just warn.
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := db.Model(&user).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("[Bootstrap] Updated admin user role to super_admin", "email", adminEmail)
|
||||
} else {
|
||||
slog.Info("[Bootstrap] Admin user already has super_admin role", "email", adminEmail)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
524
baron-sso/backend/internal/bootstrap/tenant_seed.go
Normal file
524
baron-sso/backend/internal/bootstrap/tenant_seed.go
Normal file
@@ -0,0 +1,524 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const seedTenantCSVPathEnv = "SEED_TENANT_CSV_PATH"
|
||||
|
||||
var seedTenantCSVPathCandidates = []string{
|
||||
"adminfront/seed-tenant.csv",
|
||||
"../adminfront/seed-tenant.csv",
|
||||
"../../adminfront/seed-tenant.csv",
|
||||
"../../../adminfront/seed-tenant.csv",
|
||||
"/app/adminfront/seed-tenant.csv",
|
||||
}
|
||||
|
||||
type InitialTenantConfig struct {
|
||||
TenantID string
|
||||
Name string
|
||||
Slug string
|
||||
Type string
|
||||
ParentSlug string
|
||||
Description string
|
||||
Domains []string
|
||||
Config domain.JSONMap
|
||||
}
|
||||
|
||||
func SeedTenants(db *gorm.DB) error {
|
||||
slog.Info("[Bootstrap] Checking initial tenant seed...")
|
||||
|
||||
configs, err := loadSeedTenantConfigs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(configs) == 0 {
|
||||
return errors.New("seed tenant csv has no tenant rows")
|
||||
}
|
||||
|
||||
existingSlugs, existingIDs, err := loadExistingTenantIdentitySet(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
missingConfigs := filterMissingSeedTenantConfigs(configs, existingSlugs, existingIDs)
|
||||
if len(missingConfigs) == 0 {
|
||||
slog.Info("[Bootstrap] Tenant seed skipped because all seed slugs already exist", "count", len(configs))
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Info(
|
||||
"[Bootstrap] Tenant seed will create missing seed tenants",
|
||||
"total", len(configs),
|
||||
"missing", len(missingConfigs),
|
||||
"existing", len(configs)-len(missingConfigs),
|
||||
)
|
||||
return seedTenantConfigs(db, missingConfigs)
|
||||
}
|
||||
|
||||
func loadExistingTenantIdentitySet(db *gorm.DB) (map[string]bool, map[string]bool, error) {
|
||||
var tenants []domain.Tenant
|
||||
if err := db.Select("id", "slug").Find(&tenants).Error; err != nil {
|
||||
return nil, nil, fmt.Errorf("load existing tenants before seed: %w", err)
|
||||
}
|
||||
|
||||
slugs := make(map[string]bool, len(tenants))
|
||||
ids := make(map[string]bool, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
slug := strings.TrimSpace(strings.ToLower(tenant.Slug))
|
||||
if slug != "" {
|
||||
slugs[slug] = true
|
||||
}
|
||||
id := strings.TrimSpace(strings.ToLower(tenant.ID))
|
||||
if id != "" {
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
return slugs, ids, nil
|
||||
}
|
||||
|
||||
func filterMissingSeedTenantConfigs(configs []InitialTenantConfig, existingSlugs map[string]bool, existingIDs map[string]bool) []InitialTenantConfig {
|
||||
filtered := make([]InitialTenantConfig, 0, len(configs))
|
||||
for _, config := range configs {
|
||||
slug := strings.TrimSpace(strings.ToLower(config.Slug))
|
||||
id := strings.TrimSpace(strings.ToLower(config.TenantID))
|
||||
if slug == "" || existingSlugs[slug] || (id != "" && existingIDs[id]) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, config)
|
||||
existingSlugs[slug] = true
|
||||
if id != "" {
|
||||
existingIDs[id] = true
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func seedTenantConfigs(db *gorm.DB, configs []InitialTenantConfig) error {
|
||||
slog.Info("[Bootstrap] Seeding initial tenants from CSV...", "count", len(configs))
|
||||
repo := repository.NewTenantRepository(db)
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
userGroupRepo := repository.NewUserGroupRepository(db)
|
||||
outboxRepo := repository.NewKetoOutboxRepository(db)
|
||||
svc := service.NewTenantService(repo, userRepo, userGroupRepo, outboxRepo)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, config := range orderSeedTenantConfigsByParentSlug(configs) {
|
||||
tenantType := config.Type
|
||||
if tenantType == "" {
|
||||
tenantType = domain.TenantTypeCompany
|
||||
}
|
||||
|
||||
var parentID *string
|
||||
if config.ParentSlug != "" {
|
||||
parent, err := repo.FindBySlug(ctx, config.ParentSlug)
|
||||
if err != nil || parent == nil {
|
||||
if err == nil {
|
||||
err = errors.New("parent tenant not found")
|
||||
}
|
||||
slog.Error("Failed to resolve parent tenant for seed", "slug", config.Slug, "parentSlug", config.ParentSlug, "error", err)
|
||||
return fmt.Errorf("resolve parent tenant %q for seed %q: %w", config.ParentSlug, config.Slug, err)
|
||||
}
|
||||
parentID = &parent.ID
|
||||
}
|
||||
|
||||
slog.Info("[Bootstrap] Creating seed tenant", "name", config.Name, "slug", config.Slug)
|
||||
var tenant *domain.Tenant
|
||||
var err error
|
||||
if config.TenantID != "" {
|
||||
tenant, err = createSeedTenant(ctx, repo, outboxRepo, config, tenantType, parentID)
|
||||
} else {
|
||||
tenant, err = svc.RegisterTenant(ctx, config.Name, config.Slug, tenantType, config.Description, config.Domains, parentID, "")
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("Failed to seed tenant", "slug", config.Slug, "error", err)
|
||||
return err
|
||||
}
|
||||
tenant.Status = domain.TenantStatusActive
|
||||
if len(config.Config) > 0 {
|
||||
tenant.Config = config.Config
|
||||
}
|
||||
if err := db.Save(tenant).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadSeedTenantConfigs() ([]InitialTenantConfig, error) {
|
||||
path, err := findSeedTenantCSVPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open seed tenant csv %q: %w", path, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
configs, err := parseSeedTenantCSV(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse seed tenant csv %q: %w", path, err)
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func SeedTenantSlugSet() (map[string]bool, error) {
|
||||
configs, err := loadSeedTenantConfigs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slugs := make(map[string]bool, len(configs))
|
||||
for _, config := range configs {
|
||||
slug := strings.TrimSpace(strings.ToLower(config.Slug))
|
||||
if slug != "" {
|
||||
slugs[slug] = true
|
||||
}
|
||||
}
|
||||
return slugs, nil
|
||||
}
|
||||
|
||||
func IsSeedTenantSlug(slug string) bool {
|
||||
normalized := strings.TrimSpace(strings.ToLower(slug))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
slugs, err := SeedTenantSlugSet()
|
||||
if err != nil {
|
||||
slog.Warn("[Bootstrap] Failed to load seed tenant slug set", "error", err)
|
||||
return false
|
||||
}
|
||||
return slugs[normalized]
|
||||
}
|
||||
|
||||
func findSeedTenantCSVPath() (string, error) {
|
||||
if configured := strings.TrimSpace(os.Getenv(seedTenantCSVPathEnv)); configured != "" {
|
||||
return configured, nil
|
||||
}
|
||||
|
||||
for _, candidate := range seedTenantCSVPathCandidates {
|
||||
cleaned := filepath.Clean(candidate)
|
||||
if _, err := os.Stat(cleaned); err == nil {
|
||||
return cleaned, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("seed tenant csv not found; set %s or add adminfront/seed-tenant.csv", seedTenantCSVPathEnv)
|
||||
}
|
||||
|
||||
func parseSeedTenantCSV(r io.Reader) ([]InitialTenantConfig, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read csv")
|
||||
}
|
||||
data = bytes.TrimPrefix(data, []byte{0xEF, 0xBB, 0xBF})
|
||||
|
||||
reader := csv.NewReader(bytes.NewReader(data))
|
||||
reader.FieldsPerRecord = -1
|
||||
rows, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid csv: %w", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return nil, errors.New("csv is empty")
|
||||
}
|
||||
|
||||
header := seedTenantCSVHeaderIndex(rows[0])
|
||||
for _, key := range []string{"name", "type", "slug"} {
|
||||
if _, ok := header[key]; !ok {
|
||||
return nil, fmt.Errorf("missing required column: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
configs := make([]InitialTenantConfig, 0, len(rows)-1)
|
||||
for i, row := range rows[1:] {
|
||||
if seedTenantCSVRowIsEmpty(row) {
|
||||
continue
|
||||
}
|
||||
|
||||
name := seedTenantCSVValue(row, header, "name")
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("row %d: name is required", i+2)
|
||||
}
|
||||
|
||||
tenantType := normalizeSeedTenantType(seedTenantCSVValue(row, header, "type"))
|
||||
if tenantType == "" {
|
||||
return nil, fmt.Errorf("row %d: invalid tenant type", i+2)
|
||||
}
|
||||
|
||||
slug := utils.GenerateSlug(seedTenantCSVValue(row, header, "slug"))
|
||||
if slug == "" {
|
||||
return nil, fmt.Errorf("row %d: slug is required", i+2)
|
||||
}
|
||||
|
||||
config, err := seedTenantCSVRecordConfig(row, header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("row %d: %w", i+2, err)
|
||||
}
|
||||
|
||||
configs = append(configs, InitialTenantConfig{
|
||||
TenantID: seedTenantCSVValue(row, header, "tenant_id"),
|
||||
Name: name,
|
||||
Type: tenantType,
|
||||
ParentSlug: seedTenantCSVValue(row, header, "parent_tenant_slug"),
|
||||
Slug: slug,
|
||||
Description: seedTenantCSVValue(row, header, "memo"),
|
||||
Domains: splitSeedTenantCSVDomains(seedTenantCSVValue(row, header, "email_domain")),
|
||||
Config: config,
|
||||
})
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func seedTenantCSVHeaderIndex(header []string) map[string]int {
|
||||
index := make(map[string]int, len(header))
|
||||
aliases := map[string]string{
|
||||
"id": "tenant_id",
|
||||
"tenantid": "tenant_id",
|
||||
"tenant_id": "tenant_id",
|
||||
"name": "name",
|
||||
"type": "type",
|
||||
"parenttenantslug": "parent_tenant_slug",
|
||||
"parent_tenant_slug": "parent_tenant_slug",
|
||||
"parent_slug": "parent_tenant_slug",
|
||||
"slug": "slug",
|
||||
"memo": "memo",
|
||||
"description": "memo",
|
||||
"email-domain": "email_domain",
|
||||
"emaildomain": "email_domain",
|
||||
"email_domain": "email_domain",
|
||||
"domain": "email_domain",
|
||||
"domains": "email_domain",
|
||||
"visibility": "visibility",
|
||||
"public_setting": "visibility",
|
||||
"publicsetting": "visibility",
|
||||
"org_unit_type": "org_unit_type",
|
||||
"orgunittype": "org_unit_type",
|
||||
"organization_type": "org_unit_type",
|
||||
"organizationtype": "org_unit_type",
|
||||
"worksmobile": "worksmobile_sync",
|
||||
"worksmobilesync": "worksmobile_sync",
|
||||
"worksmobile_sync": "worksmobile_sync",
|
||||
"works_sync": "worksmobile_sync",
|
||||
"works": "worksmobile_sync",
|
||||
}
|
||||
for i, column := range header {
|
||||
key := strings.ToLower(strings.TrimSpace(column))
|
||||
key = strings.ReplaceAll(key, " ", "_")
|
||||
if canonical, ok := aliases[key]; ok {
|
||||
index[canonical] = i
|
||||
}
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
func seedTenantCSVValue(row []string, header map[string]int, key string) string {
|
||||
idx, ok := header[key]
|
||||
if !ok || idx >= len(row) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(row[idx])
|
||||
}
|
||||
|
||||
func seedTenantCSVRecordConfig(row []string, header map[string]int) (domain.JSONMap, error) {
|
||||
config := domain.JSONMap{}
|
||||
visibility := strings.TrimSpace(seedTenantCSVValue(row, header, "visibility"))
|
||||
if visibility != "" {
|
||||
normalizedVisibility, err := normalizeSeedTenantVisibility(visibility)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config["visibility"] = normalizedVisibility
|
||||
}
|
||||
orgUnitType := strings.TrimSpace(seedTenantCSVValue(row, header, "org_unit_type"))
|
||||
if orgUnitType != "" {
|
||||
if !isAllowedSeedTenantOrgUnitType(orgUnitType) {
|
||||
return nil, errors.New("orgUnitType must be one of 실, 팀, TF, TF팀, 센터, 디비전, 셀, 본부, 지역본부, 부, 임원직속")
|
||||
}
|
||||
config["orgUnitType"] = orgUnitType
|
||||
}
|
||||
if worksmobileSync := strings.TrimSpace(seedTenantCSVValue(row, header, "worksmobile_sync")); worksmobileSync != "" {
|
||||
excluded, err := normalizeSeedTenantWorksmobileExcluded(worksmobileSync)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config["worksmobileExcluded"] = excluded
|
||||
}
|
||||
if len(config) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func normalizeSeedTenantWorksmobileExcluded(value string) (bool, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "", "yes", "y", "true", "1", "on", "sync", "linked", "연동":
|
||||
return false, nil
|
||||
case "no", "n", "false", "0", "off", "none", "excluded", "exclude", "not_sync", "not-synced", "미연동", "연동안함", "제외":
|
||||
return true, nil
|
||||
default:
|
||||
return false, errors.New("worksmobile_sync must be yes or no")
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSeedTenantVisibility(value string) (string, error) {
|
||||
visibility := strings.ToLower(strings.TrimSpace(value))
|
||||
if visibility == "" || visibility == "public" {
|
||||
return "public", nil
|
||||
}
|
||||
if visibility != "internal" && visibility != "private" {
|
||||
return "", errors.New("visibility must be public, internal, or private")
|
||||
}
|
||||
return visibility, nil
|
||||
}
|
||||
|
||||
func isAllowedSeedTenantOrgUnitType(value string) bool {
|
||||
switch strings.TrimSpace(value) {
|
||||
case "실", "팀", "TF", "TF팀", "센터", "디비전", "셀", "본부", "지역본부", "부", "임원직속":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func seedTenantCSVRowIsEmpty(row []string) bool {
|
||||
for _, value := range row {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeSeedTenantType(value string) string {
|
||||
switch strings.ToUpper(strings.TrimSpace(value)) {
|
||||
case domain.TenantTypePersonal:
|
||||
return domain.TenantTypePersonal
|
||||
case domain.TenantTypeCompany:
|
||||
return domain.TenantTypeCompany
|
||||
case domain.TenantTypeCompanyGroup:
|
||||
return domain.TenantTypeCompanyGroup
|
||||
case domain.TenantTypeOrganization:
|
||||
return domain.TenantTypeOrganization
|
||||
case domain.TenantTypeUserGroup:
|
||||
return domain.TenantTypeUserGroup
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func splitSeedTenantCSVDomains(value string) []string {
|
||||
value = strings.ReplaceAll(value, "\n", ";")
|
||||
value = strings.ReplaceAll(value, ",", ";")
|
||||
parts := strings.Split(value, ";")
|
||||
domains := make([]string, 0, len(parts))
|
||||
seen := make(map[string]bool, len(parts))
|
||||
for _, part := range parts {
|
||||
domainName := strings.ToLower(strings.TrimSpace(part))
|
||||
if domainName == "" || seen[domainName] {
|
||||
continue
|
||||
}
|
||||
seen[domainName] = true
|
||||
domains = append(domains, domainName)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
func orderSeedTenantConfigsByParentSlug(configs []InitialTenantConfig) []InitialTenantConfig {
|
||||
bySlug := make(map[string]InitialTenantConfig, len(configs))
|
||||
for _, config := range configs {
|
||||
bySlug[strings.ToLower(config.Slug)] = config
|
||||
}
|
||||
|
||||
ordered := make([]InitialTenantConfig, 0, len(configs))
|
||||
visited := make(map[string]bool, len(configs))
|
||||
var visit func(config InitialTenantConfig)
|
||||
visit = func(config InitialTenantConfig) {
|
||||
key := strings.ToLower(config.Slug)
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
if config.ParentSlug != "" {
|
||||
if parent, ok := bySlug[strings.ToLower(config.ParentSlug)]; ok {
|
||||
visit(parent)
|
||||
}
|
||||
}
|
||||
visited[key] = true
|
||||
ordered = append(ordered, config)
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
visit(config)
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
|
||||
func createSeedTenant(
|
||||
ctx context.Context,
|
||||
repo repository.TenantRepository,
|
||||
outboxRepo repository.KetoOutboxRepository,
|
||||
config InitialTenantConfig,
|
||||
tenantType string,
|
||||
parentID *string,
|
||||
) (*domain.Tenant, error) {
|
||||
tenant := &domain.Tenant{
|
||||
ID: config.TenantID,
|
||||
Type: tenantType,
|
||||
Name: config.Name,
|
||||
Slug: config.Slug,
|
||||
Description: config.Description,
|
||||
Status: domain.TenantStatusActive,
|
||||
ParentID: parentID,
|
||||
Config: config.Config,
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, tenant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "admins",
|
||||
Subject: "System:global#super_admins",
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tenant.ParentID != nil {
|
||||
if err := outboxRepo.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "Tenant",
|
||||
Object: tenant.ID,
|
||||
Relation: "parents",
|
||||
Subject: "Tenant:" + *tenant.ParentID,
|
||||
Action: domain.KetoOutboxActionCreate,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for _, domainName := range config.Domains {
|
||||
if err := repo.AddDomain(ctx, tenant.ID, domainName, true); err != nil {
|
||||
slog.Error("Failed to add domain to seeded tenant", "tenant", config.Slug, "domain", domainName, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return repo.FindBySlug(ctx, config.Slug)
|
||||
}
|
||||
388
baron-sso/backend/internal/bootstrap/tenant_seed_test.go
Normal file
388
baron-sso/backend/internal/bootstrap/tenant_seed_test.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/testsupport"
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
gorm_postgres "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestSeedTenantCSVDefinesWorksmobileDomainClassTenants(t *testing.T) {
|
||||
configs, err := loadSeedTenantConfigs()
|
||||
if err != nil {
|
||||
t.Fatalf("loadSeedTenantConfigs returned error: %v", err)
|
||||
}
|
||||
|
||||
expected := []struct {
|
||||
name string
|
||||
slug string
|
||||
tenantType string
|
||||
parentSlug string
|
||||
domains []string
|
||||
}{
|
||||
{
|
||||
name: "한맥가족",
|
||||
slug: "hanmac-family",
|
||||
tenantType: domain.TenantTypeCompanyGroup,
|
||||
},
|
||||
{
|
||||
name: "삼안",
|
||||
slug: "saman",
|
||||
tenantType: domain.TenantTypeCompany,
|
||||
parentSlug: "hanmac-family",
|
||||
domains: []string{"samaneng.com"},
|
||||
},
|
||||
{
|
||||
name: "한맥기술",
|
||||
slug: "hanmac",
|
||||
tenantType: domain.TenantTypeCompany,
|
||||
parentSlug: "hanmac-family",
|
||||
domains: []string{"hanmaceng.co.kr"},
|
||||
},
|
||||
{
|
||||
name: "총괄기획&기술개발센터",
|
||||
slug: "gpdtdc",
|
||||
tenantType: domain.TenantTypeCompany,
|
||||
parentSlug: "hanmac-family",
|
||||
domains: []string{"baroncs.co.kr"},
|
||||
},
|
||||
{
|
||||
name: "바론그룹",
|
||||
slug: "baron-group",
|
||||
tenantType: domain.TenantTypeCompanyGroup,
|
||||
parentSlug: "hanmac-family",
|
||||
domains: []string{"brsw.kr"},
|
||||
},
|
||||
{
|
||||
name: "(주)장헌",
|
||||
slug: "jangheon",
|
||||
tenantType: domain.TenantTypeCompany,
|
||||
parentSlug: "baron-group",
|
||||
domains: []string{"jangheon.com"},
|
||||
},
|
||||
{
|
||||
name: "장헌산업",
|
||||
slug: "jangheon-sanup",
|
||||
tenantType: domain.TenantTypeCompany,
|
||||
parentSlug: "baron-group",
|
||||
domains: []string{"jangheon.co.kr"},
|
||||
},
|
||||
{
|
||||
name: "한라산업개발",
|
||||
slug: "halla",
|
||||
tenantType: domain.TenantTypeCompany,
|
||||
parentSlug: "hanmac-family",
|
||||
domains: []string{"hallasanup.com"},
|
||||
},
|
||||
{
|
||||
name: "(주)피티씨",
|
||||
slug: "ptc",
|
||||
tenantType: domain.TenantTypeCompany,
|
||||
parentSlug: "baron-group",
|
||||
domains: []string{"pre-cast.co.kr"},
|
||||
},
|
||||
{
|
||||
name: "Personal",
|
||||
slug: "personal",
|
||||
tenantType: domain.TenantTypePersonal,
|
||||
},
|
||||
}
|
||||
|
||||
if len(configs) < len(expected) {
|
||||
t.Fatalf("expected at least %d seed tenants, got %d", len(expected), len(configs))
|
||||
}
|
||||
|
||||
wantFamilyChildOrder := []string{
|
||||
"gpdtdc",
|
||||
"saman",
|
||||
"hanmac",
|
||||
"baron-group",
|
||||
"halla",
|
||||
}
|
||||
policyFamilyChildSlugs := map[string]bool{}
|
||||
for _, slug := range wantFamilyChildOrder {
|
||||
policyFamilyChildSlugs[slug] = true
|
||||
}
|
||||
gotFamilyChildOrder := make([]string, 0, len(wantFamilyChildOrder))
|
||||
for _, config := range configs {
|
||||
if config.ParentSlug == "hanmac-family" && policyFamilyChildSlugs[config.Slug] {
|
||||
gotFamilyChildOrder = append(gotFamilyChildOrder, config.Slug)
|
||||
}
|
||||
}
|
||||
if len(gotFamilyChildOrder) != len(wantFamilyChildOrder) {
|
||||
t.Fatalf("hanmac-family child order = %#v, want %#v", gotFamilyChildOrder, wantFamilyChildOrder)
|
||||
}
|
||||
for i, wantSlug := range wantFamilyChildOrder {
|
||||
if gotFamilyChildOrder[i] != wantSlug {
|
||||
t.Fatalf("hanmac-family child order[%d] = %q, want %q", i, gotFamilyChildOrder[i], wantSlug)
|
||||
}
|
||||
}
|
||||
|
||||
configBySlug := make(map[string]InitialTenantConfig, len(configs))
|
||||
for _, config := range configs {
|
||||
configBySlug[config.Slug] = config
|
||||
}
|
||||
|
||||
for _, want := range expected {
|
||||
got, ok := configBySlug[want.slug]
|
||||
if !ok {
|
||||
t.Fatalf("tenant slug %q not found in seed configs", want.slug)
|
||||
}
|
||||
if got.Name != want.name {
|
||||
t.Fatalf("tenant[%s] name = %q, want %q", want.slug, got.Name, want.name)
|
||||
}
|
||||
if got.Slug != want.slug {
|
||||
t.Fatalf("tenant[%s] slug = %q, want %q", want.slug, got.Slug, want.slug)
|
||||
}
|
||||
if got.Type != want.tenantType {
|
||||
t.Fatalf("tenant[%s] type = %q, want %q", want.slug, got.Type, want.tenantType)
|
||||
}
|
||||
if got.ParentSlug != want.parentSlug {
|
||||
t.Fatalf("tenant[%s] parent slug = %q, want %q", want.slug, got.ParentSlug, want.parentSlug)
|
||||
}
|
||||
if len(got.Domains) != len(want.domains) {
|
||||
t.Fatalf("tenant[%s] domains = %#v, want %#v", want.slug, got.Domains, want.domains)
|
||||
}
|
||||
for j, wantDomain := range want.domains {
|
||||
if got.Domains[j] != wantDomain {
|
||||
t.Fatalf("tenant[%s] domain[%d] = %q, want %q", want.slug, j, got.Domains[j], wantDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSeedTenantTypeAllowsOrganization(t *testing.T) {
|
||||
if got := normalizeSeedTenantType("organization"); got != domain.TenantTypeOrganization {
|
||||
t.Fatalf("normalizeSeedTenantType(organization) = %q, want %q", got, domain.TenantTypeOrganization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSeedTenantConfigsUsesConfiguredCSVPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "seed-tenant.csv")
|
||||
csv := "name,type,parent_tenant_slug,slug,memo,email_domain,visibility,org_unit_type,worksmobile_sync\n" +
|
||||
"Root,COMPANY_GROUP,,root,Root memo,,,,\n" +
|
||||
"Child,USER_GROUP,root,child,Child memo,child.example.com,private,팀,no\n"
|
||||
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
|
||||
t.Fatalf("failed to write seed csv: %v", err)
|
||||
}
|
||||
t.Setenv(seedTenantCSVPathEnv, path)
|
||||
|
||||
configs, err := loadSeedTenantConfigs()
|
||||
if err != nil {
|
||||
t.Fatalf("loadSeedTenantConfigs returned error: %v", err)
|
||||
}
|
||||
if len(configs) != 2 {
|
||||
t.Fatalf("expected 2 configs, got %d", len(configs))
|
||||
}
|
||||
if configs[1].ParentSlug != "root" {
|
||||
t.Fatalf("child parent slug = %q, want root", configs[1].ParentSlug)
|
||||
}
|
||||
if len(configs[1].Domains) != 1 || configs[1].Domains[0] != "child.example.com" {
|
||||
t.Fatalf("child domains = %#v, want child.example.com", configs[1].Domains)
|
||||
}
|
||||
if configs[1].Config["visibility"] != "private" {
|
||||
t.Fatalf("child visibility = %#v, want private", configs[1].Config["visibility"])
|
||||
}
|
||||
if configs[1].Config["orgUnitType"] != "팀" {
|
||||
t.Fatalf("child orgUnitType = %#v, want 팀", configs[1].Config["orgUnitType"])
|
||||
}
|
||||
if configs[1].Config["worksmobileExcluded"] != true {
|
||||
t.Fatalf("child worksmobileExcluded = %#v, want true", configs[1].Config["worksmobileExcluded"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeedTenantCSVDefinesMHDAsPrivateUserGroup(t *testing.T) {
|
||||
configs, err := loadSeedTenantConfigs()
|
||||
if err != nil {
|
||||
t.Fatalf("loadSeedTenantConfigs returned error: %v", err)
|
||||
}
|
||||
|
||||
configBySlug := make(map[string]InitialTenantConfig, len(configs))
|
||||
for _, config := range configs {
|
||||
configBySlug[config.Slug] = config
|
||||
}
|
||||
|
||||
mhd, ok := configBySlug["mhd"]
|
||||
if !ok {
|
||||
t.Fatal("mhd seed tenant not found")
|
||||
}
|
||||
if mhd.Type != domain.TenantTypeUserGroup {
|
||||
t.Fatalf("mhd type = %q, want %q", mhd.Type, domain.TenantTypeUserGroup)
|
||||
}
|
||||
if mhd.Config["visibility"] != "private" {
|
||||
t.Fatalf("mhd visibility = %#v, want private", mhd.Config["visibility"])
|
||||
}
|
||||
if mhd.Config["worksmobileExcluded"] != true {
|
||||
t.Fatalf("mhd worksmobileExcluded = %#v, want true", mhd.Config["worksmobileExcluded"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSeedTenantSlugUsesConfiguredCSVPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "seed-tenant.csv")
|
||||
csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" +
|
||||
"Root,COMPANY_GROUP,,protected-root,Root memo,\n"
|
||||
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
|
||||
t.Fatalf("failed to write seed csv: %v", err)
|
||||
}
|
||||
t.Setenv(seedTenantCSVPathEnv, path)
|
||||
|
||||
if !IsSeedTenantSlug("protected-root") {
|
||||
t.Fatal("protected-root must be detected as seed tenant")
|
||||
}
|
||||
if IsSeedTenantSlug("normal-tenant") {
|
||||
t.Fatal("normal-tenant must not be detected as seed tenant")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterMissingSeedTenantConfigsSkipsExistingSlugs(t *testing.T) {
|
||||
configs := []InitialTenantConfig{
|
||||
{TenantID: "existing-root-id", Name: "Existing Root", Slug: "existing-root"},
|
||||
{Name: "Missing Child", Slug: "missing-child", ParentSlug: "existing-root"},
|
||||
{TenantID: "existing-child-id", Name: "Existing Child", Slug: "existing-child", ParentSlug: "existing-root"},
|
||||
{TenantID: "existing-other-id", Name: "Conflicting ID", Slug: "new-slug"},
|
||||
}
|
||||
existingSlugs := map[string]bool{
|
||||
"existing-root": true,
|
||||
"existing-child": true,
|
||||
}
|
||||
existingIDs := map[string]bool{
|
||||
"existing-root-id": true,
|
||||
"existing-child-id": true,
|
||||
"existing-other-id": true,
|
||||
}
|
||||
|
||||
filtered := filterMissingSeedTenantConfigs(configs, existingSlugs, existingIDs)
|
||||
|
||||
if len(filtered) != 1 {
|
||||
t.Fatalf("filtered count = %d, want 1: %#v", len(filtered), filtered)
|
||||
}
|
||||
if filtered[0].Slug != "missing-child" {
|
||||
t.Fatalf("filtered slug = %q, want missing-child", filtered[0].Slug)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeedTenantsCreatesMissingSeedRowsWithoutTouchingExistingSlugs(t *testing.T) {
|
||||
if !testsupport.DockerAvailable() {
|
||||
t.Skip("Docker provider is unavailable in this environment")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
postgresContainer, err := postgres_module.Run(ctx,
|
||||
"postgres:16-alpine",
|
||||
postgres_module.WithDatabase("testdb"),
|
||||
postgres_module.WithUsername("user"),
|
||||
postgres_module.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(30*time.Second),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start postgres container: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := postgresContainer.Terminate(ctx); err != nil {
|
||||
log.Printf("failed to terminate postgres container: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get postgres connection string: %v", err)
|
||||
}
|
||||
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open postgres connection: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&domain.Tenant{}, &domain.TenantDomain{}, &domain.KetoOutbox{}); err != nil {
|
||||
t.Fatalf("failed to migrate seed test tables: %v", err)
|
||||
}
|
||||
|
||||
existingRoot := domain.Tenant{
|
||||
ID: "00000000-0000-0000-0000-000000000001",
|
||||
Name: "Existing Root Name",
|
||||
Slug: "existing-root",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
Description: "manual tenant must not be overwritten",
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
nonSeedTenant := domain.Tenant{
|
||||
ID: "00000000-0000-0000-0000-000000000002",
|
||||
Name: "Manual Tenant",
|
||||
Slug: "manual-tenant",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
if err := db.Create(&existingRoot).Error; err != nil {
|
||||
t.Fatalf("failed to create existing root tenant: %v", err)
|
||||
}
|
||||
if err := db.Create(&nonSeedTenant).Error; err != nil {
|
||||
t.Fatalf("failed to create non-seed tenant: %v", err)
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "seed-tenant.csv")
|
||||
csv := "id,name,type,parent_tenant_slug,slug,memo,email_domain\n" +
|
||||
"10000000-0000-0000-0000-000000000001,Seed Root Name,COMPANY_GROUP,,existing-root,seed must be skipped,\n" +
|
||||
"00000000-0000-0000-0000-000000000002,Conflicting ID,COMPANY,existing-root,conflicting-id,seed id must be skipped,\n" +
|
||||
"10000000-0000-0000-0000-000000000002,Missing Child,COMPANY,existing-root,missing-child,created from seed,child.example.com\n"
|
||||
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
|
||||
t.Fatalf("failed to write seed csv: %v", err)
|
||||
}
|
||||
t.Setenv(seedTenantCSVPathEnv, path)
|
||||
|
||||
if err := SeedTenants(db); err != nil {
|
||||
t.Fatalf("SeedTenants returned error: %v", err)
|
||||
}
|
||||
|
||||
var root domain.Tenant
|
||||
if err := db.First(&root, "slug = ?", "existing-root").Error; err != nil {
|
||||
t.Fatalf("failed to load existing root after seed: %v", err)
|
||||
}
|
||||
if root.ID != existingRoot.ID {
|
||||
t.Fatalf("existing root ID = %q, want %q", root.ID, existingRoot.ID)
|
||||
}
|
||||
if root.Name != existingRoot.Name {
|
||||
t.Fatalf("existing root name = %q, want untouched %q", root.Name, existingRoot.Name)
|
||||
}
|
||||
|
||||
var child domain.Tenant
|
||||
if err := db.Preload("Domains").First(&child, "slug = ?", "missing-child").Error; err != nil {
|
||||
t.Fatalf("missing seed child was not created: %v", err)
|
||||
}
|
||||
if child.ParentID == nil || *child.ParentID != existingRoot.ID {
|
||||
t.Fatalf("child parent ID = %v, want %q", child.ParentID, existingRoot.ID)
|
||||
}
|
||||
if len(child.Domains) != 1 || child.Domains[0].Domain != "child.example.com" {
|
||||
t.Fatalf("child domains = %#v, want child.example.com", child.Domains)
|
||||
}
|
||||
|
||||
var rootCount int64
|
||||
if err := db.Model(&domain.Tenant{}).Where("slug = ?", "existing-root").Count(&rootCount).Error; err != nil {
|
||||
t.Fatalf("failed to count existing root rows: %v", err)
|
||||
}
|
||||
if rootCount != 1 {
|
||||
t.Fatalf("existing-root row count = %d, want 1", rootCount)
|
||||
}
|
||||
|
||||
var conflictingIDCount int64
|
||||
if err := db.Model(&domain.Tenant{}).Where("slug = ?", "conflicting-id").Count(&conflictingIDCount).Error; err != nil {
|
||||
t.Fatalf("failed to count conflicting-id rows: %v", err)
|
||||
}
|
||||
if conflictingIDCount != 0 {
|
||||
t.Fatalf("conflicting-id row count = %d, want 0", conflictingIDCount)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const sanitizeLegacyUserMetadataSQL = `
|
||||
update users
|
||||
set metadata = metadata - 'hanmacFamily' - 'userType',
|
||||
updated_at = now()
|
||||
where metadata ? 'hanmacFamily'
|
||||
or metadata ? 'userType'
|
||||
`
|
||||
|
||||
// SanitizeLegacyUserMetadata removes legacy UI classification flags from Baron user metadata.
|
||||
func SanitizeLegacyUserMetadata(db *gorm.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("database is not configured")
|
||||
}
|
||||
if !db.Migrator().HasTable("users") {
|
||||
slog.Info("[Bootstrap] Legacy user metadata sanitize skipped because users table does not exist")
|
||||
return nil
|
||||
}
|
||||
|
||||
result := db.Exec(sanitizeLegacyUserMetadataSQL)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("sanitize legacy user metadata: %w", result.Error)
|
||||
}
|
||||
slog.Info("[Bootstrap] Legacy user metadata sanitized", "rowsAffected", result.RowsAffected)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/testsupport"
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
gorm_postgres "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestSanitizeLegacyUserMetadataRemovesClassificationFlags(t *testing.T) {
|
||||
db := openBootstrapPostgresTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
||||
t.Fatalf("failed to migrate users table: %v", err)
|
||||
}
|
||||
|
||||
user := domain.User{
|
||||
ID: "10000000-0000-0000-0000-000000000001",
|
||||
Email: "legacy@example.com",
|
||||
Name: "Legacy User",
|
||||
Role: domain.RoleUser,
|
||||
Status: domain.UserStatusActive,
|
||||
Metadata: domain.JSONMap{
|
||||
"hanmacFamily": true,
|
||||
"userType": "hanmac",
|
||||
"employeeId": "E001",
|
||||
"nested": map[string]any{
|
||||
"userType": "must stay nested",
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
t.Fatalf("failed to create user: %v", err)
|
||||
}
|
||||
|
||||
if err := SanitizeLegacyUserMetadata(db); err != nil {
|
||||
t.Fatalf("SanitizeLegacyUserMetadata returned error: %v", err)
|
||||
}
|
||||
if err := SanitizeLegacyUserMetadata(db); err != nil {
|
||||
t.Fatalf("SanitizeLegacyUserMetadata must be idempotent: %v", err)
|
||||
}
|
||||
|
||||
var got domain.User
|
||||
if err := db.First(&got, "id = ?", user.ID).Error; err != nil {
|
||||
t.Fatalf("failed to load sanitized user: %v", err)
|
||||
}
|
||||
if _, ok := got.Metadata["hanmacFamily"]; ok {
|
||||
t.Fatalf("hanmacFamily must be removed from metadata: %#v", got.Metadata)
|
||||
}
|
||||
if _, ok := got.Metadata["userType"]; ok {
|
||||
t.Fatalf("userType must be removed from metadata: %#v", got.Metadata)
|
||||
}
|
||||
if got.Metadata["employeeId"] != "E001" {
|
||||
t.Fatalf("employeeId = %#v, want E001", got.Metadata["employeeId"])
|
||||
}
|
||||
nested, ok := got.Metadata["nested"].(map[string]any)
|
||||
if !ok || nested["userType"] != "must stay nested" {
|
||||
t.Fatalf("nested metadata must be preserved: %#v", got.Metadata["nested"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalizeLegacyUserStatuses(t *testing.T) {
|
||||
db := openBootstrapPostgresTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
||||
t.Fatalf("failed to migrate users table: %v", err)
|
||||
}
|
||||
|
||||
users := []domain.User{
|
||||
{ID: "11000000-0000-0000-0000-000000000001", Email: "inactive@example.com", Name: "Inactive", Role: domain.RoleUser, Status: "inactive"},
|
||||
{ID: "11000000-0000-0000-0000-000000000002", Email: "leave@example.com", Name: "Leave", Role: domain.RoleUser, Status: "leave_of_absence"},
|
||||
{ID: "11000000-0000-0000-0000-000000000003", Email: "baron-only@example.com", Name: "Baron Only", Role: domain.RoleUser, Status: "baron_only"},
|
||||
{ID: "11000000-0000-0000-0000-000000000004", Email: "active@example.com", Name: "Active", Role: domain.RoleUser, Status: domain.UserStatusActive},
|
||||
}
|
||||
if err := db.Create(&users).Error; err != nil {
|
||||
t.Fatalf("failed to create users: %v", err)
|
||||
}
|
||||
|
||||
if err := CanonicalizeLegacyUserStatuses(db); err != nil {
|
||||
t.Fatalf("CanonicalizeLegacyUserStatuses returned error: %v", err)
|
||||
}
|
||||
if err := CanonicalizeLegacyUserStatuses(db); err != nil {
|
||||
t.Fatalf("CanonicalizeLegacyUserStatuses must be idempotent: %v", err)
|
||||
}
|
||||
|
||||
got := map[string]string{}
|
||||
var loaded []domain.User
|
||||
if err := db.Find(&loaded).Error; err != nil {
|
||||
t.Fatalf("failed to load users: %v", err)
|
||||
}
|
||||
for _, user := range loaded {
|
||||
got[user.Email] = user.Status
|
||||
}
|
||||
|
||||
if got["inactive@example.com"] != domain.UserStatusPreboarding {
|
||||
t.Fatalf("inactive status = %q, want %q", got["inactive@example.com"], domain.UserStatusPreboarding)
|
||||
}
|
||||
if got["leave@example.com"] != domain.UserStatusTemporaryLeave {
|
||||
t.Fatalf("leave status = %q, want %q", got["leave@example.com"], domain.UserStatusTemporaryLeave)
|
||||
}
|
||||
if got["baron-only@example.com"] != domain.UserStatusBaronGuest {
|
||||
t.Fatalf("baron_only status = %q, want %q", got["baron-only@example.com"], domain.UserStatusBaronGuest)
|
||||
}
|
||||
if got["active@example.com"] != domain.UserStatusActive {
|
||||
t.Fatalf("active status = %q, want %q", got["active@example.com"], domain.UserStatusActive)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSanitizesLegacyUserMetadata(t *testing.T) {
|
||||
db := openBootstrapPostgresTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
||||
t.Fatalf("failed to migrate users table: %v", err)
|
||||
}
|
||||
|
||||
user := domain.User{
|
||||
ID: "20000000-0000-0000-0000-000000000001",
|
||||
Email: "run-legacy@example.com",
|
||||
Name: "Run Legacy User",
|
||||
Role: domain.RoleUser,
|
||||
Status: domain.UserStatusActive,
|
||||
Metadata: domain.JSONMap{
|
||||
"hanmacFamily": true,
|
||||
"userType": "external",
|
||||
"employeeId": "E002",
|
||||
},
|
||||
}
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
t.Fatalf("failed to create user: %v", err)
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "seed-tenant.csv")
|
||||
csv := "id,name,type,parent_tenant_slug,slug,memo,email_domain\n" +
|
||||
"30000000-0000-0000-0000-000000000001,Seed Root,COMPANY_GROUP,,seed-root,seed root,\n"
|
||||
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
|
||||
t.Fatalf("failed to write seed csv: %v", err)
|
||||
}
|
||||
t.Setenv(seedTenantCSVPathEnv, path)
|
||||
|
||||
if err := Run(db); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
|
||||
var got domain.User
|
||||
if err := db.First(&got, "id = ?", user.ID).Error; err != nil {
|
||||
t.Fatalf("failed to load sanitized user: %v", err)
|
||||
}
|
||||
if _, ok := got.Metadata["hanmacFamily"]; ok {
|
||||
t.Fatalf("Run must remove hanmacFamily from metadata: %#v", got.Metadata)
|
||||
}
|
||||
if _, ok := got.Metadata["userType"]; ok {
|
||||
t.Fatalf("Run must remove userType from metadata: %#v", got.Metadata)
|
||||
}
|
||||
if got.Metadata["employeeId"] != "E002" {
|
||||
t.Fatalf("employeeId = %#v, want E002", got.Metadata["employeeId"])
|
||||
}
|
||||
}
|
||||
|
||||
func openBootstrapPostgresTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
if !testsupport.DockerAvailable() {
|
||||
t.Skip("Docker provider is unavailable in this environment")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
postgresContainer, err := postgres_module.Run(ctx,
|
||||
"postgres:16-alpine",
|
||||
postgres_module.WithDatabase("testdb"),
|
||||
postgres_module.WithUsername("user"),
|
||||
postgres_module.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(30*time.Second),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start postgres container: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := postgresContainer.Terminate(ctx); err != nil {
|
||||
log.Printf("failed to terminate postgres container: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get postgres connection string: %v", err)
|
||||
}
|
||||
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open postgres connection: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
30
baron-sso/backend/internal/domain/api_key.go
Normal file
30
baron-sso/backend/internal/domain/api_key.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ApiKey represents an internal API key for Machine-to-Machine communication.
|
||||
type ApiKey struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
ClientID string `gorm:"uniqueIndex;not null" json:"clientId"`
|
||||
ClientSecretHash string `gorm:"not null" json:"-"`
|
||||
Scopes string `json:"scopes"` // Space or comma separated
|
||||
Status string `gorm:"default:'active'" json:"status"`
|
||||
LastUsedAt *time.Time `json:"lastUsedAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
|
||||
// BeforeCreate hook to generate UUID if not present.
|
||||
func (k *ApiKey) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if k.ID == "" {
|
||||
k.ID = uuid.NewString()
|
||||
}
|
||||
return
|
||||
}
|
||||
123
baron-sso/backend/internal/domain/auth_models.go
Normal file
123
baron-sso/backend/internal/domain/auth_models.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package domain
|
||||
|
||||
type EnchantedLinkInitRequest struct {
|
||||
LoginID string `json:"loginId"`
|
||||
URI string `json:"uri,omitempty"` // Redirect URI (optional for polling flow)
|
||||
Method string `json:"method,omitempty"` // "email" or "sms"
|
||||
CodeOnly bool `json:"codeOnly,omitempty"`
|
||||
DryRun bool `json:"dryRun,omitempty"`
|
||||
DrySend bool `json:"drySend,omitempty"`
|
||||
}
|
||||
|
||||
type EnchantedLinkInitResponse struct {
|
||||
LinkID string `json:"linkId"`
|
||||
PendingRef string `json:"pendingRef"`
|
||||
MaskedEmail string `json:"maskedEmail"`
|
||||
}
|
||||
|
||||
type EnchantedLinkPollRequest struct {
|
||||
PendingRef string `json:"pendingRef"`
|
||||
}
|
||||
|
||||
type EnchantedLinkPollResponse struct {
|
||||
SessionToken string `json:"sessionToken"` // JWT
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
}
|
||||
|
||||
type MagicLinkVerifyRequest struct {
|
||||
Token string `json:"token"`
|
||||
VerifyOnly bool `json:"verifyOnly,omitempty"`
|
||||
}
|
||||
|
||||
type QRInitResponse struct {
|
||||
QRCode string `json:"qrCode"` // Base64 or URL
|
||||
PendingRef string `json:"pendingRef"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// Signup Flow Models
|
||||
|
||||
type CheckEmailRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type SendSignupCodeRequest struct {
|
||||
Target string `json:"target"` // Email or Phone
|
||||
Type string `json:"type"` // "email" or "phone"
|
||||
}
|
||||
|
||||
type VerifySignupCodeRequest struct {
|
||||
Target string `json:"target"` // Email or Phone
|
||||
Type string `json:"type"` // "email" or "phone"
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
type SignupRequest struct {
|
||||
Email string `json:"email"`
|
||||
LoginID string `json:"loginId,omitempty"`
|
||||
Password string `json:"password"`
|
||||
Name string `json:"name"`
|
||||
Phone string `json:"phone"`
|
||||
AffiliationType string `json:"affiliationType"` // "AFFILIATE" or "GENERAL"
|
||||
TenantSlug string `json:"tenantSlug,omitempty"`
|
||||
CompanyCode string `json:"companyCode,omitempty"`
|
||||
Department string `json:"department"`
|
||||
Metadata JSONMap `json:"metadata,omitempty"`
|
||||
TermsAccepted bool `json:"termsAccepted"`
|
||||
}
|
||||
|
||||
// User Profile Models
|
||||
|
||||
type UserProfileResponse struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
LoginID string `json:"loginId,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Phone string `json:"phone"`
|
||||
Role string `json:"role"` // 추가
|
||||
SessionAuthenticatedAt string `json:"sessionAuthenticatedAt,omitempty"`
|
||||
Department string `json:"department"`
|
||||
AffiliationType string `json:"affiliationType"`
|
||||
CompanyCode string `json:"companyCode,omitempty"`
|
||||
TenantID *string `json:"tenantId,omitempty"` // 추가
|
||||
SessionTenantID *string `json:"sessionTenantId,omitempty"` // [New] 로그인에 사용된 식별자 기반 테넌트
|
||||
RelyingPartyID *string `json:"relyingPartyId,omitempty"` // 추가
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Tenant *Tenant `json:"tenant,omitempty"`
|
||||
ManageableTenants []Tenant `json:"manageableTenants,omitempty"` // 추가: 관리 가능한 테넌트 목록
|
||||
JoinedTenants []Tenant `json:"joinedTenants,omitempty"` // [New] 다중 소속 테넌트 목록
|
||||
}
|
||||
|
||||
type UpdateUserRequest struct {
|
||||
Name string `json:"name"`
|
||||
Phone string `json:"phone"`
|
||||
Department string `json:"department"`
|
||||
VerificationCode string `json:"verificationCode,omitempty"` // For phone change
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// PasswordResetInitiateRequest is the request body for initiating a password reset.
|
||||
type PasswordResetInitiateRequest struct {
|
||||
LoginID string `json:"loginId"`
|
||||
DryRun bool `json:"dryRun,omitempty"`
|
||||
DrySend bool `json:"drySend,omitempty"`
|
||||
}
|
||||
|
||||
// PasswordResetCompleteRequest is the request body for completing a password reset.
|
||||
type PasswordResetCompleteRequest struct {
|
||||
LoginID string `json:"loginId"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
// PasswordChangeRequest는 로그인 상태에서 비밀번호 변경 요청을 표현합니다.
|
||||
type PasswordChangeRequest struct {
|
||||
CurrentPassword string `json:"currentPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
type CheckLoginIDRequest struct {
|
||||
LoginID string `json:"loginId"`
|
||||
TenantSlug string `json:"tenantSlug,omitempty"`
|
||||
CompanyCode string `json:"companyCode,omitempty"`
|
||||
}
|
||||
33
baron-sso/backend/internal/domain/client_consent.go
Normal file
33
baron-sso/backend/internal/domain/client_consent.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ClientConsent struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
ClientID string `gorm:"index;uniqueIndex:idx_client_subject;not null" json:"clientId"`
|
||||
Subject string `gorm:"index;uniqueIndex:idx_client_subject;not null" json:"subject"` // User UUID
|
||||
GrantedScopes pq.StringArray `gorm:"type:text[];not null" json:"grantedScopes"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
|
||||
// ClientConsentWithTenantInfo is a struct to hold joined data for API responses
|
||||
type ClientConsentWithTenantInfo struct {
|
||||
ClientConsent
|
||||
TenantID string `gorm:"column:tenant_id" json:"tenantId"`
|
||||
TenantName string `gorm:"column:tenant_name" json:"tenantName"`
|
||||
}
|
||||
|
||||
func (c *ClientConsent) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if c.ID == "" {
|
||||
c.ID = uuid.New().String()
|
||||
}
|
||||
return
|
||||
}
|
||||
21
baron-sso/backend/internal/domain/client_secret.go
Normal file
21
baron-sso/backend/internal/domain/client_secret.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientSecret represents the stored client secret for OIDC clients.
|
||||
// Since Hydra only returns the secret once during creation, we store it here.
|
||||
type ClientSecret struct {
|
||||
ClientID string `gorm:"primaryKey;column:client_id"`
|
||||
ClientSecret string `gorm:"column:client_secret;not null"`
|
||||
CreatedAt time.Time `gorm:"column:created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at"`
|
||||
}
|
||||
|
||||
type ClientSecretRepository interface {
|
||||
Upsert(ctx context.Context, clientID, secret string) error
|
||||
GetByID(ctx context.Context, clientID string) (string, error)
|
||||
Delete(ctx context.Context, clientID string) error
|
||||
}
|
||||
60
baron-sso/backend/internal/domain/data_integrity.go
Normal file
60
baron-sso/backend/internal/domain/data_integrity.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
type DataIntegrityStatus string
|
||||
|
||||
const (
|
||||
DataIntegrityStatusPass DataIntegrityStatus = "pass"
|
||||
DataIntegrityStatusWarning DataIntegrityStatus = "warning"
|
||||
DataIntegrityStatusFail DataIntegrityStatus = "fail"
|
||||
)
|
||||
|
||||
type DataIntegrityReport struct {
|
||||
Status DataIntegrityStatus `json:"status"`
|
||||
CheckedAt time.Time `json:"checkedAt"`
|
||||
Summary DataIntegritySummary `json:"summary"`
|
||||
Sections []DataIntegritySection `json:"sections"`
|
||||
}
|
||||
|
||||
type DataIntegritySummary struct {
|
||||
TotalChecks int `json:"totalChecks"`
|
||||
Passed int `json:"passed"`
|
||||
Warnings int `json:"warnings"`
|
||||
Failures int64 `json:"failures"`
|
||||
}
|
||||
|
||||
type DataIntegritySection struct {
|
||||
Key string `json:"key"`
|
||||
Label string `json:"label"`
|
||||
Status DataIntegrityStatus `json:"status"`
|
||||
Checks []DataIntegrityCheck `json:"checks"`
|
||||
}
|
||||
|
||||
type DataIntegrityCheck struct {
|
||||
Key string `json:"key"`
|
||||
Label string `json:"label"`
|
||||
Description string `json:"description"`
|
||||
Status DataIntegrityStatus `json:"status"`
|
||||
Severity string `json:"severity"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type OrphanUserLoginID struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
UserEmail string `json:"userEmail,omitempty"`
|
||||
UserDeletedAt *time.Time `json:"userDeletedAt,omitempty"`
|
||||
TenantID string `json:"tenantId"`
|
||||
TenantSlug string `json:"tenantSlug,omitempty"`
|
||||
TenantDeletedAt *time.Time `json:"tenantDeletedAt,omitempty"`
|
||||
FieldKey string `json:"fieldKey"`
|
||||
LoginID string `json:"loginId"`
|
||||
Reasons []string `json:"reasons"`
|
||||
}
|
||||
|
||||
type DeleteOrphanUserLoginIDsResult struct {
|
||||
DeletedCount int64 `json:"deletedCount"`
|
||||
Deleted []OrphanUserLoginID `json:"deleted"`
|
||||
SkippedIDs []string `json:"skippedIds"`
|
||||
}
|
||||
29
baron-sso/backend/internal/domain/developer_request.go
Normal file
29
baron-sso/backend/internal/domain/developer_request.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DeveloperRequestStatusPending = "pending"
|
||||
DeveloperRequestStatusApproved = "approved"
|
||||
DeveloperRequestStatusRejected = "rejected"
|
||||
DeveloperRequestStatusCancelled = "cancelled"
|
||||
)
|
||||
|
||||
// DeveloperRequest represents a user's application to become a developer.
|
||||
type DeveloperRequest struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID string `gorm:"index;not null" json:"userId"` // Kratos User ID
|
||||
TenantID string `gorm:"index;not null" json:"tenantId"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Organization string `json:"organization"`
|
||||
Email string `json:"email"`
|
||||
Phone string `json:"phone"`
|
||||
Role string `json:"role"`
|
||||
Reason string `json:"reason"`
|
||||
Status string `gorm:"default:'pending';not null" json:"status"` // pending, approved, rejected, cancelled
|
||||
AdminNotes string `json:"adminNotes"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
6
baron-sso/backend/internal/domain/email_models.go
Normal file
6
baron-sso/backend/internal/domain/email_models.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package domain
|
||||
|
||||
// EmailService defines the interface for sending emails.
|
||||
type EmailService interface {
|
||||
SendEmail(to, subject, body string) error
|
||||
}
|
||||
50
baron-sso/backend/internal/domain/federation_models.go
Normal file
50
baron-sso/backend/internal/domain/federation_models.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ProviderType defines the type of the identity provider.
|
||||
type ProviderType string
|
||||
|
||||
const (
|
||||
ProviderTypeOIDC ProviderType = "oidc"
|
||||
ProviderTypeSAML ProviderType = "saml"
|
||||
)
|
||||
|
||||
// IdentityProviderConfig stores the configuration for an external Identity Provider.
|
||||
type IdentityProviderConfig struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
ClientID string `gorm:"type:uuid;not null;index" json:"client_id"` // Replaces TenantID
|
||||
ProviderType ProviderType `gorm:"type:varchar(10);not null" json:"provider_type"`
|
||||
DisplayName string `gorm:"not null" json:"display_name"`
|
||||
Status string `gorm:"default:'active'" json:"status"`
|
||||
|
||||
// OIDC Specific Fields
|
||||
IssuerURL *string `gorm:"null" json:"issuer_url,omitempty"`
|
||||
OIDCClientID *string `gorm:"null" json:"oidc_client_id,omitempty"` // Renamed from ClientID
|
||||
OIDCClientSecret *string `gorm:"null" json:"oidc_client_secret,omitempty"` // Renamed from ClientSecret
|
||||
// Scopes are space-separated
|
||||
Scopes *string `gorm:"null" json:"scopes,omitempty"`
|
||||
|
||||
// SAML Specific Fields
|
||||
MetadataURL *string `gorm:"null" json:"metadata_url,omitempty"`
|
||||
MetadataXML *string `gorm:"type:text;null" json:"metadata_xml,omitempty"`
|
||||
EntityID *string `gorm:"null" json:"entity_id,omitempty"`
|
||||
AcsURL *string `gorm:"null" json:"acs_url,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
|
||||
// BeforeCreate hook to generate UUID if not present.
|
||||
func (idc *IdentityProviderConfig) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if idc.ID == "" {
|
||||
idc.ID = uuid.NewString()
|
||||
}
|
||||
return
|
||||
}
|
||||
196
baron-sso/backend/internal/domain/hanmac_email.go
Normal file
196
baron-sso/backend/internal/domain/hanmac_email.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var hanmacSurnameRomanization = map[rune]string{
|
||||
'한': "han",
|
||||
'김': "kim",
|
||||
'이': "lee",
|
||||
'박': "park",
|
||||
'최': "choi",
|
||||
'정': "jung",
|
||||
'조': "cho",
|
||||
'강': "kang",
|
||||
'윤': "yoon",
|
||||
'장': "jang",
|
||||
'임': "lim",
|
||||
'림': "lim",
|
||||
'신': "shin",
|
||||
'오': "oh",
|
||||
'서': "seo",
|
||||
'권': "kwon",
|
||||
'황': "hwang",
|
||||
'안': "ahn",
|
||||
'송': "song",
|
||||
'전': "jeon",
|
||||
'홍': "hong",
|
||||
'유': "yoo",
|
||||
'고': "ko",
|
||||
'문': "moon",
|
||||
'양': "yang",
|
||||
'손': "son",
|
||||
'배': "bae",
|
||||
'백': "baek",
|
||||
'허': "heo",
|
||||
'남': "nam",
|
||||
'심': "sim",
|
||||
'노': "noh",
|
||||
'하': "ha",
|
||||
'곽': "kwak",
|
||||
'성': "sung",
|
||||
'차': "cha",
|
||||
'주': "joo",
|
||||
'우': "woo",
|
||||
'구': "koo",
|
||||
'민': "min",
|
||||
'류': "ryu",
|
||||
'나': "na",
|
||||
'진': "jin",
|
||||
'지': "ji",
|
||||
'엄': "um",
|
||||
'채': "chae",
|
||||
'원': "won",
|
||||
'천': "cheon",
|
||||
'방': "bang",
|
||||
'공': "gong",
|
||||
'현': "hyun",
|
||||
'함': "ham",
|
||||
'여': "yeo",
|
||||
'추': "choo",
|
||||
'도': "do",
|
||||
'소': "so",
|
||||
'석': "seok",
|
||||
'선': "sun",
|
||||
'설': "seol",
|
||||
'마': "ma",
|
||||
'길': "gil",
|
||||
'연': "yeon",
|
||||
'위': "wi",
|
||||
'표': "pyo",
|
||||
'명': "myung",
|
||||
'기': "ki",
|
||||
'반': "ban",
|
||||
'라': "ra",
|
||||
'왕': "wang",
|
||||
'금': "geum",
|
||||
'옥': "ok",
|
||||
'육': "yook",
|
||||
'인': "in",
|
||||
'맹': "maeng",
|
||||
'제': "je",
|
||||
'모': "mo",
|
||||
'탁': "tak",
|
||||
'국': "guk",
|
||||
'어': "eo",
|
||||
'은': "eun",
|
||||
'편': "pyeon",
|
||||
'용': "yong",
|
||||
}
|
||||
|
||||
var hanmacInitialRomanization = []string{
|
||||
"g", "g", "n", "d", "d", "r", "m", "b", "b", "s",
|
||||
"s", "y", "j", "j", "c", "k", "t", "p", "h",
|
||||
}
|
||||
|
||||
func SplitEmailDomain(email string) (string, string, error) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
before, after, ok := strings.Cut(normalized, "@")
|
||||
if !ok {
|
||||
return "", "", errors.New("email must contain @")
|
||||
}
|
||||
if strings.Count(normalized, "@") != 1 {
|
||||
return "", "", errors.New("email must contain one @")
|
||||
}
|
||||
localPart := strings.TrimSpace(before)
|
||||
domainPart := strings.TrimSpace(after)
|
||||
if domainPart == "" || !strings.Contains(domainPart, ".") {
|
||||
return "", "", errors.New("email domain is invalid")
|
||||
}
|
||||
return localPart, domainPart, nil
|
||||
}
|
||||
|
||||
func ExtractNormalizedEmailLocalPart(email string) (string, error) {
|
||||
localPart, _, err := SplitEmailDomain(email)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if localPart == "" {
|
||||
return "", errors.New("email local-part is empty")
|
||||
}
|
||||
return localPart, nil
|
||||
}
|
||||
|
||||
func BuildKoreanNameEmailBase(name string) (string, bool, error) {
|
||||
runes := compactNameRunes(name)
|
||||
if len(runes) < 2 {
|
||||
return "", true, nil
|
||||
}
|
||||
|
||||
surname, ok := hanmacSurnameRomanization[runes[0]]
|
||||
if !ok {
|
||||
return "", true, nil
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for _, r := range runes[1:] {
|
||||
initial, ok := romanizedHangulInitial(r)
|
||||
if !ok {
|
||||
return "", true, nil
|
||||
}
|
||||
builder.WriteString(initial)
|
||||
}
|
||||
builder.WriteString(surname)
|
||||
return builder.String(), false, nil
|
||||
}
|
||||
|
||||
func MatchesSuggestedNameRule(localPart string, base string) bool {
|
||||
localPart = strings.ToLower(strings.TrimSpace(localPart))
|
||||
base = strings.ToLower(strings.TrimSpace(base))
|
||||
if localPart == "" || base == "" {
|
||||
return false
|
||||
}
|
||||
if localPart == base {
|
||||
return true
|
||||
}
|
||||
if !strings.HasPrefix(localPart, base) {
|
||||
return false
|
||||
}
|
||||
suffix := localPart[len(base):]
|
||||
if suffix == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range suffix {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func compactNameRunes(name string) []rune {
|
||||
var runes []rune
|
||||
for _, r := range strings.TrimSpace(name) {
|
||||
if unicode.IsSpace(r) {
|
||||
continue
|
||||
}
|
||||
runes = append(runes, r)
|
||||
}
|
||||
return runes
|
||||
}
|
||||
|
||||
func romanizedHangulInitial(r rune) (string, bool) {
|
||||
const hangulBase = 0xAC00
|
||||
const hangulEnd = 0xD7A3
|
||||
if r < hangulBase || r > hangulEnd {
|
||||
return "", false
|
||||
}
|
||||
index := int(r-hangulBase) / 588
|
||||
if index < 0 || index >= len(hanmacInitialRomanization) {
|
||||
return "", false
|
||||
}
|
||||
return hanmacInitialRomanization[index], true
|
||||
}
|
||||
76
baron-sso/backend/internal/domain/hanmac_email_test.go
Normal file
76
baron-sso/backend/internal/domain/hanmac_email_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package domain
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSplitEmailDomainAllowsDomainOnlyImportInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
wantLocal string
|
||||
wantDomain string
|
||||
}{
|
||||
{name: "full address", email: " Han@SamanEng.com ", wantLocal: "han", wantDomain: "samaneng.com"},
|
||||
{name: "domain only", email: "@hanmaceng.co.kr", wantLocal: "", wantDomain: "hanmaceng.co.kr"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
local, domain, err := SplitEmailDomain(tt.email)
|
||||
if err != nil {
|
||||
t.Fatalf("SplitEmailDomain() error = %v", err)
|
||||
}
|
||||
if local != tt.wantLocal || domain != tt.wantDomain {
|
||||
t.Fatalf("SplitEmailDomain() = (%q, %q), want (%q, %q)", local, domain, tt.wantLocal, tt.wantDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKoreanNameEmailBase(t *testing.T) {
|
||||
base, needsReview, err := BuildKoreanNameEmailBase("한치영")
|
||||
if err != nil {
|
||||
t.Fatalf("BuildKoreanNameEmailBase() error = %v", err)
|
||||
}
|
||||
if needsReview {
|
||||
t.Fatalf("BuildKoreanNameEmailBase() needsReview = true")
|
||||
}
|
||||
if base != "cyhan" {
|
||||
t.Fatalf("BuildKoreanNameEmailBase() = %q, want %q", base, "cyhan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKoreanNameEmailBaseNeedsReviewForUnknownName(t *testing.T) {
|
||||
base, needsReview, err := BuildKoreanNameEmailBase("A치영")
|
||||
if err != nil {
|
||||
t.Fatalf("BuildKoreanNameEmailBase() error = %v", err)
|
||||
}
|
||||
if base != "" {
|
||||
t.Fatalf("BuildKoreanNameEmailBase() base = %q, want empty", base)
|
||||
}
|
||||
if !needsReview {
|
||||
t.Fatalf("BuildKoreanNameEmailBase() needsReview = false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesSuggestedNameRule(t *testing.T) {
|
||||
tests := []struct {
|
||||
localPart string
|
||||
base string
|
||||
want bool
|
||||
}{
|
||||
{localPart: "cyhan", base: "cyhan", want: true},
|
||||
{localPart: "cyhan1", base: "cyhan", want: true},
|
||||
{localPart: "cyhan20", base: "cyhan", want: true},
|
||||
{localPart: "hcy", base: "cyhan", want: false},
|
||||
{localPart: "han.cy", base: "cyhan", want: false},
|
||||
{localPart: "cyhan-a", base: "cyhan", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.localPart, func(t *testing.T) {
|
||||
if got := MatchesSuggestedNameRule(tt.localPart, tt.base); got != tt.want {
|
||||
t.Fatalf("MatchesSuggestedNameRule(%q, %q) = %v, want %v", tt.localPart, tt.base, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
30
baron-sso/backend/internal/domain/headless_jwks_cache.go
Normal file
30
baron-sso/backend/internal/domain/headless_jwks_cache.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
type HeadlessJWKSParsedKey struct {
|
||||
Kid string `json:"kid,omitempty"`
|
||||
Kty string `json:"kty,omitempty"`
|
||||
Use string `json:"use,omitempty"`
|
||||
Alg string `json:"alg,omitempty"`
|
||||
N string `json:"n,omitempty"`
|
||||
}
|
||||
|
||||
// HeadlessJWKSCacheState는 headless login용 JWKS 캐시 상태와 최근 동기화 결과를 나타냅니다.
|
||||
type HeadlessJWKSCacheState struct {
|
||||
ClientID string `json:"clientId"`
|
||||
JWKSURI string `json:"jwksUri"`
|
||||
CachedAt *time.Time `json:"cachedAt,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||
LastCheckedAt *time.Time `json:"lastCheckedAt,omitempty"`
|
||||
NextRetryAt *time.Time `json:"nextRetryAt,omitempty"`
|
||||
LastSuccessfulVerificationAt *time.Time `json:"lastSuccessfulVerificationAt,omitempty"`
|
||||
LastRefreshStatus string `json:"lastRefreshStatus,omitempty"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
ConsecutiveFailures int `json:"consecutiveFailures,omitempty"`
|
||||
CachedKids []string `json:"cachedKids,omitempty"`
|
||||
ParsedKeys []HeadlessJWKSParsedKey `json:"parsedKeys,omitempty"`
|
||||
ETag string `json:"etag,omitempty"`
|
||||
LastModified string `json:"lastModified,omitempty"`
|
||||
RawJWKS string `json:"-"`
|
||||
}
|
||||
145
baron-sso/backend/internal/domain/hydra_models.go
Normal file
145
baron-sso/backend/internal/domain/hydra_models.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
MetadataHeadlessLoginEnabled = "headless_login_enabled"
|
||||
MetadataHeadlessTokenEndpointAuthMethod = "headless_token_endpoint_auth_method"
|
||||
MetadataHeadlessJWKSURI = "headless_jwks_uri"
|
||||
MetadataHeadlessJWKS = "headless_jwks"
|
||||
MetadataRequestObjectSigningAlg = "request_object_signing_alg"
|
||||
MetadataIDTokenClaims = "id_token_claims"
|
||||
MetadataBackChannelLogoutURI = "backchannel_logout_uri"
|
||||
MetadataBackChannelLogoutSessionRequired = "backchannel_logout_session_required"
|
||||
MetadataAutoLoginSupported = "auto_login_supported"
|
||||
MetadataAutoLoginURL = "auto_login_url"
|
||||
)
|
||||
|
||||
type HydraClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"` // Added
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
SkipConsent *bool `json:"skip_consent,omitempty"`
|
||||
JWKSUri string `json:"jwks_uri,omitempty"`
|
||||
JWKS any `json:"jwks,omitempty"`
|
||||
BackChannelLogoutURI string `json:"backchannel_logout_uri,omitempty"`
|
||||
BackChannelLogoutSessionRequired *bool `json:"backchannel_logout_session_required,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (c *HydraClient) SupportsHeadlessLogin() bool {
|
||||
// Headless login now supports jwksUri only.
|
||||
hasPublicKey := c.HeadlessJWKSURI() != ""
|
||||
isPrivateKeyJwt := c.HeadlessTokenEndpointAuthMethod() == "private_key_jwt"
|
||||
return hasPublicKey && isPrivateKeyJwt
|
||||
}
|
||||
|
||||
func (c *HydraClient) HeadlessTokenEndpointAuthMethod() string {
|
||||
if c.Metadata != nil {
|
||||
if raw, ok := c.Metadata[MetadataHeadlessTokenEndpointAuthMethod].(string); ok {
|
||||
if value := strings.TrimSpace(raw); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(c.TokenEndpointAuthMethod)
|
||||
}
|
||||
|
||||
func (c *HydraClient) HeadlessJWKSURI() string {
|
||||
if c.Metadata != nil {
|
||||
if raw, ok := c.Metadata[MetadataHeadlessJWKSURI].(string); ok {
|
||||
if value := strings.TrimSpace(raw); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(c.JWKSUri)
|
||||
}
|
||||
|
||||
func (c *HydraClient) HeadlessJWKS() any {
|
||||
if c.Metadata != nil {
|
||||
if value, ok := c.Metadata[MetadataHeadlessJWKS]; ok && value != nil {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return c.JWKS
|
||||
}
|
||||
|
||||
func (c *HydraClient) IsHeadlessLoginEnabled() bool {
|
||||
if !c.SupportsHeadlessLogin() {
|
||||
return false
|
||||
}
|
||||
if c.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
val, ok := c.Metadata[MetadataHeadlessLoginEnabled]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if b, ok := val.(bool); ok {
|
||||
return b
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HydraClient) BackchannelLogoutURI() string {
|
||||
if c.Metadata != nil {
|
||||
if raw, ok := c.Metadata[MetadataBackChannelLogoutURI].(string); ok {
|
||||
if value := strings.TrimSpace(raw); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(c.BackChannelLogoutURI)
|
||||
}
|
||||
|
||||
func (c *HydraClient) BackchannelLogoutSessionRequiredValue() bool {
|
||||
if c.Metadata != nil {
|
||||
if raw, ok := c.Metadata[MetadataBackChannelLogoutSessionRequired].(bool); ok {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
if c.BackChannelLogoutSessionRequired != nil {
|
||||
return *c.BackChannelLogoutSessionRequired
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type HydraConsentRequest struct {
|
||||
Challenge string `json:"challenge"`
|
||||
RequestedScope []string `json:"requested_scope"`
|
||||
RequestedAudience []string `json:"requested_access_token_audience"`
|
||||
Skip bool `json:"skip"`
|
||||
Subject string `json:"subject"`
|
||||
Client HydraClient `json:"client"`
|
||||
}
|
||||
|
||||
type HydraLoginRequest struct {
|
||||
Challenge string `json:"challenge"`
|
||||
Subject string `json:"subject"`
|
||||
Skip bool `json:"skip"`
|
||||
Client HydraClient `json:"client"`
|
||||
}
|
||||
|
||||
type HydraConsentSession struct {
|
||||
ConsentRequestID string `json:"consent_request_id,omitempty"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
GrantedScope []string `json:"grant_scope,omitempty"`
|
||||
GrantedAudience []string `json:"grant_access_token_audience,omitempty"`
|
||||
Remember bool `json:"remember"`
|
||||
RememberFor int `json:"remember_for,omitempty"`
|
||||
AuthenticatedAt *time.Time `json:"authenticated_at,omitempty"`
|
||||
RequestedAt *time.Time `json:"requested_at,omitempty"`
|
||||
HandledAt *time.Time `json:"handled_at,omitempty"`
|
||||
Client HydraClient `json:"client"`
|
||||
ConsentRequest *HydraConsentRequest `json:"consent_request,omitempty"`
|
||||
}
|
||||
182
baron-sso/backend/internal/domain/hydra_models_test.go
Normal file
182
baron-sso/backend/internal/domain/hydra_models_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHydraClient_HeadlessLoginFlags(t *testing.T) {
|
||||
t.Run("metadata-backed headless login client is supported", func(t *testing.T) {
|
||||
client := HydraClient{
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
}
|
||||
|
||||
if !client.SupportsHeadlessLogin() {
|
||||
t.Fatalf("expected metadata-backed headless login client")
|
||||
}
|
||||
if !client.IsHeadlessLoginEnabled() {
|
||||
t.Fatalf("expected metadata-backed headless login enabled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("inline jwks without jwks uri does not support headless login", func(t *testing.T) {
|
||||
client := HydraClient{
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKS: map[string]any{
|
||||
"keys": []map[string]any{{
|
||||
"kty": "RSA",
|
||||
}},
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"headless_login_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
if client.SupportsHeadlessLogin() {
|
||||
t.Fatalf("expected headless login prerequisites to be missing")
|
||||
}
|
||||
if client.IsHeadlessLoginEnabled() {
|
||||
t.Fatalf("expected headless login disabled without jwks uri")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("jwks uri without private_key_jwt does not support headless login", func(t *testing.T) {
|
||||
client := HydraClient{
|
||||
TokenEndpointAuthMethod: "none",
|
||||
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
|
||||
Metadata: map[string]any{
|
||||
"headless_login_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
if client.SupportsHeadlessLogin() {
|
||||
t.Fatalf("expected headless login prerequisites to be missing")
|
||||
}
|
||||
if client.IsHeadlessLoginEnabled() {
|
||||
t.Fatalf("expected headless login disabled when prerequisites are missing")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("headless login client without boolean metadata flag is not enabled", func(t *testing.T) {
|
||||
client := HydraClient{
|
||||
TokenEndpointAuthMethod: "private_key_jwt",
|
||||
JWKSUri: "https://rp.example.com/.well-known/jwks.json",
|
||||
Metadata: map[string]any{
|
||||
"headless_login_enabled": "true",
|
||||
},
|
||||
}
|
||||
|
||||
if !client.SupportsHeadlessLogin() {
|
||||
t.Fatalf("expected headless login client")
|
||||
}
|
||||
if client.IsHeadlessLoginEnabled() {
|
||||
t.Fatalf("expected headless login disabled for non-bool metadata")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHydraClientHeadlessMetadataAccessors(t *testing.T) {
|
||||
t.Run("metadata values override inline values", func(t *testing.T) {
|
||||
metadataJWKS := map[string]any{"keys": []any{"metadata-key"}}
|
||||
client := HydraClient{
|
||||
TokenEndpointAuthMethod: "client_secret_post",
|
||||
JWKSUri: "https://inline.example.com/jwks.json",
|
||||
JWKS: map[string]any{"keys": []any{"inline-key"}},
|
||||
Metadata: map[string]any{
|
||||
MetadataHeadlessTokenEndpointAuthMethod: " private_key_jwt ",
|
||||
MetadataHeadlessJWKSURI: " https://metadata.example.com/jwks.json ",
|
||||
MetadataHeadlessJWKS: metadataJWKS,
|
||||
},
|
||||
}
|
||||
|
||||
if got := client.HeadlessTokenEndpointAuthMethod(); got != "private_key_jwt" {
|
||||
t.Fatalf("unexpected auth method: %q", got)
|
||||
}
|
||||
if got := client.HeadlessJWKSURI(); got != "https://metadata.example.com/jwks.json" {
|
||||
t.Fatalf("unexpected jwks uri: %q", got)
|
||||
}
|
||||
if got := client.HeadlessJWKS(); !reflect.DeepEqual(got, metadataJWKS) {
|
||||
t.Fatalf("unexpected jwks value: %#v", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blank or missing metadata values fall back to inline values", func(t *testing.T) {
|
||||
inlineJWKS := map[string]any{"keys": []any{"inline-key"}}
|
||||
client := HydraClient{
|
||||
TokenEndpointAuthMethod: " private_key_jwt ",
|
||||
JWKSUri: " https://inline.example.com/jwks.json ",
|
||||
JWKS: inlineJWKS,
|
||||
Metadata: map[string]any{
|
||||
MetadataHeadlessTokenEndpointAuthMethod: " ",
|
||||
MetadataHeadlessJWKSURI: " ",
|
||||
MetadataHeadlessJWKS: nil,
|
||||
},
|
||||
}
|
||||
|
||||
if got := client.HeadlessTokenEndpointAuthMethod(); got != "private_key_jwt" {
|
||||
t.Fatalf("unexpected auth method: %q", got)
|
||||
}
|
||||
if got := client.HeadlessJWKSURI(); got != "https://inline.example.com/jwks.json" {
|
||||
t.Fatalf("unexpected jwks uri: %q", got)
|
||||
}
|
||||
if got := client.HeadlessJWKS(); !reflect.DeepEqual(got, inlineJWKS) {
|
||||
t.Fatalf("unexpected jwks value: %#v", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHydraClientBackchannelLogoutAccessors(t *testing.T) {
|
||||
t.Run("metadata values override inline values", func(t *testing.T) {
|
||||
inlineRequired := false
|
||||
client := HydraClient{
|
||||
BackChannelLogoutURI: "https://inline.example.com/logout",
|
||||
BackChannelLogoutSessionRequired: &inlineRequired,
|
||||
Metadata: map[string]any{
|
||||
MetadataBackChannelLogoutURI: " https://metadata.example.com/logout ",
|
||||
MetadataBackChannelLogoutSessionRequired: true,
|
||||
},
|
||||
}
|
||||
|
||||
if got := client.BackchannelLogoutURI(); got != "https://metadata.example.com/logout" {
|
||||
t.Fatalf("unexpected logout uri: %q", got)
|
||||
}
|
||||
if !client.BackchannelLogoutSessionRequiredValue() {
|
||||
t.Fatalf("expected metadata session_required value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blank or missing metadata values fall back to inline values", func(t *testing.T) {
|
||||
inlineRequired := true
|
||||
client := HydraClient{
|
||||
BackChannelLogoutURI: " https://inline.example.com/logout ",
|
||||
BackChannelLogoutSessionRequired: &inlineRequired,
|
||||
Metadata: map[string]any{
|
||||
MetadataBackChannelLogoutURI: " ",
|
||||
MetadataBackChannelLogoutSessionRequired: "true",
|
||||
},
|
||||
}
|
||||
|
||||
if got := client.BackchannelLogoutURI(); got != "https://inline.example.com/logout" {
|
||||
t.Fatalf("unexpected logout uri: %q", got)
|
||||
}
|
||||
if !client.BackchannelLogoutSessionRequiredValue() {
|
||||
t.Fatalf("expected inline session_required value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing session required defaults to false", func(t *testing.T) {
|
||||
client := HydraClient{}
|
||||
|
||||
if got := client.BackchannelLogoutURI(); got != "" {
|
||||
t.Fatalf("unexpected logout uri: %q", got)
|
||||
}
|
||||
if client.BackchannelLogoutSessionRequiredValue() {
|
||||
t.Fatalf("expected default session_required false")
|
||||
}
|
||||
})
|
||||
}
|
||||
19
baron-sso/backend/internal/domain/identity_cache.go
Normal file
19
baron-sso/backend/internal/domain/identity_cache.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
type IdentityCacheStatus struct {
|
||||
Status string `json:"status"`
|
||||
RedisReady bool `json:"redisReady"`
|
||||
ObservedCount int64 `json:"observedCount"`
|
||||
KeyCount int64 `json:"keyCount"`
|
||||
LastRefreshedAt *time.Time `json:"lastRefreshedAt,omitempty"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
|
||||
}
|
||||
|
||||
type IdentityCacheFlushResult struct {
|
||||
Status string `json:"status"`
|
||||
FlushedKeys int64 `json:"flushedKeys"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
92
baron-sso/backend/internal/domain/idp_models.go
Normal file
92
baron-sso/backend/internal/domain/idp_models.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrNotSupported는 IDP가 특정 인증 흐름을 지원하지 않을 때 반환합니다.
|
||||
var ErrNotSupported = errors.New("idp: not supported")
|
||||
|
||||
// BrokerUser is the standard user model used within Baron SSO business logic.
|
||||
// It defines the canonical set of fields that must be supported by any underlying IDP.
|
||||
type BrokerUser struct {
|
||||
ID string `json:"id" required:"true"`
|
||||
Email string `json:"email" required:"true"`
|
||||
LoginID string `json:"login_id"`
|
||||
CustomLoginIDs []string `json:"custom_login_ids"` // [New] 다중 로그인 ID
|
||||
Name string `json:"name"`
|
||||
PhoneNumber string `json:"phone_number"`
|
||||
// Attributes stores custom user attributes.
|
||||
// The "required_keys" tag specifies which keys MUST be present in the IDP's schema support.
|
||||
Attributes map[string]any `json:"attributes" required_keys:"grade,department"`
|
||||
}
|
||||
|
||||
// IDPMetadata represents the schema capabilities of an Identity Provider.
|
||||
type IDPMetadata struct {
|
||||
// SupportedFields lists the BrokerUser fields (json tag names) that the IDP supports.
|
||||
// For custom attributes, use the key name directly (e.g., "grade").
|
||||
SupportedFields []string
|
||||
}
|
||||
|
||||
// PasswordPolicy는 비밀번호 정책 정보를 표현합니다.
|
||||
type PasswordPolicy struct {
|
||||
MinLength int
|
||||
Lowercase bool
|
||||
Uppercase bool
|
||||
Number bool
|
||||
NonAlphanumeric bool
|
||||
MinCharacterTypes int
|
||||
}
|
||||
|
||||
// Token represents a session or refresh token.
|
||||
type Token struct {
|
||||
JWT string
|
||||
Expiration time.Time
|
||||
SessionID string
|
||||
}
|
||||
|
||||
// AuthInfo contains authentication information after a successful login.
|
||||
type AuthInfo struct {
|
||||
SessionToken *Token
|
||||
RefreshToken *Token
|
||||
// Subject는 IDP 세션이 대표하는 주체(예: Kratos identity.id)를 나타냅니다.
|
||||
Subject string
|
||||
SetCookies []*http.Cookie
|
||||
}
|
||||
|
||||
// LinkLoginInit는 링크 로그인 초기화 결과입니다.
|
||||
type LinkLoginInit struct {
|
||||
FlowID string
|
||||
ExpiresAt time.Time
|
||||
// Mode는 링크 로그인 완료 후 세션 처리 방식입니다. (예: "cookie")
|
||||
Mode string
|
||||
// LoginID는 IDP에 실제 전달된 식별자입니다.
|
||||
LoginID string
|
||||
}
|
||||
|
||||
// IdentityProvider is the interface that all IDP adapters must implement.
|
||||
type IdentityProvider interface {
|
||||
Name() string
|
||||
// GetMetadata returns the schema support information for this IDP.
|
||||
// This is used for startup-time validation.
|
||||
GetMetadata() (*IDPMetadata, error)
|
||||
// CreateUser는 BrokerUser 스키마를 기반으로 신규 사용자를 생성하고 주체 ID(예: identity.id)를 반환합니다.
|
||||
CreateUser(user *BrokerUser, password string) (string, error)
|
||||
// SignIn은 로그인 ID/비밀번호로 인증해 세션 정보를 반환합니다.
|
||||
SignIn(loginID, password string) (*AuthInfo, error)
|
||||
// UserExists는 loginID 기준으로 사용자 존재 여부를 확인합니다.
|
||||
UserExists(loginID string) (bool, error)
|
||||
// IssueSession은 비밀번호 없이 세션을 발급해야 하는 흐름에서 사용합니다.
|
||||
IssueSession(loginID string) (*AuthInfo, error)
|
||||
// InitiateLinkLogin은 링크 기반 로그인 요청을 IDP에 전달합니다.
|
||||
InitiateLinkLogin(loginID, returnTo string) (*LinkLoginInit, error)
|
||||
// VerifyLoginCode는 링크/코드 기반 로그인에서 코드를 제출해 세션을 발급합니다.
|
||||
VerifyLoginCode(loginID, flowID, code string) (*AuthInfo, error)
|
||||
// GetPasswordPolicy는 IDP가 제공하는 비밀번호 정책을 반환합니다.
|
||||
GetPasswordPolicy() (*PasswordPolicy, error)
|
||||
InitiatePasswordReset(loginID, redirectUrl string) error
|
||||
VerifyPasswordResetToken(token string) (*AuthInfo, error)
|
||||
UpdateUserPassword(loginID, newPassword string, r *http.Request) error
|
||||
}
|
||||
42
baron-sso/backend/internal/domain/json_map.go
Normal file
42
baron-sso/backend/internal/domain/json_map.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// JSONMap is a custom type for handling map[string]any with PostgreSQL JSONB
|
||||
type JSONMap map[string]any
|
||||
|
||||
// Value implements the driver.Valuer interface
|
||||
func (m JSONMap) Value() (driver.Value, error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
ba, err := json.Marshal(m)
|
||||
return string(ba), err
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface
|
||||
func (m *JSONMap) Scan(value any) error {
|
||||
if value == nil {
|
||||
*m = make(JSONMap)
|
||||
return nil
|
||||
}
|
||||
var bytes []byte
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return errors.New(fmt.Sprintf("failed to scan JSONMap: %v", value))
|
||||
}
|
||||
|
||||
result := make(JSONMap)
|
||||
err := json.Unmarshal(bytes, &result)
|
||||
*m = result
|
||||
return err
|
||||
}
|
||||
93
baron-sso/backend/internal/domain/json_map_test.go
Normal file
93
baron-sso/backend/internal/domain/json_map_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJSONMapValue(t *testing.T) {
|
||||
t.Run("nil map returns nil database value", func(t *testing.T) {
|
||||
var payload JSONMap
|
||||
|
||||
value, err := payload.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if value != nil {
|
||||
t.Fatalf("expected nil value, got %v", value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("map marshals to JSON string", func(t *testing.T) {
|
||||
payload := JSONMap{"enabled": true, "name": "baron"}
|
||||
|
||||
value, err := payload.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
raw, ok := value.(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected string value, got %T", value)
|
||||
}
|
||||
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &decoded); err != nil {
|
||||
t.Fatalf("value should be valid json: %v", err)
|
||||
}
|
||||
if decoded["enabled"] != true || decoded["name"] != "baron" {
|
||||
t.Fatalf("unexpected decoded value: %#v", decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJSONMapScan(t *testing.T) {
|
||||
t.Run("nil value becomes empty map", func(t *testing.T) {
|
||||
var payload JSONMap
|
||||
|
||||
if err := payload.Scan(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if payload == nil || len(payload) != 0 {
|
||||
t.Fatalf("expected empty map, got %#v", payload)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("byte slice value decodes JSON", func(t *testing.T) {
|
||||
var payload JSONMap
|
||||
|
||||
if err := payload.Scan([]byte(`{"count":2,"name":"baron"}`)); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if payload["count"] != float64(2) || payload["name"] != "baron" {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("string value decodes JSON", func(t *testing.T) {
|
||||
var payload JSONMap
|
||||
|
||||
if err := payload.Scan(`{"active":true}`); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if payload["active"] != true {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported value type returns error", func(t *testing.T) {
|
||||
var payload JSONMap
|
||||
|
||||
if err := payload.Scan(42); err == nil {
|
||||
t.Fatalf("expected unsupported type error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON returns error", func(t *testing.T) {
|
||||
var payload JSONMap
|
||||
|
||||
if err := payload.Scan(`{invalid`); err == nil {
|
||||
t.Fatalf("expected invalid JSON error")
|
||||
}
|
||||
})
|
||||
}
|
||||
48
baron-sso/backend/internal/domain/keto_outbox.go
Normal file
48
baron-sso/backend/internal/domain/keto_outbox.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// KetoOutbox status
|
||||
const (
|
||||
KetoOutboxStatusPending = "pending"
|
||||
KetoOutboxStatusProcessed = "processed"
|
||||
KetoOutboxStatusFailed = "failed"
|
||||
)
|
||||
|
||||
// KetoOutbox action
|
||||
const (
|
||||
KetoOutboxActionCreate = "CREATE"
|
||||
KetoOutboxActionDelete = "DELETE"
|
||||
)
|
||||
|
||||
// KetoOutbox represents a Keto relationship tuple update event.
|
||||
type KetoOutbox struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
Namespace string `gorm:"not null" json:"namespace"`
|
||||
Object string `gorm:"not null" json:"object"`
|
||||
Relation string `gorm:"not null" json:"relation"`
|
||||
Subject string `gorm:"not null" json:"subject"` // format: "User:ID" or "Tenant:ID#members"
|
||||
Action string `gorm:"not null" json:"action"` // CREATE, DELETE
|
||||
Status string `gorm:"default:'pending';index" json:"status"`
|
||||
RetryCount int `gorm:"default:0" json:"retryCount"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
ProcessedAt *time.Time `json:"processedAt,omitempty"`
|
||||
}
|
||||
|
||||
func (ko *KetoOutbox) TableName() string {
|
||||
return "keto_outbox"
|
||||
}
|
||||
|
||||
func (ko *KetoOutbox) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if ko.ID == "" {
|
||||
ko.ID = uuid.NewString()
|
||||
}
|
||||
return
|
||||
}
|
||||
357
baron-sso/backend/internal/domain/model_hooks_test.go
Normal file
357
baron-sso/backend/internal/domain/model_hooks_test.go
Normal file
@@ -0,0 +1,357 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func requireGeneratedUUID(t *testing.T, value string) {
|
||||
t.Helper()
|
||||
|
||||
if value == "" {
|
||||
t.Fatalf("expected generated uuid")
|
||||
}
|
||||
if _, err := uuid.Parse(value); err != nil {
|
||||
t.Fatalf("expected valid uuid, got %q: %v", value, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeforeCreateGeneratesMissingIDs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "api key",
|
||||
run: func(t *testing.T) {
|
||||
model := ApiKey{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "client consent",
|
||||
run: func(t *testing.T) {
|
||||
model := ClientConsent{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "identity provider config",
|
||||
run: func(t *testing.T) {
|
||||
model := IdentityProviderConfig{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "keto outbox",
|
||||
run: func(t *testing.T) {
|
||||
model := KetoOutbox{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tenant",
|
||||
run: func(t *testing.T) {
|
||||
model := Tenant{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tenant domain",
|
||||
run: func(t *testing.T) {
|
||||
model := TenantDomain{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user",
|
||||
run: func(t *testing.T) {
|
||||
model := User{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user group",
|
||||
run: func(t *testing.T) {
|
||||
model := UserGroup{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "worksmobile resource mapping",
|
||||
run: func(t *testing.T) {
|
||||
model := WorksmobileResourceMapping{}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, model.ID)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, tc.run)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeforeCreatePreservesExistingIDs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "api key",
|
||||
run: func(t *testing.T) {
|
||||
model := ApiKey{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "client consent",
|
||||
run: func(t *testing.T) {
|
||||
model := ClientConsent{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "identity provider config",
|
||||
run: func(t *testing.T) {
|
||||
model := IdentityProviderConfig{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "keto outbox",
|
||||
run: func(t *testing.T) {
|
||||
model := KetoOutbox{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tenant",
|
||||
run: func(t *testing.T) {
|
||||
model := Tenant{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tenant domain",
|
||||
run: func(t *testing.T) {
|
||||
model := TenantDomain{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user",
|
||||
run: func(t *testing.T) {
|
||||
model := User{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user group",
|
||||
run: func(t *testing.T) {
|
||||
model := UserGroup{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "worksmobile resource mapping",
|
||||
run: func(t *testing.T) {
|
||||
model := WorksmobileResourceMapping{ID: "existing-id"}
|
||||
if err := model.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if model.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, tc.run)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableNames(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
got string
|
||||
expected string
|
||||
}{
|
||||
{name: "keto outbox", got: (&KetoOutbox{}).TableName(), expected: "keto_outbox"},
|
||||
{name: "rp usage event", got: (&RPUsageEvent{}).TableName(), expected: "rp_usage_outbox"},
|
||||
{name: "rp user metadata", got: (RPUserMetadata{}).TableName(), expected: "rp_user_metadata"},
|
||||
{name: "user group", got: (&UserGroup{}).TableName(), expected: "user_groups"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.got != tc.expected {
|
||||
t.Fatalf("unexpected table name: got=%s expected=%s", tc.got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantIsActive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
expected bool
|
||||
}{
|
||||
{name: "active", status: TenantStatusActive, expected: true},
|
||||
{name: "pending", status: TenantStatusPending, expected: false},
|
||||
{name: "suspended", status: TenantStatusSuspended, expected: false},
|
||||
{name: "deleted", status: TenantStatusDeleted, expected: false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tenant := Tenant{Status: tc.status}
|
||||
if got := tenant.IsActive(); got != tc.expected {
|
||||
t.Fatalf("unexpected active state: got=%v expected=%v", got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPUsageEventBeforeCreateDefaults(t *testing.T) {
|
||||
event := RPUsageEvent{}
|
||||
|
||||
if err := event.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, event.ID)
|
||||
if event.Status != RPUsageOutboxStatusPending {
|
||||
t.Fatalf("unexpected status: %s", event.Status)
|
||||
}
|
||||
if event.OccurredAt.IsZero() {
|
||||
t.Fatalf("expected occurred_at default")
|
||||
}
|
||||
if event.Payload == nil {
|
||||
t.Fatalf("expected empty payload default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPUsageEventBeforeCreatePreservesExplicitValues(t *testing.T) {
|
||||
occurredAt := time.Date(2026, 5, 29, 1, 2, 3, 0, time.UTC)
|
||||
event := RPUsageEvent{
|
||||
ID: "existing-id",
|
||||
Status: RPUsageOutboxStatusProcessing,
|
||||
OccurredAt: occurredAt,
|
||||
Payload: JSONMap{"source": "test"},
|
||||
}
|
||||
|
||||
if err := event.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if event.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
if event.Status != RPUsageOutboxStatusProcessing {
|
||||
t.Fatalf("expected status to be preserved")
|
||||
}
|
||||
if !event.OccurredAt.Equal(occurredAt) {
|
||||
t.Fatalf("expected occurred_at to be preserved")
|
||||
}
|
||||
if event.Payload["source"] != "test" {
|
||||
t.Fatalf("expected payload to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorksmobileOutboxBeforeCreateDefaults(t *testing.T) {
|
||||
outbox := WorksmobileOutbox{}
|
||||
|
||||
if err := outbox.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
requireGeneratedUUID(t, outbox.ID)
|
||||
if outbox.Status != WorksmobileOutboxStatusPending {
|
||||
t.Fatalf("unexpected status: %s", outbox.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorksmobileOutboxBeforeCreatePreservesExplicitValues(t *testing.T) {
|
||||
outbox := WorksmobileOutbox{
|
||||
ID: "existing-id",
|
||||
Status: WorksmobileOutboxStatusProcessing,
|
||||
}
|
||||
|
||||
if err := outbox.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if outbox.ID != "existing-id" {
|
||||
t.Fatalf("expected existing id to be preserved")
|
||||
}
|
||||
if outbox.Status != WorksmobileOutboxStatusProcessing {
|
||||
t.Fatalf("expected status to be preserved")
|
||||
}
|
||||
}
|
||||
48
baron-sso/backend/internal/domain/models.go
Normal file
48
baron-sso/backend/internal/domain/models.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditLog represents a single audit event
|
||||
type AuditLog struct {
|
||||
EventID string `json:"event_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
UserID string `json:"user_id"`
|
||||
TenantID string `json:"tenant_id,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
EventType string `json:"event_type"` // e.g., "login_success", "login_failed", "otp_sent"
|
||||
Status string `json:"status"` // e.g., "success", "failure"
|
||||
AuthMethod string `json:"auth_method,omitempty"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
Details string `json:"details,omitempty"` // JSON string or simple text
|
||||
}
|
||||
|
||||
// AuditRepository defines interface for storing logs
|
||||
type AuditRepository interface {
|
||||
Create(log *AuditLog) error
|
||||
FindPage(ctx context.Context, limit int, cursor *AuditCursor, tenantID string) ([]AuditLog, error)
|
||||
FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]AuditLog, error)
|
||||
CountEventsSince(ctx context.Context, since time.Time) (int64, error)
|
||||
CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
|
||||
CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error)
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
type AuditCursor struct {
|
||||
Timestamp time.Time
|
||||
EventID string
|
||||
}
|
||||
|
||||
// RedisRepository defines interface for KV storage (Redis)
|
||||
type RedisRepository interface {
|
||||
Set(key string, value string, expiration time.Duration) error
|
||||
Get(key string) (string, error)
|
||||
Delete(key string) error
|
||||
StoreVerificationCode(phone, code string) error
|
||||
GetVerificationCode(phone string) (string, error)
|
||||
DeleteVerificationCode(phone string) error
|
||||
}
|
||||
31
baron-sso/backend/internal/domain/oathkeeper_models.go
Normal file
31
baron-sso/backend/internal/domain/oathkeeper_models.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OathkeeperAccessLog struct {
|
||||
Timestamp time.Time
|
||||
RequestID string
|
||||
Method string
|
||||
Path string
|
||||
Status int
|
||||
LatencyMs int
|
||||
ClientID string
|
||||
RP string
|
||||
Action string
|
||||
Target string
|
||||
Subject string
|
||||
ClientIP string
|
||||
UserAgent string
|
||||
Decision string
|
||||
TraceID string
|
||||
SpanID string
|
||||
Raw string
|
||||
}
|
||||
|
||||
type OathkeeperLogRepository interface {
|
||||
FindPageBySubject(ctx context.Context, subject string, limit int, cursor *AuditCursor) ([]OathkeeperAccessLog, error)
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
19
baron-sso/backend/internal/domain/relying_party.go
Normal file
19
baron-sso/backend/internal/domain/relying_party.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// RelyingParty represents an OAuth2 Client owner by a Tenant.
|
||||
// It maps 1:1 to a Hydra Client.
|
||||
type RelyingParty struct {
|
||||
ClientID string `json:"clientId"` // Maps to Hydra Client ID
|
||||
TenantID string `json:"tenantId"`
|
||||
Name string `json:"name"` // Display name (can be same as Hydra Client Name)
|
||||
Description string `json:"description"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
// DeletedAt removed as it's not a DB model anymore
|
||||
}
|
||||
|
||||
// TableName removed
|
||||
101
baron-sso/backend/internal/domain/rp_usage_event.go
Normal file
101
baron-sso/backend/internal/domain/rp_usage_event.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
RPUsageOutboxStatusPending = "pending"
|
||||
RPUsageOutboxStatusProcessing = "processing"
|
||||
RPUsageOutboxStatusProcessed = "processed"
|
||||
RPUsageOutboxStatusFailed = "failed"
|
||||
)
|
||||
|
||||
const (
|
||||
RPUsageEventTypeAuthorizationGranted = "rp_usage.authorization_granted"
|
||||
RPUsageEventTypeAuthorizationRevoked = "rp_usage.authorization_revoked"
|
||||
)
|
||||
|
||||
const (
|
||||
RPUsageTenantTypeCompany = TenantTypeCompany
|
||||
RPUsageTenantTypeOrganization = TenantTypeOrganization
|
||||
)
|
||||
|
||||
type RPUsageEvent struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
EventType string `gorm:"not null;index:idx_rp_usage_outbox_event" json:"eventType"`
|
||||
Subject string `gorm:"not null;index:idx_rp_usage_outbox_subject" json:"subject"`
|
||||
TenantID string `gorm:"index:idx_rp_usage_outbox_tenant" json:"tenantId,omitempty"`
|
||||
TenantType string `gorm:"index:idx_rp_usage_outbox_tenant" json:"tenantType,omitempty"`
|
||||
ClientID string `gorm:"not null;index:idx_rp_usage_outbox_client" json:"clientId"`
|
||||
ClientName string `json:"clientName,omitempty"`
|
||||
SessionID string `gorm:"index" json:"sessionId,omitempty"`
|
||||
Scopes pq.StringArray `gorm:"type:text[]" json:"scopes,omitempty"`
|
||||
Source string `gorm:"not null;index" json:"source"`
|
||||
CorrelationID string `gorm:"index" json:"correlationId,omitempty"`
|
||||
Payload JSONMap `gorm:"type:jsonb" json:"payload,omitempty"`
|
||||
DedupeKey string `gorm:"uniqueIndex" json:"dedupeKey"`
|
||||
Status string `gorm:"default:'pending';index" json:"status"`
|
||||
RetryCount int `gorm:"default:0" json:"retryCount"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
NextAttemptAt *time.Time `json:"nextAttemptAt,omitempty"`
|
||||
OccurredAt time.Time `gorm:"not null;index" json:"occurredAt"`
|
||||
ProcessedAt *time.Time `json:"processedAt,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (e *RPUsageEvent) TableName() string {
|
||||
return "rp_usage_outbox"
|
||||
}
|
||||
|
||||
func (e *RPUsageEvent) BeforeCreate(tx *gorm.DB) error {
|
||||
if e.ID == "" {
|
||||
e.ID = uuid.NewString()
|
||||
}
|
||||
if e.Status == "" {
|
||||
e.Status = RPUsageOutboxStatusPending
|
||||
}
|
||||
if e.OccurredAt.IsZero() {
|
||||
e.OccurredAt = time.Now()
|
||||
}
|
||||
if e.Payload == nil {
|
||||
e.Payload = JSONMap{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RPUsageEventSink interface {
|
||||
EmitRPUsageEvent(ctx context.Context, event RPUsageEvent) error
|
||||
}
|
||||
|
||||
type RPUsageProjectionRepository interface {
|
||||
CreateRPUsageEvent(ctx context.Context, event RPUsageEvent) error
|
||||
}
|
||||
|
||||
type RPUsageDailyMetric struct {
|
||||
Date string `json:"date"`
|
||||
TenantID string `json:"tenantId"`
|
||||
TenantType string `json:"tenantType"`
|
||||
TenantName string `json:"tenantName,omitempty"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientName string `json:"clientName"`
|
||||
LoginRequests uint64 `json:"loginRequests"`
|
||||
OtherRequests uint64 `json:"otherRequests"`
|
||||
UniqueSubjects uint64 `json:"uniqueSubjects"`
|
||||
}
|
||||
|
||||
type RPUsageQuery struct {
|
||||
Days int
|
||||
Period string
|
||||
TenantID string
|
||||
}
|
||||
|
||||
type RPUsageQueryRepository interface {
|
||||
FindRPUsage(ctx context.Context, query RPUsageQuery) ([]RPUsageDailyMetric, error)
|
||||
}
|
||||
16
baron-sso/backend/internal/domain/rp_user_metadata.go
Normal file
16
baron-sso/backend/internal/domain/rp_user_metadata.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
type RPUserMetadata struct {
|
||||
ClientID string `gorm:"column:client_id;primaryKey" json:"clientId"`
|
||||
UserID string `gorm:"column:user_id;type:uuid;primaryKey" json:"userId"`
|
||||
User *User `gorm:"foreignKey:UserID" json:"-"`
|
||||
Metadata JSONMap `gorm:"column:metadata;type:jsonb" json:"metadata"`
|
||||
CreatedAt time.Time `gorm:"column:created_at" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at" json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (RPUserMetadata) TableName() string {
|
||||
return "rp_user_metadata"
|
||||
}
|
||||
53
baron-sso/backend/internal/domain/shared_link.go
Normal file
53
baron-sso/backend/internal/domain/shared_link.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SharedLink struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
TenantID string `gorm:"type:uuid;not null;index" json:"tenantId"`
|
||||
Token string `gorm:"uniqueIndex;not null" json:"token"`
|
||||
Name string `gorm:"not null" json:"name"` // 링크 식별을 위한 이름 (예: "24년 상반기 채용공고용")
|
||||
Description string `json:"description"`
|
||||
AccessLevel string `gorm:"default:'READ_ONLY'" json:"accessLevel"`
|
||||
IsActive bool `gorm:"default:true" json:"isActive"`
|
||||
ExpiresAt *time.Time `json:"expiresAt"`
|
||||
Password string `json:"-"` // 필요 시 비밀번호 (선택 사항)
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// Relation
|
||||
Tenant Tenant `gorm:"foreignKey:TenantID" json:"-"`
|
||||
}
|
||||
|
||||
func (s *SharedLink) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if s.ID == "" {
|
||||
s.ID = uuid.NewString()
|
||||
}
|
||||
if s.Token == "" {
|
||||
// 32바이트(64자)의 강력한 난수 토큰 생성
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Token = hex.EncodeToString(b)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *SharedLink) IsValid() bool {
|
||||
if !s.IsActive {
|
||||
return false
|
||||
}
|
||||
if s.ExpiresAt != nil && s.ExpiresAt.Before(time.Now()) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
80
baron-sso/backend/internal/domain/shared_link_test.go
Normal file
80
baron-sso/backend/internal/domain/shared_link_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSharedLinkBeforeCreate(t *testing.T) {
|
||||
t.Run("generates id and token when missing", func(t *testing.T) {
|
||||
link := SharedLink{}
|
||||
|
||||
if err := link.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if link.ID == "" {
|
||||
t.Fatalf("expected generated id")
|
||||
}
|
||||
if len(link.Token) != 64 {
|
||||
t.Fatalf("expected 64-character token, got %q", link.Token)
|
||||
}
|
||||
if _, err := hex.DecodeString(link.Token); err != nil {
|
||||
t.Fatalf("expected hex token: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing id and token", func(t *testing.T) {
|
||||
link := SharedLink{
|
||||
ID: "existing-id",
|
||||
Token: "existing-token",
|
||||
}
|
||||
|
||||
if err := link.BeforeCreate(nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if link.ID != "existing-id" || link.Token != "existing-token" {
|
||||
t.Fatalf("expected existing fields to be preserved: %#v", link)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSharedLinkIsValid(t *testing.T) {
|
||||
future := time.Now().Add(time.Hour)
|
||||
past := time.Now().Add(-time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
link SharedLink
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "active link without expiration is valid",
|
||||
link: SharedLink{IsActive: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "active link with future expiration is valid",
|
||||
link: SharedLink{IsActive: true, ExpiresAt: &future},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "inactive link is invalid",
|
||||
link: SharedLink{IsActive: false, ExpiresAt: &future},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "expired link is invalid",
|
||||
link: SharedLink{IsActive: true, ExpiresAt: &past},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := tc.link.IsValid(); got != tc.expected {
|
||||
t.Fatalf("unexpected validity: got=%v expected=%v", got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
42
baron-sso/backend/internal/domain/sms_models.go
Normal file
42
baron-sso/backend/internal/domain/sms_models.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package domain
|
||||
|
||||
// SmsService defines the interface for sending SMS messages.
|
||||
type SmsService interface {
|
||||
SendSms(to, content string) error
|
||||
}
|
||||
|
||||
// NaverSmsRequest represents the request body for the Naver Cloud SMS API.
|
||||
type NaverSmsRequest struct {
|
||||
Type string `json:"type"`
|
||||
ContentType string `json:"contentType"`
|
||||
CountryCode string `json:"countryCode"`
|
||||
From string `json:"from"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
Content string `json:"content"`
|
||||
Messages []SmsMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// SmsMessage represents a single message to be sent.
|
||||
type SmsMessage struct {
|
||||
To string `json:"to"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
// NaverSmsResponse represents the response from the Naver Cloud SMS API.
|
||||
type NaverSmsResponse struct {
|
||||
RequestID string `json:"requestId"`
|
||||
RequestTime string `json:"requestTime"`
|
||||
StatusCode string `json:"statusCode"`
|
||||
StatusName string `json:"statusName"`
|
||||
}
|
||||
|
||||
// SmsRequest represents the request body for sending an SMS.
|
||||
type SmsRequest struct {
|
||||
PhoneNumber string `json:"phoneNumber"`
|
||||
}
|
||||
|
||||
// SmsVerifyRequest represents the request body for verifying an SMS code.
|
||||
type SmsVerifyRequest struct {
|
||||
PhoneNumber string `json:"phoneNumber"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
11
baron-sso/backend/internal/domain/system_setting.go
Normal file
11
baron-sso/backend/internal/domain/system_setting.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// SystemSetting stores small global configuration documents.
|
||||
type SystemSetting struct {
|
||||
Key string `gorm:"primaryKey;size:128" json:"key"`
|
||||
Value JSONMap `gorm:"type:jsonb" json:"value"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
53
baron-sso/backend/internal/domain/tenant.go
Normal file
53
baron-sso/backend/internal/domain/tenant.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Tenant statuses
|
||||
const (
|
||||
TenantStatusPending = "pending"
|
||||
TenantStatusActive = "active"
|
||||
TenantStatusSuspended = "suspended"
|
||||
TenantStatusDeleted = "deleted"
|
||||
)
|
||||
|
||||
// Tenant types
|
||||
const (
|
||||
TenantTypePersonal = "PERSONAL"
|
||||
TenantTypeCompany = "COMPANY"
|
||||
TenantTypeCompanyGroup = "COMPANY_GROUP"
|
||||
TenantTypeOrganization = "ORGANIZATION"
|
||||
TenantTypeUserGroup = "USER_GROUP"
|
||||
)
|
||||
|
||||
// Tenant represents a tenant model stored in PostgreSQL.
|
||||
type Tenant struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
Type string `gorm:"not null;default:'PERSONAL'" json:"type"` // PERSONAL, COMPANY, COMPANY_GROUP, ORGANIZATION, USER_GROUP
|
||||
ParentID *string `gorm:"type:uuid;index" json:"parentId,omitempty"` // 부모 테넌트 ID
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Slug string `gorm:"uniqueIndex;not null" json:"slug"`
|
||||
Description string `json:"description"`
|
||||
Status string `gorm:"default:'pending'" json:"status"`
|
||||
Domains []TenantDomain `gorm:"foreignKey:TenantID" json:"domains,omitempty"`
|
||||
Config JSONMap `gorm:"type:jsonb" json:"config,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
|
||||
func (t *Tenant) IsActive() bool {
|
||||
return t.Status == TenantStatusActive
|
||||
}
|
||||
|
||||
// BeforeCreate hook to generate UUID if not present.
|
||||
func (t *Tenant) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if t.ID == "" {
|
||||
t.ID = uuid.NewString()
|
||||
}
|
||||
return
|
||||
}
|
||||
27
baron-sso/backend/internal/domain/tenant_domain.go
Normal file
27
baron-sso/backend/internal/domain/tenant_domain.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TenantDomain represents a domain associated with a tenant for auto-assignment.
|
||||
type TenantDomain struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
TenantID string `gorm:"type:uuid;not null;uniqueIndex:idx_tenant_domains_tenant_domain" json:"tenantId"`
|
||||
Domain string `gorm:"not null;uniqueIndex:idx_tenant_domains_tenant_domain" json:"domain"` // e.g. "example.com"
|
||||
Verified bool `gorm:"default:false" json:"verified"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
|
||||
// BeforeCreate hook to generate UUID if not present.
|
||||
func (td *TenantDomain) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if td.ID == "" {
|
||||
td.ID = uuid.NewString()
|
||||
}
|
||||
return
|
||||
}
|
||||
247
baron-sso/backend/internal/domain/user.go
Normal file
247
baron-sso/backend/internal/domain/user.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User roles
|
||||
const (
|
||||
RoleSuperAdmin = "super_admin" // 시스템 전역 관리자
|
||||
RoleUser = "user" // 일반 사용자
|
||||
)
|
||||
|
||||
// User statuses
|
||||
const (
|
||||
UserStatusActive = "active"
|
||||
UserStatusSuspended = "suspended"
|
||||
UserStatusTemporaryLeave = "temporary_leave"
|
||||
UserStatusPreboarding = "preboarding"
|
||||
UserStatusBaronGuest = "baron_guest"
|
||||
UserStatusExtendedLeave = "extended_leave"
|
||||
UserStatusArchived = "archived"
|
||||
)
|
||||
|
||||
func NormalizeUserStatus(status string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(status)) {
|
||||
case "", UserStatusActive:
|
||||
return UserStatusActive
|
||||
case "blocked", UserStatusSuspended:
|
||||
return UserStatusSuspended
|
||||
case "inactive", UserStatusPreboarding:
|
||||
return UserStatusPreboarding
|
||||
case "leave_of_absence", UserStatusTemporaryLeave:
|
||||
return UserStatusTemporaryLeave
|
||||
case "baron_only", UserStatusBaronGuest:
|
||||
return UserStatusBaronGuest
|
||||
case UserStatusExtendedLeave:
|
||||
return UserStatusExtendedLeave
|
||||
case UserStatusArchived:
|
||||
return UserStatusArchived
|
||||
default:
|
||||
return strings.ToLower(strings.TrimSpace(status))
|
||||
}
|
||||
}
|
||||
|
||||
func IsBaronActivityAllowedStatus(status string) bool {
|
||||
switch NormalizeUserStatus(status) {
|
||||
case UserStatusActive, UserStatusTemporaryLeave, UserStatusBaronGuest:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func IsOrgVisibleUserStatus(status string) bool {
|
||||
switch NormalizeUserStatus(status) {
|
||||
case UserStatusActive, UserStatusTemporaryLeave, UserStatusSuspended:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func IsWorksProvisionedUserStatus(status string) bool {
|
||||
switch NormalizeUserStatus(status) {
|
||||
case UserStatusActive, UserStatusTemporaryLeave, UserStatusSuspended:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func IsWorksDeprovisionUserStatus(status string) bool {
|
||||
switch NormalizeUserStatus(status) {
|
||||
case UserStatusBaronGuest, UserStatusExtendedLeave, UserStatusArchived:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeRole maps legacy/synonym role values to canonical role keys.
|
||||
func NormalizeRole(role string) string {
|
||||
if normalized, ok := NormalizeRoleAlias(role); ok {
|
||||
return normalized
|
||||
}
|
||||
return RoleUser
|
||||
}
|
||||
|
||||
func NormalizeRoleAlias(role string) (string, bool) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(role))
|
||||
switch normalized {
|
||||
case RoleSuperAdmin, RoleUser:
|
||||
return normalized, true
|
||||
case "tenant_admin", "rp_admin", "tenant_member", "member", "admin", "tenantadmin", "tenant-admin":
|
||||
return RoleUser, true
|
||||
case "superadmin", "super-admin":
|
||||
return RoleSuperAdmin, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// User represents the user model stored in PostgreSQL
|
||||
type User struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;not null" json:"email"`
|
||||
PasswordHash *string `gorm:"column:password_hash" json:"-"`
|
||||
Name string `gorm:"column:name;not null" json:"name"`
|
||||
Phone string `gorm:"column:phone" json:"phone"`
|
||||
Role string `gorm:"column:role;default:'user';not null" json:"role"` // super_admin, user
|
||||
AffiliationType string `gorm:"column:affiliation_type" json:"affiliationType"`
|
||||
CompanyCode string `gorm:"-" json:"companyCode,omitempty"`
|
||||
CompanyCodes pq.StringArray `gorm:"-" json:"companyCodes,omitempty"`
|
||||
TenantID *string `gorm:"column:tenant_id;type:uuid;index" json:"tenantId,omitempty"`
|
||||
Tenant *Tenant `gorm:"foreignKey:TenantID" json:"tenant,omitempty"`
|
||||
RelyingPartyID *string `gorm:"column:relying_party_id;type:uuid;index" json:"relyingPartyId,omitempty"` // RP Admin용
|
||||
Department string `gorm:"column:department" json:"department"`
|
||||
Grade string `gorm:"column:grade" json:"grade"` // 직급 (예: 수석, 책임, 선임)
|
||||
Position string `gorm:"column:position" json:"position"` // 직책 (예: 팀장, 센터장)
|
||||
JobTitle string `gorm:"column:job_title" json:"jobTitle"` // 직무 (예: 프론트엔드 개발, 기획)
|
||||
Metadata JSONMap `gorm:"column:metadata;type:jsonb" json:"metadata,omitempty"`
|
||||
Status string `gorm:"column:status;default:'active'" json:"status"`
|
||||
CreatedAt time.Time `gorm:"column:created_at" json:"createdAt"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at" json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index" json:"-"`
|
||||
|
||||
// Multiple identifiers support
|
||||
UserLoginIDs []UserLoginID `gorm:"foreignKey:UserID" json:"userLoginIds,omitempty"`
|
||||
}
|
||||
|
||||
// UserLoginID represents multiple custom identifiers for a user
|
||||
type UserLoginID struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
UserID string `gorm:"type:uuid;not null;index" json:"userId"`
|
||||
TenantID string `gorm:"type:uuid;not null;index" json:"tenantId"` // 발급 테넌트
|
||||
FieldKey string `gorm:"not null" json:"fieldKey"` // 스키마 필드 키 (예: emp_id)
|
||||
LoginID string `gorm:"uniqueIndex;not null" json:"loginId"` // 실제 값 (예: EMP001)
|
||||
}
|
||||
|
||||
// BeforeCreate hook to generate UUID if not present
|
||||
func (u *User) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if u.ID == "" {
|
||||
u.ID = uuid.New().String()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ValidateLoginID checks if the loginID violates any collision, length, or security rules.
|
||||
func ValidateLoginID(loginID string, emails []string, phone string) error {
|
||||
loginID = strings.TrimSpace(loginID)
|
||||
if loginID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(loginID) < 4 || len(loginID) > 30 {
|
||||
return fmt.Errorf("ID must be between 4 and 30 characters")
|
||||
}
|
||||
|
||||
if strings.Contains(loginID, "@") {
|
||||
return fmt.Errorf("ID cannot be an email format")
|
||||
}
|
||||
|
||||
for _, email := range emails {
|
||||
if email != "" && strings.EqualFold(loginID, email) {
|
||||
return fmt.Errorf("ID cannot be the same as the email address")
|
||||
}
|
||||
}
|
||||
|
||||
if phone != "" {
|
||||
normalizedPhone := NormalizePhoneNumber(phone)
|
||||
|
||||
if loginID == phone || loginID == normalizedPhone {
|
||||
return fmt.Errorf("ID cannot be the same as the phone number")
|
||||
}
|
||||
}
|
||||
|
||||
isPureNumber := true
|
||||
loginIDDigits := strings.ReplaceAll(loginID, "-", "")
|
||||
loginIDDigits = strings.ReplaceAll(loginIDDigits, " ", "")
|
||||
for _, c := range loginIDDigits {
|
||||
if (c < '0' || c > '9') && c != '+' {
|
||||
isPureNumber = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isPureNumber && len(loginIDDigits) >= 10 && len(loginIDDigits) <= 12 {
|
||||
if strings.HasPrefix(loginIDDigits, "010") || strings.HasPrefix(loginIDDigits, "82") || strings.HasPrefix(loginIDDigits, "+82") {
|
||||
return fmt.Errorf("ID cannot be a phone number format")
|
||||
}
|
||||
}
|
||||
|
||||
reserved := []string{"admin", "system", "root", "master", "superuser", "guest", "operator"}
|
||||
lowerID := strings.ToLower(loginID)
|
||||
if slices.Contains(reserved, lowerID) {
|
||||
return fmt.Errorf("reserved ID cannot be used")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizePhoneNumber(phone string) string {
|
||||
trimmed := strings.TrimSpace(phone)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
hasLeadingPlus := false
|
||||
digits := strings.Builder{}
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case r >= '0' && r <= '9':
|
||||
digits.WriteRune(r)
|
||||
case r == '+' && digits.Len() == 0 && !hasLeadingPlus:
|
||||
hasLeadingPlus = true
|
||||
}
|
||||
}
|
||||
|
||||
number := digits.String()
|
||||
if number == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(number, "010") {
|
||||
return "+82" + number[1:]
|
||||
}
|
||||
if strings.HasPrefix(number, "82") {
|
||||
rest := number[2:]
|
||||
for strings.HasPrefix(rest, "82") {
|
||||
rest = rest[2:]
|
||||
}
|
||||
if strings.HasPrefix(rest, "0") {
|
||||
rest = rest[1:]
|
||||
}
|
||||
return "+82" + rest
|
||||
}
|
||||
if hasLeadingPlus {
|
||||
return "+" + number
|
||||
}
|
||||
return number
|
||||
}
|
||||
50
baron-sso/backend/internal/domain/user_group.go
Normal file
50
baron-sso/backend/internal/domain/user_group.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserGroup represents a collection of users within a tenant.
|
||||
type UserGroup struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
TenantID string `gorm:"type:uuid;index;not null" json:"tenantId"`
|
||||
ParentID *string `gorm:"type:uuid;index" json:"parentId,omitempty"` // 상위 조직 ID
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Slug string `gorm:"index" json:"slug"` // 추가
|
||||
Description string `json:"description"`
|
||||
UnitType string `json:"unitType"` // 부, 국, 팀, 셀 등
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// Relationships
|
||||
Parent *UserGroup `gorm:"foreignKey:ParentID" json:"parent,omitempty"`
|
||||
Members []User `gorm:"-" json:"members,omitempty"`
|
||||
}
|
||||
|
||||
type GroupCreateRequest struct {
|
||||
Name string `json:"name"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Description string `json:"description"`
|
||||
UnitType string `json:"unitType"`
|
||||
}
|
||||
|
||||
type GroupRole struct {
|
||||
TenantID string `json:"tenantId"`
|
||||
TenantName string `json:"tenantName"`
|
||||
Relation string `json:"relation"`
|
||||
}
|
||||
|
||||
func (ug *UserGroup) TableName() string {
|
||||
return "user_groups"
|
||||
}
|
||||
|
||||
func (ug *UserGroup) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if ug.ID == "" {
|
||||
ug.ID = uuid.NewString()
|
||||
}
|
||||
return
|
||||
}
|
||||
29
baron-sso/backend/internal/domain/user_projection.go
Normal file
29
baron-sso/backend/internal/domain/user_projection.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
UserProjectionNameKratos = "kratos_users"
|
||||
|
||||
UserProjectionStatusSyncing = "syncing"
|
||||
UserProjectionStatusReady = "ready"
|
||||
UserProjectionStatusFailed = "failed"
|
||||
)
|
||||
|
||||
type UserProjectionState struct {
|
||||
Name string `gorm:"primaryKey;column:name" json:"name"`
|
||||
Status string `gorm:"column:status;not null" json:"status"`
|
||||
LastSyncedAt *time.Time `gorm:"column:last_synced_at" json:"lastSyncedAt,omitempty"`
|
||||
LastError string `gorm:"column:last_error;type:text" json:"lastError,omitempty"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at" json:"updatedAt"`
|
||||
}
|
||||
|
||||
type UserProjectionStatus struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Ready bool `json:"ready"`
|
||||
LastSyncedAt *time.Time `json:"lastSyncedAt,omitempty"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
|
||||
ProjectedUsers int64 `json:"projectedUsers"`
|
||||
}
|
||||
73
baron-sso/backend/internal/domain/user_test.go
Normal file
73
baron-sso/backend/internal/domain/user_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package domain
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeRole(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "super admin unchanged", in: "super_admin", want: RoleSuperAdmin},
|
||||
{name: "tenant admin mapped to user", in: "tenant_admin", want: RoleUser},
|
||||
{name: "rp admin mapped to user", in: "rp_admin", want: RoleUser},
|
||||
{name: "user unchanged", in: "user", want: RoleUser},
|
||||
{name: "super admin hyphen alias", in: "super-admin", want: RoleSuperAdmin},
|
||||
{name: "super admin compact alias", in: "superadmin", want: RoleSuperAdmin},
|
||||
{name: "legacy admin mapped to user", in: "admin", want: RoleUser},
|
||||
{name: "legacy tenant member", in: "tenant_member", want: RoleUser},
|
||||
{name: "trim and lower", in: " ADMIN ", want: RoleUser},
|
||||
{name: "unknown role mapped to user", in: "custom_role", want: RoleUser},
|
||||
{name: "empty string mapped to user", in: " ", want: RoleUser},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := NormalizeRole(tc.in); got != tc.want {
|
||||
t.Fatalf("NormalizeRole(%q)=%q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserStatusPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
status string
|
||||
normalized string
|
||||
baronAllowed bool
|
||||
orgVisible bool
|
||||
worksProvisioned bool
|
||||
worksDeprovisioned bool
|
||||
}{
|
||||
{status: UserStatusActive, normalized: UserStatusActive, baronAllowed: true, orgVisible: true, worksProvisioned: true},
|
||||
{status: UserStatusTemporaryLeave, normalized: UserStatusTemporaryLeave, baronAllowed: true, orgVisible: true, worksProvisioned: true},
|
||||
{status: UserStatusSuspended, normalized: UserStatusSuspended, orgVisible: true, worksProvisioned: true},
|
||||
{status: UserStatusPreboarding, normalized: UserStatusPreboarding},
|
||||
{status: UserStatusBaronGuest, normalized: UserStatusBaronGuest, baronAllowed: true, worksDeprovisioned: true},
|
||||
{status: UserStatusExtendedLeave, normalized: UserStatusExtendedLeave, worksDeprovisioned: true},
|
||||
{status: UserStatusArchived, normalized: UserStatusArchived, worksDeprovisioned: true},
|
||||
{status: "inactive", normalized: UserStatusPreboarding},
|
||||
{status: "leave_of_absence", normalized: UserStatusTemporaryLeave, baronAllowed: true, orgVisible: true, worksProvisioned: true},
|
||||
{status: "BARON_ONLY", normalized: UserStatusBaronGuest, baronAllowed: true, worksDeprovisioned: true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.status, func(t *testing.T) {
|
||||
if got := NormalizeUserStatus(tc.status); got != tc.normalized {
|
||||
t.Fatalf("NormalizeUserStatus(%q)=%q, want %q", tc.status, got, tc.normalized)
|
||||
}
|
||||
if got := IsBaronActivityAllowedStatus(tc.status); got != tc.baronAllowed {
|
||||
t.Fatalf("IsBaronActivityAllowedStatus(%q)=%v, want %v", tc.status, got, tc.baronAllowed)
|
||||
}
|
||||
if got := IsOrgVisibleUserStatus(tc.status); got != tc.orgVisible {
|
||||
t.Fatalf("IsOrgVisibleUserStatus(%q)=%v, want %v", tc.status, got, tc.orgVisible)
|
||||
}
|
||||
if got := IsWorksProvisionedUserStatus(tc.status); got != tc.worksProvisioned {
|
||||
t.Fatalf("IsWorksProvisionedUserStatus(%q)=%v, want %v", tc.status, got, tc.worksProvisioned)
|
||||
}
|
||||
if got := IsWorksDeprovisionUserStatus(tc.status); got != tc.worksDeprovisioned {
|
||||
t.Fatalf("IsWorksDeprovisionUserStatus(%q)=%v, want %v", tc.status, got, tc.worksDeprovisioned)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
64
baron-sso/backend/internal/domain/user_validate_test.go
Normal file
64
baron-sso/backend/internal/domain/user_validate_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateLoginID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
loginID string
|
||||
emails []string
|
||||
phone string
|
||||
wantErr bool
|
||||
}{
|
||||
{"Empty", "", []string{"test@email.com"}, "01012345678", false},
|
||||
{"Valid alphanumeric", "user123", []string{"test@email.com"}, "01012345678", false},
|
||||
{"Too short", "us", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Too long", "thisisaverylongloginidthatiswayoverthirtycharacters", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Email format", "user@domain.com", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Exact email match", "Test@Email.Com", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Secondary email match", "sub@test.com", []string{"test@email.com", "sub@test.com"}, "01012345678", true},
|
||||
{"Phone number match", "010-1234-5678", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Phone number match +82", "+821012345678", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Phone number match digits", "01012345678", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Phone format (11 digits)", "01098765432", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Valid pure digits (employee ID)", "20230001", []string{"test@email.com"}, "01012345678", false},
|
||||
{"Valid pure digits long", "123456789", []string{"test@email.com"}, "01012345678", false},
|
||||
{"Valid pure digits 10 chars", "1234567890", []string{"test@email.com"}, "01012345678", false},
|
||||
{"Reserved word admin", "ADMIN", []string{"test@email.com"}, "01012345678", true},
|
||||
{"Reserved word root", "root", []string{"test@email.com"}, "01012345678", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateLoginID(tt.loginID, tt.emails, tt.phone)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateLoginID() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePhoneNumberDeduplicatesKoreanCountryCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"Local mobile", "010-9191-7771", "+821091917771"},
|
||||
{"Korean country code", "+82 10-9191-7771", "+821091917771"},
|
||||
{"Duplicate plus Korean country code", "+82 +821091917771", "+821091917771"},
|
||||
{"Duplicate compact Korean country code", "+82821091917771", "+821091917771"},
|
||||
{"Duplicate spaced Korean country code", "+82 8210 9191 7771", "+821091917771"},
|
||||
{"Non Korean international phone preserved", "+1 914 481 2222", "+19144812222"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NormalizePhoneNumber(tt.input); got != tt.want {
|
||||
t.Fatalf("NormalizePhoneNumber(%q)=%q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
73
baron-sso/backend/internal/domain/worksmobile.go
Normal file
73
baron-sso/backend/internal/domain/worksmobile.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
WorksmobileOutboxStatusPending = "pending"
|
||||
WorksmobileOutboxStatusProcessing = "processing"
|
||||
WorksmobileOutboxStatusProcessed = "processed"
|
||||
WorksmobileOutboxStatusFailed = "failed"
|
||||
)
|
||||
|
||||
const (
|
||||
WorksmobileResourceOrgUnit = "ORGUNIT"
|
||||
WorksmobileResourceUser = "USER"
|
||||
)
|
||||
|
||||
const (
|
||||
WorksmobileActionUpsert = "UPSERT"
|
||||
WorksmobileActionDelete = "DELETE"
|
||||
WorksmobileActionDryRun = "DRY_RUN"
|
||||
WorksmobileActionSuspend = "SUSPEND"
|
||||
WorksmobileActionPasswordReset = "PASSWORD_RESET"
|
||||
)
|
||||
|
||||
type WorksmobileOutbox struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
ResourceType string `gorm:"not null;index:idx_worksmobile_outbox_resource" json:"resourceType"`
|
||||
ResourceID string `gorm:"not null;index:idx_worksmobile_outbox_resource" json:"resourceId"`
|
||||
Action string `gorm:"not null" json:"action"`
|
||||
Payload JSONMap `gorm:"type:jsonb" json:"payload,omitempty"`
|
||||
DedupeKey string `gorm:"uniqueIndex" json:"dedupeKey"`
|
||||
Status string `gorm:"default:'pending';index" json:"status"`
|
||||
RetryCount int `gorm:"default:0" json:"retryCount"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
NextAttemptAt *time.Time `json:"nextAttemptAt,omitempty"`
|
||||
ProcessedAt *time.Time `json:"processedAt,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (w *WorksmobileOutbox) BeforeCreate(tx *gorm.DB) error {
|
||||
if w.ID == "" {
|
||||
w.ID = uuid.NewString()
|
||||
}
|
||||
if w.Status == "" {
|
||||
w.Status = WorksmobileOutboxStatusPending
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type WorksmobileResourceMapping struct {
|
||||
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
BaronResourceType string `gorm:"not null;uniqueIndex:idx_worksmobile_mapping_baron" json:"baronResourceType"`
|
||||
BaronResourceID string `gorm:"not null;uniqueIndex:idx_worksmobile_mapping_baron" json:"baronResourceId"`
|
||||
ExternalKey string `gorm:"not null;uniqueIndex" json:"externalKey"`
|
||||
WorksmobileResourceID string `json:"worksmobileResourceId,omitempty"`
|
||||
DomainID int64 `json:"domainId"`
|
||||
LastSyncedAt *time.Time `json:"lastSyncedAt,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (w *WorksmobileResourceMapping) BeforeCreate(tx *gorm.DB) error {
|
||||
if w.ID == "" {
|
||||
w.ID = uuid.NewString()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
491
baron-sso/backend/internal/handler/admin_handler.go
Normal file
491
baron-sso/backend/internal/handler/admin_handler.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type adminHydraClientLister interface {
|
||||
ListClients(ctx context.Context, limit, offset int) ([]domain.HydraClient, error)
|
||||
}
|
||||
|
||||
type identityCacheAdmin interface {
|
||||
GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error)
|
||||
FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error)
|
||||
}
|
||||
|
||||
type AdminHandler struct {
|
||||
DB *gorm.DB
|
||||
Keto service.KetoService
|
||||
KetoOutbox repository.KetoOutboxRepository
|
||||
RPUsageQueries domain.RPUsageQueryRepository
|
||||
TenantRepo repository.TenantRepository
|
||||
Hydra adminHydraClientLister
|
||||
AuditRepo domain.AuditRepository
|
||||
UserProjectionRepo repository.UserProjectionRepository
|
||||
IdentityCache identityCacheAdmin
|
||||
IntegrityChecker repository.DataIntegrityChecker
|
||||
}
|
||||
|
||||
const globalCustomClaimsSettingKey = "global_custom_claim_definitions"
|
||||
|
||||
type globalCustomClaimDefinition struct {
|
||||
Key string `json:"key"`
|
||||
Label string `json:"label"`
|
||||
ValueType string `json:"valueType"`
|
||||
ReadPermission string `json:"readPermission"`
|
||||
WritePermission string `json:"writePermission"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type globalCustomClaimDefinitionsResponse struct {
|
||||
Items []globalCustomClaimDefinition `json:"items"`
|
||||
}
|
||||
|
||||
func NewAdminHandler(keto service.KetoService, ketoOutbox repository.KetoOutboxRepository) *AdminHandler {
|
||||
return &AdminHandler{
|
||||
Keto: keto,
|
||||
KetoOutbox: ketoOutbox,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetRPUsageDaily(c *fiber.Ctx) error {
|
||||
if h == nil || h.RPUsageQueries == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
|
||||
"error": "rp usage query service unavailable",
|
||||
})
|
||||
}
|
||||
days := 14
|
||||
if raw := c.Query("days"); raw != "" {
|
||||
if parsed, err := strconv.Atoi(raw); err == nil {
|
||||
days = parsed
|
||||
}
|
||||
}
|
||||
period := normalizeRPUsagePeriod(c.Query("period"))
|
||||
tenantID, allowed := h.authorizedRPUsageTenantID(c, strings.TrimSpace(c.Query("tenantId")))
|
||||
if !allowed {
|
||||
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
||||
"error": "forbidden: tenant rp usage stats permission denied",
|
||||
})
|
||||
}
|
||||
items, err := h.RPUsageQueries.FindRPUsage(c.Context(), domain.RPUsageQuery{
|
||||
Days: days,
|
||||
Period: period,
|
||||
TenantID: tenantID,
|
||||
})
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
return c.JSON(fiber.Map{
|
||||
"items": items,
|
||||
"days": days,
|
||||
"period": period,
|
||||
"tenantId": tenantID,
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeRPUsagePeriod(period string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(period)) {
|
||||
case "week":
|
||||
return "week"
|
||||
case "month":
|
||||
return "month"
|
||||
default:
|
||||
return "day"
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AdminHandler) authorizedRPUsageTenantID(c *fiber.Ctx, requestedTenantID string) (string, bool) {
|
||||
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
if profile != nil && domain.NormalizeRole(profile.Role) == domain.RoleSuperAdmin {
|
||||
return requestedTenantID, true
|
||||
}
|
||||
tenantID := requestedTenantID
|
||||
if tenantID == "" && profile != nil && profile.TenantID != nil {
|
||||
tenantID = strings.TrimSpace(*profile.TenantID)
|
||||
}
|
||||
if tenantID == "" {
|
||||
return "", false
|
||||
}
|
||||
if h == nil || h.Keto == nil || profile == nil || strings.TrimSpace(profile.ID) == "" {
|
||||
return "", false
|
||||
}
|
||||
allowed, err := h.Keto.CheckPermission(c.Context(), "User:"+profile.ID, "Tenant", tenantID, "view_rp_usage_stats")
|
||||
if err != nil || !allowed {
|
||||
return "", false
|
||||
}
|
||||
return tenantID, true
|
||||
}
|
||||
|
||||
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
|
||||
if h == nil || h.DB == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
|
||||
"error": "settings store unavailable",
|
||||
})
|
||||
}
|
||||
|
||||
var setting domain.SystemSetting
|
||||
if err := h.DB.WithContext(c.Context()).First(&setting, "key = ?", globalCustomClaimsSettingKey).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(globalCustomClaimDefinitionsResponse{Items: []globalCustomClaimDefinition{}})
|
||||
}
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(globalCustomClaimDefinitionsResponse{
|
||||
Items: normalizeGlobalCustomClaimDefinitions(setting.Value["items"]),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) UpdateGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
|
||||
if h == nil || h.DB == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
|
||||
"error": "settings store unavailable",
|
||||
})
|
||||
}
|
||||
|
||||
var req globalCustomClaimDefinitionsResponse
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
|
||||
}
|
||||
items, err := validateGlobalCustomClaimDefinitions(req.Items)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
setting := domain.SystemSetting{
|
||||
Key: globalCustomClaimsSettingKey,
|
||||
Value: domain.JSONMap{"items": globalCustomClaimDefinitionsToJSON(items)},
|
||||
}
|
||||
if err := h.DB.WithContext(c.Context()).Save(&setting).Error; err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(globalCustomClaimDefinitionsResponse{Items: items})
|
||||
}
|
||||
|
||||
func normalizeGlobalCustomClaimDefinitions(value any) []globalCustomClaimDefinition {
|
||||
rawItems, ok := value.([]any)
|
||||
if !ok {
|
||||
return []globalCustomClaimDefinition{}
|
||||
}
|
||||
items := make([]globalCustomClaimDefinition, 0, len(rawItems))
|
||||
for _, item := range rawItems {
|
||||
raw, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
def := globalCustomClaimDefinition{
|
||||
Key: strings.TrimSpace(stringValue(raw["key"])),
|
||||
Label: strings.TrimSpace(stringValue(raw["label"])),
|
||||
ValueType: normalizeGlobalCustomClaimType(stringValue(raw["valueType"])),
|
||||
ReadPermission: adminNormalizeCustomClaimPermission(stringValue(raw["readPermission"])),
|
||||
WritePermission: adminNormalizeCustomClaimPermission(stringValue(raw["writePermission"])),
|
||||
Description: strings.TrimSpace(stringValue(raw["description"])),
|
||||
}
|
||||
if def.Key != "" {
|
||||
items = append(items, def)
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func validateGlobalCustomClaimDefinitions(items []globalCustomClaimDefinition) ([]globalCustomClaimDefinition, error) {
|
||||
seen := map[string]struct{}{}
|
||||
normalized := make([]globalCustomClaimDefinition, 0, len(items))
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if !isValidCustomClaimKey(key) {
|
||||
return nil, fiber.NewError(fiber.StatusBadRequest, "claim key must use letters, numbers, underscore, dot, or hyphen")
|
||||
}
|
||||
if _, exists := seen[key]; exists {
|
||||
return nil, fiber.NewError(fiber.StatusBadRequest, "duplicate claim key: "+key)
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
normalized = append(normalized, globalCustomClaimDefinition{
|
||||
Key: key,
|
||||
Label: strings.TrimSpace(item.Label),
|
||||
ValueType: normalizeGlobalCustomClaimType(item.ValueType),
|
||||
ReadPermission: adminNormalizeCustomClaimPermission(item.ReadPermission),
|
||||
WritePermission: adminNormalizeCustomClaimPermission(item.WritePermission),
|
||||
Description: strings.TrimSpace(item.Description),
|
||||
})
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func globalCustomClaimDefinitionsToJSON(items []globalCustomClaimDefinition) []any {
|
||||
values := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
values = append(values, map[string]any{
|
||||
"key": item.Key,
|
||||
"label": item.Label,
|
||||
"valueType": item.ValueType,
|
||||
"readPermission": item.ReadPermission,
|
||||
"writePermission": item.WritePermission,
|
||||
"description": item.Description,
|
||||
})
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func normalizeGlobalCustomClaimType(value string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "number", "boolean", "array", "object", "date", "datetime":
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
default:
|
||||
return "text"
|
||||
}
|
||||
}
|
||||
|
||||
func adminNormalizeCustomClaimPermission(value string) string {
|
||||
if strings.TrimSpace(value) == "user_and_admin" {
|
||||
return "user_and_admin"
|
||||
}
|
||||
return "admin_only"
|
||||
}
|
||||
|
||||
func isValidCustomClaimKey(value string) bool {
|
||||
for _, r := range value {
|
||||
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '_' || r == '-' || r == '.' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringValue(value any) string {
|
||||
if text, ok := value.(string); ok {
|
||||
return text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func requireSuperAdminProfile(c *fiber.Ctx) bool {
|
||||
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
if profile == nil || domain.NormalizeRole(profile.Role) != domain.RoleSuperAdmin {
|
||||
_ = c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: super_admin required"})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetUserProjectionStatus(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.UserProjectionRepo == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
|
||||
}
|
||||
status, err := h.UserProjectionRepo.GetStatus(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(status)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetOrySSOTSystemStatus(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.UserProjectionRepo == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
|
||||
}
|
||||
projectionStatus, err := h.UserProjectionRepo.GetStatus(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
cacheStatus := domain.IdentityCacheStatus{
|
||||
Status: "unavailable",
|
||||
RedisReady: false,
|
||||
LastError: "identity cache service unavailable",
|
||||
}
|
||||
if h.IdentityCache != nil {
|
||||
cacheStatus, err = h.IdentityCache.GetIdentityCacheStatus(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"userProjection": projectionStatus,
|
||||
"identityCache": cacheStatus,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) FlushIdentityCache(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.IdentityCache == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "identity cache service unavailable"})
|
||||
}
|
||||
result, err := h.IdentityCache.FlushIdentityCache(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetDataIntegrity(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.IntegrityChecker == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
|
||||
}
|
||||
report, err := h.IntegrityChecker.CheckDataIntegrity(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(report)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) ListOrphanUserLoginIDs(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.IntegrityChecker == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
|
||||
}
|
||||
items, err := h.IntegrityChecker.ListOrphanUserLoginIDs(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(fiber.Map{
|
||||
"items": items,
|
||||
"total": len(items),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) DeleteOrphanUserLoginIDs(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.IntegrityChecker == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
|
||||
}
|
||||
var req struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
result, err := h.IntegrityChecker.DeleteOrphanUserLoginIDs(c.Context(), req.IDs)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
// GetSystemStats returns runtime statistics for monitoring
|
||||
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
ctx := c.Context()
|
||||
|
||||
stats := fiber.Map{
|
||||
"totalTenants": h.countTenants(ctx),
|
||||
"totalUsers": h.countUsers(ctx),
|
||||
"oidcClients": h.countOIDCClients(ctx),
|
||||
"auditEvents24h": h.countAuditEventsSince(ctx, time.Now().UTC().Add(-24*time.Hour)),
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"cpus": runtime.NumCPU(),
|
||||
"memory": fiber.Map{
|
||||
"alloc": m.Alloc,
|
||||
"totalAlign": m.TotalAlloc,
|
||||
"sys": m.Sys,
|
||||
"numGC": m.NumGC,
|
||||
},
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(stats)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) countTenants(ctx context.Context) int64 {
|
||||
if h == nil || h.TenantRepo == nil {
|
||||
return 0
|
||||
}
|
||||
_, total, err := h.TenantRepo.List(ctx, 1, 0, "", "")
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (h *AdminHandler) countUsers(ctx context.Context) int64 {
|
||||
if h == nil || h.UserProjectionRepo == nil {
|
||||
return 0
|
||||
}
|
||||
status, err := h.UserProjectionRepo.GetStatus(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return status.ProjectedUsers
|
||||
}
|
||||
|
||||
func (h *AdminHandler) countOIDCClients(ctx context.Context) int64 {
|
||||
if h == nil || h.Hydra == nil {
|
||||
return 0
|
||||
}
|
||||
const pageSize = 500
|
||||
var total int64
|
||||
for offset := 0; ; offset += pageSize {
|
||||
clients, err := h.Hydra.ListClients(ctx, pageSize, offset)
|
||||
if err != nil {
|
||||
return total
|
||||
}
|
||||
for _, client := range clients {
|
||||
if isHiddenSystemClient(client) {
|
||||
continue
|
||||
}
|
||||
total++
|
||||
}
|
||||
if len(clients) < pageSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (h *AdminHandler) countAuditEventsSince(ctx context.Context, since time.Time) int64 {
|
||||
if h == nil || h.AuditRepo == nil {
|
||||
return 0
|
||||
}
|
||||
count, err := h.AuditRepo.CountEventsSince(ctx, since)
|
||||
if err == nil && count > 0 {
|
||||
return count
|
||||
}
|
||||
logs, pageErr := h.AuditRepo.FindPage(ctx, 10000, nil, "")
|
||||
if pageErr != nil {
|
||||
return count
|
||||
}
|
||||
var fallbackCount int64
|
||||
for _, log := range logs {
|
||||
if !log.Timestamp.Before(since) {
|
||||
fallbackCount++
|
||||
}
|
||||
}
|
||||
return fallbackCount
|
||||
}
|
||||
343
baron-sso/backend/internal/handler/admin_handler_test.go
Normal file
343
baron-sso/backend/internal/handler/admin_handler_test.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeRPUsageQueryRepo struct {
|
||||
query domain.RPUsageQuery
|
||||
items []domain.RPUsageDailyMetric
|
||||
}
|
||||
|
||||
func (f *fakeRPUsageQueryRepo) FindRPUsage(ctx context.Context, query domain.RPUsageQuery) ([]domain.RPUsageDailyMetric, error) {
|
||||
f.query = query
|
||||
return f.items, nil
|
||||
}
|
||||
|
||||
type fakeAdminKeto struct {
|
||||
allowed bool
|
||||
subject string
|
||||
object string
|
||||
relation string
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
|
||||
f.subject = subject
|
||||
f.object = object
|
||||
f.relation = relation
|
||||
return f.allowed, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type fakeOverviewAuditRepo struct {
|
||||
mockAuditRepo
|
||||
since time.Time
|
||||
count int64
|
||||
}
|
||||
|
||||
func (f *fakeOverviewAuditRepo) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
f.since = since
|
||||
return f.count, nil
|
||||
}
|
||||
|
||||
type fakeAdminUserProjectionRepo struct {
|
||||
status domain.UserProjectionStatus
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) IsReady(ctx context.Context) (bool, error) {
|
||||
return f.status.Ready, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) MarkFailed(ctx context.Context, syncErr error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
|
||||
return f.status, nil
|
||||
}
|
||||
|
||||
type fakeIdentityCacheAdmin struct {
|
||||
status domain.IdentityCacheStatus
|
||||
flush domain.IdentityCacheFlushResult
|
||||
err error
|
||||
statusHit int
|
||||
flushCalls int
|
||||
}
|
||||
|
||||
func (f *fakeIdentityCacheAdmin) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
|
||||
f.statusHit++
|
||||
return f.status, f.err
|
||||
}
|
||||
|
||||
func (f *fakeIdentityCacheAdmin) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
|
||||
f.flushCalls++
|
||||
return f.flush, f.err
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetRPUsageDaily(t *testing.T) {
|
||||
repo := &fakeRPUsageQueryRepo{
|
||||
items: []domain.RPUsageDailyMetric{
|
||||
{
|
||||
Date: "2026-05-06",
|
||||
TenantID: "tenant-1",
|
||||
TenantType: domain.TenantTypeCompany,
|
||||
ClientID: "orgfront",
|
||||
ClientName: "OrgFront",
|
||||
LoginRequests: 12,
|
||||
OtherRequests: 4,
|
||||
UniqueSubjects: 8,
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{RPUsageQueries: repo}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/rp-usage/daily", h.GetRPUsageDaily)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/rp-usage/daily?days=7&period=week&tenantId=tenant-1", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 7, repo.query.Days)
|
||||
require.Equal(t, "week", repo.query.Period)
|
||||
require.Equal(t, "tenant-1", repo.query.TenantID)
|
||||
|
||||
var body struct {
|
||||
Items []domain.RPUsageDailyMetric `json:"items"`
|
||||
Days int `json:"days"`
|
||||
Period string `json:"period"`
|
||||
TenantID string `json:"tenantId"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, 7, body.Days)
|
||||
require.Equal(t, "week", body.Period)
|
||||
require.Equal(t, "tenant-1", body.TenantID)
|
||||
require.Len(t, body.Items, 1)
|
||||
require.Equal(t, "orgfront", body.Items[0].ClientID)
|
||||
require.Equal(t, uint64(12), body.Items[0].LoginRequests)
|
||||
}
|
||||
|
||||
func TestAdminHandler_UserProjectionStatusRequiresSuperAdmin(t *testing.T) {
|
||||
h := &AdminHandler{
|
||||
UserProjectionRepo: &fakeAdminUserProjectionRepo{
|
||||
status: domain.UserProjectionStatus{Name: domain.UserProjectionNameKratos, Status: domain.UserProjectionStatusReady, Ready: true},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/projections/users", h.GetUserProjectionStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/projections/users", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestAdminHandler_UserProjectionStatusReturnsProjectionStateForSuperAdmin(t *testing.T) {
|
||||
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
|
||||
h := &AdminHandler{
|
||||
UserProjectionRepo: &fakeAdminUserProjectionRepo{
|
||||
status: domain.UserProjectionStatus{
|
||||
Name: domain.UserProjectionNameKratos,
|
||||
Status: domain.UserProjectionStatusReady,
|
||||
Ready: true,
|
||||
LastSyncedAt: &syncedAt,
|
||||
ProjectedUsers: 152,
|
||||
},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/projections/users", h.GetUserProjectionStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/projections/users", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body domain.UserProjectionStatus
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, domain.UserProjectionNameKratos, body.Name)
|
||||
require.Equal(t, domain.UserProjectionStatusReady, body.Status)
|
||||
require.True(t, body.Ready)
|
||||
require.Equal(t, int64(152), body.ProjectedUsers)
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetOrySSOTSystemStatusReturnsProjectionAndIdentityCache(t *testing.T) {
|
||||
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
|
||||
cache := &fakeIdentityCacheAdmin{
|
||||
status: domain.IdentityCacheStatus{
|
||||
Status: "ready",
|
||||
RedisReady: true,
|
||||
ObservedCount: 151,
|
||||
KeyCount: 153,
|
||||
LastRefreshedAt: &syncedAt,
|
||||
UpdatedAt: &syncedAt,
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{
|
||||
UserProjectionRepo: &fakeAdminUserProjectionRepo{
|
||||
status: domain.UserProjectionStatus{
|
||||
Name: domain.UserProjectionNameKratos,
|
||||
Status: domain.UserProjectionStatusReady,
|
||||
Ready: true,
|
||||
LastSyncedAt: &syncedAt,
|
||||
ProjectedUsers: 152,
|
||||
},
|
||||
},
|
||||
IdentityCache: cache,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/ory/ssot", h.GetOrySSOTSystemStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/ory/ssot", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
UserProjection domain.UserProjectionStatus `json:"userProjection"`
|
||||
IdentityCache domain.IdentityCacheStatus `json:"identityCache"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, int64(152), body.UserProjection.ProjectedUsers)
|
||||
require.True(t, body.IdentityCache.RedisReady)
|
||||
require.Equal(t, int64(151), body.IdentityCache.ObservedCount)
|
||||
require.Equal(t, int64(153), body.IdentityCache.KeyCount)
|
||||
require.Equal(t, 1, cache.statusHit)
|
||||
}
|
||||
|
||||
func TestAdminHandler_FlushIdentityCacheRequiresSuperAdminAndFlushesCacheOnly(t *testing.T) {
|
||||
cache := &fakeIdentityCacheAdmin{
|
||||
flush: domain.IdentityCacheFlushResult{
|
||||
Status: "success",
|
||||
FlushedKeys: 7,
|
||||
UpdatedAt: time.Date(2026, 5, 11, 3, 2, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{
|
||||
IdentityCache: cache,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/admin/ory/ssot/identity-cache/flush", h.FlushIdentityCache)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/ory/ssot/identity-cache/flush", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body domain.IdentityCacheFlushResult
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, int64(7), body.FlushedKeys)
|
||||
require.Equal(t, 1, cache.flushCalls)
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetRPUsageDailyChecksTenantPermission(t *testing.T) {
|
||||
repo := &fakeRPUsageQueryRepo{}
|
||||
keto := &fakeAdminKeto{allowed: true}
|
||||
h := &AdminHandler{RPUsageQueries: repo, Keto: keto}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-1",
|
||||
Role: "tenant_admin",
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/rp-usage/daily", h.GetRPUsageDaily)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/rp-usage/daily?tenantId=tenant-allowed", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "User:user-1", keto.subject)
|
||||
require.Equal(t, "tenant-allowed", keto.object)
|
||||
require.Equal(t, "view_rp_usage_stats", keto.relation)
|
||||
require.Equal(t, "tenant-allowed", repo.query.TenantID)
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetSystemStatsIncludesOverviewMetrics(t *testing.T) {
|
||||
auditRepo := &fakeOverviewAuditRepo{count: 22}
|
||||
h := &AdminHandler{
|
||||
AuditRepo: auditRepo,
|
||||
UserProjectionRepo: &fakeAdminUserProjectionRepo{
|
||||
status: domain.UserProjectionStatus{
|
||||
Name: domain.UserProjectionNameKratos,
|
||||
Status: domain.UserProjectionStatusReady,
|
||||
Ready: true,
|
||||
ProjectedUsers: 152,
|
||||
},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/admin/stats", h.GetSystemStats)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/stats", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Contains(t, body, "totalTenants")
|
||||
require.Contains(t, body, "totalUsers")
|
||||
require.Contains(t, body, "oidcClients")
|
||||
require.Contains(t, body, "auditEvents24h")
|
||||
require.Equal(t, float64(152), body["totalUsers"])
|
||||
require.Equal(t, float64(22), body["auditEvents24h"])
|
||||
require.Equal(t, time.UTC, auditRepo.since.Location())
|
||||
}
|
||||
196
baron-sso/backend/internal/handler/admin_integrity_test.go
Normal file
196
baron-sso/backend/internal/handler/admin_integrity_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeDataIntegrityChecker struct {
|
||||
calls int
|
||||
listCalls int
|
||||
deleteCalls int
|
||||
deletedIDs []string
|
||||
report domain.DataIntegrityReport
|
||||
orphans []domain.OrphanUserLoginID
|
||||
deleteResult domain.DeleteOrphanUserLoginIDsResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeDataIntegrityChecker) CheckDataIntegrity(ctx context.Context) (domain.DataIntegrityReport, error) {
|
||||
f.calls++
|
||||
return f.report, f.err
|
||||
}
|
||||
|
||||
func (f *fakeDataIntegrityChecker) ListOrphanUserLoginIDs(ctx context.Context) ([]domain.OrphanUserLoginID, error) {
|
||||
f.listCalls++
|
||||
return f.orphans, f.err
|
||||
}
|
||||
|
||||
func (f *fakeDataIntegrityChecker) DeleteOrphanUserLoginIDs(ctx context.Context, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error) {
|
||||
f.deleteCalls++
|
||||
f.deletedIDs = append([]string(nil), ids...)
|
||||
return f.deleteResult, f.err
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetDataIntegrityRequiresSuperAdmin(t *testing.T) {
|
||||
checker := &fakeDataIntegrityChecker{}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/integrity", h.GetDataIntegrity)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
require.Equal(t, 0, checker.calls)
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetDataIntegrityReturnsReportForSuperAdmin(t *testing.T) {
|
||||
checkedAt := time.Date(2026, 5, 14, 0, 0, 0, 0, time.UTC)
|
||||
checker := &fakeDataIntegrityChecker{
|
||||
report: domain.DataIntegrityReport{
|
||||
Status: domain.DataIntegrityStatusFail,
|
||||
CheckedAt: checkedAt,
|
||||
Summary: domain.DataIntegritySummary{
|
||||
TotalChecks: 1,
|
||||
Failures: 1,
|
||||
},
|
||||
Sections: []domain.DataIntegritySection{
|
||||
{
|
||||
Key: "tenant_integrity",
|
||||
Label: "테넌트 정합성",
|
||||
Status: domain.DataIntegrityStatusFail,
|
||||
Checks: []domain.DataIntegrityCheck{
|
||||
{
|
||||
Key: "duplicate_tenant_slugs",
|
||||
Label: "중복 테넌트 slug",
|
||||
Status: domain.DataIntegrityStatusFail,
|
||||
Count: 1,
|
||||
Severity: "error",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/integrity", h.GetDataIntegrity)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 1, checker.calls)
|
||||
|
||||
var body domain.DataIntegrityReport
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, domain.DataIntegrityStatusFail, body.Status)
|
||||
require.Equal(t, int64(1), body.Summary.Failures)
|
||||
require.Len(t, body.Sections, 1)
|
||||
require.Equal(t, "tenant_integrity", body.Sections[0].Key)
|
||||
}
|
||||
|
||||
func TestAdminHandler_ListOrphanUserLoginIDsReturnsTargetsForSuperAdmin(t *testing.T) {
|
||||
checker := &fakeDataIntegrityChecker{
|
||||
orphans: []domain.OrphanUserLoginID{
|
||||
{
|
||||
ID: "login-id-1",
|
||||
UserID: "user-1",
|
||||
TenantID: "tenant-1",
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "EMP001",
|
||||
Reasons: []string{"missing_tenant"},
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/integrity/orphan-user-login-ids", h.ListOrphanUserLoginIDs)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity/orphan-user-login-ids", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 1, checker.listCalls)
|
||||
|
||||
var body struct {
|
||||
Items []domain.OrphanUserLoginID `json:"items"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, 1, body.Total)
|
||||
require.Equal(t, "login-id-1", body.Items[0].ID)
|
||||
require.Equal(t, []string{"missing_tenant"}, body.Items[0].Reasons)
|
||||
}
|
||||
|
||||
func TestAdminHandler_DeleteOrphanUserLoginIDsRequiresSuperAdminAndDeletesSelectedTargets(t *testing.T) {
|
||||
checker := &fakeDataIntegrityChecker{
|
||||
deleteResult: domain.DeleteOrphanUserLoginIDsResult{
|
||||
DeletedCount: 1,
|
||||
Deleted: []domain.OrphanUserLoginID{
|
||||
{ID: "login-id-1", LoginID: "EMP001", Reasons: []string{"missing_user"}},
|
||||
},
|
||||
SkippedIDs: []string{"valid-login-id"},
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Delete("/api/v1/admin/integrity/orphan-user-login-ids", h.DeleteOrphanUserLoginIDs)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/admin/integrity/orphan-user-login-ids", strings.NewReader(`{"ids":["login-id-1","valid-login-id"]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 1, checker.deleteCalls)
|
||||
require.Equal(t, []string{"login-id-1", "valid-login-id"}, checker.deletedIDs)
|
||||
|
||||
var body domain.DeleteOrphanUserLoginIDsResult
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, int64(1), body.DeletedCount)
|
||||
require.Equal(t, []string{"valid-login-id"}, body.SkippedIDs)
|
||||
}
|
||||
|
||||
func TestAdminHandler_DeleteOrphanUserLoginIDsRejectsTenantAdmin(t *testing.T) {
|
||||
checker := &fakeDataIntegrityChecker{}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
|
||||
return c.Next()
|
||||
})
|
||||
app.Delete("/api/v1/admin/integrity/orphan-user-login-ids", h.DeleteOrphanUserLoginIDs)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/admin/integrity/orphan-user-login-ids", strings.NewReader(`{"ids":["login-id-1"]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
require.Equal(t, 0, checker.deleteCalls)
|
||||
}
|
||||
288
baron-sso/backend/internal/handler/api_key_handler.go
Normal file
288
baron-sso/backend/internal/handler/api_key_handler.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/pagination"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyHandler struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyHandler(db *gorm.DB) *ApiKeyHandler {
|
||||
return &ApiKeyHandler{DB: db}
|
||||
}
|
||||
|
||||
type apiKeySummary struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Status string `json:"status"`
|
||||
LastUsedAt *string `json:"lastUsedAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type apiKeyListResponse struct {
|
||||
Items []apiKeySummary `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Cursor string `json:"cursor,omitempty"`
|
||||
NextCursor string `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
func apiKeyToSummary(k domain.ApiKey) apiKeySummary {
|
||||
lastUsed := ""
|
||||
if k.LastUsedAt != nil {
|
||||
lastUsed = k.LastUsedAt.Format(time.RFC3339)
|
||||
}
|
||||
return apiKeySummary{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
ClientID: k.ClientID,
|
||||
Scopes: strings.Fields(strings.ReplaceAll(k.Scopes, ",", " ")),
|
||||
Status: k.Status,
|
||||
LastUsedAt: &lastUsed,
|
||||
CreatedAt: k.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func apiKeyWithUpdatedScopes(k domain.ApiKey, scopes []string) domain.ApiKey {
|
||||
k.Scopes = strings.Join(normalizeApiKeyScopes(scopes), " ")
|
||||
return k
|
||||
}
|
||||
|
||||
func apiKeyWithRotatedSecretHash(k domain.ApiKey, hashedSecret string) domain.ApiKey {
|
||||
k.ClientSecretHash = hashedSecret
|
||||
return k
|
||||
}
|
||||
|
||||
func normalizeApiKeyScopes(scopes []string) []string {
|
||||
seen := make(map[string]struct{}, len(scopes))
|
||||
normalized := make([]string, 0, len(scopes))
|
||||
for _, scope := range scopes {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[scope]; exists {
|
||||
continue
|
||||
}
|
||||
seen[scope] = struct{}{}
|
||||
normalized = append(normalized, scope)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) ListApiKeys(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
limit := c.QueryInt("limit", 50)
|
||||
offset := c.QueryInt("offset", 0)
|
||||
cursorRaw := strings.TrimSpace(c.Query("cursor"))
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := h.DB.Model(&domain.ApiKey{}).Count(&total).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
var keys []domain.ApiKey
|
||||
query := h.DB.Order("created_at desc, id desc").Limit(limit + 1)
|
||||
if cursorRaw != "" {
|
||||
cursor, err := pagination.Decode(cursorRaw)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid cursor")
|
||||
}
|
||||
query = pagination.ApplyCreatedAtIDCursor(query, cursor, "created_at", "id")
|
||||
offset = 0
|
||||
} else {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
hasMore := len(keys) > limit
|
||||
if len(keys) > limit {
|
||||
keys = keys[:limit]
|
||||
}
|
||||
if cursorRaw == "" && total > int64(offset+len(keys)) {
|
||||
hasMore = true
|
||||
}
|
||||
if hasMore && len(keys) > 0 {
|
||||
last := keys[len(keys)-1]
|
||||
nextCursor = pagination.Encode(last.CreatedAt, last.ID)
|
||||
}
|
||||
|
||||
items := make([]apiKeySummary, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
items = append(items, apiKeyToSummary(k))
|
||||
}
|
||||
|
||||
return c.JSON(apiKeyListResponse{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
Cursor: cursorRaw,
|
||||
NextCursor: nextCursor,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) CreateApiKey(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "name is required")
|
||||
}
|
||||
req.Scopes = normalizeApiKeyScopes(req.Scopes)
|
||||
if len(req.Scopes) == 0 {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "at least one scope is required")
|
||||
}
|
||||
|
||||
// Generate Client ID (16 chars hex)
|
||||
clientID := GenerateSecureToken(8)
|
||||
|
||||
// Generate plain secret (16 chars hex)
|
||||
plainSecret := GenerateSecureToken(8)
|
||||
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(plainSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to hash secret")
|
||||
}
|
||||
|
||||
apiKey := domain.ApiKey{
|
||||
Name: req.Name,
|
||||
ClientID: clientID,
|
||||
ClientSecretHash: string(hashedSecret),
|
||||
Scopes: strings.Join(req.Scopes, " "),
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
if err := h.DB.Create(&apiKey).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
// Return summary + PLAIN SECRET (only this time)
|
||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
|
||||
"apiKey": apiKeyToSummary(apiKey),
|
||||
"clientSecret": plainSecret, // VERY IMPORTANT: user must save this now
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) UpdateApiKey(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
id := c.Params("id")
|
||||
if id == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
req.Scopes = normalizeApiKeyScopes(req.Scopes)
|
||||
if len(req.Scopes) == 0 {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "at least one scope is required")
|
||||
}
|
||||
|
||||
var apiKey domain.ApiKey
|
||||
if err := h.DB.First(&apiKey, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errorJSON(c, fiber.StatusNotFound, "api key not found")
|
||||
}
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
apiKey = apiKeyWithUpdatedScopes(apiKey, req.Scopes)
|
||||
if err := h.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("scopes", apiKey.Scopes).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(apiKeyToSummary(apiKey))
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) RotateApiKeySecret(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
id := c.Params("id")
|
||||
if id == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
var apiKey domain.ApiKey
|
||||
if err := h.DB.First(&apiKey, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errorJSON(c, fiber.StatusNotFound, "api key not found")
|
||||
}
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
plainSecret := GenerateSecureToken(8)
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(plainSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to hash secret")
|
||||
}
|
||||
|
||||
apiKey = apiKeyWithRotatedSecretHash(apiKey, string(hashedSecret))
|
||||
if err := h.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("client_secret_hash", apiKey.ClientSecretHash).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"apiKey": apiKeyToSummary(apiKey),
|
||||
"clientSecret": plainSecret,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) DeleteApiKey(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
id := c.Params("id")
|
||||
if id == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
if err := h.DB.Delete(&domain.ApiKey{}, "id = ?", id).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
133
baron-sso/backend/internal/handler/api_key_handler_test.go
Normal file
133
baron-sso/backend/internal/handler/api_key_handler_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Mock DB for ApiKey tests using a real GORM instance but with a hijacked connection
|
||||
// or just a simple mock if we only check nil.
|
||||
// For ApiKeyHandler, it uses DB for Create/List/Delete.
|
||||
|
||||
func TestApiKeyHandler_CreateApiKey(t *testing.T) {
|
||||
app := fiber.New()
|
||||
// ApiKeyHandler requires a valid DB connection to perform h.DB.Create
|
||||
// Since we don't have a real DB here, we'll check if it fails gracefully
|
||||
// or we can use sqlite in-memory for a more realistic test.
|
||||
h := &ApiKeyHandler{DB: nil} // Testing ServiceUnavailable
|
||||
|
||||
app.Post("/api-keys", h.CreateApiKey)
|
||||
|
||||
input := map[string]any{
|
||||
"name": "M2M Test",
|
||||
"scopes": []string{"read", "write"},
|
||||
}
|
||||
body, _ := json.Marshal(input)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api-keys", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestApiKeyHandler_Validation(t *testing.T) {
|
||||
app := fiber.New()
|
||||
// Using a dummy DB pointer to pass the nil check
|
||||
h := &ApiKeyHandler{DB: &gorm.DB{}}
|
||||
|
||||
app.Post("/api-keys", h.CreateApiKey)
|
||||
|
||||
// Missing name
|
||||
input := map[string]any{
|
||||
"scopes": []string{"read"},
|
||||
}
|
||||
body, _ := json.Marshal(input)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api-keys", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestApiKeyHandler_UpdateApiKeyScopesRequiresDatabase(t *testing.T) {
|
||||
app := fiber.New()
|
||||
h := &ApiKeyHandler{DB: nil}
|
||||
|
||||
app.Patch("/api-keys/:id", h.UpdateApiKey)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"scopes": []string{"org-context:read"},
|
||||
})
|
||||
req := httptest.NewRequest("PATCH", "/api-keys/api-key-id", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestApiKeyHandler_RotateApiKeySecretRequiresDatabase(t *testing.T) {
|
||||
app := fiber.New()
|
||||
h := &ApiKeyHandler{DB: nil}
|
||||
|
||||
app.Post("/api-keys/:id/secret/rotate", h.RotateApiKeySecret)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api-keys/api-key-id/secret/rotate", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestApiKeyWithUpdatedScopesPreservesClientID(t *testing.T) {
|
||||
key := domain.ApiKey{
|
||||
ID: "api-key-id",
|
||||
Name: "M2M Test",
|
||||
ClientID: "client-id-stable",
|
||||
ClientSecretHash: "old-secret-hash",
|
||||
Scopes: "audit:read",
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
updated := apiKeyWithUpdatedScopes(key, []string{"audit:read", "org-context:read"})
|
||||
|
||||
assert.Equal(t, "client-id-stable", updated.ClientID)
|
||||
assert.Equal(t, "old-secret-hash", updated.ClientSecretHash)
|
||||
assert.Equal(t, "audit:read org-context:read", updated.Scopes)
|
||||
}
|
||||
|
||||
func TestApiKeyWithRotatedSecretHashPreservesClientIDAndScopes(t *testing.T) {
|
||||
key := domain.ApiKey{
|
||||
ID: "api-key-id",
|
||||
Name: "M2M Test",
|
||||
ClientID: "client-id-stable",
|
||||
ClientSecretHash: "old-secret-hash",
|
||||
Scopes: "audit:read org-context:read",
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
updated := apiKeyWithRotatedSecretHash(key, "new-secret-hash")
|
||||
|
||||
assert.Equal(t, "client-id-stable", updated.ClientID)
|
||||
assert.Equal(t, "audit:read org-context:read", updated.Scopes)
|
||||
assert.Equal(t, "new-secret-hash", updated.ClientSecretHash)
|
||||
}
|
||||
|
||||
func TestNormalizeApiKeyScopesTrimsAndDeduplicates(t *testing.T) {
|
||||
scopes := normalizeApiKeyScopes([]string{
|
||||
" audit:read ",
|
||||
"",
|
||||
"org-context:read",
|
||||
"audit:read",
|
||||
})
|
||||
|
||||
assert.Equal(t, []string{"audit:read", "org-context:read"}, scopes)
|
||||
}
|
||||
139
baron-sso/backend/internal/handler/audit_handler.go
Normal file
139
baron-sso/backend/internal/handler/audit_handler.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditHandler struct {
|
||||
repo domain.AuditRepository
|
||||
}
|
||||
|
||||
func NewAuditHandler(repo domain.AuditRepository) *AuditHandler {
|
||||
return &AuditHandler{repo: repo}
|
||||
}
|
||||
|
||||
// CreateLog handles POST /api/v1/audit
|
||||
func (h *AuditHandler) CreateLog(c *fiber.Ctx) error {
|
||||
var req domain.AuditLog
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "Cannot parse JSON")
|
||||
}
|
||||
|
||||
// Auto-fill metadata if missing
|
||||
if req.IPAddress == "" {
|
||||
req.IPAddress = c.IP()
|
||||
}
|
||||
if req.UserAgent == "" {
|
||||
req.UserAgent = c.Get("User-Agent")
|
||||
}
|
||||
if req.Timestamp.IsZero() {
|
||||
req.Timestamp = time.Now()
|
||||
}
|
||||
if req.EventID == "" {
|
||||
req.EventID = ensureRequestID(c)
|
||||
}
|
||||
|
||||
if h.repo == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
|
||||
}
|
||||
|
||||
if err := h.repo.Create(&req); err != nil {
|
||||
// Log internal error but don't expose details
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to save audit log")
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
|
||||
"message": "Audit log saved",
|
||||
})
|
||||
}
|
||||
|
||||
// ListLogs handles GET /api/v1/audit
|
||||
func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
|
||||
limit := c.QueryInt("limit", 50)
|
||||
cursorRaw := c.Query("cursor")
|
||||
requestedTenantID := c.Query("tenantId")
|
||||
|
||||
cursor, err := parseAuditCursor(cursorRaw)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor")
|
||||
}
|
||||
|
||||
if h.repo == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
|
||||
}
|
||||
|
||||
// [New] Role-based Filtering
|
||||
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
var filterTenantID string
|
||||
|
||||
if profile != nil {
|
||||
if profile.Role == domain.RoleSuperAdmin {
|
||||
// Super Admin can see everything or filter by a specific tenant if requested
|
||||
filterTenantID = requestedTenantID
|
||||
} else {
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden")
|
||||
}
|
||||
}
|
||||
|
||||
logs, err := h.repo.FindPage(c.Context(), limit+1, cursor, filterTenantID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(logs) > limit {
|
||||
last := logs[limit-1]
|
||||
nextCursor = encodeAuditCursor(last)
|
||||
logs = logs[:limit]
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"items": logs,
|
||||
"limit": limit,
|
||||
"cursor": cursorRaw,
|
||||
"next_cursor": nextCursor,
|
||||
})
|
||||
}
|
||||
|
||||
func ensureRequestID(c *fiber.Ctx) string {
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
return reqID
|
||||
}
|
||||
|
||||
func parseAuditCursor(raw string) (*domain.AuditCursor, error) {
|
||||
if raw == "" {
|
||||
return nil, nil
|
||||
}
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := strings.SplitN(string(decoded), "|", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid cursor")
|
||||
}
|
||||
ts, err := time.Parse(time.RFC3339Nano, parts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &domain.AuditCursor{
|
||||
Timestamp: ts,
|
||||
EventID: parts[1],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeAuditCursor(log domain.AuditLog) string {
|
||||
payload := log.Timestamp.UTC().Format(time.RFC3339Nano) + "|" + log.EventID
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||
}
|
||||
9066
baron-sso/backend/internal/handler/auth_handler.go
Normal file
9066
baron-sso/backend/internal/handler/auth_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
371
baron-sso/backend/internal/handler/auth_handler_async_test.go
Normal file
371
baron-sso/backend/internal/handler/auth_handler_async_test.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// --- Async Test Mocks ---
|
||||
|
||||
type AsyncMockIdpProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) Name() string { return "mock-idp" }
|
||||
func (m *AsyncMockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
return &domain.IDPMetadata{}, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) UserExists(loginID string) (bool, error) {
|
||||
args := m.Called(loginID)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
args := m.Called(user, password)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return &domain.PasswordPolicy{MinLength: 12}, nil
|
||||
}
|
||||
func (m *AsyncMockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
|
||||
func (m *AsyncMockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type AsyncMockUserRepo struct {
|
||||
mock.Mock
|
||||
createCalled chan bool
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) Create(ctx context.Context, user *domain.User) error {
|
||||
// Simulate DB latency
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
args := m.Called(ctx, user)
|
||||
if m.createCalled != nil {
|
||||
m.createCalled <- true
|
||||
}
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) Update(ctx context.Context, user *domain.User) error {
|
||||
args := m.Called(ctx, user)
|
||||
if m.createCalled != nil {
|
||||
m.createCalled <- true
|
||||
}
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) Delete(ctx context.Context, id string) error { return nil }
|
||||
func (m *AsyncMockUserRepo) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) {
|
||||
return nil, 0, "", nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, tenantIDs)
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, codes)
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) DB() *gorm.DB {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type AsyncMockRedisRepo struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AsyncMockRedisRepo) Set(key string, value string, expiration time.Duration) error {
|
||||
args := m.Called(key, value, expiration)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AsyncMockRedisRepo) Get(key string) (string, error) {
|
||||
args := m.Called(key)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockRedisRepo) Delete(key string) error {
|
||||
args := m.Called(key)
|
||||
return args.Error(0)
|
||||
}
|
||||
func (m *AsyncMockRedisRepo) StoreVerificationCode(phone, code string) error { return nil }
|
||||
func (m *AsyncMockRedisRepo) GetVerificationCode(phone string) (string, error) { return "", nil }
|
||||
func (m *AsyncMockRedisRepo) DeleteVerificationCode(phone string) error { return nil }
|
||||
|
||||
type AsyncMockTenantService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, emailDomain)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *AsyncMockTenantService) ApproveTenant(ctx context.Context, id string) error { return nil }
|
||||
func (m *AsyncMockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *AsyncMockTenantService) SetKetoService(keto service.KetoService) {}
|
||||
func (m *AsyncMockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) DeleteTenantsBulk(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
args := m.Called(ctx, userID)
|
||||
if args.Get(0) != nil {
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
type AsyncMockKetoService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestSignup_AsyncDB_Isolation(t *testing.T) {
|
||||
mockIdp := new(AsyncMockIdpProvider)
|
||||
mockUserRepo := new(AsyncMockUserRepo)
|
||||
mockRedis := new(AsyncMockRedisRepo)
|
||||
mockTenant := new(AsyncMockTenantService)
|
||||
mockKeto := new(AsyncMockKetoService)
|
||||
|
||||
h := &AuthHandler{
|
||||
IdpProvider: mockIdp,
|
||||
UserRepo: mockUserRepo,
|
||||
RedisService: mockRedis,
|
||||
TenantService: mockTenant,
|
||||
KetoService: mockKeto,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/signup", h.Signup)
|
||||
|
||||
t.Run("SoT_DB_Failure_Ignored_And_Async", func(t *testing.T) {
|
||||
email := "test@example.com"
|
||||
phone := "010-1234-5678"
|
||||
emailKey := "signup:email:" + email
|
||||
phoneKey := "signup:phone:" + "01012345678"
|
||||
|
||||
// Redis Mocks
|
||||
mockRedis.On("Get", emailKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
|
||||
mockRedis.On("Get", phoneKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
|
||||
mockRedis.On("Delete", emailKey).Return(nil)
|
||||
mockRedis.On("Delete", phoneKey).Return(nil)
|
||||
|
||||
// Tenant Mocks
|
||||
personalTenant := &domain.Tenant{ID: "personal-t1", Slug: "personal-test", Type: domain.TenantTypePersonal, Status: domain.TenantStatusActive}
|
||||
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(nil, nil)
|
||||
mockTenant.On(
|
||||
"RegisterTenant",
|
||||
mock.Anything,
|
||||
"Personal - test@example.com",
|
||||
mock.MatchedBy(func(slug string) bool { return strings.HasPrefix(slug, "personal-") }),
|
||||
domain.TenantTypePersonal,
|
||||
"Automatically provisioned personal tenant",
|
||||
[]string(nil),
|
||||
(*string)(nil),
|
||||
"",
|
||||
).Return(personalTenant, nil)
|
||||
mockTenant.On("GetTenant", mock.Anything, "personal-t1").Return(personalTenant, nil)
|
||||
|
||||
// Kratos Mocks (Success)
|
||||
mockIdp.On("CreateUser", mock.Anything, "Password123!").Return("new-user-uuid", nil)
|
||||
|
||||
// UserRepo Mocks (Async & Failure)
|
||||
mockUserRepo.createCalled = make(chan bool, 1)
|
||||
mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
|
||||
return u.Email == email
|
||||
})).Return(errors.New("db connection error"))
|
||||
|
||||
// Keto Mocks (Optional, since it's also async)
|
||||
// We won't block on this either
|
||||
|
||||
body, _ := json.Marshal(domain.SignupRequest{
|
||||
Email: email,
|
||||
Password: "Password123!",
|
||||
Name: "Test User",
|
||||
Phone: phone,
|
||||
TermsAccepted: true,
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
start := time.Now()
|
||||
resp, err := app.Test(req, 5000)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
// Ensure API responded faster than DB latency (50ms)
|
||||
assert.Less(t, int64(elapsed), int64(60*time.Millisecond), "API should return before DB timeout")
|
||||
|
||||
// Wait for async execution
|
||||
select {
|
||||
case <-mockUserRepo.createCalled:
|
||||
// Pass
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("UserRepo.Create was not called asynchronously")
|
||||
}
|
||||
|
||||
mockRedis.AssertExpectations(t)
|
||||
mockIdp.AssertExpectations(t)
|
||||
mockUserRepo.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
200
baron-sso/backend/internal/handler/auth_handler_client_test.go
Normal file
200
baron-sso/backend/internal/handler/auth_handler_client_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRevokeLinkedRp_Success(t *testing.T) {
|
||||
// Mock Hydra transport for revocation
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// 1. Kratos whoami
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{"id": "user-123"},
|
||||
}), nil
|
||||
}
|
||||
// 2. Hydra Revoke
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
assert.Equal(t, "user-123", r.URL.Query().Get("subject"))
|
||||
assert.Equal(t, "app-1", r.URL.Query().Get("client"))
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
rpUsageSink := &mockRPUsageEventSink{}
|
||||
consentRepo := &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{
|
||||
ClientID: "app-1",
|
||||
Subject: "user-123",
|
||||
GrantedScopes: []string{"openid", "profile"},
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
AuditRepo: auditRepo,
|
||||
ConsentRepo: consentRepo,
|
||||
RPUsageSink: rpUsageSink,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/rp/linked/:id", h.RevokeLinkedRp)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/rp/linked/app-1", nil)
|
||||
req.Header.Set("Cookie", "valid")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, 1, len(auditRepo.logs))
|
||||
assert.Equal(t, "consent.revoked", auditRepo.logs[0].EventType)
|
||||
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
|
||||
assert.Equal(t, "success", auditRepo.logs[0].Status)
|
||||
auditDetails, err := utils.ParseAuditDetails(auditRepo.logs[0].Details)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "app-1", auditDetails["client_id"])
|
||||
assert.Equal(t, 1, len(rpUsageSink.events))
|
||||
assert.Equal(t, domain.RPUsageEventTypeAuthorizationRevoked, rpUsageSink.events[0].EventType)
|
||||
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
|
||||
assert.Equal(t, "app-1", rpUsageSink.events[0].ClientID)
|
||||
remaining, err := consentRepo.Find(req.Context(), "app-1", "user-123")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, remaining)
|
||||
}
|
||||
|
||||
func TestRevokeLinkedRp_SendsBackchannelLogoutTokenWhenConfigured(t *testing.T) {
|
||||
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
|
||||
|
||||
var receivedBody string
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{"id": "user-123"},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Host == "hydra.test" && r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
if r.URL.Host == "hydra.test" && r.Method == http.MethodGet && r.URL.Path == "/clients/app-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "app-1",
|
||||
"backchannel_logout_uri": "https://rp.example.com/backchannel-logout",
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Host == "rp.example.com" && r.Method == http.MethodPost && r.URL.Path == "/backchannel-logout" {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
receivedBody = string(raw)
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
backchannelLogout, err := service.NewBackchannelLogoutService()
|
||||
assert.NoError(t, err)
|
||||
backchannelLogout.HTTPClient = client
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
BackchannelLogout: backchannelLogout,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/rp/linked/:id", h.RevokeLinkedRp)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/rp/linked/app-1", nil)
|
||||
req.Header.Set("Cookie", "valid")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.True(t, strings.Contains(receivedBody, "logout_token="))
|
||||
|
||||
values, err := url.ParseQuery(receivedBody)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, values.Get("logout_token"))
|
||||
|
||||
assert.Len(t, auditRepo.logs, 2)
|
||||
assert.Equal(t, "backchannel_logout.sent", auditRepo.logs[1].EventType)
|
||||
}
|
||||
|
||||
func TestListRpHistory_Aggregation(t *testing.T) {
|
||||
now := time.Now()
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.revoked", // Newest
|
||||
Timestamp: now,
|
||||
Details: `{"client_id":"app-1"}`,
|
||||
},
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.granted", // Oldest
|
||||
Timestamp: now.Add(-1 * time.Hour),
|
||||
Details: `{"client_id":"app-1", "client_name":"App One"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/rp/history", h.ListRpHistory)
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{"id": "user-123"},
|
||||
}), nil
|
||||
})
|
||||
http.DefaultClient = &http.Client{Transport: transport}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/history", nil)
|
||||
req.Header.Set("Cookie", "valid")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res struct {
|
||||
Items []struct {
|
||||
ClientID string `json:"client_id"`
|
||||
Status string `json:"status"`
|
||||
} `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Equal(t, 1, len(res.Items))
|
||||
assert.Equal(t, "app-1", res.Items[0].ClientID)
|
||||
// Newest event (revoked) should win
|
||||
assert.Equal(t, "revoked", res.Items[0].Status)
|
||||
}
|
||||
515
baron-sso/backend/internal/handler/auth_handler_consent_test.go
Normal file
515
baron-sso/backend/internal/handler/auth_handler_consent_test.go
Normal file
@@ -0,0 +1,515 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// --- Mocks ---
|
||||
|
||||
type MockKratosAdminServiceForConsent struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
|
||||
args := m.Called(ctx, identifier)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) GetIdentity(ctx context.Context, id string) (*service.KratosIdentity, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*service.KratosIdentity), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) ListIdentities(ctx context.Context) ([]service.KratosIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) CreateIdentity(ctx context.Context, traits map[string]any) (*service.KratosIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) DeleteIdentity(ctx context.Context, identityID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) ListIdentitySessions(ctx context.Context, identityID string) ([]service.KratosSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) GetSession(ctx context.Context, sessionID string) (*service.KratosSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type MockTenantServiceForConsent struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) GetTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ApproveTenant(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
args := m.Called(ctx, userID)
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
args := m.Called(ctx, userID)
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) SetKetoService(keto service.KetoService) {}
|
||||
|
||||
func (m *MockTenantServiceForConsent) DeleteTenantsBulk(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Test Helpers ---
|
||||
|
||||
func newConsentTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
return app
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestGetConsentRequest_Normal(t *testing.T) {
|
||||
// Mock Hydra transport
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-123",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"skip": false,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"client_name": "Test App",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
app := newConsentTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-123", nil)
|
||||
resp, err := app.Test(req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&body)
|
||||
|
||||
assert.Equal(t, "challenge-123", body["challenge"])
|
||||
assert.Equal(t, false, body["skip"])
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_AddsMandatoryTenantScope(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-scope" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-tenant-scope",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"skip": false,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"client_name": "Test App",
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-allow"},
|
||||
"structured_scopes": []map[string]any{
|
||||
{"name": "openid", "mandatory": true},
|
||||
{"name": "tenant", "mandatory": true, "locked": true},
|
||||
{"name": "profile", "mandatory": false},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
mockTenantSvc := &MockTenantServiceForConsent{}
|
||||
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
|
||||
|
||||
// Mock profile resolution to allow tenant access
|
||||
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, "tenant-allow").Return(&domain.Tenant{
|
||||
ID: "tenant-allow",
|
||||
Slug: "tenant-allow",
|
||||
Name: "Allowed Tenant",
|
||||
}, nil)
|
||||
|
||||
// Mock hydration calls
|
||||
mockTenantSvc.On("ListJoinedTenants", mock.Anything, mock.Anything).Return([]domain.Tenant{
|
||||
{ID: "tenant-allow", Slug: "tenant-allow", Name: "Allowed Tenant"},
|
||||
}, nil)
|
||||
mockTenantSvc.On("ListManageableTenants", mock.Anything, mock.Anything).Return([]domain.Tenant{}, nil)
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
TenantService: mockTenantSvc,
|
||||
KratosAdmin: mockKratosAdmin,
|
||||
}
|
||||
app := newConsentTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant-scope", nil)
|
||||
req.Header.Set("X-Mock-Role", "user")
|
||||
req.Header.Set("X-Tenant-ID", "tenant-allow")
|
||||
resp, err := app.Test(req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&body)
|
||||
|
||||
assert.Equal(t, []any{"openid", "tenant", "profile"}, body["requested_scope"])
|
||||
scopeDetails := body["scope_details"].(map[string]any)
|
||||
tenantDetail := scopeDetails["tenant"].(map[string]any)
|
||||
assert.Equal(t, true, tenantDetail["mandatory"])
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_Skip_AutoAccept(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// Hydra: Get Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-skip",
|
||||
"requested_scope": []string{"openid"},
|
||||
"skip": true,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Kratos: Get Identity
|
||||
if r.URL.Path == "/admin/identities/user-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Hydra: Accept Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-skip" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
consentRepo := &mockConsentRepo{}
|
||||
rpUsageSink := &mockRPUsageEventSink{}
|
||||
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: mockKratosAdmin,
|
||||
ConsentRepo: consentRepo,
|
||||
RPUsageSink: rpUsageSink,
|
||||
}
|
||||
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := newConsentTestApp(h)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip", nil)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.Equal(t, "http://rp/cb", body["redirectTo"])
|
||||
assert.Equal(t, 1, len(rpUsageSink.events))
|
||||
assert.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, rpUsageSink.events[0].EventType)
|
||||
assert.Equal(t, "client-app", rpUsageSink.events[0].ClientID)
|
||||
assert.Equal(t, "challenge-skip", rpUsageSink.events[0].CorrelationID)
|
||||
assert.Equal(t, true, rpUsageSink.events[0].Payload["auto_accepted"])
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_Normal(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-accept" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-accept",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"client_name": "Test App",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/identities/user-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-accept" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
consentRepo := &mockConsentRepo{}
|
||||
rpUsageSink := &mockRPUsageEventSink{}
|
||||
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: mockKratosAdmin,
|
||||
AuditRepo: auditRepo,
|
||||
ConsentRepo: consentRepo,
|
||||
RPUsageSink: rpUsageSink,
|
||||
}
|
||||
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := newConsentTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-accept",
|
||||
"grant_scope": []string{"openid"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, 1, len(auditRepo.logs))
|
||||
assert.Equal(t, "consent.granted", auditRepo.logs[0].EventType)
|
||||
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
|
||||
assert.Equal(t, "success", auditRepo.logs[0].Status)
|
||||
auditDetails, err := utils.ParseAuditDetails(auditRepo.logs[0].Details)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "client-app", auditDetails["client_id"])
|
||||
assert.Equal(t, "Test App", auditDetails["client_name"])
|
||||
assert.Equal(t, []any{"openid"}, auditDetails["scopes"])
|
||||
assert.Equal(t, 1, len(rpUsageSink.events))
|
||||
assert.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, rpUsageSink.events[0].EventType)
|
||||
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
|
||||
assert.Equal(t, "client-app", rpUsageSink.events[0].ClientID)
|
||||
assert.Equal(t, "Test App", rpUsageSink.events[0].ClientName)
|
||||
assert.Equal(t, []string{"openid"}, []string(rpUsageSink.events[0].Scopes))
|
||||
assert.Equal(t, "hydra_consent", rpUsageSink.events[0].Source)
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_EnforcesMandatoryTenantScope(t *testing.T) {
|
||||
t.Setenv("APP_ENV", "dev")
|
||||
|
||||
var capturedGrantScopes []string
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-accept" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-tenant-accept",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-abc",
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-abc"},
|
||||
"structured_scopes": []map[string]any{
|
||||
{"name": "openid", "mandatory": true},
|
||||
{"name": "tenant", "mandatory": true, "locked": true},
|
||||
{"name": "profile", "mandatory": false},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/identities/user-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-accept" {
|
||||
var payload map[string]any
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&payload))
|
||||
for _, scope := range payload["grant_scope"].([]any) {
|
||||
capturedGrantScopes = append(capturedGrantScopes, scope.(string))
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: mockKratosAdmin,
|
||||
}
|
||||
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := newConsentTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-tenant-accept",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Mock-Role", "user")
|
||||
req.Header.Set("X-Tenant-ID", "tenant-abc")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, []string{"openid", "tenant", "profile"}, capturedGrantScopes)
|
||||
}
|
||||
@@ -0,0 +1,829 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestBuildOidcClaimsFromTraits_DynamicClaims(t *testing.T) {
|
||||
traits := map[string]any{
|
||||
"email": "user@baron.com",
|
||||
"name": "홍길동",
|
||||
"tenant_id": "primary-tenant-999", // Added primary tenant
|
||||
"tenant-1": map[string]any{
|
||||
"department": "개발팀",
|
||||
"grade": "선임",
|
||||
},
|
||||
"tenant-2": map[string]any{
|
||||
"department": "재무팀",
|
||||
"grade": "팀장",
|
||||
},
|
||||
}
|
||||
scopes := []string{"openid", "profile"}
|
||||
|
||||
t.Run("No tenantID", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, scopes, "")
|
||||
assert.Equal(t, "user@baron.com", claims["email"])
|
||||
assert.Equal(t, "홍길동", claims["name"])
|
||||
assert.Equal(t, "primary-tenant-999", claims["tenant_id"])
|
||||
assert.Nil(t, claims["department"])
|
||||
assert.Nil(t, claims["grade"])
|
||||
|
||||
assert.Nil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-1")
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-2")
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999") // Should contain primary
|
||||
})
|
||||
|
||||
t.Run("With tenant-1", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-1")
|
||||
assert.Equal(t, "user@baron.com", claims["email"])
|
||||
assert.Equal(t, "홍길동", claims["name"])
|
||||
assert.Equal(t, "tenant-1", claims["tenant_id"])
|
||||
assert.Nil(t, claims["department"])
|
||||
assert.Nil(t, claims["grade"])
|
||||
|
||||
assert.Nil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-1")
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-2")
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
|
||||
})
|
||||
|
||||
t.Run("With tenant-2", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-2")
|
||||
assert.Equal(t, "user@baron.com", claims["email"])
|
||||
assert.Equal(t, "홍길동", claims["name"])
|
||||
assert.Equal(t, "tenant-2", claims["tenant_id"])
|
||||
assert.Nil(t, claims["department"])
|
||||
assert.Nil(t, claims["grade"])
|
||||
|
||||
assert.Nil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
|
||||
})
|
||||
|
||||
t.Run("With non-existent tenant", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-3")
|
||||
assert.Equal(t, "user@baron.com", claims["email"])
|
||||
assert.Equal(t, "홍길동", claims["name"])
|
||||
assert.Equal(t, "tenant-3", claims["tenant_id"])
|
||||
assert.Nil(t, claims["department"])
|
||||
assert.Nil(t, claims["grade"])
|
||||
|
||||
assert.Nil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-1")
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
|
||||
})
|
||||
|
||||
t.Run("Tenant scope includes detailed tenant metadata", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, []string{"openid", "profile", "tenant"}, "tenant-1")
|
||||
assert.Equal(t, "tenant-1", claims["tenant_id"])
|
||||
assert.Equal(t, "개발팀", claims["department"])
|
||||
assert.Equal(t, "선임", claims["grade"])
|
||||
assert.NotNil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-1")
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-2")
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepresentativeTenantIDFromTraits(t *testing.T) {
|
||||
t.Run("explicit tenant_id wins", func(t *testing.T) {
|
||||
traits := map[string]any{
|
||||
"tenant_id": "01970f0a-5c28-74d8-a73a-f6e9e9a7b210",
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", "isPrimary": true},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "01970f0a-5c28-74d8-a73a-f6e9e9a7b210", representativeTenantIDFromTraits(traits))
|
||||
})
|
||||
|
||||
t.Run("primary appointment wins when tenant_id is absent", func(t *testing.T) {
|
||||
traits := map[string]any{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
|
||||
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630", "representative": true},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "01970f0c-8c44-7069-9f20-7d28c0b8e630", representativeTenantIDFromTraits(traits))
|
||||
})
|
||||
|
||||
t.Run("first appointment is fallback", func(t *testing.T) {
|
||||
traits := map[string]any{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
|
||||
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630"},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", representativeTenantIDFromTraits(traits))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// Hydra: Get Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-dynamic",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-abc",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Kratos: Get Identity
|
||||
if r.URL.Path == "/admin/identities/user-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
"tenant-abc": map[string]any{
|
||||
"department": "Innovation",
|
||||
"position": "Architect",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Hydra: Accept Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
|
||||
// Capture the claims sent to Hydra
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
"tenant-abc": map[string]any{
|
||||
"department": "Innovation",
|
||||
"position": "Architect",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-dynamic",
|
||||
"grant_scope": []string{"openid", "profile", "tenant"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify captured claims
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, "user@test.com", capturedClaims["email"])
|
||||
assert.Equal(t, "tenant-abc", capturedClaims["tenant_id"])
|
||||
assert.Equal(t, "Innovation", capturedClaims["department"])
|
||||
assert.Equal(t, "Architect", capturedClaims["position"])
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_UsesRepresentativeTenantIDInsteadOfClientTenantContext(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
representativeTenantID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
|
||||
rpContextTenantID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-representative-tenant",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"subject": "user-representative",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": rpContextTenantID,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-representative").Return(&service.KratosIdentity{
|
||||
ID: "user-representative",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{"tenantId": representativeTenantID, "isPrimary": true},
|
||||
map[string]any{"tenantId": rpContextTenantID},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-representative-tenant",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, representativeTenantID, capturedClaims["tenant_id"])
|
||||
assert.Contains(t, capturedClaims["joined_tenants"], representativeTenantID)
|
||||
assert.Contains(t, capturedClaims["joined_tenants"], rpContextTenantID)
|
||||
assert.Nil(t, capturedClaims["tenants"])
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_IncludesHanmacFamilyTenantClaimDetails(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
|
||||
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
|
||||
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
|
||||
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-hanmac-tenant-claim",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"subject": "user-hanmac",
|
||||
"client": map[string]any{
|
||||
"client_id": "hanmac-rp",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": deptID,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-hanmac").Return(&service.KratosIdentity{
|
||||
ID: "user-hanmac",
|
||||
Traits: map[string]any{
|
||||
"email": "hanmac-user@example.com",
|
||||
"name": "한맥 사용자",
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": deptID,
|
||||
"isPrimary": true,
|
||||
"isOwner": true,
|
||||
"grade": "책임",
|
||||
"jobTitle": "기술기획",
|
||||
"position": "팀장",
|
||||
},
|
||||
map[string]any{
|
||||
"tenantId": secondDeptID,
|
||||
"isPrimary": false,
|
||||
"isOwner": false,
|
||||
"grade": "선임",
|
||||
"jobTitle": "품질관리",
|
||||
"position": "파트원",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mockTenantSvc := new(MockTenantService)
|
||||
mockTenantSvc.On("ListJoinedTenants", mock.Anything, "user-hanmac").Return([]domain.Tenant{}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
|
||||
ID: deptID,
|
||||
Slug: "tech-planning",
|
||||
Name: "기술기획팀",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &companyID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
|
||||
ID: secondDeptID,
|
||||
Slug: "quality",
|
||||
Name: "품질관리팀",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &companyID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
Name: "한맥기술",
|
||||
Type: domain.TenantTypeCompany,
|
||||
ParentID: &rootID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
|
||||
ID: rootID,
|
||||
Slug: "hanmac-family",
|
||||
Name: "한맥가족",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
}, nil)
|
||||
h.TenantService = mockTenantSvc
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-hanmac-tenant-claim",
|
||||
"grant_scope": []string{"openid", "profile", "tenant"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, []any{deptID}, capturedClaims["lead_tenants"])
|
||||
assert.ElementsMatch(t, []any{deptID, secondDeptID}, capturedClaims["joined_tenants"])
|
||||
tenants := capturedClaims["tenants"].(map[string]any)
|
||||
dept := tenants[deptID].(map[string]any)
|
||||
assert.Equal(t, true, dept["lead"])
|
||||
assert.Equal(t, true, dept["representative"])
|
||||
assert.Equal(t, "책임", dept["grade"])
|
||||
assert.Equal(t, "기술기획", dept["jobTitle"])
|
||||
assert.Equal(t, "팀장", dept["position"])
|
||||
assert.Equal(t, companyID, dept["parentTenantId"])
|
||||
assert.NotContains(t, dept, "parentTenant")
|
||||
|
||||
ancestors := dept["ancestors"].([]any)
|
||||
assert.Len(t, ancestors, 2)
|
||||
companyAncestor := ancestors[0].(map[string]any)
|
||||
assert.Equal(t, companyID, companyAncestor["id"])
|
||||
assert.Equal(t, "hanmac", companyAncestor["slug"])
|
||||
assert.Equal(t, rootID, companyAncestor["parentTenantId"])
|
||||
assert.NotContains(t, companyAncestor, "parentTenant")
|
||||
rootAncestor := ancestors[1].(map[string]any)
|
||||
assert.Equal(t, rootID, rootAncestor["id"])
|
||||
assert.Equal(t, "hanmac-family", rootAncestor["slug"])
|
||||
assert.Contains(t, rootAncestor, "parentTenantId")
|
||||
assert.Nil(t, rootAncestor["parentTenantId"])
|
||||
assert.NotContains(t, rootAncestor, "parentTenant")
|
||||
|
||||
secondDept := tenants[secondDeptID].(map[string]any)
|
||||
assert.Equal(t, false, secondDept["lead"])
|
||||
assert.Equal(t, false, secondDept["representative"])
|
||||
assert.Equal(t, "선임", secondDept["grade"])
|
||||
assert.Equal(t, "품질관리", secondDept["jobTitle"])
|
||||
assert.Equal(t, "파트원", secondDept["position"])
|
||||
assert.Equal(t, companyID, secondDept["parentTenantId"])
|
||||
}
|
||||
|
||||
func TestWithHanmacFamilyTenantClaims_DefaultClaimsOnlyWithoutTenantScope(t *testing.T) {
|
||||
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
|
||||
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
|
||||
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
|
||||
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
|
||||
|
||||
mockTenantSvc := new(MockTenantService)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
|
||||
ID: deptID,
|
||||
Slug: "tech-planning",
|
||||
Name: "기술기획팀",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &companyID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
|
||||
ID: secondDeptID,
|
||||
Slug: "quality",
|
||||
Name: "품질관리팀",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &companyID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
Name: "한맥기술",
|
||||
Type: domain.TenantTypeCompany,
|
||||
ParentID: &rootID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
|
||||
ID: rootID,
|
||||
Slug: "hanmac-family",
|
||||
Name: "한맥가족",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
}, nil)
|
||||
|
||||
h := &AuthHandler{TenantService: mockTenantSvc}
|
||||
claims := map[string]any{"tenant_id": deptID}
|
||||
traits := map[string]any{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": deptID,
|
||||
"isPrimary": true,
|
||||
"isOwner": true,
|
||||
"grade": "책임",
|
||||
},
|
||||
map[string]any{
|
||||
"tenantId": secondDeptID,
|
||||
"grade": "선임",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claims = h.withHanmacFamilyTenantClaims(context.Background(), claims, traits, []string{"openid", "profile"})
|
||||
|
||||
assert.Equal(t, deptID, claims["tenant_id"])
|
||||
assert.ElementsMatch(t, []string{deptID, secondDeptID}, claims["joined_tenants"])
|
||||
assert.NotContains(t, claims, "tenants")
|
||||
assert.NotContains(t, claims, "lead_tenants")
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_IncludesRPProfileClaims(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-rp-profile",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]any{
|
||||
"customUserSchema": []map[string]any{
|
||||
{
|
||||
"key": "approvalLevel",
|
||||
"label": "승인 등급",
|
||||
"type": "text",
|
||||
"claimEnabled": true,
|
||||
},
|
||||
{
|
||||
"key": "internalMemo",
|
||||
"label": "내부 메모",
|
||||
"type": "text",
|
||||
"claimEnabled": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
}, nil)
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
repo.On("Get", mock.Anything, "client-app", "user-123").Return(&domain.RPUserMetadata{
|
||||
ClientID: "client-app",
|
||||
UserID: "user-123",
|
||||
Metadata: domain.JSONMap{
|
||||
"approvalLevel": "A",
|
||||
"internalMemo": "관리자 전용",
|
||||
},
|
||||
}, nil).Once()
|
||||
h.RPUserMetadataRepo = repo
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-rp-profile",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
rpProfiles, ok := capturedClaims["rp_profiles"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, rpProfiles, 1)
|
||||
profile := rpProfiles[0].(map[string]any)
|
||||
assert.Equal(t, "client-app", profile["client_id"])
|
||||
fields := profile["fields"].(map[string]any)
|
||||
assert.Equal(t, "A", fields["approvalLevel"])
|
||||
assert.NotContains(t, fields, "internalMemo")
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_Skip_DynamicClaims(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// Hydra: Get Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-skip-dynamic",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"skip": true,
|
||||
"subject": "user-456",
|
||||
"client": map[string]any{
|
||||
"client_id": "skip-app",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-xyz",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Kratos: Get Identity
|
||||
if r.URL.Path == "/admin/identities/user-456" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-456",
|
||||
"traits": map[string]any{
|
||||
"email": "skip@test.com",
|
||||
"tenant-xyz": map[string]any{
|
||||
"department": "Security",
|
||||
"position": "Officer",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Hydra: Accept Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
|
||||
// Capture the claims sent to Hydra
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-456").Return(&service.KratosIdentity{
|
||||
ID: "user-456",
|
||||
Traits: map[string]any{
|
||||
"email": "skip@test.com",
|
||||
"tenant-xyz": map[string]any{
|
||||
"department": "Security",
|
||||
"position": "Officer",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip-dynamic", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify captured claims
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, "skip@test.com", capturedClaims["email"])
|
||||
assert.Equal(t, "tenant-xyz", capturedClaims["tenant_id"])
|
||||
assert.Equal(t, "Security", capturedClaims["department"])
|
||||
assert.Equal(t, "Officer", capturedClaims["position"])
|
||||
}
|
||||
|
||||
func TestBuildOidcClaimsFromTraits_IncludesGlobalCustomClaims(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
"global_custom_claims": map[string]any{
|
||||
"contract_date": "2026-06-09",
|
||||
"approved_at": "2026-06-09T09:30:00+09:00",
|
||||
"email": "override@test.com",
|
||||
"rp_claims": "reserved",
|
||||
},
|
||||
"global_custom_claim_permissions": map[string]any{
|
||||
"contract_date": map[string]any{
|
||||
"readPermission": "user_and_admin",
|
||||
"writePermission": "admin_only",
|
||||
},
|
||||
},
|
||||
}, []string{"openid", "profile", "email"}, "")
|
||||
|
||||
assert.Equal(t, "2026-06-09", claims["contract_date"])
|
||||
assert.Equal(t, "2026-06-09T09:30:00+09:00", claims["approved_at"])
|
||||
assert.Equal(t, "user@test.com", claims["email"])
|
||||
assert.NotEqual(t, "reserved", claims["rp_claims"])
|
||||
assert.NotContains(t, claims, "global_custom_claim_permissions")
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_AppliesConfiguredIDTokenClaims(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-configured-claims" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-configured-claims",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"subject": "user-789",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-configured-claims",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-claims",
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "top_level",
|
||||
"key": "locale",
|
||||
"value": "ko-KR",
|
||||
"valueType": "text",
|
||||
},
|
||||
{
|
||||
"namespace": "top_level",
|
||||
"key": "email",
|
||||
"value": "should-not-override@example.com",
|
||||
"valueType": "text",
|
||||
},
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "tier",
|
||||
"value": "2",
|
||||
"valueType": "number",
|
||||
},
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "features",
|
||||
"value": "[\"sso\",\"claims\"]",
|
||||
"valueType": "array",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-configured-claims" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-789").Return(&service.KratosIdentity{
|
||||
ID: "user-789",
|
||||
Traits: map[string]any{
|
||||
"email": "real-user@example.com",
|
||||
"name": "Configured User",
|
||||
"tenant-claims": map[string]any{
|
||||
"department": "Platform",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-configured-claims",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, "real-user@example.com", capturedClaims["email"])
|
||||
assert.Equal(t, "ko-KR", capturedClaims["locale"])
|
||||
assert.Equal(t, "tenant-claims", capturedClaims["tenant_id"])
|
||||
|
||||
rpClaims, ok := capturedClaims["rp_claims"].(map[string]any)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, float64(2), rpClaims["tier"])
|
||||
assert.Equal(t, []any{"sso", "claims"}, rpClaims["features"])
|
||||
}
|
||||
}
|
||||
904
baron-sso/backend/internal/handler/auth_handler_link_test.go
Normal file
904
baron-sso/backend/internal/handler/auth_handler_link_test.go
Normal file
@@ -0,0 +1,904 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/testsupport"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Mock services
|
||||
type mockEmailService struct {
|
||||
lastTo string
|
||||
lastSubject string
|
||||
lastBody string
|
||||
}
|
||||
|
||||
func (m *mockEmailService) SendEmail(to, subject, body string) error {
|
||||
m.lastTo = to
|
||||
m.lastSubject = subject
|
||||
m.lastBody = body
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockSmsService struct {
|
||||
lastTo string
|
||||
lastContent string
|
||||
}
|
||||
|
||||
func (m *mockSmsService) SendSms(to, content string) error {
|
||||
m.lastTo = to
|
||||
m.lastContent = content
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit)
|
||||
app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll)
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
return app
|
||||
}
|
||||
|
||||
func newKratosWhoamiTestServer(t *testing.T, identityID string) *httptest.Server {
|
||||
t.Helper()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/sessions/whoami" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Cookie") == "" && r.Header.Get("X-Session-Token") == "" {
|
||||
http.Error(w, "missing session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "session-123",
|
||||
"authenticated_at": "2026-05-21T00:00:00Z",
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
origDefaultClient := http.DefaultClient
|
||||
http.DefaultClient = server.Client()
|
||||
t.Cleanup(func() {
|
||||
http.DefaultClient = origDefaultClient
|
||||
})
|
||||
t.Cleanup(server.Close)
|
||||
return server
|
||||
}
|
||||
|
||||
func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
EmailService: &mockEmailService{},
|
||||
SmsService: &mockSmsService{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/init", h.InitEnchantedLink)
|
||||
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
// 1. Init Enchanted Link (Email)
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"loginId": "user@example.com",
|
||||
"method": "email",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
// Find the token key "enchanted_token:..." in mock redis
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 2. Verify Magic Link
|
||||
verifyBody, _ := json.Marshal(map[string]any{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// 3. Poll (Success)
|
||||
pollBody, _ := json.Marshal(map[string]string{"pendingRef": pendingRef})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var pollResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&pollResp)
|
||||
assert.Equal(t, "ok", pollResp["status"])
|
||||
assert.Equal(t, "valid-jwt", pollResp["sessionJwt"])
|
||||
}
|
||||
|
||||
func TestEnchantedLinkFlow_Sms_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
SmsService: &mockSmsService{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/init", h.InitEnchantedLink)
|
||||
|
||||
// 1. Init Enchanted Link (SMS)
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"loginId": "010-1234-5678",
|
||||
"method": "sms",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
assert.NotEmpty(t, initResp["userCode"])
|
||||
}
|
||||
|
||||
func TestVerifyMagicLink_VerifyOnlyWithoutSharedBrowserSessionApprovesOnly(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
|
||||
}}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"token": "token-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Nil(t, got["sessionJwt"])
|
||||
assert.Nil(t, got["token"])
|
||||
}
|
||||
|
||||
func TestVerifyMagicLink_VerifyOnlySharedBrowserSameSubjectApprovesOnly(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-user-1")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"token": "token-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Nil(t, got["sessionJwt"])
|
||||
assert.Nil(t, got["token"])
|
||||
}
|
||||
|
||||
func TestVerifyMagicLink_VerifyOnlySharedBrowserDifferentSubjectApprovesOnly(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"token": "token-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Nil(t, got["sessionJwt"])
|
||||
assert.Nil(t, got["token"])
|
||||
assert.Contains(t, redis.data[prefixSession+"pending-123"], "approved")
|
||||
}
|
||||
|
||||
func TestResolveUserfrontURL_DevLocalhostUsesConfiguredPort(t *testing.T) {
|
||||
t.Setenv("APP_ENV", "dev")
|
||||
t.Setenv("USERFRONT_URL", "http://localhost:5000")
|
||||
|
||||
h := &AuthHandler{}
|
||||
app := fiber.New()
|
||||
app.Get("/probe", func(c *fiber.Ctx) error {
|
||||
return c.SendString(h.resolveUserfrontURL(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost/probe", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "http://localhost:5000", string(body))
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_VerifyOnlySharedBrowserDifferentSubjectApprovesOnly(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixLoginCode + "user@example.com": "flow-123",
|
||||
prefixLoginCodePending + "user@example.com": "pending-123",
|
||||
prefixLoginCodeValue + "pending-123": "569765",
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"loginId": "user@example.com",
|
||||
"code": "569765",
|
||||
"pendingRef": "pending-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Nil(t, got["sessionJwt"])
|
||||
assert.Nil(t, got["token"])
|
||||
assert.Contains(t, redis.data[prefixSession+"pending-123"], "approved")
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_MapsSmsPhoneBeforeFlowLookup(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixLoginCode + "su-@samaneng.com": "flow-123",
|
||||
prefixLoginCodePending + "su-@samaneng.com": "pending-123",
|
||||
prefixLoginCodeSmsLookup + "+821041585840": "su-@samaneng.com",
|
||||
prefixLoginCodeSmsTarget + "su-@samaneng.com": "+821041585840",
|
||||
prefixLoginCodeValue + "pending-123": "569765",
|
||||
}}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"loginId": "01041585840",
|
||||
"code": "569765",
|
||||
"pendingRef": "pending-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Equal(t, "pending-123", got["pendingRef"])
|
||||
}
|
||||
|
||||
func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"pendingRef": "missing-ref",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var got map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "expired_token", got["error"])
|
||||
assert.Equal(t, "expired_token", got["code"])
|
||||
}
|
||||
|
||||
func TestPollEnchantedLink_SharedBrowserSameSubjectIssuesSession(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixSession + "pending-123": `{"status":"approved","loginId":"user@example.com"}`,
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-user-1")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{
|
||||
issueSession: &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "new-session-id"},
|
||||
Subject: "kratos-user-1",
|
||||
},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"pendingRef": "pending-123"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "ok", got["status"])
|
||||
assert.Equal(t, "valid-jwt", got["sessionJwt"])
|
||||
}
|
||||
|
||||
func TestPollEnchantedLink_SharedBrowserDifferentSubjectConflicts(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixSession + "pending-123": `{"status":"approved","loginId":"user@example.com"}`,
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{
|
||||
issueSession: &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "new-session-id"},
|
||||
Subject: "kratos-user-1",
|
||||
},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"pendingRef": "pending-123"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, resp.StatusCode)
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "session_subject_conflict", got["code"])
|
||||
assert.NotContains(t, redis.data[prefixSession+"pending-123"], "valid-jwt")
|
||||
}
|
||||
|
||||
func TestHeadlessLinkInit_HeadlessLoginClientSuccess(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(jwksBody)
|
||||
}))
|
||||
defer jwksServer.Close()
|
||||
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
SmsService: &mockSmsService{},
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.NotEmpty(t, got["pendingRef"])
|
||||
_, hasUserCode := got["userCode"]
|
||||
assert.False(t, hasUserCode)
|
||||
}
|
||||
|
||||
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
ClientName: "local-demo-rp",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil)
|
||||
auditRepo := &mockAuditRepo{}
|
||||
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
|
||||
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
SmsService: &mockSmsService{},
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
initBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
verifyBody, _ := json.Marshal(map[string]any{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
pollBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
|
||||
"pendingRef": pendingRef,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var pollResp map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&pollResp)
|
||||
assert.Equal(t, "http://rp/cb", pollResp["redirectTo"])
|
||||
assert.Equal(t, "ok", pollResp["status"])
|
||||
assert.Nil(t, pollResp["sessionJwt"])
|
||||
assert.Nil(t, pollResp["token"])
|
||||
assert.Empty(t, resp.Cookies())
|
||||
if assert.Len(t, auditRepo.logs, 1) {
|
||||
assert.Contains(t, auditRepo.logs[0].EventType, "/api/v1/auth/")
|
||||
details, err := parseAuditDetails(auditRepo.logs[0].Details)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse audit details: %v", err)
|
||||
}
|
||||
assert.Equal(t, "headless-login-client", details["client_id"])
|
||||
assert.Equal(t, "local-demo-rp", details["client_name"])
|
||||
assert.Equal(t, "challenge-123", details["login_challenge"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessLinkPoll_ApproverSubjectConflictBlocksMixedRP(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
acceptCalled := false
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
ClientName: "local-demo-rp",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
acceptCalled = true
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-target-b", nil)
|
||||
auditRepo := &mockAuditRepo{}
|
||||
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
|
||||
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
},
|
||||
SmsService: &mockSmsService{},
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
initBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-userfront-a")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
verifyBody, _ := json.Marshal(map[string]any{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=userfront-a-session")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
pollBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
|
||||
"pendingRef": pendingRef,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, resp.StatusCode)
|
||||
assert.False(t, acceptCalled)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "oidc_subject_conflict", got["code"])
|
||||
assert.Equal(t, "redirect_to_userfront_login", got["recommendedAction"])
|
||||
assert.Equal(t, "kratos-userfront-a", got["currentSubject"])
|
||||
assert.Equal(t, "kratos-target-b", got["targetSubject"])
|
||||
assert.Empty(t, auditRepo.logs)
|
||||
}
|
||||
|
||||
func TestHeadlessLinkPoll_RequestCookieSubjectConflictBlocksMixedRP(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
acceptCalled := false
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
acceptCalled = true
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-target-b", nil)
|
||||
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
|
||||
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
},
|
||||
SmsService: &mockSmsService{},
|
||||
KratosAdmin: mockKratos,
|
||||
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
initBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
verifyBody, _ := json.Marshal(map[string]any{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-userfront-a")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
pollBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
|
||||
"pendingRef": pendingRef,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=userfront-a-session")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, resp.StatusCode)
|
||||
assert.False(t, acceptCalled)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "oidc_subject_conflict", got["code"])
|
||||
assert.Equal(t, "kratos-userfront-a", got["currentSubject"])
|
||||
assert.Equal(t, "kratos-target-b", got["targetSubject"])
|
||||
}
|
||||
287
baron-sso/backend/internal/handler/auth_handler_linked_test.go
Normal file
287
baron-sso/backend/internal/handler/auth_handler_linked_test.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// --- Helper ---
|
||||
|
||||
func newLinkedRpTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/rp/linked", h.ListLinkedRps)
|
||||
return app
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
if r.Header.Get("X-Session-Token") == "" && r.Header.Get("Cookie") == "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "unauthorized"), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
return httpJSONAny(r, http.StatusOK, []map[string]any{
|
||||
{
|
||||
"client": map[string]any{
|
||||
"client_id": "devfront",
|
||||
"client_name": "DevFront",
|
||||
"redirect_uris": []string{
|
||||
"https://active.example.com/callback",
|
||||
},
|
||||
},
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
"handled_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
"client": map[string]any{
|
||||
"client_id": "orgfront",
|
||||
"client_name": "OrgFront",
|
||||
"metadata": map[string]any{
|
||||
"auto_login_supported": true,
|
||||
"auto_login_url": "http://localhost:5175/login",
|
||||
},
|
||||
"redirect_uris": []string{
|
||||
"http://localhost:5175/auth/callback",
|
||||
},
|
||||
},
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
"handled_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/clients/client-audit" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-audit",
|
||||
"client_name": "Audit App",
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/clients/client-consent" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-consent",
|
||||
"client_name": "Consent App",
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() {
|
||||
http.DefaultClient = origDefault
|
||||
}()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.granted",
|
||||
Timestamp: time.Now().Add(-10 * time.Hour),
|
||||
Details: `{"client_id":"client-audit", "scopes":["audit_scope"]}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
consentRepo := &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{
|
||||
Subject: "user-123",
|
||||
ClientID: "client-consent",
|
||||
GrantedScopes: []string{"consent_scope"},
|
||||
UpdatedAt: time.Now().Add(-2 * time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
AuditRepo: auditRepo,
|
||||
ConsentRepo: consentRepo,
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
|
||||
t.Setenv("HYDRA_PUBLIC_URL", "https://sso.example.com/oidc")
|
||||
t.Setenv("DEVFRONT_URL", "http://localhost:5174")
|
||||
|
||||
app := newLinkedRpTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/linked", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Scopes []string `json:"scopes"`
|
||||
InitURL string `json:"init_url"`
|
||||
AutoLoginSupported bool `json:"auto_login_supported"`
|
||||
AutoLoginURL string `json:"auto_login_url"`
|
||||
} `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Equal(t, 4, len(res.Items))
|
||||
|
||||
statusMap := make(map[string]string)
|
||||
for _, item := range res.Items {
|
||||
statusMap[item.ID] = item.Status
|
||||
}
|
||||
|
||||
assert.Equal(t, "active", statusMap["devfront"])
|
||||
assert.Equal(t, "active", statusMap["orgfront"])
|
||||
assert.Equal(t, "inactive", statusMap["client-consent"])
|
||||
assert.Equal(t, "inactive", statusMap["client-audit"])
|
||||
|
||||
var activeInitURL string
|
||||
for _, item := range res.Items {
|
||||
if item.ID == "devfront" {
|
||||
activeInitURL = item.InitURL
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
parsedInitURL, err := url.Parse(activeInitURL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "http", parsedInitURL.Scheme)
|
||||
assert.Equal(t, "localhost:5174", parsedInitURL.Host)
|
||||
assert.Equal(t, "/login", parsedInitURL.Path)
|
||||
assert.Equal(t, "1", parsedInitURL.Query().Get("auto"))
|
||||
assert.Equal(t, "/clients", parsedInitURL.Query().Get("returnTo"))
|
||||
|
||||
var orgfrontItem struct {
|
||||
InitURL string
|
||||
AutoLoginSupported bool
|
||||
AutoLoginURL string
|
||||
}
|
||||
for _, item := range res.Items {
|
||||
if item.ID == "orgfront" {
|
||||
orgfrontItem.InitURL = item.InitURL
|
||||
orgfrontItem.AutoLoginSupported = item.AutoLoginSupported
|
||||
orgfrontItem.AutoLoginURL = item.AutoLoginURL
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, orgfrontItem.AutoLoginSupported)
|
||||
assert.Equal(t, "http://localhost:5175/login?auto=1", orgfrontItem.AutoLoginURL)
|
||||
assert.Equal(t, orgfrontItem.AutoLoginURL, orgfrontItem.InitURL)
|
||||
}
|
||||
|
||||
func TestListLinkedRps_EnrichesLogoFromHydraClientWhenConsentSessionOmitsMetadata(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
return httpJSONAny(r, http.StatusOK, []map[string]any{
|
||||
{
|
||||
"client": map[string]any{
|
||||
"client_id": "gitea-client",
|
||||
"client_name": "Gitea",
|
||||
"redirect_uris": []string{
|
||||
"https://gitea.example.com/callback",
|
||||
},
|
||||
},
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
"handled_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/clients/gitea-client" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "gitea-client",
|
||||
"client_name": "Gitea",
|
||||
"redirect_uris": []string{
|
||||
"https://gitea.example.com/callback",
|
||||
},
|
||||
"metadata": map[string]any{
|
||||
"logo_url": "https://cdn.example.com/gitea.svg",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() {
|
||||
http.DefaultClient = origDefault
|
||||
}()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
|
||||
t.Setenv("HYDRA_PUBLIC_URL", "https://sso.example.com/oidc")
|
||||
|
||||
app := newLinkedRpTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/linked", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
Logo string `json:"logo"`
|
||||
} `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Len(t, res.Items, 1)
|
||||
assert.Equal(t, "gitea-client", res.Items[0].ID)
|
||||
assert.Equal(t, "https://cdn.example.com/gitea.svg", res.Items[0].Logo)
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func newVerifyLoginCodeTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
|
||||
app.Post("/api/v1/auth/login/code/verify-short", h.VerifyLoginShortCode)
|
||||
return app
|
||||
}
|
||||
|
||||
func decodeJSONBody(t *testing.T, resp *http.Response) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
var got map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response body: %v", err)
|
||||
}
|
||||
return got
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_InvalidBody_ReturnsExplicitCode(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewBufferString("{"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "bad_request" {
|
||||
t.Fatalf("expected code=bad_request, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_IdpUnavailable_ReturnsExplicitCode(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"loginId": "user@example.com",
|
||||
"code": "AA-111111",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "service_unavailable" {
|
||||
t.Fatalf("expected code=service_unavailable, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_VerifyOnlyInvalidCode_ReturnsExplicitCode(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
redis.data[prefixLoginCode+"user@example.com"] = "flow-1"
|
||||
redis.data[prefixLoginCodePending+"user@example.com"] = "pending-1"
|
||||
redis.data[prefixLoginCodeValue+"pending-1"] = "AB-123"
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"loginId": "user@example.com",
|
||||
"code": "ZZ-999",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "invalid_code" {
|
||||
t.Fatalf("expected code=invalid_code, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginShortCode_MissingShortCode_ReturnsExplicitCode(t *testing.T) {
|
||||
h := &AuthHandler{
|
||||
RedisService: &mockRedisRepo{data: make(map[string]string)},
|
||||
}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"shortCode": "",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "bad_request" {
|
||||
t.Fatalf("expected code=bad_request, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginShortCode_InvalidOrExpired_ReturnsExplicitCode(t *testing.T) {
|
||||
h := &AuthHandler{
|
||||
RedisService: &mockRedisRepo{data: make(map[string]string)},
|
||||
}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"shortCode": "AB-123456",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "invalid_or_expired_code" {
|
||||
t.Fatalf("expected code=invalid_or_expired_code, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginShortCode_VerifyOnlyMissingPendingRef_ReturnsExplicitCode(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
payload, _ := json.Marshal(shortLoginCodePayload{
|
||||
LoginID: "user@example.com",
|
||||
Code: "AB-123",
|
||||
})
|
||||
redis.data[prefixLoginCodeShort+"AB-123456"] = string(payload)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"shortCode": "AB-123456",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "invalid_session_reference" {
|
||||
t.Fatalf("expected code=invalid_session_reference, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
2398
baron-sso/backend/internal/handler/auth_handler_login_test.go
Normal file
2398
baron-sso/backend/internal/handler/auth_handler_login_test.go
Normal file
File diff suppressed because it is too large
Load Diff
178
baron-sso/backend/internal/handler/auth_handler_oidc_test.go
Normal file
178
baron-sso/backend/internal/handler/auth_handler_oidc_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func newOidcLoginTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/oidc/login/accept", h.AcceptOidcLoginRequest)
|
||||
return app
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_CookieOnly(t *testing.T) {
|
||||
var gotSubject string
|
||||
var gotChallenge string
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path != "/sessions/whoami" {
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
if r.Header.Get("X-Session-Token") != "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "invalid token"), nil
|
||||
}
|
||||
if r.Header.Get("Cookie") == "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "missing cookie"), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "kratos-123",
|
||||
"traits": map[string]any{},
|
||||
},
|
||||
}), nil
|
||||
case "hydra.test":
|
||||
if r.URL.Path != "/oauth2/auth/requests/login/accept" {
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
gotChallenge = r.URL.Query().Get("login_challenge")
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var payload map[string]any
|
||||
_ = json.Unmarshal(body, &payload)
|
||||
if subject, ok := payload["subject"].(string); ok {
|
||||
gotSubject = subject
|
||||
}
|
||||
return httpResponse(r, http.StatusOK, `{"redirect_to":"http://rp/cb"}`), nil
|
||||
default:
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
})
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() {
|
||||
http.DefaultClient = origDefault
|
||||
}()
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
app := newOidcLoginTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=abc123")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
var got map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["redirectTo"] != "http://rp/cb" {
|
||||
t.Fatalf("unexpected redirectTo: %v", got["redirectTo"])
|
||||
}
|
||||
if gotSubject != "kratos-123" {
|
||||
t.Fatalf("unexpected subject: %v", gotSubject)
|
||||
}
|
||||
if gotChallenge != "challenge-123" {
|
||||
t.Fatalf("unexpected login_challenge: %v", gotChallenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_TokenFallbackToCookie(t *testing.T) {
|
||||
var gotSubject string
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path != "/sessions/whoami" {
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
if r.Header.Get("X-Session-Token") != "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "invalid token"), nil
|
||||
}
|
||||
if r.Header.Get("Cookie") == "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "missing cookie"), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "kratos-456",
|
||||
"traits": map[string]any{},
|
||||
},
|
||||
}), nil
|
||||
case "hydra.test":
|
||||
if r.URL.Path != "/oauth2/auth/requests/login/accept" {
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var payload map[string]any
|
||||
_ = json.Unmarshal(body, &payload)
|
||||
if subject, ok := payload["subject"].(string); ok {
|
||||
gotSubject = subject
|
||||
}
|
||||
return httpResponse(r, http.StatusOK, `{"redirect_to":"http://rp/cb"}`), nil
|
||||
default:
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
})
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() {
|
||||
http.DefaultClient = origDefault
|
||||
}()
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
app := newOidcLoginTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"login_challenge": "challenge-456",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=def456")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if gotSubject != "kratos-456" {
|
||||
t.Fatalf("unexpected subject: %v", gotSubject)
|
||||
}
|
||||
}
|
||||
110
baron-sso/backend/internal/handler/auth_handler_otp_test.go
Normal file
110
baron-sso/backend/internal/handler/auth_handler_otp_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHandleKratosCourierRelay_Email(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
emailSvc := &mockEmailService{}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
EmailService: emailSvc,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/kratos/courier", h.HandleKratosCourierRelay)
|
||||
|
||||
// Simulate Kratos Courier Request for Email
|
||||
reqBody := map[string]any{
|
||||
"recipient": "user@example.com",
|
||||
"template_type": "verification_code",
|
||||
"template_data": map[string]any{
|
||||
"verification_code": "123456",
|
||||
},
|
||||
"subject": "Verify your email",
|
||||
"body": "Your code is 123456",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/kratos/courier", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestVerifySignupCode_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/signup/verify", h.VerifySignupCode)
|
||||
|
||||
// Mock stored code in redis
|
||||
// signup:email:user@test.com -> {"code":"654321", "verified":false, "expires_at":...}
|
||||
state := map[string]any{
|
||||
"code": "654321",
|
||||
"verified": false,
|
||||
"expires_at": 9999999999, // far future
|
||||
}
|
||||
stateJSON, _ := json.Marshal(state)
|
||||
redis.data["signup:email:user@test.com"] = string(stateJSON)
|
||||
|
||||
// Verify Code
|
||||
verifyBody := map[string]string{
|
||||
"type": "email",
|
||||
"target": "user@test.com",
|
||||
"code": "654321",
|
||||
}
|
||||
body, _ := json.Marshal(verifyBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/signup/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
assert.True(t, res["success"].(bool))
|
||||
|
||||
// Check redis state updated to verified
|
||||
val, _ := redis.Get("signup:email:user@test.com")
|
||||
var updatedState map[string]any
|
||||
json.Unmarshal([]byte(val), &updatedState)
|
||||
assert.True(t, updatedState["verified"].(bool))
|
||||
}
|
||||
|
||||
func TestVerifySignupCode_Invalid(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/signup/verify", h.VerifySignupCode)
|
||||
|
||||
stateJSON, _ := json.Marshal(map[string]any{
|
||||
"code": "111111",
|
||||
"expires_at": 9999999999,
|
||||
})
|
||||
redis.data["signup:email:user@test.com"] = string(stateJSON)
|
||||
|
||||
verifyBody := map[string]string{
|
||||
"type": "email",
|
||||
"target": "user@test.com",
|
||||
"code": "222222", // wrong code
|
||||
}
|
||||
body, _ := json.Marshal(verifyBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/signup/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type recordingUpdateMeUserRepo struct {
|
||||
MockUserRepoForHandler
|
||||
updated *domain.User
|
||||
loginIDs []domain.UserLoginID
|
||||
}
|
||||
|
||||
func (r *recordingUpdateMeUserRepo) Update(ctx context.Context, user *domain.User) error {
|
||||
copied := *user
|
||||
r.updated = &copied
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingUpdateMeUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
||||
r.loginIDs = append([]domain.UserLoginID(nil), loginIDs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
type recordingUpdateMeKratosAdmin struct {
|
||||
MockKratosAdminService
|
||||
updatedIdentityID string
|
||||
updatedTraits map[string]any
|
||||
updatedState string
|
||||
storedTraits map[string]any
|
||||
}
|
||||
|
||||
func (r *recordingUpdateMeKratosAdmin) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
|
||||
r.updatedIdentityID = identityID
|
||||
r.updatedTraits = maps.Clone(traits)
|
||||
r.updatedState = state
|
||||
if r.storedTraits != nil {
|
||||
maps.Copy(r.storedTraits, traits)
|
||||
}
|
||||
return &service.KratosIdentity{
|
||||
ID: identityID,
|
||||
Traits: traits,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestUpdateMe_InvalidatesProfileCacheForTokenSession(t *testing.T) {
|
||||
token := "token-abc"
|
||||
identityID := "user-1"
|
||||
traits := map[string]any{
|
||||
"email": "qa@example.com",
|
||||
"name": "QA User",
|
||||
"phone_number": "+821012345678",
|
||||
"department": "Old Dept",
|
||||
"affiliationType": "employee",
|
||||
"companyCode": "",
|
||||
"role": domain.RoleUser,
|
||||
}
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/sessions/whoami" &&
|
||||
r.Method == http.MethodGet:
|
||||
if r.Header.Get("X-Session-Token") != token {
|
||||
return httpResponse(r, http.StatusUnauthorized, `{"error":"invalid token"}`), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": traits,
|
||||
},
|
||||
}), nil
|
||||
|
||||
case r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/admin/identities/"+identityID &&
|
||||
r.Method == http.MethodPut:
|
||||
var payload struct {
|
||||
Traits map[string]any `json:"traits"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
return httpResponse(r, http.StatusBadRequest, `{"error":"invalid body"}`), nil
|
||||
}
|
||||
maps.Copy(traits, payload.Traits)
|
||||
return httpResponse(r, http.StatusOK, `{"ok":true}`), nil
|
||||
}
|
||||
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
setDefaultHTTPClientForTest(t, transport)
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
KratosAdmin: kratosAdmin,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/me", h.GetMe)
|
||||
app.Put("/api/v1/user/me", h.UpdateMe)
|
||||
|
||||
// 1) 첫 조회로 Old Dept가 캐시에 저장됨
|
||||
getReq1 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
|
||||
getReq1.Header.Set("Authorization", "Bearer "+token)
|
||||
getResp1, err := app.Test(getReq1, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, getResp1.StatusCode)
|
||||
var profile1 map[string]any
|
||||
require.NoError(t, json.NewDecoder(getResp1.Body).Decode(&profile1))
|
||||
require.Equal(t, "Old Dept", profile1["department"])
|
||||
|
||||
// 2) 소속을 New Dept로 변경
|
||||
updateBody, _ := json.Marshal(map[string]string{
|
||||
"name": "QA User",
|
||||
"phone": "01012345678",
|
||||
"department": "New Dept",
|
||||
})
|
||||
updateReq := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/v1/user/me",
|
||||
bytes.NewReader(updateBody),
|
||||
)
|
||||
updateReq.Header.Set("Content-Type", "application/json")
|
||||
updateReq.Header.Set("Authorization", "Bearer "+token)
|
||||
updateResp, err := app.Test(updateReq, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, updateResp.StatusCode)
|
||||
require.Equal(t, "New Dept", traits["department"])
|
||||
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
|
||||
require.Equal(t, "New Dept", kratosAdmin.updatedTraits["department"])
|
||||
|
||||
// 3) 새로고침 재조회 시 New Dept가 보여야 함(캐시 무효화 회귀 방지)
|
||||
getReq2 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
|
||||
getReq2.Header.Set("Authorization", "Bearer "+token)
|
||||
getResp2, err := app.Test(getReq2, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, getResp2.StatusCode)
|
||||
var profile2 map[string]any
|
||||
require.NoError(t, json.NewDecoder(getResp2.Body).Decode(&profile2))
|
||||
require.Equal(t, "New Dept", profile2["department"])
|
||||
}
|
||||
|
||||
func TestUpdateMe_SyncsLocalReadModelFields(t *testing.T) {
|
||||
token := "token-sync"
|
||||
identityID := "user-sync"
|
||||
traits := map[string]any{
|
||||
"email": "sync@example.com",
|
||||
"name": "Old Name",
|
||||
"phone_number": "+821012345678",
|
||||
"department": "Old Dept",
|
||||
"affiliationType": "employee",
|
||||
"companyCode": "saman",
|
||||
"tenant_id": "11111111-1111-1111-1111-111111111111",
|
||||
"role": domain.RoleUser,
|
||||
}
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/sessions/whoami" &&
|
||||
r.Method == http.MethodGet:
|
||||
if r.Header.Get("X-Session-Token") != token {
|
||||
return httpResponse(r, http.StatusUnauthorized, `{"error":"invalid token"}`), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": traits,
|
||||
},
|
||||
}), nil
|
||||
|
||||
case r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/admin/identities/"+identityID &&
|
||||
r.Method == http.MethodPut:
|
||||
var payload struct {
|
||||
Traits map[string]any `json:"traits"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
return httpResponse(r, http.StatusBadRequest, `{"error":"invalid body"}`), nil
|
||||
}
|
||||
maps.Copy(traits, payload.Traits)
|
||||
return httpResponse(r, http.StatusOK, `{"ok":true}`), nil
|
||||
}
|
||||
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
setDefaultHTTPClientForTest(t, transport)
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
|
||||
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
"verify_update_phone:" + identityID + ":+821087654321": "verified",
|
||||
}}
|
||||
userRepo := &recordingUpdateMeUserRepo{}
|
||||
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
UserRepo: userRepo,
|
||||
KratosAdmin: kratosAdmin,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Put("/api/v1/user/me", h.UpdateMe)
|
||||
|
||||
updateBody, _ := json.Marshal(map[string]any{
|
||||
"name": "New Name",
|
||||
"phone": "01087654321",
|
||||
"department": "New Dept",
|
||||
})
|
||||
updateReq := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/v1/user/me",
|
||||
bytes.NewReader(updateBody),
|
||||
)
|
||||
updateReq.Header.Set("Content-Type", "application/json")
|
||||
updateReq.Header.Set("Authorization", "Bearer "+token)
|
||||
updateResp, err := app.Test(updateReq, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, updateResp.StatusCode)
|
||||
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
|
||||
require.Equal(t, "New Name", kratosAdmin.updatedTraits["name"])
|
||||
require.Equal(t, "+821087654321", kratosAdmin.updatedTraits["phone_number"])
|
||||
|
||||
require.NotNil(t, userRepo.updated)
|
||||
require.Equal(t, identityID, userRepo.updated.ID)
|
||||
require.Equal(t, "sync@example.com", userRepo.updated.Email)
|
||||
require.Equal(t, "New Name", userRepo.updated.Name)
|
||||
require.Equal(t, "+821087654321", userRepo.updated.Phone)
|
||||
require.Equal(t, "New Dept", userRepo.updated.Department)
|
||||
require.Empty(t, userRepo.updated.CompanyCode)
|
||||
require.NotNil(t, userRepo.updated.TenantID)
|
||||
require.Equal(t, "11111111-1111-1111-1111-111111111111", *userRepo.updated.TenantID)
|
||||
}
|
||||
206
baron-sso/backend/internal/handler/auth_handler_qr_test.go
Normal file
206
baron-sso/backend/internal/handler/auth_handler_qr_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// --- Mock Redis ---
|
||||
|
||||
type mockRedisRepo struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Set(key, value string, ttl time.Duration) error {
|
||||
if m.data == nil {
|
||||
m.data = make(map[string]string)
|
||||
}
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Get(key string) (string, error) {
|
||||
// Bypass rate limiting for tests
|
||||
if strings.HasPrefix(key, "poll_meta:") {
|
||||
return "", nil
|
||||
}
|
||||
return m.data[key], nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Delete(key string) error {
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) StoreVerificationCode(phone, code string) error {
|
||||
return m.Set("sms:"+phone, code, time.Minute)
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) GetVerificationCode(phone string) (string, error) {
|
||||
return m.Get("sms:" + phone)
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) DeleteVerificationCode(phone string) error {
|
||||
return m.Delete("sms:" + phone)
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestQRLoginFlow_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/qr/init", h.InitQRLogin)
|
||||
app.Post("/api/v1/auth/qr/poll", h.PollQRLogin)
|
||||
|
||||
// 1. Init QR Login
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/init", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
|
||||
// 2. Poll (Pending)
|
||||
body, _ := json.Marshal(map[string]string{"pendingRef": pendingRef})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
// Expect authorization_pending (400)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
var pollResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&pollResp)
|
||||
assert.Equal(t, "authorization_pending", pollResp["error"])
|
||||
assert.Equal(t, "authorization_pending", pollResp["code"])
|
||||
|
||||
// 3. Mock Approval
|
||||
sessionData, _ := json.Marshal(map[string]string{
|
||||
"status": "success",
|
||||
"jwt": "mock-session-jwt",
|
||||
})
|
||||
redis.data["enchanted_session:"+pendingRef] = string(sessionData)
|
||||
|
||||
// 4. Poll (Success)
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var successResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&successResp)
|
||||
assert.Equal(t, "ok", successResp["status"])
|
||||
assert.Equal(t, "mock-session-jwt", successResp["sessionJwt"])
|
||||
}
|
||||
|
||||
func TestScanQRLogin_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
idp := &mockIdpProvider{userExists: true}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/qr/approve", h.ScanQRLogin)
|
||||
|
||||
pendingRef := "test-ref"
|
||||
redis.data["enchanted_session:"+pendingRef] = `{"status":"pending"}`
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = &http.Client{Transport: transport}
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"pendingRef": pendingRef,
|
||||
"token": "valid-token",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/approve", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestResolveConsentSubjects_TokenAndCookie(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Header.Get("X-Session-Token") == "token-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-token",
|
||||
"traits": map[string]any{
|
||||
"email": "token@test.com",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.Header.Get("Cookie") == "ory_kratos_session=cookie-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-cookie",
|
||||
"traits": map[string]any{
|
||||
"email": "cookie@test.com",
|
||||
"phone": "010-1234-5678",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusUnauthorized, "unauthorized"), nil
|
||||
})
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = &http.Client{Transport: transport}
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
// Token case
|
||||
app.Get("/test-token", func(c *fiber.Ctx) error {
|
||||
subjects, err := h.resolveConsentSubjects(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, subjects, "user-token")
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
req := httptest.NewRequest("GET", "/test-token", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-123")
|
||||
app.Test(req, -1)
|
||||
|
||||
// Cookie case
|
||||
app.Get("/test-cookie", func(c *fiber.Ctx) error {
|
||||
subjects, err := h.resolveConsentSubjects(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, subjects, "user-cookie")
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
req = httptest.NewRequest("GET", "/test-cookie", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=cookie-123")
|
||||
app.Test(req, -1)
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetMe_IncludesSessionAuthenticatedAtFromKratosSession(t *testing.T) {
|
||||
const (
|
||||
token = "token-session"
|
||||
identityID = "user-session"
|
||||
sessionAuthenticated = "2026-03-23T15:30:00Z"
|
||||
)
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/sessions/whoami" &&
|
||||
r.Method == http.MethodGet {
|
||||
require.Equal(t, token, r.Header.Get("X-Session-Token"))
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "kratos-session-1",
|
||||
"authenticated_at": sessionAuthenticated,
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": map[string]any{
|
||||
"email": "qa@example.com",
|
||||
"name": "QA User",
|
||||
"department": "Platform",
|
||||
"affiliationType": "GENERAL",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
setDefaultHTTPClientForTest(t, transport)
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/me", h.GetMe)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err := app.Test(req, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var profile map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&profile))
|
||||
require.Equal(t, sessionAuthenticated, profile["sessionAuthenticatedAt"])
|
||||
}
|
||||
|
||||
func TestGetMe_IncludesSessionAuthenticatedAtForCookieSession(t *testing.T) {
|
||||
const (
|
||||
cookieHeader = "ory_kratos_session=session-cookie"
|
||||
identityID = "user-cookie"
|
||||
sessionAuthenticated = "2026-03-24T01:20:00Z"
|
||||
)
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/sessions/whoami" &&
|
||||
r.Method == http.MethodGet {
|
||||
require.Equal(t, cookieHeader, r.Header.Get("Cookie"))
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "kratos-session-cookie",
|
||||
"authenticated_at": sessionAuthenticated,
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": map[string]any{
|
||||
"email": "cookie@example.com",
|
||||
"name": "Cookie User",
|
||||
"department": "Platform",
|
||||
"affiliationType": "GENERAL",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
setDefaultHTTPClientForTest(t, transport)
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/me", h.GetMe)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
|
||||
req.Header.Set("Cookie", cookieHeader)
|
||||
resp, err := app.Test(req, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var profile map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&profile))
|
||||
require.Equal(t, sessionAuthenticated, profile["sessionAuthenticatedAt"])
|
||||
}
|
||||
944
baron-sso/backend/internal/handler/auth_handler_sessions_test.go
Normal file
944
baron-sso/backend/internal/handler/auth_handler_sessions_test.go
Normal file
@@ -0,0 +1,944 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestListMySessions_Success(t *testing.T) {
|
||||
now := time.Date(2026, 4, 2, 1, 2, 3, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "other-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now.Add(-2 * time.Hour),
|
||||
ExpiresAt: now.Add(22 * time.Hour),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "login_success",
|
||||
SessionID: "other-sid",
|
||||
Timestamp: now.Add(-30 * time.Minute),
|
||||
IPAddress: "203.0.113.10",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
IsCurrent bool `json:"is_current"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 2) {
|
||||
assert.Equal(t, "current-sid", body.Items[0].SessionID)
|
||||
assert.True(t, body.Items[0].IsCurrent)
|
||||
assert.Equal(t, "other-sid", body.Items[1].SessionID)
|
||||
assert.True(t, body.Items[1].IsActive)
|
||||
assert.Equal(t, "203.0.113.10", body.Items[1].IPAddress)
|
||||
assert.Equal(t, "Mozilla/5.0", body.Items[1].UserAgent)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestListMySessions_UsesConsentGrantForAppName(t *testing.T) {
|
||||
now := time.Date(2026, 4, 2, 4, 40, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "c7c721ea-session",
|
||||
Active: true,
|
||||
AuthenticatedAt: now.Add(-5 * time.Minute),
|
||||
ExpiresAt: now.Add(23*time.Hour + 55*time.Minute),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.granted",
|
||||
SessionID: "c7c721ea-session",
|
||||
Timestamp: now,
|
||||
Details: `{"client_id":"devfront","client_name":"DevFront","session_id":"c7c721ea-session","approved_session_id":"c7c721ea-session"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
AppName string `json:"app_name"`
|
||||
ClientID string `json:"client_id"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 2) {
|
||||
assert.Equal(t, "c7c721ea-session", body.Items[1].SessionID)
|
||||
assert.Equal(t, "DevFront", body.Items[1].AppName)
|
||||
assert.Equal(t, "devfront", body.Items[1].ClientID)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestListMySessions_PreservesAppNameFromOlderConsentGrant(t *testing.T) {
|
||||
now := time.Date(2026, 4, 2, 4, 40, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "c7c721ea-session",
|
||||
Active: true,
|
||||
AuthenticatedAt: now.Add(-5 * time.Minute),
|
||||
ExpiresAt: now.Add(23*time.Hour + 55*time.Minute),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.granted",
|
||||
SessionID: "c7c721ea-session",
|
||||
Timestamp: now.Add(-30 * time.Second),
|
||||
IPAddress: "203.0.113.10",
|
||||
Details: `{"client_id":"devfront","client_name":"DevFront","session_id":"c7c721ea-session"}`,
|
||||
},
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "login_success",
|
||||
SessionID: "c7c721ea-session",
|
||||
Timestamp: now,
|
||||
IPAddress: "10.0.0.12",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
AppName string `json:"app_name"`
|
||||
ClientID string `json:"client_id"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 2) {
|
||||
assert.Equal(t, "c7c721ea-session", body.Items[1].SessionID)
|
||||
assert.Equal(t, "DevFront", body.Items[1].AppName)
|
||||
assert.Equal(t, "devfront", body.Items[1].ClientID)
|
||||
assert.Equal(t, "203.0.113.10", body.Items[1].IPAddress)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestListMySessions_CurrentSessionFallsBackToRequestMetadata(t *testing.T) {
|
||||
now := time.Date(2026, 4, 6, 1, 2, 3, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: &mockAuditRepo{},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) Chrome/146.0.0.0 Safari/537.36")
|
||||
req.Header.Set("X-Forwarded-For", "100.100.100.1, 203.0.113.25")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
IsCurrent bool `json:"is_current"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
ClientID string `json:"client_id"`
|
||||
AppName string `json:"app_name"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 1) {
|
||||
assert.Equal(t, "current-sid", body.Items[0].SessionID)
|
||||
assert.True(t, body.Items[0].IsCurrent)
|
||||
assert.Equal(t, "203.0.113.25", body.Items[0].IPAddress)
|
||||
assert.Contains(t, body.Items[0].UserAgent, "Mozilla/5.0")
|
||||
assert.Equal(t, "userfront", body.Items[0].ClientID)
|
||||
assert.Equal(t, "UserFront", body.Items[0].AppName)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDeleteMySession_Success(t *testing.T) {
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
var hydraRevokeCalls int
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
if r.URL.Query().Get("subject") != "user-123" {
|
||||
t.Fatalf("unexpected revoke subject: %s", r.URL.Query().Get("subject"))
|
||||
}
|
||||
if r.URL.Query().Get("client") != "devfront" {
|
||||
t.Fatalf("unexpected revoke client: %s", r.URL.Query().Get("client"))
|
||||
}
|
||||
hydraRevokeCalls++
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
setDefaultHTTPClientForTest(t, client.Transport)
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{ID: "target-sid", Active: true},
|
||||
}, nil).Once()
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: true,
|
||||
}, nil).Once()
|
||||
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
|
||||
UserID: "user-123",
|
||||
EventType: "POST /api/v1/auth/oidc/login/accept",
|
||||
SessionID: "target-sid",
|
||||
Details: `{"client_id":"devfront","client_name":"Devfront"}`,
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "session-test-agent")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
if assert.Len(t, auditRepo.logs, 2) {
|
||||
assert.Equal(t, "session.revoked", auditRepo.logs[len(auditRepo.logs)-1].EventType)
|
||||
assert.Equal(t, "user-123", auditRepo.logs[len(auditRepo.logs)-1].UserID)
|
||||
assert.Equal(t, "current-sid", auditRepo.logs[len(auditRepo.logs)-1].SessionID)
|
||||
assert.Contains(t, auditRepo.logs[len(auditRepo.logs)-1].Details, "target-sid")
|
||||
}
|
||||
assert.Equal(t, 1, hydraRevokeCalls)
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDeleteMySession_DoesNotRevokeAllHydraSessionsWhenClientBindingMissing(t *testing.T) {
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
var hydraRevokeCalls int
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
hydraRevokeCalls++
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
setDefaultHTTPClientForTest(t, client.Transport)
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{ID: "target-sid", Active: true},
|
||||
}, nil).Once()
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: true,
|
||||
}, nil).Once()
|
||||
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "session-test-agent")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, 0, hydraRevokeCalls)
|
||||
if assert.Len(t, auditRepo.logs, 1) {
|
||||
assert.Equal(t, "session.revoked", auditRepo.logs[0].EventType)
|
||||
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
|
||||
assert.Contains(t, auditRepo.logs[0].Details, "target-sid")
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDeleteMySession_SendsBackchannelLogoutTokenWhenClientConfigured(t *testing.T) {
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
|
||||
|
||||
var receivedBody string
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/clients/devfront" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "devfront",
|
||||
"backchannel_logout_uri": "https://rp.example.com/backchannel-logout",
|
||||
}), nil
|
||||
}
|
||||
case "rp.example.com":
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/backchannel-logout" {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
receivedBody = string(raw)
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
setDefaultHTTPClientForTest(t, client.Transport)
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{ID: "target-sid", Active: true},
|
||||
}, nil).Once()
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: true,
|
||||
}, nil).Once()
|
||||
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
|
||||
|
||||
backchannelLogout, err := service.NewBackchannelLogoutService()
|
||||
assert.NoError(t, err)
|
||||
backchannelLogout.HTTPClient = client
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
BackchannelLogout: backchannelLogout,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
|
||||
UserID: "user-123",
|
||||
EventType: "POST /api/v1/auth/oidc/login/accept",
|
||||
SessionID: "target-sid",
|
||||
Details: `{"client_id":"devfront","client_name":"Devfront"}`,
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "session-test-agent")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.True(t, strings.Contains(receivedBody, "logout_token="))
|
||||
|
||||
values, err := url.ParseQuery(receivedBody)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, values.Get("logout_token"))
|
||||
|
||||
foundBackchannelAudit := false
|
||||
for _, log := range auditRepo.logs {
|
||||
if log.EventType == "backchannel_logout.sent" {
|
||||
foundBackchannelAudit = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundBackchannelAudit)
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDeleteMySession_RevokesHydraClientBoundFromPasswordLoginAudit(t *testing.T) {
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
var hydraRevokeCalls int
|
||||
var revokedClient string
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
revokedClient = r.URL.Query().Get("client")
|
||||
hydraRevokeCalls++
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
setDefaultHTTPClientForTest(t, client.Transport)
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{ID: "target-sid", Active: true},
|
||||
}, nil).Once()
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: true,
|
||||
}, nil).Once()
|
||||
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
|
||||
UserID: "user-123",
|
||||
EventType: "POST /api/v1/auth/password/login",
|
||||
SessionID: "target-sid",
|
||||
Details: `{"client_id":"adminfront","client_name":"AdminFront","session_id":"target-sid"}`,
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "session-test-agent")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, 1, hydraRevokeCalls)
|
||||
assert.Equal(t, "adminfront", revokedClient)
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetHydraProfile_RejectsInactiveLinkedSession(t *testing.T) {
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "hydra.test" && r.URL.Path == "/oauth2/introspect" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
if string(body) != "token=opaque-token" {
|
||||
t.Fatalf("unexpected introspect body: %s", string(body))
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"active": true,
|
||||
"sub": "user-123",
|
||||
"client_id": "devfront",
|
||||
"ext": map[string]any{
|
||||
"session_id": "target-sid",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: false,
|
||||
Identity: &service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
|
||||
profile, err := h.getHydraProfile(context.Background(), "opaque-token")
|
||||
assert.Nil(t, profile)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "inactive")
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetAuthTimeline_FillsSessionIDFromOathkeeperRaw(t *testing.T) {
|
||||
now := time.Date(2026, 4, 7, 4, 39, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
h := &AuthHandler{
|
||||
AuditRepo: &mockAuditRepo{},
|
||||
OathkeeperRepo: &mockOathkeeperRepo{
|
||||
logs: []domain.OathkeeperAccessLog{
|
||||
{
|
||||
Timestamp: now,
|
||||
RequestID: "req-1",
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/v1/dev/sessions",
|
||||
Status: http.StatusOK,
|
||||
Subject: "user-123",
|
||||
ClientIP: "203.0.113.7",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
Raw: `{"request":{"url":"https://devfront.example.com/callback?client_id=devfront"},"extra":{"session_id":"target-sid"}}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/audit/auth/timeline", h.GetAuthTimeline)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/auth/timeline", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ClientID string `json:"client_id"`
|
||||
AppName string `json:"app_name"`
|
||||
Source string `json:"source"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 1) {
|
||||
assert.Equal(t, "target-sid", body.Items[0].SessionID)
|
||||
assert.Equal(t, "devfront", body.Items[0].ClientID)
|
||||
assert.Equal(t, "devfront", body.Items[0].AppName)
|
||||
assert.Equal(t, "oathkeeper", body.Items[0].Source)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAuthTimeline_IncludesHeadlessPasswordLogin(t *testing.T) {
|
||||
now := time.Date(2026, 4, 7, 5, 10, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
h := &AuthHandler{
|
||||
AuditRepo: &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
EventID: "audit-1",
|
||||
Timestamp: now,
|
||||
UserID: "user-123",
|
||||
SessionID: "headless-session-1",
|
||||
EventType: "POST /api/v1/auth/headless/password/login",
|
||||
Status: "success",
|
||||
IPAddress: "203.0.113.20",
|
||||
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36",
|
||||
Details: `{"client_id":"headless-login-client","client_name":"Headless Login Portal","session_id":"headless-session-1","login_id":"user@example.com","login_challenge":"challenge-123"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/audit/auth/timeline", h.GetAuthTimeline)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/auth/timeline", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ClientID string `json:"client_id"`
|
||||
AppName string `json:"app_name"`
|
||||
AuthMethod string `json:"auth_method"`
|
||||
EventType string `json:"event_type"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 1) {
|
||||
assert.Equal(t, "headless-session-1", body.Items[0].SessionID)
|
||||
assert.Equal(t, "headless-login-client", body.Items[0].ClientID)
|
||||
assert.Equal(t, "Headless Login Portal", body.Items[0].AppName)
|
||||
assert.Equal(t, "비밀번호(Email)", body.Items[0].AuthMethod)
|
||||
assert.Equal(t, "POST /api/v1/auth/headless/password/login", body.Items[0].EventType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListMySessions_UsesHeadlessPasswordLoginForClientBinding(t *testing.T) {
|
||||
now := time.Date(2026, 4, 7, 5, 35, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "headless-session-1",
|
||||
Active: true,
|
||||
AuthenticatedAt: now.Add(-10 * time.Minute),
|
||||
ExpiresAt: now.Add(23*time.Hour + 50*time.Minute),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "POST /api/v1/auth/headless/password/login",
|
||||
SessionID: "headless-session-1",
|
||||
Timestamp: now,
|
||||
IPAddress: "203.0.113.20",
|
||||
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36",
|
||||
Details: `{"client_id":"headless-login-client","client_name":"Headless Login Portal","session_id":"headless-session-1"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
AppName string `json:"app_name"`
|
||||
ClientID string `json:"client_id"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 2) {
|
||||
assert.Equal(t, "headless-session-1", body.Items[1].SessionID)
|
||||
assert.Equal(t, "Headless Login Portal", body.Items[1].AppName)
|
||||
assert.Equal(t, "headless-login-client", body.Items[1].ClientID)
|
||||
assert.Equal(t, "203.0.113.20", body.Items[1].IPAddress)
|
||||
assert.Equal(t, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36", body.Items[1].UserAgent)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
144
baron-sso/backend/internal/handler/auth_handler_signup_test.go
Normal file
144
baron-sso/backend/internal/handler/auth_handler_signup_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// --- Local Mocks for Signup Test ---
|
||||
|
||||
type MockRedisForSignup struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRedisForSignup) Set(key string, value string, ttl time.Duration) error {
|
||||
return m.Called(key, value, ttl).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRedisForSignup) Get(key string) (string, error) {
|
||||
args := m.Called(key)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedisForSignup) Delete(key string) error {
|
||||
return m.Called(key).Error(0)
|
||||
}
|
||||
func (m *MockRedisForSignup) StoreVerificationCode(phone, code string) error { return nil }
|
||||
func (m *MockRedisForSignup) GetVerificationCode(phone string) (string, error) { return "", nil }
|
||||
func (m *MockRedisForSignup) DeleteVerificationCode(phone string) error { return nil }
|
||||
func (m *MockRedisForSignup) Ping(ctx context.Context) error { return nil }
|
||||
|
||||
type MockIdpForSignup struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) Name() string { return "mock-idp" }
|
||||
func (m *MockIdpForSignup) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
return &domain.IDPMetadata{SupportedFields: []string{"email", "name", "phoneNumber", "grade", "department"}}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
args := m.Called(user, password)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockIdpForSignup) UserExists(loginID string) (bool, error) { return false, nil }
|
||||
func (m *MockIdpForSignup) IssueSession(loginID string) (*domain.AuthInfo, error) { return nil, nil }
|
||||
func (m *MockIdpForSignup) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return &domain.PasswordPolicy{MinLength: 12}, nil
|
||||
}
|
||||
func (m *MockIdpForSignup) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
|
||||
func (m *MockIdpForSignup) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSignup_TenantSlugValidation(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockTenantSvc := new(MockTenantService)
|
||||
mockRedis := new(MockRedisForSignup)
|
||||
mockIdp := new(MockIdpForSignup)
|
||||
|
||||
h := &AuthHandler{
|
||||
TenantService: mockTenantSvc,
|
||||
RedisService: mockRedis,
|
||||
IdpProvider: mockIdp,
|
||||
}
|
||||
|
||||
app.Post("/signup", h.Signup)
|
||||
|
||||
// Prepare mock state (already verified email/phone)
|
||||
verifiedState, _ := json.Marshal(map[string]any{
|
||||
"verified": true,
|
||||
"expires_at": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
mockRedis.On("Get", mock.Anything).Return(string(verifiedState), nil)
|
||||
|
||||
t.Run("Rejects legacy CompanyCode", func(t *testing.T) {
|
||||
reqBody := domain.SignupRequest{
|
||||
Email: "user@gmail.com",
|
||||
Password: "StrongPass123!",
|
||||
Name: "Test User",
|
||||
Phone: "010-1234-5678",
|
||||
TermsAccepted: true,
|
||||
CompanyCode: "new-slug",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Active Tenant Slug", func(t *testing.T) {
|
||||
reqBody := domain.SignupRequest{
|
||||
Email: "user@hanmaceng.co.kr",
|
||||
Password: "StrongPass123!",
|
||||
Name: "Test User",
|
||||
Phone: "010-1234-5678",
|
||||
TermsAccepted: true,
|
||||
TenantSlug: "hanmac",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
validTenant := &domain.Tenant{ID: "t1", Slug: "hanmac", Status: domain.TenantStatusActive}
|
||||
mockTenantSvc.On("GetTenantByDomain", mock.Anything, "hanmaceng.co.kr").Return(&domain.Tenant{Slug: "hanmac"}, nil).Once()
|
||||
mockTenantSvc.On("ProvisionTenantByDomain", mock.Anything, "hanmaceng.co.kr").Return(validTenant, nil).Maybe()
|
||||
mockTenantSvc.On("GetTenantBySlug", mock.Anything, "hanmac").Return(validTenant, nil).Once()
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, "t1").Return(validTenant, nil).Once()
|
||||
mockIdp.On("CreateUser", mock.Anything, mock.Anything).Return("user-id", nil).Once()
|
||||
mockRedis.On("Delete", mock.Anything).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
521
baron-sso/backend/internal/handler/auth_handler_test.go
Normal file
521
baron-sso/backend/internal/handler/auth_handler_test.go
Normal file
@@ -0,0 +1,521 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/middleware"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// helper to build a Fiber app with the handler route mounted.
|
||||
func newTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/password/reset/complete", h.CompletePasswordReset)
|
||||
return app
|
||||
}
|
||||
|
||||
func newResetFlowTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/password/reset/verify", h.ProcessPasswordResetToken)
|
||||
app.Post("/api/v1/auth/password/reset/complete", h.CompletePasswordReset)
|
||||
return app
|
||||
}
|
||||
|
||||
func newResetInitAppWithErrorCodeEnricher(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Use(middleware.ErrorCodeEnricher())
|
||||
app.Post("/api/v1/auth/password/reset/init", h.InitiatePasswordReset)
|
||||
return app
|
||||
}
|
||||
|
||||
type testRedisRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) Set(key string, value string, expiration time.Duration) error {
|
||||
if m.values == nil {
|
||||
m.values = map[string]string{}
|
||||
}
|
||||
m.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) Get(key string) (string, error) {
|
||||
if m.values == nil {
|
||||
return "", nil
|
||||
}
|
||||
return m.values[key], nil
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) Delete(key string) error {
|
||||
if m.values != nil {
|
||||
delete(m.values, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) StoreVerificationCode(phone, code string) error {
|
||||
return m.Set("sms:"+phone, code, time.Minute)
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) GetVerificationCode(phone string) (string, error) {
|
||||
return m.Get("sms:" + phone)
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) DeleteVerificationCode(phone string) error {
|
||||
return m.Delete("sms:" + phone)
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_MissingLoginID(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": "Password1!",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for missing loginId, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "Login ID and new password are required" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_InvalidPasswordPolicy(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": "short", // too short + missing complexity
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete?loginId=user@example.com", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for weak password, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "비밀번호는 최소 12자 이상이어야 합니다" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_NilIDPProvider(t *testing.T) {
|
||||
h := &AuthHandler{} // IdpProvider intentionally nil to hit the configuration error branch
|
||||
app := newTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": "StrongPass1!",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete?loginId=user@example.com", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500 when IDP provider is nil, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "Authentication service not configured" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_TokenValueOverridesLoginIDQuery(t *testing.T) {
|
||||
const resetToken = "tok-reset-1"
|
||||
const tokenLoginID = "user@example.com"
|
||||
const wrongLoginID = "wrong@example.com"
|
||||
const newPassword = "StrongPass1!"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + resetToken: tokenLoginID,
|
||||
},
|
||||
}
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
err: nil,
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
}
|
||||
app := newResetFlowTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": newPassword,
|
||||
})
|
||||
url := fmt.Sprintf(
|
||||
"/api/v1/auth/password/reset/complete?loginId=%s&token=%s",
|
||||
wrongLoginID,
|
||||
resetToken,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if !idp.updateCalled {
|
||||
t.Fatal("expected UpdateUserPassword to be called")
|
||||
}
|
||||
if idp.updatedLoginID != tokenLoginID {
|
||||
t.Fatalf("expected loginId from token(%s), got %s", tokenLoginID, idp.updatedLoginID)
|
||||
}
|
||||
if idp.updatedPassword != newPassword {
|
||||
t.Fatalf("expected newPassword propagated, got %s", idp.updatedPassword)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_InvalidTokenRejectedEvenWhenLoginIDExists(t *testing.T) {
|
||||
const resetToken = "invalid-token"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{},
|
||||
}
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
err: nil,
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
}
|
||||
app := newResetFlowTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": "StrongPass1!",
|
||||
})
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/auth/password/reset/complete?loginId=user@example.com&token="+resetToken,
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401 for invalid token, got %d", resp.StatusCode)
|
||||
}
|
||||
if idp.updateCalled {
|
||||
t.Fatal("UpdateUserPassword must not be called when token is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_DuplicateTokenSubmitIsIdempotent(t *testing.T) {
|
||||
const resetToken = "dup-token"
|
||||
const loginID = "user@example.com"
|
||||
const newPassword = "StrongPass1!"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + resetToken: loginID,
|
||||
},
|
||||
}
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
err: nil,
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
}
|
||||
app := newResetFlowTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": newPassword,
|
||||
})
|
||||
url := fmt.Sprintf(
|
||||
"/api/v1/auth/password/reset/complete?token=%s",
|
||||
resetToken,
|
||||
)
|
||||
|
||||
firstReq := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
firstReq.Header.Set("Content-Type", "application/json")
|
||||
firstResp, err := app.Test(firstReq)
|
||||
if err != nil {
|
||||
t.Fatalf("first request failed: %v", err)
|
||||
}
|
||||
defer firstResp.Body.Close()
|
||||
|
||||
if firstResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected first response to be 200, got %d", firstResp.StatusCode)
|
||||
}
|
||||
if idp.updateCallCount != 1 {
|
||||
t.Fatalf("expected first request to update password once, got %d", idp.updateCallCount)
|
||||
}
|
||||
|
||||
secondReq := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
secondReq.Header.Set("Content-Type", "application/json")
|
||||
secondResp, err := app.Test(secondReq)
|
||||
if err != nil {
|
||||
t.Fatalf("second request failed: %v", err)
|
||||
}
|
||||
defer secondResp.Body.Close()
|
||||
|
||||
if secondResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected duplicate response to be 200, got %d", secondResp.StatusCode)
|
||||
}
|
||||
if idp.updateCallCount != 1 {
|
||||
t.Fatalf("expected duplicate request not to update password again, got %d", idp.updateCallCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessPasswordResetToken_EncodesLoginIDInRedirect(t *testing.T) {
|
||||
const token = "tok-enc"
|
||||
const loginID = "user+alias@example.com"
|
||||
|
||||
t.Setenv("USERFRONT_URL", "https://sss.hmac.kr")
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + token: loginID,
|
||||
},
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := newResetFlowTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/auth/password/reset/verify?token="+token,
|
||||
nil,
|
||||
)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected 302, got %d", resp.StatusCode)
|
||||
}
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
t.Fatal("missing redirect location")
|
||||
}
|
||||
redirectReq := httptest.NewRequest(http.MethodGet, location, nil)
|
||||
gotLoginID := redirectReq.URL.Query().Get("loginId")
|
||||
if gotLoginID != loginID {
|
||||
t.Fatalf("expected encoded loginId round-trip=%s, got %s (location=%s)", loginID, gotLoginID, location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetVerifyAlias_AcceptsShortVePath(t *testing.T) {
|
||||
const token = "tok-ve"
|
||||
const loginID = "user@example.com"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + token: loginID,
|
||||
},
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/password/reset/ve", h.VerifyPasswordResetPage)
|
||||
app.Post("/api/v1/auth/password/reset/ve", h.ProcessPasswordResetToken)
|
||||
|
||||
getReq := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/api/v1/auth/password/reset/ve?token="+token,
|
||||
nil,
|
||||
)
|
||||
getResp, err := app.Test(getReq)
|
||||
if err != nil {
|
||||
t.Fatalf("get request failed: %v", err)
|
||||
}
|
||||
defer getResp.Body.Close()
|
||||
|
||||
if getResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected alias GET to return 200, got %d", getResp.StatusCode)
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/auth/password/reset/ve?token="+token,
|
||||
nil,
|
||||
)
|
||||
postResp, err := app.Test(postReq)
|
||||
if err != nil {
|
||||
t.Fatalf("post request failed: %v", err)
|
||||
}
|
||||
defer postResp.Body.Close()
|
||||
|
||||
if postResp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected alias POST to return 302, got %d", postResp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetVerifyPathToken_AcceptsShortVPath(t *testing.T) {
|
||||
const token = "tok-path"
|
||||
const loginID = "user@example.com"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + token: loginID,
|
||||
},
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/password/reset/v/:token", h.VerifyPasswordResetPage)
|
||||
app.Post("/api/v1/auth/password/reset/v/:token", h.ProcessPasswordResetToken)
|
||||
|
||||
getReq := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/api/v1/auth/password/reset/v/"+token,
|
||||
nil,
|
||||
)
|
||||
getResp, err := app.Test(getReq)
|
||||
if err != nil {
|
||||
t.Fatalf("get request failed: %v", err)
|
||||
}
|
||||
defer getResp.Body.Close()
|
||||
|
||||
if getResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected path-token GET to return 200, got %d", getResp.StatusCode)
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/auth/password/reset/v/"+token,
|
||||
nil,
|
||||
)
|
||||
postResp, err := app.Test(postReq)
|
||||
if err != nil {
|
||||
t.Fatalf("post request failed: %v", err)
|
||||
}
|
||||
defer postResp.Body.Close()
|
||||
|
||||
if postResp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected path-token POST to return 302, got %d", postResp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetInit_LegacyErrorResponseHasCodeViaMiddleware(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newResetInitAppWithErrorCodeEnricher(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"loginId": "",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "Login ID is required" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
if got["code"] != "bad_request" {
|
||||
t.Fatalf("expected code=bad_request, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitiatePasswordReset_SmsContainsVerifyLink(t *testing.T) {
|
||||
t.Setenv("USERFRONT_URL", "https://sss.hmac.kr")
|
||||
|
||||
redis := &testRedisRepo{values: map[string]string{}}
|
||||
smsSvc := &mockSmsService{}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
SmsService: smsSvc,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/password/reset/init", h.InitiatePasswordReset)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"loginId": "01012345678",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if !strings.Contains(smsSvc.lastContent, "/api/v1/auth/password/reset/v/") {
|
||||
t.Fatalf("expected SMS to contain short path verify link, got %q", smsSvc.lastContent)
|
||||
}
|
||||
if strings.Contains(smsSvc.lastContent, "/reset-password?token=") {
|
||||
t.Fatalf("expected direct reset-password link to be removed, got %q", smsSvc.lastContent)
|
||||
}
|
||||
}
|
||||
537
baron-sso/backend/internal/handler/client_tenant_access.go
Normal file
537
baron-sso/backend/internal/handler/client_tenant_access.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/response"
|
||||
"baron-sso-backend/internal/service"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
clientTenantAccessRestrictedKey = "tenant_access_restricted"
|
||||
clientAllowedTenantsKey = "allowed_tenants"
|
||||
)
|
||||
|
||||
func normalizeClientTenantAccessMetadata(metadata map[string]any) (map[string]any, error) {
|
||||
if metadata == nil {
|
||||
metadata = map[string]any{}
|
||||
}
|
||||
|
||||
restricted := readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey)
|
||||
allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])
|
||||
ownerTenantID := normalizeMetadataString(metadata["tenant_id"])
|
||||
|
||||
if len(allowedTenants) > 0 {
|
||||
restricted = true
|
||||
}
|
||||
|
||||
if !restricted {
|
||||
delete(metadata, clientAllowedTenantsKey)
|
||||
metadata[clientTenantAccessRestrictedKey] = false
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
if ownerTenantID != "" {
|
||||
allowedTenants = append(allowedTenants, ownerTenantID)
|
||||
}
|
||||
allowedTenants = uniqueSortedStrings(allowedTenants)
|
||||
if len(allowedTenants) == 0 {
|
||||
return nil, errors.New("allowed_tenants is required when tenant_access_restricted is enabled")
|
||||
}
|
||||
|
||||
metadata[clientTenantAccessRestrictedKey] = true
|
||||
metadata[clientAllowedTenantsKey] = allowedTenants
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func clientTenantAccessRestricted(metadata map[string]any) bool {
|
||||
if metadata == nil {
|
||||
return false
|
||||
}
|
||||
if readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey) {
|
||||
return true
|
||||
}
|
||||
return len(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])) > 0
|
||||
}
|
||||
|
||||
func clientAllowedTenants(metadata map[string]any) []string {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
if !clientTenantAccessRestricted(metadata) {
|
||||
return nil
|
||||
}
|
||||
return uniqueSortedStrings(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey]))
|
||||
}
|
||||
|
||||
func normalizeMetadataStringSlice(raw any) []string {
|
||||
switch value := raw.(type) {
|
||||
case []string:
|
||||
return uniqueSortedStrings(value)
|
||||
case []any:
|
||||
items := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
if s, ok := item.(string); ok {
|
||||
items = append(items, s)
|
||||
}
|
||||
}
|
||||
return uniqueSortedStrings(items)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMetadataString(raw any) string {
|
||||
s, ok := raw.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func uniqueSortedStrings(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
||||
if !clientTenantAccessRestricted(client.Metadata) {
|
||||
return true
|
||||
}
|
||||
allowed := clientAllowedTenants(client.Metadata)
|
||||
if len(allowed) == 0 {
|
||||
return false
|
||||
}
|
||||
keys := manageableTenantKeysFromProfile(profile)
|
||||
if len(keys) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, tenantID := range allowed {
|
||||
if _, ok := keys[strings.ToLower(strings.TrimSpace(tenantID))]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func clientTenantAccessAllowedForSubtree(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
||||
if clientTenantAccessAllowed(profile, client) {
|
||||
return true
|
||||
}
|
||||
if tenantSvc == nil || profile == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
allowedTenants := make([]domain.Tenant, 0)
|
||||
for _, identifier := range clientAllowedTenants(client.Metadata) {
|
||||
if tenant, ok := resolveTenantAccessTenant(c, tenantSvc, domain.Tenant{ID: identifier, Slug: identifier}); ok {
|
||||
allowedTenants = append(allowedTenants, tenant)
|
||||
}
|
||||
}
|
||||
if len(allowedTenants) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, candidate := range tenantAccessProfileTenants(profile) {
|
||||
resolvedCandidate, ok := resolveTenantAccessTenant(c, tenantSvc, candidate)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, allowed := range allowedTenants {
|
||||
if tenantMatchesOrDescendsFrom(c, tenantSvc, resolvedCandidate, allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tenantAccessProfileTenants(profile *domain.UserProfileResponse) []domain.Tenant {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
tenants := make([]domain.Tenant, 0, len(profile.ManageableTenants)+len(profile.JoinedTenants)+2)
|
||||
add := func(tenant domain.Tenant) {
|
||||
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Name))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
tenants = append(tenants, tenant)
|
||||
}
|
||||
|
||||
if profile.Tenant != nil {
|
||||
add(*profile.Tenant)
|
||||
}
|
||||
if profile.TenantID != nil {
|
||||
add(domain.Tenant{ID: strings.TrimSpace(*profile.TenantID)})
|
||||
}
|
||||
for _, tenant := range profile.ManageableTenants {
|
||||
add(tenant)
|
||||
}
|
||||
for _, tenant := range profile.JoinedTenants {
|
||||
add(tenant)
|
||||
}
|
||||
return tenants
|
||||
}
|
||||
|
||||
func resolveTenantAccessTenant(c *fiber.Ctx, tenantSvc service.TenantService, tenant domain.Tenant) (domain.Tenant, bool) {
|
||||
if tenantSvc == nil {
|
||||
return tenant, firstNonEmptyString(tenant.ID, tenant.Slug) != ""
|
||||
}
|
||||
if strings.TrimSpace(tenant.ID) != "" {
|
||||
if resolved, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(tenant.ID)); err == nil && resolved != nil {
|
||||
return *resolved, true
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tenant.Slug) != "" {
|
||||
if resolved, err := tenantSvc.GetTenantBySlug(c.Context(), strings.TrimSpace(tenant.Slug)); err == nil && resolved != nil {
|
||||
return *resolved, true
|
||||
}
|
||||
}
|
||||
return tenant, firstNonEmptyString(tenant.ID, tenant.Slug) != ""
|
||||
}
|
||||
|
||||
func tenantMatchesOrDescendsFrom(c *fiber.Ctx, tenantSvc service.TenantService, tenant domain.Tenant, ancestor domain.Tenant) bool {
|
||||
if tenantAccessTenantMatches(tenant, ancestor) {
|
||||
return true
|
||||
}
|
||||
if tenantSvc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
visited := make(map[string]struct{})
|
||||
current := tenant
|
||||
for current.ParentID != nil && strings.TrimSpace(*current.ParentID) != "" {
|
||||
parentID := strings.TrimSpace(*current.ParentID)
|
||||
if _, ok := visited[parentID]; ok {
|
||||
return false
|
||||
}
|
||||
visited[parentID] = struct{}{}
|
||||
|
||||
parent, err := tenantSvc.GetTenant(c.Context(), parentID)
|
||||
if err != nil || parent == nil {
|
||||
return false
|
||||
}
|
||||
if tenantAccessTenantMatches(*parent, ancestor) {
|
||||
return true
|
||||
}
|
||||
current = *parent
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tenantAccessTenantMatches(left, right domain.Tenant) bool {
|
||||
leftID := strings.ToLower(strings.TrimSpace(left.ID))
|
||||
rightID := strings.ToLower(strings.TrimSpace(right.ID))
|
||||
if leftID != "" && rightID != "" && leftID == rightID {
|
||||
return true
|
||||
}
|
||||
|
||||
leftSlug := strings.ToLower(strings.TrimSpace(left.Slug))
|
||||
rightSlug := strings.ToLower(strings.TrimSpace(right.Slug))
|
||||
return leftSlug != "" && rightSlug != "" && leftSlug == rightSlug
|
||||
}
|
||||
|
||||
type tenantAccessDeniedDetails struct {
|
||||
Account tenantAccessDeniedAccount `json:"account"`
|
||||
CurrentTenant tenantAccessDeniedTenant `json:"current_tenant"`
|
||||
AffiliatedTenants []tenantAccessDeniedTenant `json:"affiliated_tenants,omitempty"`
|
||||
AllowedTenants []tenantAccessDeniedTenant `json:"allowed_tenants,omitempty"`
|
||||
}
|
||||
|
||||
type tenantAccessDeniedAccount struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
type tenantAccessDeniedTenant struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Slug string `json:"slug,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Identifier string `json:"identifier,omitempty"`
|
||||
}
|
||||
|
||||
func tenantNotAllowedError(c *fiber.Ctx, details tenantAccessDeniedDetails) error {
|
||||
return response.ErrorWithDetails(
|
||||
c,
|
||||
fiber.StatusForbidden,
|
||||
"tenant_not_allowed",
|
||||
"허용되지 않은 테넌트입니다.",
|
||||
details,
|
||||
)
|
||||
}
|
||||
|
||||
func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
||||
if profile == nil {
|
||||
return false
|
||||
}
|
||||
return clientTenantAccessAllowed(profile, client)
|
||||
}
|
||||
|
||||
func enforceClientTenantAccess(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) bool {
|
||||
if !clientTenantAccessRestricted(client.Metadata) {
|
||||
return false
|
||||
}
|
||||
details := buildTenantAccessDeniedDetails(c, tenantSvc, client, profile)
|
||||
if resolveErr != nil || profile == nil {
|
||||
_ = tenantNotAllowedError(c, details)
|
||||
return true
|
||||
}
|
||||
if !clientTenantAccessAllowedForSubtree(c, tenantSvc, profile, client) {
|
||||
_ = tenantNotAllowedError(c, details)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildTenantAccessDeniedDetails(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse) tenantAccessDeniedDetails {
|
||||
details := tenantAccessDeniedDetails{
|
||||
Account: tenantAccessDeniedAccount{Email: strings.TrimSpace(profileEmail(profile))},
|
||||
CurrentTenant: resolveCurrentTenantDetails(c, tenantSvc, profile),
|
||||
AffiliatedTenants: resolveAffiliatedTenantDetails(c, tenantSvc, profile),
|
||||
}
|
||||
|
||||
for _, identifier := range clientAllowedTenants(client.Metadata) {
|
||||
details.AllowedTenants = append(details.AllowedTenants, resolveAllowedTenantDetails(c, tenantSvc, identifier))
|
||||
}
|
||||
|
||||
return details
|
||||
}
|
||||
|
||||
func resolveAffiliatedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) []tenantAccessDeniedTenant {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]tenantAccessDeniedTenant, 0, len(profile.JoinedTenants)+1)
|
||||
appendTenant := func(tenant tenantAccessDeniedTenant) {
|
||||
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Identifier, tenant.Name))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, tenant)
|
||||
}
|
||||
|
||||
appendTenant(resolveCurrentTenantDetails(c, tenantSvc, profile))
|
||||
|
||||
for _, joined := range profile.JoinedTenants {
|
||||
appendTenant(tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(joined.ID),
|
||||
Slug: strings.TrimSpace(joined.Slug),
|
||||
Name: strings.TrimSpace(joined.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(joined.Slug), strings.TrimSpace(joined.ID)),
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveCurrentTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) tenantAccessDeniedTenant {
|
||||
if profile == nil {
|
||||
return tenantAccessDeniedTenant{}
|
||||
}
|
||||
|
||||
if profile.Tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(profile.Tenant.ID),
|
||||
Slug: strings.TrimSpace(profile.Tenant.Slug),
|
||||
Name: strings.TrimSpace(profile.Tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(profile.Tenant.Slug), strings.TrimSpace(profile.Tenant.ID)),
|
||||
}
|
||||
}
|
||||
|
||||
if tenantSvc != nil {
|
||||
if profile.TenantID != nil && strings.TrimSpace(*profile.TenantID) != "" {
|
||||
if tenant, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(*profile.TenantID)); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(pointerValue(profile.TenantID)),
|
||||
Identifier: strings.TrimSpace(pointerValue(profile.TenantID)),
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAllowedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, identifier string) tenantAccessDeniedTenant {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return tenantAccessDeniedTenant{}
|
||||
}
|
||||
|
||||
if tenantSvc != nil {
|
||||
if tenant, err := tenantSvc.GetTenant(c.Context(), identifier); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
|
||||
}
|
||||
}
|
||||
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), identifier); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tenantAccessDeniedTenant{Identifier: identifier}
|
||||
}
|
||||
|
||||
func profileEmail(profile *domain.UserProfileResponse) string {
|
||||
if profile == nil {
|
||||
return ""
|
||||
}
|
||||
return profile.Email
|
||||
}
|
||||
|
||||
func pointerValue(value *string) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type clientStructuredScope struct {
|
||||
Name string `json:"name"`
|
||||
Mandatory bool `json:"mandatory"`
|
||||
Locked bool `json:"locked"`
|
||||
}
|
||||
|
||||
func mergeRequestedScopesWithClientRequirements(client domain.HydraClient, requested []string) []string {
|
||||
combined := make([]string, 0, len(requested)+2)
|
||||
combined = append(combined, requested...)
|
||||
combined = append(combined, requiredClientScopes(client)...)
|
||||
|
||||
return normalizeScopesInConsentOrder(combined)
|
||||
}
|
||||
|
||||
func normalizeScopesInConsentOrder(scopes []string) []string {
|
||||
combined := make([]string, 0, len(scopes))
|
||||
combined = append(combined, scopes...)
|
||||
|
||||
seen := make(map[string]struct{}, len(combined))
|
||||
out := make([]string, 0, len(combined))
|
||||
|
||||
appendIfPresent := func(scope string) {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[scope]; ok {
|
||||
return
|
||||
}
|
||||
for _, candidate := range combined {
|
||||
if strings.TrimSpace(candidate) != scope {
|
||||
continue
|
||||
}
|
||||
seen[scope] = struct{}{}
|
||||
out = append(out, scope)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
appendIfPresent("openid")
|
||||
appendIfPresent("tenant")
|
||||
|
||||
for _, scope := range combined {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[scope]; ok {
|
||||
continue
|
||||
}
|
||||
seen[scope] = struct{}{}
|
||||
out = append(out, scope)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func requiredClientScopes(client domain.HydraClient) []string {
|
||||
required := make([]string, 0, 4)
|
||||
if clientTenantAccessRestricted(client.Metadata) {
|
||||
required = append(required, "tenant")
|
||||
}
|
||||
|
||||
if client.Metadata == nil {
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
|
||||
rawStructuredScopes, ok := client.Metadata["structured_scopes"]
|
||||
if !ok || rawStructuredScopes == nil {
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
|
||||
rawBytes, err := json.Marshal(rawStructuredScopes)
|
||||
if err != nil {
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
|
||||
var scopes []clientStructuredScope
|
||||
if err := json.Unmarshal(rawBytes, &scopes); err != nil {
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
|
||||
for _, scope := range scopes {
|
||||
name := strings.TrimSpace(scope.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if scope.Mandatory || scope.Locked {
|
||||
required = append(required, name)
|
||||
}
|
||||
}
|
||||
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
458
baron-sso/backend/internal/handler/client_tenant_access_test.go
Normal file
458
baron-sso/backend/internal/handler/client_tenant_access_test.go
Normal file
@@ -0,0 +1,458 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) {
|
||||
var captured domain.HydraClient
|
||||
ownerTenantID := "tenant-owner"
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, json.Unmarshal(body, &captured))
|
||||
return httpJSONAny(r, http.StatusCreated, map[string]any{
|
||||
"client_id": captured.ClientID,
|
||||
"client_name": captured.ClientName,
|
||||
"redirect_uris": captured.RedirectURIs,
|
||||
"grant_types": captured.GrantTypes,
|
||||
"response_types": captured.ResponseTypes,
|
||||
"scope": captured.Scope,
|
||||
"token_endpoint_auth_method": captured.TokenEndpointAuthMethod,
|
||||
"metadata": captured.Metadata,
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
Keto: new(devMockKetoService),
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-1",
|
||||
Role: domain.RoleSuperAdmin,
|
||||
TenantID: &ownerTenantID,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/dev/clients", h.CreateClient)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"id": "client-tenant",
|
||||
"name": "Tenant Client",
|
||||
"type": "pkce",
|
||||
"redirectUris": []string{"https://rp.example.com/cb"},
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-b", "tenant-a", "tenant-b"},
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusCreated, resp.StatusCode)
|
||||
assert.True(t, clientTenantAccessRestricted(captured.Metadata))
|
||||
assert.Equal(t, []string{"tenant-a", "tenant-b", "tenant-owner"}, clientAllowedTenants(captured.Metadata))
|
||||
}
|
||||
|
||||
func TestCreateClient_RejectsTenantAccessWithoutAllowedTenants(t *testing.T) {
|
||||
hydraCalled := false
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
|
||||
hydraCalled = true
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
Keto: new(devMockKetoService),
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "user-1", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/dev/clients", h.CreateClient)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"id": "client-tenant",
|
||||
"name": "Tenant Client",
|
||||
"type": "pkce",
|
||||
"redirectUris": []string{"https://rp.example.com/cb"},
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
assert.False(t, hydraCalled)
|
||||
}
|
||||
|
||||
func TestMergeRequestedScopesWithClientRequirements_AddsTenantScope(t *testing.T) {
|
||||
client := domain.HydraClient{
|
||||
Metadata: map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"structured_scopes": []map[string]any{
|
||||
{"name": "openid", "mandatory": true},
|
||||
{"name": "tenant", "mandatory": true, "locked": true},
|
||||
{"name": "profile", "mandatory": false},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
merged := mergeRequestedScopesWithClientRequirements(client, []string{"openid", "profile"})
|
||||
assert.Equal(t, []string{"openid", "tenant", "profile"}, merged)
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant":
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-tenant",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"skip": false,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-tenant",
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-b"},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
case r.URL.Host == "kratos.test" && r.URL.Path == "/sessions/whoami":
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
"tenant_id": "tenant-a",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
default:
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
}
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
|
||||
|
||||
t.Setenv("APP_ENV", "dev")
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant", nil)
|
||||
req.Header.Set("X-Mock-Role", "user")
|
||||
req.Header.Set("X-Tenant-ID", "tenant-a")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, account["email"])
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, currentTenant["identifier"])
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) {
|
||||
acceptCalled := false
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-profile-missing":
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-profile-missing",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"skip": false,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-tenant",
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-b"},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
case r.URL.Path == "/oauth2/auth/requests/consent/accept":
|
||||
acceptCalled = true
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
default:
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
}
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: func() service.KratosAdminService {
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"tenant_id": "tenant-a",
|
||||
"companyCode": "tenant-a",
|
||||
},
|
||||
}, nil).Once()
|
||||
return mockKratos
|
||||
}(),
|
||||
TenantService: func() service.TenantService {
|
||||
tenantSvc := new(MockTenantService)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
||||
ID: "tenant-a",
|
||||
Slug: "tenant-a",
|
||||
Name: "Tenant A",
|
||||
}, nil)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-c").Return(&domain.Tenant{
|
||||
ID: "tenant-c",
|
||||
Slug: "tenant-c",
|
||||
Name: "Tenant C",
|
||||
}, nil)
|
||||
tenantSvc.On("ListJoinedTenants", mock.Anything, "user-123").Return([]domain.Tenant{
|
||||
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
|
||||
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
|
||||
}, nil).Once()
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError)
|
||||
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
|
||||
ID: "tenant-b-id",
|
||||
Slug: "tenant-b",
|
||||
Name: "Tenant B",
|
||||
}, nil)
|
||||
return tenantSvc
|
||||
}(),
|
||||
ConsentRepo: &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{
|
||||
ClientID: "client-tenant",
|
||||
Subject: "user-123",
|
||||
GrantedScopes: []string{"openid", "profile"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-profile-missing", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=invalid-session")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
assert.False(t, acceptCalled)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "user@test.com", account["email"])
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant A", currentTenant["name"])
|
||||
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, affiliatedTenants, 2)
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/deny", func(c *fiber.Ctx) error {
|
||||
tenantID := "tenant-a"
|
||||
profile := &domain.UserProfileResponse{
|
||||
ID: "user-123",
|
||||
Role: domain.RoleUser,
|
||||
Email: "user@test.com",
|
||||
TenantID: &tenantID,
|
||||
CompanyCode: "tenant-a",
|
||||
JoinedTenants: []domain.Tenant{
|
||||
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
|
||||
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
|
||||
},
|
||||
}
|
||||
client := domain.HydraClient{
|
||||
ClientID: "client-tenant",
|
||||
Metadata: map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-b"},
|
||||
},
|
||||
}
|
||||
tenantSvc := new(MockTenantService)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
||||
ID: "tenant-a",
|
||||
Slug: "tenant-a",
|
||||
Name: "Tenant A",
|
||||
}, nil)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-c").Return(&domain.Tenant{
|
||||
ID: "tenant-c",
|
||||
Slug: "tenant-c",
|
||||
Name: "Tenant C",
|
||||
}, nil)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError)
|
||||
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
|
||||
ID: "tenant-b-id",
|
||||
Slug: "tenant-b",
|
||||
Name: "Tenant B",
|
||||
}, nil)
|
||||
enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/deny", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "user@test.com", account["email"])
|
||||
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant A", currentTenant["name"])
|
||||
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, affiliatedTenants, 2)
|
||||
|
||||
allowedTenants, ok := details["allowed_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, allowedTenants, 1)
|
||||
allowedTenant, ok := allowedTenants[0].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant B", allowedTenant["name"])
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_AllowsRestrictedClientForHanmacFamilyDescendant(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/allow-descendant", func(c *fiber.Ctx) error {
|
||||
hanmacFamilyID := "hanmac-family-id"
|
||||
samanID := "saman-id"
|
||||
profile := &domain.UserProfileResponse{
|
||||
ID: "user-123",
|
||||
Role: domain.RoleUser,
|
||||
Email: "user@samaneng.com",
|
||||
TenantID: &samanID,
|
||||
Tenant: &domain.Tenant{
|
||||
ID: samanID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
ParentID: &hanmacFamilyID,
|
||||
},
|
||||
JoinedTenants: []domain.Tenant{
|
||||
{
|
||||
ID: samanID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
ParentID: &hanmacFamilyID,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := domain.HydraClient{
|
||||
ClientID: "orgfront",
|
||||
Metadata: map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"hanmac-family"},
|
||||
},
|
||||
}
|
||||
tenantSvc := new(MockTenantService)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "hanmac-family").Return(nil, assert.AnError).Maybe()
|
||||
tenantSvc.On("GetTenantBySlug", mock.Anything, "hanmac-family").Return(&domain.Tenant{
|
||||
ID: hanmacFamilyID,
|
||||
Slug: "hanmac-family",
|
||||
Name: "한맥가족",
|
||||
}, nil).Maybe()
|
||||
tenantSvc.On("GetTenant", mock.Anything, samanID).Return(&domain.Tenant{
|
||||
ID: samanID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
ParentID: &hanmacFamilyID,
|
||||
}, nil).Maybe()
|
||||
tenantSvc.On("GetTenant", mock.Anything, hanmacFamilyID).Return(&domain.Tenant{
|
||||
ID: hanmacFamilyID,
|
||||
Slug: "hanmac-family",
|
||||
Name: "한맥가족",
|
||||
}, nil).Maybe()
|
||||
|
||||
blocked := enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
|
||||
assert.False(t, blocked)
|
||||
return c.SendStatus(http.StatusNoContent)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/allow-descendant", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
}
|
||||
303
baron-sso/backend/internal/handler/common_test.go
Normal file
303
baron-sso/backend/internal/handler/common_test.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Mock IDP Provider ---
|
||||
|
||||
type mockIdpProvider struct {
|
||||
userExists bool
|
||||
name string
|
||||
signInInfo *domain.AuthInfo
|
||||
issueSession *domain.AuthInfo
|
||||
verifyCodeInfo *domain.AuthInfo
|
||||
err error
|
||||
initiateLinkErr error
|
||||
updateCalled bool
|
||||
updateCallCount int
|
||||
updatedLoginID string
|
||||
updatedPassword string
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) Name() string {
|
||||
if m.name != "" {
|
||||
return m.name
|
||||
}
|
||||
return "mock-idp"
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) { return nil, m.err }
|
||||
func (m *mockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
return "mock-user-id", m.err
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
return m.signInInfo, m.err
|
||||
}
|
||||
func (m *mockIdpProvider) UserExists(loginID string) (bool, error) { return m.userExists, m.err }
|
||||
func (m *mockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
if m.issueSession != nil {
|
||||
return m.issueSession, m.err
|
||||
}
|
||||
return &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "valid-sid"},
|
||||
}, m.err
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
if m.initiateLinkErr != nil {
|
||||
return nil, m.initiateLinkErr
|
||||
}
|
||||
return &domain.LinkLoginInit{FlowID: "mock-flow-id", Mode: "code"}, m.err
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return m.verifyCodeInfo, m.err
|
||||
}
|
||||
func (m *mockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) { return nil, m.err }
|
||||
func (m *mockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return m.err }
|
||||
func (m *mockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
m.updateCalled = true
|
||||
m.updateCallCount++
|
||||
m.updatedLoginID = loginID
|
||||
m.updatedPassword = newPassword
|
||||
return m.err
|
||||
}
|
||||
|
||||
// --- Mock Audit Repository ---
|
||||
|
||||
type mockAuditRepo struct {
|
||||
logs []domain.AuditLog
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) Create(log *domain.AuditLog) error {
|
||||
m.logs = append(m.logs, *log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
||||
return m.logs, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
|
||||
var results []domain.AuditLog
|
||||
for _, log := range m.logs {
|
||||
if log.UserID == userID {
|
||||
if slices.Contains(eventTypes, log.EventType) {
|
||||
results = append(results, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) Ping(ctx context.Context) error { return nil }
|
||||
|
||||
type mockRPUsageEventSink struct {
|
||||
events []domain.RPUsageEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRPUsageEventSink) EmitRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.events = append(m.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockOathkeeperRepo struct {
|
||||
logs []domain.OathkeeperAccessLog
|
||||
}
|
||||
|
||||
func (m *mockOathkeeperRepo) FindPageBySubject(ctx context.Context, subject string, limit int, cursor *domain.AuditCursor) ([]domain.OathkeeperAccessLog, error) {
|
||||
if subject == "" {
|
||||
return m.logs, nil
|
||||
}
|
||||
results := make([]domain.OathkeeperAccessLog, 0, len(m.logs))
|
||||
for _, log := range m.logs {
|
||||
if log.Subject == subject {
|
||||
results = append(results, log)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *mockOathkeeperRepo) Ping(ctx context.Context) error { return nil }
|
||||
|
||||
// --- Mock Consent Repository ---
|
||||
|
||||
type mockConsentRepo struct {
|
||||
consents []domain.ClientConsent
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) Upsert(ctx context.Context, consent *domain.ClientConsent) error {
|
||||
m.consents = append(m.consents, *consent)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error) {
|
||||
var results []domain.ClientConsent
|
||||
for _, c := range m.consents {
|
||||
if c.Subject == subject {
|
||||
results = append(results, c)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) ListSubjectsByClient(ctx context.Context, clientID string) ([]string, error) {
|
||||
seen := map[string]struct{}{}
|
||||
subjects := make([]string, 0, len(m.consents))
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID != clientID {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[consent.Subject]; ok {
|
||||
continue
|
||||
}
|
||||
seen[consent.Subject] = struct{}{}
|
||||
subjects = append(subjects, consent.Subject)
|
||||
}
|
||||
return subjects, nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) Find(ctx context.Context, clientID, subject string) (*domain.ClientConsent, error) {
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID == clientID && consent.Subject == subject {
|
||||
found := consent
|
||||
return &found, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) Delete(ctx context.Context, subject, clientID string) error {
|
||||
filtered := m.consents[:0]
|
||||
for _, consent := range m.consents {
|
||||
if consent.Subject == subject && (clientID == "" || consent.ClientID == clientID) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, consent)
|
||||
}
|
||||
m.consents = filtered
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) DeleteByClient(ctx context.Context, clientID string) error {
|
||||
filtered := m.consents[:0]
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID != clientID {
|
||||
filtered = append(filtered, consent)
|
||||
}
|
||||
}
|
||||
m.consents = filtered
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
|
||||
results := make([]domain.ClientConsentWithTenantInfo, 0, len(m.consents))
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID == clientID {
|
||||
results = append(results, domain.ClientConsentWithTenantInfo{ClientConsent: consent})
|
||||
}
|
||||
}
|
||||
return results, int64(len(results)), nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
|
||||
results := make([]domain.ClientConsentWithTenantInfo, 0, len(m.consents))
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID == clientID {
|
||||
results = append(results, domain.ClientConsentWithTenantInfo{
|
||||
ClientConsent: consent,
|
||||
TenantID: tenantID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return results, int64(len(results)), nil
|
||||
}
|
||||
|
||||
// --- Mock Secret Repository ---
|
||||
|
||||
type mockSecretRepo struct {
|
||||
secrets map[string]string
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) Upsert(ctx context.Context, clientID, secret string) error {
|
||||
if m.secrets == nil {
|
||||
m.secrets = make(map[string]string)
|
||||
}
|
||||
m.secrets[clientID] = secret
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) GetByID(ctx context.Context, clientID string) (string, error) {
|
||||
return m.secrets[clientID], nil
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) Delete(ctx context.Context, clientID string) error {
|
||||
delete(m.secrets, clientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- HTTP Mock Helpers ---
|
||||
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func setDefaultHTTPClientForTest(t interface{ Cleanup(func()) }, transport http.RoundTripper) {
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = &http.Client{Transport: transport}
|
||||
t.Cleanup(func() {
|
||||
http.DefaultClient = origDefault
|
||||
})
|
||||
}
|
||||
|
||||
func httpResponse(r *http.Request, code int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||
Request: r,
|
||||
}
|
||||
}
|
||||
|
||||
func httpJSONAny(r *http.Request, code int, data any) *http.Response {
|
||||
body, _ := json.Marshal(data)
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewBuffer(body)),
|
||||
Request: r,
|
||||
}
|
||||
}
|
||||
4165
baron-sso/backend/internal/handler/dev_handler.go
Normal file
4165
baron-sso/backend/internal/handler/dev_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
242
baron-sso/backend/internal/handler/dev_handler_isolation_test.go
Normal file
242
baron-sso/backend/internal/handler/dev_handler_isolation_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestDevHandler_Isolation(t *testing.T) {
|
||||
createHandler := func(mockKeto *devMockKetoService) *DevHandler {
|
||||
return &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/clients" {
|
||||
return httpJSONAny(r, http.StatusOK, []map[string]any{
|
||||
{
|
||||
"client_id": "client-tenant-a",
|
||||
"client_name": "App Tenant A",
|
||||
"token_endpoint_auth_method": "none", // PKCE
|
||||
"metadata": map[string]any{"tenant_id": "tenant-a"},
|
||||
},
|
||||
{
|
||||
"client_id": "client-tenant-b",
|
||||
"client_name": "App Tenant B",
|
||||
"token_endpoint_auth_method": "none", // PKCE
|
||||
"metadata": map[string]any{"tenant_id": "tenant-b"},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if (r.Method == http.MethodGet || r.Method == http.MethodPut) && strings.HasPrefix(r.URL.Path, "/clients/") {
|
||||
id := strings.TrimPrefix(r.URL.Path, "/clients/")
|
||||
tenantID := "tenant-a"
|
||||
if id == "client-tenant-b" {
|
||||
tenantID = "tenant-b"
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": id,
|
||||
"client_name": "App " + id,
|
||||
"token_endpoint_auth_method": "none",
|
||||
"metadata": map[string]any{"tenant_id": tenantID},
|
||||
}), nil
|
||||
}
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
return httpJSONAny(r, http.StatusCreated, body), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
}),
|
||||
},
|
||||
},
|
||||
Keto: mockKeto,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Local bypass should be removed", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/dev/clients", h.ListClients)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
|
||||
req.Header.Set("Origin", "http://localhost:5174")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("ListClients should show all for SuperAdmin", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "super-user",
|
||||
Role: domain.RoleSuperAdmin,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/dev/clients", h.ListClients)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var res struct {
|
||||
Items []clientSummary `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
// Should see both clients
|
||||
assert.Equal(t, 2, len(res.Items))
|
||||
})
|
||||
|
||||
t.Run("ListClients should filter by permit for non-SuperAdmin", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
tenantA := "tenant-a"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-a",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &tenantA,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/dev/clients", h.ListClients)
|
||||
|
||||
// Explicit permission for private client check bypass
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "System", "global", "manage_all").Return(true, nil).Maybe()
|
||||
// Mock permit for the specific client
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(true, nil).Maybe()
|
||||
// Deny for other clients
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-b", "view").Return(false, nil).Maybe()
|
||||
|
||||
mockKeto.On("ListRelations", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]service.RelationTuple{}, nil).Maybe()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var res struct {
|
||||
Items []clientSummary `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
// Should only see client-tenant-a (tenant permit)
|
||||
assert.Equal(t, 1, len(res.Items))
|
||||
assert.Equal(t, "client-tenant-a", res.Items[0].ID)
|
||||
})
|
||||
|
||||
t.Run("Tenant member should see empty list from DevFront clients if no relation", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
tenantA := "tenant-a"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-member",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &tenantA,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/dev/clients", h.ListClients)
|
||||
|
||||
// Deny all by default
|
||||
mockKeto.On("CheckPermission", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(false, nil).Maybe()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res struct {
|
||||
Items []clientSummary `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
// Empty list because we didn't mock any specific 'view' permissions for this user
|
||||
assert.Equal(t, 0, len(res.Items))
|
||||
})
|
||||
|
||||
t.Run("GetClient should enforce isolation for non-SuperAdmin", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
tenantA := "tenant-a"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-a",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &tenantA,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/dev/clients/:id", h.GetClient)
|
||||
|
||||
// Case 1: Same tenant BUT no permit (Normal users need permit now)
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(false, nil).Once()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-a", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
// Case 2: Same tenant WITH permit
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(true, nil).Maybe()
|
||||
mockKeto.On("ListRelations", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]service.RelationTuple{}, nil).Maybe()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-a", nil)
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Case 3: Different tenant
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-b", "view").Return(false, nil).Maybe()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-b", nil)
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("CreateClient should record user_id and tenant_id", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
tenantA := "tenant-a"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-a",
|
||||
Role: domain.RoleSuperAdmin, // Bypass for creation permission
|
||||
TenantID: &tenantA,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/dev/clients", h.CreateClient)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"client_name": "New App",
|
||||
"type": "pkce",
|
||||
"redirectUris": []string{"http://localhost/cb"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Tenant-ID", "tenant-a")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusCreated, resp.StatusCode)
|
||||
|
||||
var res clientDetailResponse
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Equal(t, "tenant-a", res.Client.Metadata["tenant_id"])
|
||||
assert.Equal(t, "user-a", res.Client.Metadata["user_id"])
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type devMockRPUserMetadataRepo struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *devMockRPUserMetadataRepo) Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error) {
|
||||
args := m.Called(ctx, clientID, userID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.RPUserMetadata), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *devMockRPUserMetadataRepo) Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error {
|
||||
args := m.Called(ctx, metadata)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-1",
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "approvalLevel",
|
||||
"valueType": "text",
|
||||
"value": "A",
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
repo.On("Upsert", mock.Anything, mock.MatchedBy(func(row *domain.RPUserMetadata) bool {
|
||||
return row.ClientID == "client-1" &&
|
||||
row.UserID == "user-1" &&
|
||||
row.Metadata["approvalLevel"] == "A" &&
|
||||
row.Metadata["approvalLevel_permissions"].(map[string]any)["readPermission"] == "admin_only" &&
|
||||
row.Metadata["approvalLevel_permissions"].(map[string]any)["writePermission"] == "user_and_admin"
|
||||
})).Return(nil).Once()
|
||||
repo.On("Get", mock.Anything, "client-1", "user-1").Return(&domain.RPUserMetadata{
|
||||
ClientID: "client-1",
|
||||
UserID: "user-1",
|
||||
Metadata: domain.JSONMap{"approvalLevel": "A"},
|
||||
}, nil).Once()
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
app.Get("/api/v1/dev/clients/:id/users/:userId/metadata", h.GetRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{
|
||||
"approvalLevel": "A",
|
||||
"approvalLevel_permissions": map[string]any{
|
||||
"writePermission": "user_and_admin",
|
||||
},
|
||||
},
|
||||
})
|
||||
putReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
putReq.Header.Set("Content-Type", "application/json")
|
||||
putResp, _ := app.Test(putReq, -1)
|
||||
assert.Equal(t, http.StatusOK, putResp.StatusCode)
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-1/users/user-1/metadata", nil)
|
||||
getResp, _ := app.Test(getReq, -1)
|
||||
assert.Equal(t, http.StatusOK, getResp.StatusCode)
|
||||
|
||||
var got map[string]any
|
||||
assert.NoError(t, json.NewDecoder(getResp.Body).Decode(&got))
|
||||
assert.Equal(t, "client-1", got["clientId"])
|
||||
assert.Equal(t, "user-1", got["userId"])
|
||||
assert.Equal(t, "A", got["metadata"].(map[string]any)["approvalLevel"])
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataMirrorsToKratosTraits(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-1",
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "approvalLevel",
|
||||
"valueType": "text",
|
||||
"value": "A",
|
||||
"readPermission": "user_and_admin",
|
||||
"writePermission": "admin_only",
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
repo.On("Upsert", mock.Anything, mock.AnythingOfType("*domain.RPUserMetadata")).Return(nil).Once()
|
||||
kratos := new(MockKratosAdmin)
|
||||
kratos.On("GetIdentity", mock.Anything, "user-1").Return(&service.KratosIdentity{
|
||||
ID: "user-1",
|
||||
State: "active",
|
||||
Traits: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User One",
|
||||
},
|
||||
}, nil).Once()
|
||||
var capturedTraits map[string]any
|
||||
kratos.On("UpdateIdentity", mock.Anything, "user-1", mock.Anything, "active").Run(func(args mock.Arguments) {
|
||||
capturedTraits = args.Get(2).(map[string]any)
|
||||
}).Return(&service.KratosIdentity{ID: "user-1", State: "active", Traits: map[string]any{}}, nil).Once()
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
KratosAdmin: kratos,
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{"approvalLevel": "B"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
rpClaims := capturedTraits["rp_custom_claims"].(map[string]any)
|
||||
clientClaims := rpClaims["client-1"].(domain.JSONMap)
|
||||
require.Equal(t, "B", clientClaims["approvalLevel"])
|
||||
require.Equal(t, map[string]any{
|
||||
"readPermission": "user_and_admin",
|
||||
"writePermission": "admin_only",
|
||||
}, clientClaims["approvalLevel_permissions"])
|
||||
repo.AssertExpectations(t)
|
||||
kratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataRejectsUndefinedClaimKey(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]any{
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "contract_date",
|
||||
"valueType": "date",
|
||||
"value": "2026-06-09",
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{"unknown_claim": "A"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataRejectsInvalidTypedClaimValue(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]any{
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "contract_date",
|
||||
"valueType": "date",
|
||||
"value": "2026-06-09",
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{"contract_date": "2026/06/09"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
|
||||
}
|
||||
3791
baron-sso/backend/internal/handler/dev_handler_test.go
Normal file
3791
baron-sso/backend/internal/handler/dev_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
17
baron-sso/backend/internal/handler/error_helper.go
Normal file
17
baron-sso/backend/internal/handler/error_helper.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/response"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// errorJSON은 기존 error 필드를 유지하면서 기계 판독용 code를 명시적으로 추가합니다.
|
||||
func errorJSON(c *fiber.Ctx, status int, message string) error {
|
||||
return response.Error(c, status, response.StatusCode(status), message)
|
||||
}
|
||||
|
||||
// errorJSONCode는 상태코드 기반 매핑만으로 부족한 경우 명시 코드를 강제할 때 사용합니다.
|
||||
func errorJSONCode(c *fiber.Ctx, status int, code, message string) error {
|
||||
return response.Error(c, status, code, message)
|
||||
}
|
||||
161
baron-sso/backend/internal/handler/federation_handler.go
Normal file
161
baron-sso/backend/internal/handler/federation_handler.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"errors"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// FederationHandler handles API requests for IdP federation.
|
||||
type FederationHandler struct {
|
||||
fedSvc *service.FederationService
|
||||
repo repository.FederationRepository // For IdP Config CRUD
|
||||
db *gorm.DB // For tenant existence checks, etc. in CRUD
|
||||
}
|
||||
|
||||
// NewFederationHandler creates a new FederationHandler.
|
||||
func NewFederationHandler(fedSvc *service.FederationService, repo repository.FederationRepository, db *gorm.DB) *FederationHandler {
|
||||
return &FederationHandler{
|
||||
fedSvc: fedSvc,
|
||||
repo: repo,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateOIDCLogin handles the start of the OIDC login flow.
|
||||
// It expects `provider_id` and `login_challenge` as query parameters.
|
||||
func (h *FederationHandler) InitiateOIDCLogin(c *fiber.Ctx) error {
|
||||
providerID := c.Query("provider_id")
|
||||
loginChallenge := c.Query("login_challenge")
|
||||
|
||||
if providerID == "" || loginChallenge == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "provider_id and login_challenge are required")
|
||||
}
|
||||
|
||||
redirectURL, err := h.fedSvc.InitiateOIDCLogin(c.Context(), providerID, loginChallenge)
|
||||
if err != nil {
|
||||
// Log the error properly in a real application
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to initiate OIDC login")
|
||||
}
|
||||
|
||||
return c.Redirect(redirectURL, fiber.StatusFound)
|
||||
}
|
||||
|
||||
// HandleOIDCCallback handles the OIDC callback from the IdP.
|
||||
func (h *FederationHandler) HandleOIDCCallback(c *fiber.Ctx) error {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "code and state are required")
|
||||
}
|
||||
|
||||
redirectURL, err := h.fedSvc.HandleOIDCCallback(c.Context(), code, state)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to handle OIDC callback")
|
||||
}
|
||||
|
||||
return c.Redirect(redirectURL, fiber.StatusFound)
|
||||
}
|
||||
|
||||
// --- New Client-based IdP Config Methods ---
|
||||
|
||||
// ListIdpConfigsForClient handles listing all IdP configurations for a client.
|
||||
func (h *FederationHandler) ListIdpConfigsForClient(c *fiber.Ctx) error {
|
||||
clientID := c.Params("clientId")
|
||||
if clientID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "clientId is required")
|
||||
}
|
||||
|
||||
var configs []domain.IdentityProviderConfig
|
||||
if err := h.db.Where("client_id = ?", clientID).Find(&configs).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(configs)
|
||||
}
|
||||
|
||||
// CreateIdpConfigForClient handles the creation of a new IdP configuration for a client.
|
||||
func (h *FederationHandler) CreateIdpConfigForClient(c *fiber.Ctx) error {
|
||||
clientID := c.Params("clientId")
|
||||
if clientID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "clientId is required in path")
|
||||
}
|
||||
|
||||
var req domain.IdentityProviderConfig
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
// Assign clientID from path parameter
|
||||
req.ClientID = clientID
|
||||
|
||||
// Basic validation
|
||||
if req.DisplayName == "" || req.ProviderType == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "display_name and provider_type are required")
|
||||
}
|
||||
|
||||
// TODO: Optionally, validate if the clientID exists in Hydra
|
||||
|
||||
// Create in DB
|
||||
if err := h.db.Create(&req).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(req)
|
||||
}
|
||||
|
||||
// --- Deprecated Tenant-based IdP Config Methods ---
|
||||
|
||||
// ListIdpConfigsForTenant handles listing all IdP configurations for a tenant.
|
||||
func (h *FederationHandler) ListIdpConfigsForTenant(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
if tenantID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
|
||||
}
|
||||
|
||||
// This is a temporary solution. We should create a proper method in the repository.
|
||||
var configs []domain.IdentityProviderConfig
|
||||
// Note: This now queries client_id, which is incorrect for tenants.
|
||||
// This method is deprecated.
|
||||
if err := h.db.Where("tenant_id = ?", tenantID).Find(&configs).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(configs)
|
||||
}
|
||||
|
||||
// CreateIdpConfig handles the creation of a new IdP configuration.
|
||||
func (h *FederationHandler) CreateIdpConfig(c *fiber.Ctx) error {
|
||||
var req domain.IdentityProviderConfig
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
// Basic validation - This is the old validation logic
|
||||
if req.ClientID == "" || req.DisplayName == "" || req.ProviderType == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "client_id, display_name, and provider_type are required")
|
||||
}
|
||||
|
||||
// This check is now incorrect and deprecated.
|
||||
var tenant domain.Tenant
|
||||
if err := h.db.First(&tenant, "id = ?", req.ClientID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "tenant not found")
|
||||
}
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
// Create in DB
|
||||
if err := h.db.Create(&req).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(req)
|
||||
}
|
||||
|
||||
// TODO: Re-implement Update, Delete handlers for IdP Configs for Clients
|
||||
243
baron-sso/backend/internal/handler/hanmac_email_policy.go
Normal file
243
baron-sso/backend/internal/handler/hanmac_email_policy.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const hanmacFamilyTenantSlug = "hanmac-family"
|
||||
|
||||
type hanmacEmailScope struct {
|
||||
TenantIDs map[string]bool
|
||||
Slugs map[string]bool
|
||||
IDList []string
|
||||
SlugList []string
|
||||
}
|
||||
|
||||
type hanmacEmailEvaluation struct {
|
||||
Email string
|
||||
OriginalEmail string
|
||||
SuggestedEmail string
|
||||
Status string
|
||||
Warnings []string
|
||||
Message string
|
||||
Blocking bool
|
||||
LocalPart string
|
||||
}
|
||||
|
||||
func (h *UserHandler) evaluateHanmacImportEmail(ctx context.Context, item bulkUserItem, scope *hanmacEmailScope, usedLocalParts map[string]bool) hanmacEmailEvaluation {
|
||||
originalEmail := strings.TrimSpace(item.Email)
|
||||
name := strings.TrimSpace(item.Name)
|
||||
evaluation := hanmacEmailEvaluation{
|
||||
Email: originalEmail,
|
||||
OriginalEmail: originalEmail,
|
||||
Status: "valid",
|
||||
}
|
||||
|
||||
localPart, domainPart, err := domain.SplitEmailDomain(originalEmail)
|
||||
if err != nil {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "invalid email format"
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
|
||||
base, needsReview, _ := domain.BuildKoreanNameEmailBase(name)
|
||||
if needsReview {
|
||||
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
|
||||
evaluation.Status = "needsReview"
|
||||
}
|
||||
|
||||
if localPart == "" {
|
||||
if base == "" {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "이름으로 이메일 ID를 제안할 수 없습니다."
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
nextLocalPart := nextAvailableHanmacLocalPart(base, usedLocalParts)
|
||||
evaluation.Email = nextLocalPart + "@" + domainPart
|
||||
evaluation.SuggestedEmail = evaluation.Email
|
||||
evaluation.LocalPart = nextLocalPart
|
||||
evaluation.Status = "suggested"
|
||||
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "suggested")
|
||||
return evaluation
|
||||
}
|
||||
|
||||
evaluation.LocalPart = localPart
|
||||
if usedLocalParts[localPart] {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "한맥가족 내에서 이미 사용 중인 이메일 ID입니다."
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
|
||||
if base != "" && !domain.MatchesSuggestedNameRule(localPart, base) {
|
||||
evaluation.Status = "ruleMismatch"
|
||||
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "ruleMismatch")
|
||||
}
|
||||
|
||||
if evaluation.Status == "needsReview" && len(evaluation.Warnings) == 0 {
|
||||
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
|
||||
}
|
||||
_ = scope
|
||||
return evaluation
|
||||
}
|
||||
|
||||
func (h *UserHandler) ensureHanmacCreateEmailAllowed(ctx context.Context, email string, tenantSlug string, tenantID string) error {
|
||||
scope, err := h.resolveHanmacEmailScope(ctx)
|
||||
if err != nil || scope == nil || !scope.ContainsTenant(tenantID, tenantSlug) {
|
||||
return nil
|
||||
}
|
||||
|
||||
localPart, err := domain.ExtractNormalizedEmailLocalPart(email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usedLocalParts, err := h.loadHanmacLocalParts(ctx, scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usedLocalParts[localPart] {
|
||||
return fmt.Errorf("한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *UserHandler) resolveHanmacEmailScope(ctx context.Context) (*hanmacEmailScope, error) {
|
||||
if h.TenantService == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tenants, _, err := h.TenantService.ListTenants(ctx, 10000, 0, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rootID string
|
||||
for _, tenant := range tenants {
|
||||
if strings.EqualFold(strings.TrimSpace(tenant.Slug), hanmacFamilyTenantSlug) {
|
||||
rootID = tenant.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
if rootID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tenantByID := make(map[string]domain.Tenant, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
tenantByID[tenant.ID] = tenant
|
||||
}
|
||||
|
||||
scope := &hanmacEmailScope{
|
||||
TenantIDs: make(map[string]bool),
|
||||
Slugs: make(map[string]bool),
|
||||
}
|
||||
for _, tenant := range tenants {
|
||||
if isTenantDescendantOf(tenant, rootID, tenantByID) {
|
||||
scope.TenantIDs[tenant.ID] = true
|
||||
scope.Slugs[strings.ToLower(strings.TrimSpace(tenant.Slug))] = true
|
||||
scope.IDList = append(scope.IDList, tenant.ID)
|
||||
scope.SlugList = append(scope.SlugList, tenant.Slug)
|
||||
}
|
||||
}
|
||||
return scope, nil
|
||||
}
|
||||
|
||||
func (h *UserHandler) loadHanmacLocalParts(ctx context.Context, scope *hanmacEmailScope) (map[string]bool, error) {
|
||||
used := make(map[string]bool)
|
||||
if h.UserRepo == nil || scope == nil {
|
||||
return used, nil
|
||||
}
|
||||
|
||||
if len(scope.IDList) > 0 {
|
||||
users, err := h.UserRepo.FindByTenantIDs(ctx, scope.IDList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addUserEmailLocalParts(used, users)
|
||||
}
|
||||
|
||||
if len(scope.SlugList) > 0 {
|
||||
users, err := h.UserRepo.FindByCompanyCodes(ctx, scope.SlugList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addUserEmailLocalParts(used, users)
|
||||
}
|
||||
|
||||
return used, nil
|
||||
}
|
||||
|
||||
func (s *hanmacEmailScope) ContainsTenant(tenantID string, slug string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if tenantID != "" && s.TenantIDs[tenantID] {
|
||||
return true
|
||||
}
|
||||
return s.Slugs[strings.ToLower(strings.TrimSpace(slug))]
|
||||
}
|
||||
|
||||
func isTenantDescendantOf(tenant domain.Tenant, rootID string, tenantByID map[string]domain.Tenant) bool {
|
||||
if tenant.ID == rootID {
|
||||
return true
|
||||
}
|
||||
visited := make(map[string]bool)
|
||||
parentID := ""
|
||||
if tenant.ParentID != nil {
|
||||
parentID = *tenant.ParentID
|
||||
}
|
||||
for parentID != "" {
|
||||
if parentID == rootID {
|
||||
return true
|
||||
}
|
||||
if visited[parentID] {
|
||||
return false
|
||||
}
|
||||
visited[parentID] = true
|
||||
parent, ok := tenantByID[parentID]
|
||||
if !ok || parent.ParentID == nil {
|
||||
return false
|
||||
}
|
||||
parentID = *parent.ParentID
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addUserEmailLocalParts(target map[string]bool, users []domain.User) {
|
||||
for _, user := range users {
|
||||
localPart, err := domain.ExtractNormalizedEmailLocalPart(user.Email)
|
||||
if err == nil && localPart != "" {
|
||||
target[localPart] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func nextAvailableHanmacLocalPart(base string, usedLocalParts map[string]bool) string {
|
||||
base = strings.ToLower(strings.TrimSpace(base))
|
||||
if base == "" {
|
||||
return ""
|
||||
}
|
||||
if !usedLocalParts[base] {
|
||||
return base
|
||||
}
|
||||
for index := 1; ; index++ {
|
||||
candidate := fmt.Sprintf("%s%d", base, index)
|
||||
if !usedLocalParts[candidate] {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendUniqueString(values []string, value string) []string {
|
||||
if slices.Contains(values, value) {
|
||||
return values
|
||||
}
|
||||
return append(values, value)
|
||||
}
|
||||
98
baron-sso/backend/internal/handler/password_policy_test.go
Normal file
98
baron-sso/backend/internal/handler/password_policy_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// 정책을 받아 필수 요구사항을 모두 포함하는 비밀번호를 생성한다.
|
||||
func generatePasswordFromPolicy(policy *domain.PasswordPolicy) string {
|
||||
minLen := policy.MinLength
|
||||
if minLen < 8 {
|
||||
minLen = 12 // 안전한 기본값
|
||||
}
|
||||
|
||||
pwd := make([]rune, 0, minLen)
|
||||
|
||||
if policy.Lowercase {
|
||||
pwd = append(pwd, 'a')
|
||||
}
|
||||
if policy.Uppercase {
|
||||
pwd = append(pwd, 'B')
|
||||
}
|
||||
if policy.Number {
|
||||
pwd = append(pwd, '3')
|
||||
}
|
||||
if policy.NonAlphanumeric {
|
||||
pwd = append(pwd, '!')
|
||||
}
|
||||
|
||||
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
|
||||
for len(pwd) < minLen {
|
||||
pwd = append(pwd, rune(charset[randomInt(len(charset))]))
|
||||
}
|
||||
|
||||
// 섞어서 예측 가능성을 낮춘다.
|
||||
for i := range pwd {
|
||||
j := randomInt(len(pwd))
|
||||
pwd[i], pwd[j] = pwd[j], pwd[i]
|
||||
}
|
||||
return string(pwd)
|
||||
}
|
||||
|
||||
func randomInt(n int) int {
|
||||
if n <= 0 {
|
||||
return 0
|
||||
}
|
||||
var b [8]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return 0
|
||||
}
|
||||
return int(binary.BigEndian.Uint64(b[:]) % uint64(n))
|
||||
}
|
||||
|
||||
func TestGeneratePasswordUsesNonAlphanumericRequirement(t *testing.T) {
|
||||
policy := &domain.PasswordPolicy{
|
||||
MinLength: 8,
|
||||
Lowercase: true,
|
||||
Uppercase: true,
|
||||
Number: true,
|
||||
NonAlphanumeric: true,
|
||||
}
|
||||
|
||||
pwd := generatePasswordFromPolicy(policy)
|
||||
|
||||
if len(pwd) < policy.MinLength {
|
||||
t.Fatalf("비밀번호 길이가 정책 최소 길이 미만: got %d, want >= %d", len(pwd), policy.MinLength)
|
||||
}
|
||||
|
||||
var hasLower, hasUpper, hasNumber, hasSymbol bool
|
||||
for _, r := range pwd {
|
||||
switch {
|
||||
case unicode.IsLower(r):
|
||||
hasLower = true
|
||||
case unicode.IsUpper(r):
|
||||
hasUpper = true
|
||||
case unicode.IsNumber(r):
|
||||
hasNumber = true
|
||||
case !unicode.IsLetter(r) && !unicode.IsNumber(r):
|
||||
hasSymbol = true
|
||||
}
|
||||
}
|
||||
|
||||
if policy.Lowercase && !hasLower {
|
||||
t.Fatalf("소문자 요구사항 미충족: %q", pwd)
|
||||
}
|
||||
if policy.Uppercase && !hasUpper {
|
||||
t.Fatalf("대문자 요구사항 미충족: %q", pwd)
|
||||
}
|
||||
if policy.Number && !hasNumber {
|
||||
t.Fatalf("숫자 요구사항 미충족: %q", pwd)
|
||||
}
|
||||
if policy.NonAlphanumeric && !hasSymbol {
|
||||
t.Fatalf("비영문자 요구사항 미충족: %q", pwd)
|
||||
}
|
||||
}
|
||||
114
baron-sso/backend/internal/handler/relying_party_handler.go
Normal file
114
baron-sso/backend/internal/handler/relying_party_handler.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"log/slog"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type RelyingPartyHandler struct {
|
||||
Service service.RelyingPartyService
|
||||
KratosAdmin service.KratosAdminService
|
||||
}
|
||||
|
||||
func NewRelyingPartyHandler(s service.RelyingPartyService, kratos service.KratosAdminService) *RelyingPartyHandler {
|
||||
return &RelyingPartyHandler{Service: s, KratosAdmin: kratos}
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) Create(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
if tenantID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
|
||||
}
|
||||
|
||||
var req domain.HydraClient
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
rp, err := h.Service.Create(c.Context(), tenantID, req)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(rp)
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) ListAll(c *fiber.Ctx) error {
|
||||
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
if !ok {
|
||||
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: user profile not found in context")
|
||||
}
|
||||
|
||||
var rps []domain.RelyingParty
|
||||
var err error
|
||||
role := domain.NormalizeRole(profile.Role)
|
||||
|
||||
if role == domain.RoleSuperAdmin {
|
||||
rps, err = h.Service.ListAll(c.Context())
|
||||
} else if role == "tenant_admin" && profile.TenantID != nil {
|
||||
rps, err = h.Service.List(c.Context(), *profile.TenantID)
|
||||
} else {
|
||||
slog.Warn("Forbidden access to all applications", "userID", profile.ID, "role", role)
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient role to list all applications")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(rps)
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) List(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
if tenantID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
|
||||
}
|
||||
|
||||
rps, err := h.Service.List(c.Context(), tenantID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(rps)
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) Get(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
rp, hydraClient, err := h.Service.Get(c.Context(), id)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusNotFound, "relying party not found")
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"relyingParty": rp,
|
||||
"oauth2Config": hydraClient,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) Update(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
var req domain.HydraClient
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
rp, err := h.Service.Update(c.Context(), id, req)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(rp)
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) Delete(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
if err := h.Service.Delete(c.Context(), id); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
255
baron-sso/backend/internal/handler/rp_manifest_handler.go
Normal file
255
baron-sso/backend/internal/handler/rp_manifest_handler.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"html"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type RPManifestHandler struct{}
|
||||
|
||||
const rpObjectLookupMermaid = `flowchart TD
|
||||
A[RP request] --> B{obj_id supplied?}
|
||||
B -->|yes| C[Normalize object type and obj_id]
|
||||
B -->|no| D{Route has client_id?}
|
||||
D -->|yes| E[obj_id = RelyingParty:<client_id>]
|
||||
D -->|no| F{Route has tenant_id?}
|
||||
F -->|yes| G[obj_id = Tenant:<tenant_id>]
|
||||
F -->|no| H[Reject: explicit obj_id required]
|
||||
C --> I[Check Keto relation]
|
||||
E --> I
|
||||
G --> I
|
||||
I --> J{allowed?}
|
||||
J -->|yes| K[Inject trusted Baron headers]
|
||||
J -->|no| L[Reject request]
|
||||
K --> M[Write audit with obj_id, relation, client_id, X-Request-Id]`
|
||||
|
||||
const rpExternalKeyMermaid = `flowchart TD
|
||||
A[User authenticates through Baron SSO] --> B[Baron resolves internal identity]
|
||||
B --> C[Baron derives or loads Baron-issued alias]
|
||||
C --> D[Baron injects X-Baron-External-Key]
|
||||
D --> E[Baron injects X-Baron-Subject]
|
||||
E --> I[RP receives trusted headers from Baron gateway]
|
||||
I --> F[RP upserts local user with provider + X-Baron-External-Key]
|
||||
F --> G[RP stores the full external key as opaque value]
|
||||
G --> H[RP never parses or stores raw kratos_identity_id]`
|
||||
|
||||
func NewRPManifestHandler() *RPManifestHandler {
|
||||
return &RPManifestHandler{}
|
||||
}
|
||||
|
||||
func (h *RPManifestHandler) GetJSON(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
|
||||
return c.JSON(buildRPManifest(c))
|
||||
}
|
||||
|
||||
func (h *RPManifestHandler) GetSchema(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
|
||||
return c.JSON(rpManifestSchema())
|
||||
}
|
||||
|
||||
func (h *RPManifestHandler) GetHTML(c *fiber.Ctx) error {
|
||||
manifest := buildRPManifest(c)
|
||||
issuer, _ := manifest["issuer"].(string)
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
|
||||
c.Type("html", "utf-8")
|
||||
return c.SendString(`<!doctype html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Baron RP IAM Manifest</title>
|
||||
<style>
|
||||
body { font-family: system-ui, sans-serif; margin: 2rem; line-height: 1.6; max-width: 920px; }
|
||||
code, pre { background: #f5f5f5; border-radius: 4px; padding: .1rem .3rem; }
|
||||
pre { padding: 1rem; overflow: auto; }
|
||||
table { border-collapse: collapse; width: 100%; }
|
||||
th, td { border: 1px solid #ddd; padding: .5rem; text-align: left; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Baron RP IAM Manifest</h1>
|
||||
<p>외부 RP가 Baron SSO/Ory Stack/Keto 기반 공용 IAM을 연동하기 위한 공개 규격입니다.</p>
|
||||
<ul>
|
||||
<li>Machine-readable manifest: <a href="/.well-known/baron-rp-manifest.json">/.well-known/baron-rp-manifest.json</a></li>
|
||||
<li>JSON schema: <a href="/.well-known/baron-rp-manifest.schema.json">/.well-known/baron-rp-manifest.schema.json</a></li>
|
||||
</ul>
|
||||
<h2>Issuer</h2>
|
||||
<pre>` + html.EscapeString(issuer) + `</pre>
|
||||
<h2>Identity Contract</h2>
|
||||
<table>
|
||||
<tr><th>용도</th><th>Header</th><th>정책</th></tr>
|
||||
<tr><td>Keto subject</td><td><code>X-Baron-Subject</code></td><td><code>User:<baron_identity_id></code> 전체 문자열을 opaque subject로 취급합니다.</td></tr>
|
||||
<tr><td>RP upsert key</td><td><code>X-Baron-External-Key</code></td><td>Baron-issued alias입니다. RP가 만들거나 제출하지 않고, Baron이 주입한 전체 문자열을 local user external key로 저장합니다.</td></tr>
|
||||
<tr><td>RP client</td><td><code>X-Baron-Client-ID</code></td><td>현재 접근 중인 RP client id입니다.</td></tr>
|
||||
</table>
|
||||
<h2>External Key Flow</h2>
|
||||
<p><code>X-Baron-External-Key</code>는 RP 입력값이 아니라 Baron이 인증된 subject에서 발급/조회해 주입하는 opaque alias입니다. RP upserts local user from the Baron-issued alias.</p>
|
||||
<pre>` + "```mermaid\n" + html.EscapeString(rpExternalKeyMermaid) + "\n```" + `</pre>
|
||||
<h2>Object Lookup</h2>
|
||||
<pre>check(User:abc, viewers, RelyingParty:<client_id>)
|
||||
check(User:abc, members, Tenant:<tenant_id>)
|
||||
check(User:abc, viewers, Resource:<resource_type>:<resource_id>)</pre>
|
||||
<h2>audit_contract</h2>
|
||||
<p>권한과 설정을 변경하는 command는 sync audit write에 실패하면 요청도 실패해야 합니다. Read audit은 allowlist된 조회에 한해 best effort로 취급합니다.</p>
|
||||
<pre>{
|
||||
"mutating_command_mode": "fail_closed_sync",
|
||||
"missing_audit_sink_behavior": "reject_mutation",
|
||||
"correlation_header": "X-Request-Id"
|
||||
}</pre>
|
||||
<h2>Object Lookup Flow</h2>
|
||||
<pre>` + "```mermaid\n" + html.EscapeString(rpObjectLookupMermaid) + "\n```" + `</pre>
|
||||
</body>
|
||||
</html>`)
|
||||
}
|
||||
|
||||
func buildRPManifest(c *fiber.Ctx) map[string]any {
|
||||
issuer := resolvePublicRequestBaseURL(c, os.Getenv("BACKEND_PUBLIC_URL"))
|
||||
if issuer == "" {
|
||||
issuer = strings.TrimRight(os.Getenv("USERFRONT_URL"), "/")
|
||||
}
|
||||
if issuer == "" {
|
||||
issuer = "https://sso.hmac.kr"
|
||||
}
|
||||
issuer = strings.TrimRight(issuer, "/")
|
||||
|
||||
return map[string]any{
|
||||
"version": "2026-05-11",
|
||||
"issuer": issuer,
|
||||
"oidc": map[string]any{
|
||||
"discovery_url": issuer + "/.well-known/openid-configuration",
|
||||
"jwks_url": issuer + "/.well-known/jwks.json",
|
||||
"supported_flows": []string{"authorization_code_pkce"},
|
||||
"required_scopes": []string{"openid", "profile", "email"},
|
||||
},
|
||||
"iam": map[string]any{
|
||||
"authorization_engine": "ory-keto",
|
||||
"subject_format": "User:<baron_identity_id>",
|
||||
"target_object_patterns": []string{
|
||||
"RelyingParty:<client_id>",
|
||||
"Tenant:<tenant_id>",
|
||||
"Resource:<resource_type>:<resource_id>",
|
||||
},
|
||||
"supported_relations": []string{
|
||||
"admins",
|
||||
"users",
|
||||
"viewers",
|
||||
"operators",
|
||||
"members",
|
||||
"owners",
|
||||
"editors",
|
||||
},
|
||||
},
|
||||
"identity_contract": map[string]any{
|
||||
"subject_header": "X-Baron-Subject",
|
||||
"external_key_header": "X-Baron-External-Key",
|
||||
"external_key_is_opaque": true,
|
||||
"external_key_issuer": "baron",
|
||||
"external_key_delivery": "baron_injected_header",
|
||||
"external_key_lifecycle": "issued_or_loaded_after_successful_authentication_before_rp_request",
|
||||
"rp_supplied_external_key_allowed": false,
|
||||
"rp_user_upsert_source": "rp_must_upsert_from_header_value",
|
||||
"raw_kratos_identity_id_exposed": false,
|
||||
"rp_user_upsert_key": "provider + external_key",
|
||||
"email_is_stable_primary_key": false,
|
||||
"initial_external_key_expression": "X-Baron-External-Key",
|
||||
"fallback_to_subject_allowed": false,
|
||||
},
|
||||
"trusted_headers": map[string]any{
|
||||
"subject": "X-Baron-Subject",
|
||||
"external_key": "X-Baron-External-Key",
|
||||
"email": "X-Baron-Email",
|
||||
"tenant": "X-Baron-Tenant",
|
||||
"relations": "X-Baron-Relations",
|
||||
"client_id": "X-Baron-Client-ID",
|
||||
},
|
||||
"object_lookup": map[string]any{
|
||||
"rp_level": map[string]any{
|
||||
"object": "RelyingParty:<client_id>",
|
||||
"relations": []string{"viewers", "users", "operators", "admins"},
|
||||
"example": "check(User:abc, viewers, RelyingParty:mh-dashboard)",
|
||||
},
|
||||
"tenant_level": map[string]any{
|
||||
"object": "Tenant:<tenant_id>",
|
||||
"relations": []string{"members", "admins", "owners"},
|
||||
"example": "check(User:abc, members, Tenant:9caf62e1-297d-4e8f-870b-61780998bbe)",
|
||||
},
|
||||
"resource_level": map[string]any{
|
||||
"object": "Resource:<resource_type>:<resource_id>",
|
||||
"relations": []string{"viewers", "editors", "owners"},
|
||||
"example": "check(User:abc, viewers, Resource:dashboard:mh-monthly-2026-05)",
|
||||
},
|
||||
"recommended_order": []string{
|
||||
"authenticated",
|
||||
"rp_level",
|
||||
"tenant_or_resource_level",
|
||||
"trusted_header_injection",
|
||||
},
|
||||
},
|
||||
"object_lookup_flow": map[string]any{
|
||||
"format": "mermaid",
|
||||
"mermaid": rpObjectLookupMermaid,
|
||||
},
|
||||
"external_key_flow": map[string]any{
|
||||
"format": "mermaid",
|
||||
"mermaid": rpExternalKeyMermaid,
|
||||
},
|
||||
"audit_contract": map[string]any{
|
||||
"mutating_command_mode": "fail_closed_sync",
|
||||
"missing_audit_sink_behavior": "reject_mutation",
|
||||
"read_audit_mode": "best_effort_allowlisted",
|
||||
"correlation_header": "X-Request-Id",
|
||||
"rp_business_audit_required": true,
|
||||
"baron_gateway_audit_required": true,
|
||||
"required_detail_fields": []string{
|
||||
"obj_id",
|
||||
"relation",
|
||||
"client_id",
|
||||
"subject",
|
||||
"decision",
|
||||
},
|
||||
"guarantee_scope": "Baron-mediated IAM mutations fail closed on audit write failure; RP-owned business events must be emitted by the RP with the same correlation header.",
|
||||
},
|
||||
"security_requirements": map[string]any{
|
||||
"strip_external_identity_headers": true,
|
||||
"backend_direct_exposure_allowed": false,
|
||||
"static_snapshot_requires_auth": true,
|
||||
"email_as_primary_key_allowed": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func rpManifestSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"title": "Baron RP IAM Manifest",
|
||||
"type": "object",
|
||||
"required": []string{
|
||||
"version",
|
||||
"issuer",
|
||||
"oidc",
|
||||
"iam",
|
||||
"trusted_headers",
|
||||
"identity_contract",
|
||||
"object_lookup",
|
||||
"object_lookup_flow",
|
||||
"external_key_flow",
|
||||
"audit_contract",
|
||||
"security_requirements",
|
||||
},
|
||||
"properties": map[string]any{
|
||||
"version": map[string]any{"type": "string"},
|
||||
"issuer": map[string]any{"type": "string", "format": "uri"},
|
||||
"oidc": map[string]any{"type": "object"},
|
||||
"iam": map[string]any{"type": "object"},
|
||||
"trusted_headers": map[string]any{"type": "object"},
|
||||
"identity_contract": map[string]any{"type": "object"},
|
||||
"object_lookup": map[string]any{"type": "object"},
|
||||
"object_lookup_flow": map[string]any{"type": "object"},
|
||||
"external_key_flow": map[string]any{"type": "object"},
|
||||
"audit_contract": map[string]any{"type": "object"},
|
||||
"security_requirements": map[string]any{"type": "object"},
|
||||
},
|
||||
}
|
||||
}
|
||||
125
baron-sso/backend/internal/handler/rp_manifest_handler_test.go
Normal file
125
baron-sso/backend/internal/handler/rp_manifest_handler_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRPManifestJSONIncludesIAMAndExternalKeyContract(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
app := fiber.New()
|
||||
h := NewRPManifestHandler()
|
||||
app.Get("/.well-known/baron-rp-manifest.json", h.GetJSON)
|
||||
|
||||
req := httptest.NewRequest("GET", "/.well-known/baron-rp-manifest.json", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "sso.hmac.kr")
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Contains(t, resp.Header.Get("Content-Type"), "application/json")
|
||||
|
||||
var body map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, "https://sso.hmac.kr", body["issuer"])
|
||||
|
||||
oidc := body["oidc"].(map[string]any)
|
||||
require.Equal(t, "https://sso.hmac.kr/.well-known/openid-configuration", oidc["discovery_url"])
|
||||
require.Equal(t, "https://sso.hmac.kr/.well-known/jwks.json", oidc["jwks_url"])
|
||||
|
||||
iam := body["iam"].(map[string]any)
|
||||
require.Equal(t, "ory-keto", iam["authorization_engine"])
|
||||
require.Equal(t, "User:<baron_identity_id>", iam["subject_format"])
|
||||
require.Contains(t, iam["target_object_patterns"].([]any), "RelyingParty:<client_id>")
|
||||
require.Contains(t, iam["target_object_patterns"].([]any), "Tenant:<tenant_id>")
|
||||
require.Contains(t, iam["target_object_patterns"].([]any), "Resource:<resource_type>:<resource_id>")
|
||||
|
||||
identity := body["identity_contract"].(map[string]any)
|
||||
require.Equal(t, "X-Baron-External-Key", identity["external_key_header"])
|
||||
require.Equal(t, true, identity["external_key_is_opaque"])
|
||||
require.Equal(t, false, identity["raw_kratos_identity_id_exposed"])
|
||||
require.Equal(t, "baron", identity["external_key_issuer"])
|
||||
require.Equal(t, "baron_injected_header", identity["external_key_delivery"])
|
||||
require.Equal(t, false, identity["rp_supplied_external_key_allowed"])
|
||||
require.Equal(t, "rp_must_upsert_from_header_value", identity["rp_user_upsert_source"])
|
||||
|
||||
headers := body["trusted_headers"].(map[string]any)
|
||||
require.Equal(t, "X-Baron-Subject", headers["subject"])
|
||||
require.Equal(t, "X-Baron-External-Key", headers["external_key"])
|
||||
require.Equal(t, "X-Baron-Client-ID", headers["client_id"])
|
||||
|
||||
security := body["security_requirements"].(map[string]any)
|
||||
require.Equal(t, true, security["strip_external_identity_headers"])
|
||||
require.Equal(t, false, security["backend_direct_exposure_allowed"])
|
||||
|
||||
audit := body["audit_contract"].(map[string]any)
|
||||
require.Equal(t, "fail_closed_sync", audit["mutating_command_mode"])
|
||||
require.Equal(t, "reject_mutation", audit["missing_audit_sink_behavior"])
|
||||
require.Equal(t, "X-Request-Id", audit["correlation_header"])
|
||||
require.Contains(t, audit["required_detail_fields"].([]any), "obj_id")
|
||||
require.Contains(t, audit["required_detail_fields"].([]any), "client_id")
|
||||
|
||||
flow := body["object_lookup_flow"].(map[string]any)
|
||||
require.Contains(t, flow["mermaid"].(string), "flowchart TD")
|
||||
require.Contains(t, flow["mermaid"].(string), "obj_id")
|
||||
|
||||
aliasFlow := body["external_key_flow"].(map[string]any)
|
||||
require.Contains(t, aliasFlow["mermaid"].(string), "Baron resolves internal identity")
|
||||
require.Contains(t, aliasFlow["mermaid"].(string), "Baron injects X-Baron-External-Key")
|
||||
require.Contains(t, aliasFlow["mermaid"].(string), "RP upserts local user")
|
||||
require.NotContains(t, aliasFlow["mermaid"].(string), "RP creates external key")
|
||||
}
|
||||
|
||||
func TestRPManifestSchemaRequiresLookupAndIdentityContracts(t *testing.T) {
|
||||
app := fiber.New()
|
||||
h := NewRPManifestHandler()
|
||||
app.Get("/.well-known/baron-rp-manifest.schema.json", h.GetSchema)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/.well-known/baron-rp-manifest.schema.json", nil))
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
var body map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
|
||||
required := body["required"].([]any)
|
||||
require.Contains(t, required, "iam")
|
||||
require.Contains(t, required, "trusted_headers")
|
||||
require.Contains(t, required, "identity_contract")
|
||||
require.Contains(t, required, "object_lookup")
|
||||
require.Contains(t, required, "audit_contract")
|
||||
require.Contains(t, required, "object_lookup_flow")
|
||||
require.Contains(t, required, "external_key_flow")
|
||||
}
|
||||
|
||||
func TestRPManifestHTMLLinksMachineReadableManifest(t *testing.T) {
|
||||
app := fiber.New()
|
||||
h := NewRPManifestHandler()
|
||||
app.Get("/.well-known/baron-rp-manifest", h.GetHTML)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/.well-known/baron-rp-manifest", nil))
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Contains(t, resp.Header.Get("Content-Type"), "text/html")
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
text := string(raw)
|
||||
require.Contains(t, text, "/.well-known/baron-rp-manifest.json")
|
||||
require.Contains(t, text, "X-Baron-External-Key")
|
||||
require.Contains(t, text, "RelyingParty:<client_id>")
|
||||
require.Contains(t, text, "```mermaid")
|
||||
require.Contains(t, text, "audit_contract")
|
||||
require.Contains(t, text, "Baron-issued alias")
|
||||
require.Contains(t, text, "RP upserts local user")
|
||||
}
|
||||
154
baron-sso/backend/internal/handler/tenant_access_cleanup.go
Normal file
154
baron-sso/backend/internal/handler/tenant_access_cleanup.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const tenantAccessCleanupClientPageSize = 500
|
||||
|
||||
func cleanupDeletedTenantReferences(ctx context.Context, hydra *service.HydraAdminService, consentRepo repository.ClientConsentRepository, ketoOutbox repository.KetoOutboxRepository, deletedTenantIDs []string) error {
|
||||
if hydra == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
deletedTenantSet := make(map[string]struct{}, len(deletedTenantIDs))
|
||||
for _, tenantID := range deletedTenantIDs {
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
continue
|
||||
}
|
||||
deletedTenantSet[tenantID] = struct{}{}
|
||||
}
|
||||
if len(deletedTenantSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for offset := 0; ; offset += tenantAccessCleanupClientPageSize {
|
||||
clients, err := hydra.ListClients(ctx, tenantAccessCleanupClientPageSize, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list hydra clients for tenant cleanup: %w", err)
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
beforeMetadata := maps.Clone(client.Metadata)
|
||||
updatedMetadata, changed, removedOwnerTenantID := pruneDeletedTenantReferences(beforeMetadata, deletedTenantSet)
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
|
||||
updatedClient := client
|
||||
updatedClient.Metadata = updatedMetadata
|
||||
if _, err := hydra.UpdateClient(ctx, client.ClientID, updatedClient); err != nil {
|
||||
return fmt.Errorf("failed to update hydra client %s during tenant cleanup: %w", client.ClientID, err)
|
||||
}
|
||||
if removedOwnerTenantID != "" {
|
||||
if err := enqueueDeletedTenantRelyingPartyParentCleanup(ctx, ketoOutbox, client.ClientID, removedOwnerTenantID); err != nil {
|
||||
return fmt.Errorf("failed to cleanup RP parent relation for client %s during tenant cleanup: %w", client.ClientID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if tenantAccessPolicyChanged(beforeMetadata, updatedMetadata) {
|
||||
if err := revokeClientConsentsForPolicyChange(ctx, hydra, consentRepo, client.ClientID); err != nil {
|
||||
return fmt.Errorf("failed to revoke consent sessions for client %s during tenant cleanup: %w", client.ClientID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(clients) < tenantAccessCleanupClientPageSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func pruneDeletedTenantReferences(metadata map[string]any, deletedTenantSet map[string]struct{}) (map[string]any, bool, string) {
|
||||
if len(deletedTenantSet) == 0 {
|
||||
return metadata, false, ""
|
||||
}
|
||||
|
||||
ownerTenantID := normalizeMetadataString(metadata["tenant_id"])
|
||||
_, ownerDeleted := deletedTenantSet[ownerTenantID]
|
||||
|
||||
allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])
|
||||
filtered := make([]string, 0, len(allowedTenants))
|
||||
for _, tenantID := range allowedTenants {
|
||||
if _, ok := deletedTenantSet[tenantID]; ok {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, tenantID)
|
||||
}
|
||||
allowedChanged := len(filtered) != len(allowedTenants)
|
||||
|
||||
if !ownerDeleted && !allowedChanged {
|
||||
return metadata, false, ""
|
||||
}
|
||||
|
||||
updated := maps.Clone(metadata)
|
||||
if ownerDeleted {
|
||||
delete(updated, "tenant_id")
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
delete(updated, clientAllowedTenantsKey)
|
||||
updated[clientTenantAccessRestrictedKey] = false
|
||||
return updated, true, ownerTenantID
|
||||
}
|
||||
|
||||
updated[clientAllowedTenantsKey] = uniqueSortedStrings(filtered)
|
||||
updated[clientTenantAccessRestrictedKey] = true
|
||||
return updated, true, ownerTenantID
|
||||
}
|
||||
|
||||
func enqueueDeletedTenantRelyingPartyParentCleanup(ctx context.Context, ketoOutbox repository.KetoOutboxRepository, clientID, tenantID string) error {
|
||||
if ketoOutbox == nil {
|
||||
return nil
|
||||
}
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if clientID == "" || tenantID == "" {
|
||||
return nil
|
||||
}
|
||||
return ketoOutbox.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "RelyingParty",
|
||||
Object: clientID,
|
||||
Relation: "parents",
|
||||
Subject: "Tenant:" + tenantID,
|
||||
Action: domain.KetoOutboxActionDelete,
|
||||
})
|
||||
}
|
||||
|
||||
func revokeClientConsentsForPolicyChange(ctx context.Context, hydra *service.HydraAdminService, consentRepo repository.ClientConsentRepository, clientID string) error {
|
||||
if consentRepo == nil || hydra == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
subjects, err := consentRepo.ListSubjectsByClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, subject := range subjects {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
continue
|
||||
}
|
||||
if err := hydra.RevokeConsentSessions(ctx, subject, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return consentRepo.DeleteByClient(ctx, clientID)
|
||||
}
|
||||
|
||||
func logTenantCleanupFailure(err error, deletedTenantIDs []string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
slog.Error("Failed to cleanup RP tenant restrictions after tenant deletion", "tenant_ids", deletedTenantIDs, "error", err)
|
||||
}
|
||||
178
baron-sso/backend/internal/handler/tenant_access_cleanup_test.go
Normal file
178
baron-sso/backend/internal/handler/tenant_access_cleanup_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestPruneDeletedTenantReferences_PreservesOtherAllowedTenants(t *testing.T) {
|
||||
metadata := map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []string{"keep-tenant", "deleted-tenant"},
|
||||
"tenant_id": "deleted-tenant",
|
||||
}
|
||||
|
||||
updated, changed, removedOwnerTenantID := pruneDeletedTenantReferences(metadata, map[string]struct{}{
|
||||
"deleted-tenant": {},
|
||||
})
|
||||
|
||||
require.True(t, changed)
|
||||
assert.Equal(t, "deleted-tenant", removedOwnerTenantID)
|
||||
assert.Equal(t, true, updated[clientTenantAccessRestrictedKey])
|
||||
assert.Equal(t, []string{"keep-tenant"}, updated[clientAllowedTenantsKey])
|
||||
_, exists := updated["tenant_id"]
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestPruneDeletedTenantReferences_DisablesRestrictionWhenLastTenantRemoved(t *testing.T) {
|
||||
metadata := map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []string{"deleted-tenant"},
|
||||
"tenant_id": "deleted-tenant",
|
||||
}
|
||||
|
||||
updated, changed, removedOwnerTenantID := pruneDeletedTenantReferences(metadata, map[string]struct{}{
|
||||
"deleted-tenant": {},
|
||||
})
|
||||
|
||||
require.True(t, changed)
|
||||
assert.Equal(t, "deleted-tenant", removedOwnerTenantID)
|
||||
assert.Equal(t, false, updated[clientTenantAccessRestrictedKey])
|
||||
_, exists := updated[clientAllowedTenantsKey]
|
||||
assert.False(t, exists)
|
||||
_, exists = updated["tenant_id"]
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestCleanupDeletedTenantReferences_PrunesClientsAndRevokesConsents(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
page0Called bool
|
||||
updated = map[string]map[string]any{}
|
||||
revokes []string
|
||||
)
|
||||
|
||||
transport := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
switch {
|
||||
case req.Method == http.MethodGet && req.URL.Path == "/clients":
|
||||
switch req.URL.Query().Get("offset") {
|
||||
case "":
|
||||
page0Called = true
|
||||
return httpJSONAny(req, http.StatusOK, []domain.HydraClient{
|
||||
{
|
||||
ClientID: "client-keep",
|
||||
Metadata: map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []string{"keep-tenant", "deleted-tenant"},
|
||||
"tenant_id": "deleted-tenant",
|
||||
},
|
||||
},
|
||||
{
|
||||
ClientID: "client-drop",
|
||||
Metadata: map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []string{"deleted-tenant"},
|
||||
"tenant_id": "deleted-tenant",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
default:
|
||||
return httpResponse(req, http.StatusBadRequest, "unexpected offset"), nil
|
||||
}
|
||||
|
||||
case req.Method == http.MethodPut && strings.HasPrefix(req.URL.Path, "/clients/"):
|
||||
var client domain.HydraClient
|
||||
require.NoError(t, json.NewDecoder(req.Body).Decode(&client))
|
||||
updated[client.ClientID] = client.Metadata
|
||||
return httpJSONAny(req, http.StatusOK, client), nil
|
||||
|
||||
case req.Method == http.MethodDelete && req.URL.Path == "/oauth2/auth/sessions/consent":
|
||||
revokes = append(revokes, req.URL.Query().Get("subject")+"|"+req.URL.Query().Get("client"))
|
||||
return httpResponse(req, http.StatusNoContent, ""), nil
|
||||
|
||||
default:
|
||||
return httpResponse(req, http.StatusNotFound, "unexpected request"), nil
|
||||
}
|
||||
})
|
||||
|
||||
hydra := &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
}
|
||||
consentRepo := &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{ClientID: "client-keep", Subject: "user-a"},
|
||||
{ClientID: "client-drop", Subject: "user-b"},
|
||||
},
|
||||
}
|
||||
outbox := &tenantCleanupMockKetoOutboxRepository{}
|
||||
|
||||
err := cleanupDeletedTenantReferences(context.Background(), hydra, consentRepo, outbox, []string{"deleted-tenant"})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, page0Called)
|
||||
assert.Equal(t, map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []any{"keep-tenant"},
|
||||
}, updated["client-keep"])
|
||||
assert.Equal(t, map[string]any{
|
||||
clientTenantAccessRestrictedKey: false,
|
||||
}, updated["client-drop"])
|
||||
assert.ElementsMatch(t, []string{"user-a|client-keep", "user-b|client-drop"}, revokes)
|
||||
assert.Empty(t, consentRepo.consents)
|
||||
require.Len(t, outbox.entries, 2)
|
||||
assert.ElementsMatch(t, []string{"client-keep", "client-drop"}, []string{outbox.entries[0].Object, outbox.entries[1].Object})
|
||||
for _, entry := range outbox.entries {
|
||||
assert.Equal(t, "RelyingParty", entry.Namespace)
|
||||
assert.Equal(t, "parents", entry.Relation)
|
||||
assert.Equal(t, "Tenant:deleted-tenant", entry.Subject)
|
||||
assert.Equal(t, domain.KetoOutboxActionDelete, entry.Action)
|
||||
}
|
||||
}
|
||||
|
||||
type tenantCleanupMockKetoOutboxRepository struct {
|
||||
entries []domain.KetoOutbox
|
||||
}
|
||||
|
||||
var _ repository.KetoOutboxRepository = (*tenantCleanupMockKetoOutboxRepository)(nil)
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) Create(ctx context.Context, entry *domain.KetoOutbox) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
m.entries = append(m.entries, *entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error {
|
||||
return m.Create(context.Background(), entry)
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
155
baron-sso/backend/internal/handler/tenant_assignment_policy.go
Normal file
155
baron-sso/backend/internal/handler/tenant_assignment_policy.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func representativeTenantIDFromTraits(traits map[string]any) string {
|
||||
if value := tenantClaimString(traits, "tenant_id"); value != "" {
|
||||
return value
|
||||
}
|
||||
if value := tenantClaimString(traits, "primaryTenantId"); value != "" {
|
||||
return value
|
||||
}
|
||||
if metadata, ok := traits["metadata"].(map[string]any); ok {
|
||||
if value := tenantClaimString(metadata, "primaryTenantId"); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
appointments := tenantAssignmentAppointmentsFromTraits(traits)
|
||||
for _, appointment := range appointments {
|
||||
if tenantAssignmentBool(appointment, "isPrimary", "primary", "representative", "isRepresentative") {
|
||||
if value := tenantAssignmentTenantID(appointment); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, appointment := range appointments {
|
||||
if value := tenantAssignmentTenantID(appointment); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
for _, tenantID := range tenantNamespaceIDsFromTraits(traits) {
|
||||
return tenantID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func joinedTenantIDsFromTraits(traits map[string]any, representativeTenantID string) []string {
|
||||
values := make([]string, 0)
|
||||
if representativeTenantID != "" {
|
||||
values = append(values, representativeTenantID)
|
||||
}
|
||||
if value := tenantClaimString(traits, "tenant_id"); value != "" {
|
||||
values = append(values, value)
|
||||
}
|
||||
for _, appointment := range tenantAssignmentAppointmentsFromTraits(traits) {
|
||||
if value := tenantAssignmentTenantID(appointment); value != "" {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
values = append(values, tenantNamespaceIDsFromTraits(traits)...)
|
||||
return uniqueSortedStrings(values)
|
||||
}
|
||||
|
||||
func tenantAssignmentAppointmentsFromTraits(traits map[string]any) []map[string]any {
|
||||
raw := rawAdditionalAppointments(traits)
|
||||
switch values := raw.(type) {
|
||||
case []any:
|
||||
appointments := make([]map[string]any, 0, len(values))
|
||||
for _, item := range values {
|
||||
if appointment, ok := item.(map[string]any); ok {
|
||||
appointments = append(appointments, appointment)
|
||||
}
|
||||
}
|
||||
return appointments
|
||||
case []map[string]any:
|
||||
return values
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func tenantAssignmentTenantID(appointment map[string]any) string {
|
||||
for _, key := range []string{"tenantId", "tenant_id"} {
|
||||
if value := tenantClaimString(appointment, key); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func tenantAssignmentBool(values map[string]any, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
raw, ok := values[key]
|
||||
if !ok || raw == nil {
|
||||
continue
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case bool:
|
||||
if value {
|
||||
return true
|
||||
}
|
||||
case string:
|
||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||
if normalized == "true" || normalized == "1" || normalized == "yes" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tenantNamespaceIDsFromTraits(traits map[string]any) []string {
|
||||
if traits == nil {
|
||||
return nil
|
||||
}
|
||||
ids := make([]string, 0)
|
||||
for key, value := range traits {
|
||||
if key == "" || key == "metadata" {
|
||||
continue
|
||||
}
|
||||
switch value.(type) {
|
||||
case map[string]any:
|
||||
ids = append(ids, key)
|
||||
}
|
||||
}
|
||||
return uniqueSortedStrings(ids)
|
||||
}
|
||||
|
||||
func createPersonalTenantForUser(ctx context.Context, tenantService service.TenantService, email string) (*domain.Tenant, error) {
|
||||
if tenantService == nil {
|
||||
return nil, errors.New("tenant service unavailable")
|
||||
}
|
||||
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalizedEmail == "" {
|
||||
normalizedEmail = "user"
|
||||
}
|
||||
slug := "personal-" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
tenant, err := tenantService.RegisterTenant(
|
||||
ctx,
|
||||
fmt.Sprintf("Personal - %s", normalizedEmail),
|
||||
slug,
|
||||
domain.TenantTypePersonal,
|
||||
"Automatically provisioned personal tenant",
|
||||
nil,
|
||||
nil,
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tenant == nil {
|
||||
return nil, errors.New("personal tenant not created")
|
||||
}
|
||||
return tenant, nil
|
||||
}
|
||||
3154
baron-sso/backend/internal/handler/tenant_handler.go
Normal file
3154
baron-sso/backend/internal/handler/tenant_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,159 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/testsupport"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
gorm_postgres "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func newTenantHandlerSeedDeleteDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
if !testsupport.DockerAvailable() {
|
||||
t.Skip("Docker provider is unavailable in this environment")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
postgresContainer, err := postgres_module.Run(ctx,
|
||||
"postgres:16-alpine",
|
||||
postgres_module.WithDatabase("testdb"),
|
||||
postgres_module.WithUsername("user"),
|
||||
postgres_module.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(30*time.Second)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start postgres container: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := postgresContainer.Terminate(ctx); err != nil {
|
||||
log.Printf("failed to terminate postgres container: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get postgres connection string: %v", err)
|
||||
}
|
||||
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open postgres connection: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&domain.Tenant{}); err != nil {
|
||||
t.Fatalf("failed to migrate tenants: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func setSeedTenantCSVForDeleteGuard(t *testing.T, slug string) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "seed-tenant.csv")
|
||||
csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" +
|
||||
"Protected,COMPANY_GROUP,," + slug + ",Protected seed,\n"
|
||||
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
|
||||
t.Fatalf("failed to write seed csv: %v", err)
|
||||
}
|
||||
t.Setenv("SEED_TENANT_CSV_PATH", path)
|
||||
}
|
||||
|
||||
func TestTenantHandlerDeleteTenantRejectsSeedTenant(t *testing.T) {
|
||||
setSeedTenantCSVForDeleteGuard(t, "protected-root")
|
||||
db := newTenantHandlerSeedDeleteDB(t)
|
||||
tenant := domain.Tenant{
|
||||
ID: "00000000-0000-0000-0000-000000000001",
|
||||
Name: "Protected",
|
||||
Slug: "protected-root",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
if err := db.Create(&tenant).Error; err != nil {
|
||||
t.Fatalf("failed to create tenant: %v", err)
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/tenants/:id", (&TenantHandler{DB: db}).DeleteTenant)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/tenants/"+tenant.ID, nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusConflict {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
|
||||
}
|
||||
var count int64
|
||||
if err := db.Model(&domain.Tenant{}).Where("id = ?", tenant.ID).Count(&count).Error; err != nil {
|
||||
t.Fatalf("count tenant: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("seed tenant count = %d, want 1", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantHandlerDeleteTenantsBulkRejectsSeedTenant(t *testing.T) {
|
||||
setSeedTenantCSVForDeleteGuard(t, "protected-root")
|
||||
db := newTenantHandlerSeedDeleteDB(t)
|
||||
seed := domain.Tenant{
|
||||
ID: "00000000-0000-0000-0000-000000000011",
|
||||
Name: "Protected",
|
||||
Slug: "protected-root",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
normal := domain.Tenant{
|
||||
ID: "00000000-0000-0000-0000-000000000012",
|
||||
Name: "Normal",
|
||||
Slug: "normal",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
if err := db.Create(&seed).Error; err != nil {
|
||||
t.Fatalf("failed to create seed tenant: %v", err)
|
||||
}
|
||||
if err := db.Create(&normal).Error; err != nil {
|
||||
t.Fatalf("failed to create normal tenant: %v", err)
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Delete("/tenants/bulk", (&TenantHandler{DB: db}).DeleteTenantsBulk)
|
||||
body, _ := json.Marshal(map[string][]string{"ids": {seed.ID, normal.ID}})
|
||||
req := httptest.NewRequest(http.MethodDelete, "/tenants/bulk", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusConflict {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
|
||||
}
|
||||
var count int64
|
||||
if err := db.Model(&domain.Tenant{}).Where("id IN ?", []string{seed.ID, normal.ID}).Count(&count).Error; err != nil {
|
||||
t.Fatalf("count tenants: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatalf("remaining tenant count = %d, want 2", count)
|
||||
}
|
||||
}
|
||||
1732
baron-sso/backend/internal/handler/tenant_handler_test.go
Normal file
1732
baron-sso/backend/internal/handler/tenant_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
133
baron-sso/backend/internal/handler/user_group_handler.go
Normal file
133
baron-sso/backend/internal/handler/user_group_handler.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type UserGroupHandler struct {
|
||||
Service service.UserGroupService
|
||||
}
|
||||
|
||||
func NewUserGroupHandler(s service.UserGroupService) *UserGroupHandler {
|
||||
return &UserGroupHandler{Service: s}
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) List(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
groups, err := h.Service.List(c.Context(), tenantID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(groups)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) Create(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
var req domain.GroupCreateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid body"})
|
||||
}
|
||||
|
||||
group, err := h.Service.Create(c.Context(), tenantID, req.ParentID, req.Name, req.Description, req.UnitType)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.Status(fiber.StatusCreated).JSON(group)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) Get(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
group, err := h.Service.Get(c.Context(), id)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to get group: "+err.Error())
|
||||
}
|
||||
return c.JSON(group)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) Update(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
groupID := c.Params("id")
|
||||
var req domain.GroupCreateRequest // Using create request for update fields
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid body"})
|
||||
}
|
||||
|
||||
group, err := h.Service.Update(c.Context(), tenantID, groupID, req.Name, req.Description, req.UnitType, req.ParentID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(group)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) Delete(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
groupID := c.Params("id")
|
||||
if err := h.Service.Delete(c.Context(), tenantID, groupID); err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) AddMember(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
var req struct {
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "userId is required")
|
||||
}
|
||||
|
||||
if err := h.Service.AddMember(c.Context(), groupID, req.UserID); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) RemoveMember(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
userID := c.Params("userId")
|
||||
|
||||
if err := h.Service.RemoveMember(c.Context(), groupID, userID); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) AssignRole(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
var req struct {
|
||||
TenantID string `json:"tenantId"`
|
||||
Relation string `json:"relation"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid body")
|
||||
}
|
||||
|
||||
if err := h.Service.AssignRoleToTenant(c.Context(), groupID, req.TenantID, req.Relation); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) ListRoles(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
roles, err := h.Service.ListRoles(c.Context(), groupID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(roles)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) RemoveRole(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
tenantID := c.Params("tenantId")
|
||||
relation := c.Params("relation")
|
||||
|
||||
if err := h.Service.RemoveRoleFromTenant(c.Context(), groupID, tenantID, relation); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
155
baron-sso/backend/internal/handler/user_group_handler_test.go
Normal file
155
baron-sso/backend/internal/handler/user_group_handler_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// --- Mocks ---
|
||||
|
||||
type MockUserGroupService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error) {
|
||||
args := m.Called(ctx, tenantID, parentID, name, description, unitType)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error) {
|
||||
args := m.Called(ctx, tenantID, groupID, name, description, unitType, parentID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) Delete(ctx context.Context, tenantID, groupID string) error {
|
||||
return m.Called(ctx, tenantID, groupID).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) Get(ctx context.Context, id string) (*domain.UserGroup, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) List(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
|
||||
args := m.Called(ctx, tenantID)
|
||||
return args.Get(0).([]domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) SetWorksmobileSyncer(syncer service.WorksmobileSyncer) {}
|
||||
|
||||
func (m *MockUserGroupService) AddMember(ctx context.Context, groupID, userID string) error {
|
||||
return m.Called(ctx, groupID, userID).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) RemoveMember(ctx context.Context, groupID, userID string) error {
|
||||
return m.Called(ctx, groupID, userID).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) ListRoles(ctx context.Context, groupID string) ([]domain.GroupRole, error) {
|
||||
args := m.Called(ctx, groupID)
|
||||
return args.Get(0).([]domain.GroupRole), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error {
|
||||
return m.Called(ctx, groupID, tenantID, relation).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) RemoveRoleFromTenant(ctx context.Context, groupID, tenantID, relation string) error {
|
||||
return m.Called(ctx, groupID, tenantID, relation).Error(0)
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestUserGroupHandler_List(t *testing.T) {
|
||||
mockSvc := new(MockUserGroupService)
|
||||
h := NewUserGroupHandler(mockSvc)
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/user-groups", h.List)
|
||||
|
||||
tenantID := "t1"
|
||||
groups := []domain.UserGroup{{ID: "g1", Name: "Group 1"}}
|
||||
mockSvc.On("List", mock.Anything, tenantID).Return(groups, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/tenants/t1/user-groups", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var result []domain.UserGroup
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "Group 1", result[0].Name)
|
||||
}
|
||||
|
||||
func TestUserGroupHandler_Create(t *testing.T) {
|
||||
mockSvc := new(MockUserGroupService)
|
||||
h := NewUserGroupHandler(mockSvc)
|
||||
app := fiber.New()
|
||||
app.Post("/tenants/:tenantId/user-groups", h.Create)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "New Group"})
|
||||
mockSvc.On("Create", mock.Anything, "t1", mock.Anything, "New Group", mock.Anything, mock.Anything).Return(&domain.UserGroup{ID: "g1", Name: "New Group"}, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/tenants/t1/user-groups", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestUserGroupHandler_AddMember(t *testing.T) {
|
||||
mockSvc := new(MockUserGroupService)
|
||||
h := NewUserGroupHandler(mockSvc)
|
||||
app := fiber.New()
|
||||
app.Post("/user-groups/:id/members", h.AddMember)
|
||||
|
||||
groupID := "g1"
|
||||
userID := "u1"
|
||||
body, _ := json.Marshal(map[string]string{"userId": userID})
|
||||
|
||||
mockSvc.On("AddMember", mock.Anything, groupID, userID).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/user-groups/g1/members", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestUserGroupHandler_AssignRole(t *testing.T) {
|
||||
mockSvc := new(MockUserGroupService)
|
||||
h := NewUserGroupHandler(mockSvc)
|
||||
app := fiber.New()
|
||||
app.Post("/user-groups/:id/roles", h.AssignRole)
|
||||
|
||||
groupID := "g1"
|
||||
targetTenantID := "t2"
|
||||
relation := "manage"
|
||||
body, _ := json.Marshal(map[string]string{"tenantId": targetTenantID, "relation": relation})
|
||||
|
||||
mockSvc.On("AssignRoleToTenant", mock.Anything, groupID, targetTenantID, relation).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/user-groups/g1/roles", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
3702
baron-sso/backend/internal/handler/user_handler.go
Normal file
3702
baron-sso/backend/internal/handler/user_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
2707
baron-sso/backend/internal/handler/user_handler_test.go
Normal file
2707
baron-sso/backend/internal/handler/user_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
68
baron-sso/backend/internal/handler/user_handler_uuid_test.go
Normal file
68
baron-sso/backend/internal/handler/user_handler_uuid_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestUserHandler_BulkCreateUsers_UUIDImportPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
}{
|
||||
{name: "id 필드 차단", field: "id"},
|
||||
{name: "uuid 필드 차단", field: "uuid"},
|
||||
{name: "userId 필드 차단", field: "userId"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockKratos := new(MockKratosAdmin)
|
||||
mockOry := new(MockOryProvider)
|
||||
|
||||
h := &UserHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
OryProvider: mockOry,
|
||||
}
|
||||
app.Post("/users/bulk", h.BulkCreateUsers)
|
||||
|
||||
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
|
||||
|
||||
payload := map[string]any{
|
||||
"users": []map[string]any{
|
||||
{
|
||||
"email": "uuid-import@test.com",
|
||||
"name": "UUID Import User",
|
||||
tt.field: "550e8400-e29b-41d4-a716-446655440000",
|
||||
"metadata": map[string]any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest("POST", "/users/bulk", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var result struct {
|
||||
Results []bulkUserResult `json:"results"`
|
||||
}
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
|
||||
assert.Len(t, result.Results, 1)
|
||||
assert.False(t, result.Results[0].Success)
|
||||
assert.Contains(t, result.Results[0].Message, "사용자 UUID 가져오기는 지원하지 않습니다")
|
||||
|
||||
mockOry.AssertExpectations(t)
|
||||
mockKratos.AssertNotCalled(t, "FindIdentityIDByIdentifier", mock.Anything, mock.Anything)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func markUserProjectionFailed(ctx context.Context, repo repository.UserProjectionRepository, syncErr error) {
|
||||
if repo == nil || syncErr == nil {
|
||||
return
|
||||
}
|
||||
if err := repo.MarkFailed(ctx, syncErr); err != nil {
|
||||
slog.Error("Failed to mark user projection as failed", "syncError", syncErr, "error", err)
|
||||
}
|
||||
}
|
||||
217
baron-sso/backend/internal/handler/worksmobile_handler.go
Normal file
217
baron-sso/backend/internal/handler/worksmobile_handler.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type WorksmobileHandler struct {
|
||||
Service service.WorksmobileAdminService
|
||||
}
|
||||
|
||||
func NewWorksmobileHandler(svc service.WorksmobileAdminService) *WorksmobileHandler {
|
||||
return &WorksmobileHandler{Service: svc}
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) GetOverview(c *fiber.Ctx) error {
|
||||
overview, err := h.Service.GetTenantOverview(c.Context(), strings.TrimSpace(c.Params("tenantId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "get_overview")
|
||||
}
|
||||
if !worksmobileOverviewAllowed(overview) {
|
||||
return errorJSON(c, fiber.StatusNotFound, "worksmobile is only available for hanmac-family root tenant")
|
||||
}
|
||||
return c.JSON(overview)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) GetComparison(c *fiber.Ctx) error {
|
||||
includeMatched := strings.EqualFold(strings.TrimSpace(c.Query("includeMatched")), "true")
|
||||
comparison, err := h.Service.GetComparison(c.Context(), strings.TrimSpace(c.Params("tenantId")), includeMatched)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "get_comparison")
|
||||
}
|
||||
return c.JSON(comparison)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) OAuthCallback(c *fiber.Ctx) error {
|
||||
return c.Type("html").SendString("<!doctype html><html><body>Worksmobile OAuth callback reachable</body></html>")
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) BackfillDryRun(c *fiber.Ctx) error {
|
||||
result, err := h.Service.EnqueueBackfillDryRun(c.Context(), strings.TrimSpace(c.Params("tenantId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "backfill_dry_run")
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) SyncOrgUnit(c *fiber.Ctx) error {
|
||||
orgUnitID := strings.TrimSpace(c.Params("orgUnitId"))
|
||||
job, err := h.Service.EnqueueOrgUnitSync(c.Context(), strings.TrimSpace(c.Params("tenantId")), orgUnitID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "sync_orgunit", "org_unit_id", orgUnitID)
|
||||
}
|
||||
return c.Status(fiber.StatusAccepted).JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) DeleteOrgUnit(c *fiber.Ctx) error {
|
||||
orgUnitID := strings.TrimSpace(c.Params("orgUnitId"))
|
||||
job, err := h.Service.EnqueueOrgUnitDelete(c.Context(), strings.TrimSpace(c.Params("tenantId")), orgUnitID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "delete_orgunit", "org_unit_id", orgUnitID)
|
||||
}
|
||||
return c.Status(fiber.StatusAccepted).JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) SyncUser(c *fiber.Ctx) error {
|
||||
userID := strings.TrimSpace(c.Params("userId"))
|
||||
credentialRequest, err := parseWorksmobileCredentialRequest(c)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
job, err := h.Service.EnqueueUserSync(
|
||||
c.Context(),
|
||||
strings.TrimSpace(c.Params("tenantId")),
|
||||
userID,
|
||||
credentialRequest.CredentialBatchID,
|
||||
credentialRequest.InitialPassword,
|
||||
)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "sync_user", "user_id", userID)
|
||||
}
|
||||
return c.Status(fiber.StatusAccepted).JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) ResetUserPassword(c *fiber.Ctx) error {
|
||||
userID := strings.TrimSpace(c.Params("userId"))
|
||||
credentialBatchID, err := parseWorksmobileCredentialBatchID(c)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
job, err := h.Service.EnqueueUserPasswordReset(c.Context(), strings.TrimSpace(c.Params("tenantId")), userID, credentialBatchID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "reset_user_password", "user_id", userID)
|
||||
}
|
||||
return c.Status(fiber.StatusAccepted).JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) RetryJob(c *fiber.Ctx) error {
|
||||
jobID := strings.TrimSpace(c.Params("jobId"))
|
||||
job, err := h.Service.RetryJob(c.Context(), strings.TrimSpace(c.Params("tenantId")), jobID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "retry_job", "job_id", jobID)
|
||||
}
|
||||
return c.JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) DeletePendingJobs(c *fiber.Ctx) error {
|
||||
result, err := h.Service.DeletePendingJobs(c.Context(), strings.TrimSpace(c.Params("tenantId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "delete_pending_jobs")
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) DownloadInitialPasswordsCSV(c *fiber.Ctx) error {
|
||||
credentials, err := h.Service.ListInitialPasswordCredentials(c.Context(), strings.TrimSpace(c.Params("tenantId")), strings.TrimSpace(c.Query("batchId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "download_initial_passwords")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := csv.NewWriter(&buf)
|
||||
if err := writer.Write([]string{"email", "name", "primaryLeafOrgName", "initialPassword", "status", "lastError"}); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
for _, credential := range credentials {
|
||||
if err := writer.Write([]string{credential.Email, credential.Name, credential.PrimaryLeafOrgName, credential.InitialPassword, credential.Status, credential.LastError}); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
}
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
c.Set(fiber.HeaderContentType, "text/csv; charset=utf-8")
|
||||
c.Set(fiber.HeaderContentDisposition, `attachment; filename="worksmobile_initial_passwords.csv"`)
|
||||
return c.Send(buf.Bytes())
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) ListCredentialBatches(c *fiber.Ctx) error {
|
||||
batches, err := h.Service.ListCredentialBatches(c.Context(), strings.TrimSpace(c.Params("tenantId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "list_credential_batches")
|
||||
}
|
||||
return c.JSON(batches)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) DeleteCredentialBatchPasswords(c *fiber.Ctx) error {
|
||||
batchID := strings.TrimSpace(c.Params("batchId"))
|
||||
batch, err := h.Service.DeleteCredentialBatchPasswords(c.Context(), strings.TrimSpace(c.Params("tenantId")), batchID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "delete_credential_batch_passwords", "batch_id", batchID)
|
||||
}
|
||||
return c.JSON(batch)
|
||||
}
|
||||
|
||||
type worksmobileCredentialBatchRequest struct {
|
||||
CredentialBatchID string `json:"credentialBatchId"`
|
||||
InitialPassword string `json:"initialPassword"`
|
||||
}
|
||||
|
||||
func parseWorksmobileCredentialBatchID(c *fiber.Ctx) (string, error) {
|
||||
req, err := parseWorksmobileCredentialRequest(c)
|
||||
return req.CredentialBatchID, err
|
||||
}
|
||||
|
||||
func parseWorksmobileCredentialRequest(c *fiber.Ctx) (worksmobileCredentialBatchRequest, error) {
|
||||
batchID := strings.TrimSpace(c.Query("credentialBatchId"))
|
||||
req := worksmobileCredentialBatchRequest{CredentialBatchID: batchID}
|
||||
if len(bytes.TrimSpace(c.Body())) == 0 {
|
||||
return req, nil
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return worksmobileCredentialBatchRequest{}, err
|
||||
}
|
||||
req.InitialPassword = strings.TrimSpace(req.InitialPassword)
|
||||
if bodyBatchID := strings.TrimSpace(req.CredentialBatchID); bodyBatchID != "" {
|
||||
req.CredentialBatchID = bodyBatchID
|
||||
return req, nil
|
||||
}
|
||||
req.CredentialBatchID = batchID
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func worksmobileOverviewAllowed(overview service.WorksmobileTenantOverview) bool {
|
||||
return overview.Tenant.Slug == service.HanmacFamilyTenantSlug && overview.Tenant.ParentID == nil
|
||||
}
|
||||
|
||||
func worksmobileGuardError(c *fiber.Ctx, err error, operation string, attrs ...any) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
logAttrs := []any{
|
||||
"operation", operation,
|
||||
"tenant_id", strings.TrimSpace(c.Params("tenantId")),
|
||||
"path", c.Path(),
|
||||
"error", err,
|
||||
}
|
||||
logAttrs = append(logAttrs, attrs...)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Warn("worksmobile admin operation failed", logAttrs...)
|
||||
return errorJSON(c, fiber.StatusRequestTimeout, err.Error())
|
||||
}
|
||||
slog.Error("worksmobile admin operation failed", logAttrs...)
|
||||
if strings.Contains(err.Error(), "hanmac-family root") {
|
||||
return errorJSON(c, fiber.StatusNotFound, err.Error())
|
||||
}
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
267
baron-sso/backend/internal/handler/worksmobile_handler_test.go
Normal file
267
baron-sso/backend/internal/handler/worksmobile_handler_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWorksmobileHandlerRejectsNonHanmacTenant(t *testing.T) {
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
overview: service.WorksmobileTenantOverview{
|
||||
Tenant: domain.Tenant{ID: "tenant-1", Slug: "other"},
|
||||
},
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile", h.GetOverview)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/tenant-1/worksmobile", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerReturnsOverviewForHanmacTenant(t *testing.T) {
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
overview: service.WorksmobileTenantOverview{
|
||||
Tenant: domain.Tenant{ID: "hanmac-id", Slug: "hanmac-family"},
|
||||
Config: service.WorksmobileConfigSummary{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile", h.GetOverview)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerDownloadsInitialPasswordCSV(t *testing.T) {
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
credentials: []service.WorksmobileInitialPasswordCredential{
|
||||
{
|
||||
Email: "user@hanmaceng.co.kr",
|
||||
Name: "홍길동",
|
||||
PrimaryLeafOrgName: "인재성장",
|
||||
InitialPassword: "Aa1!Aa1!Aa1!Aa1!",
|
||||
Status: "processed",
|
||||
},
|
||||
},
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile/initial-passwords.csv", h.DownloadInitialPasswordsCSV)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/initial-passwords.csv", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Contains(t, resp.Header.Get("Content-Disposition"), "worksmobile_initial_passwords.csv")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), "email,name,primaryLeafOrgName,initialPassword,status,lastError")
|
||||
require.Contains(t, string(body), "user@hanmaceng.co.kr,홍길동,인재성장,Aa1!Aa1!Aa1!Aa1!,processed,")
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerPassesInitialPasswordBatchID(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{
|
||||
credentials: []service.WorksmobileInitialPasswordCredential{
|
||||
{Email: "batch-user@hanmaceng.co.kr", InitialPassword: "BatchPass1!", Status: "pending"},
|
||||
},
|
||||
}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile/initial-passwords.csv", h.DownloadInitialPasswordsCSV)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/initial-passwords.csv?batchId=batch-1", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "batch-1", fakeService.downloadCredentialBatchID)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerPassesSyncUserCredentialBatchID(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
|
||||
|
||||
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", strings.NewReader(`{"credentialBatchId":"batch-1","initialPassword":"InputPass1!"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
|
||||
require.Equal(t, "batch-1", fakeService.syncUserCredentialBatchID)
|
||||
require.Equal(t, "InputPass1!", fakeService.syncUserInitialPassword)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerPassesPasswordResetCredentialBatchID(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Post("/tenants/:tenantId/worksmobile/users/:userId/password/reset", h.ResetUserPassword)
|
||||
|
||||
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/password/reset", strings.NewReader(`{"credentialBatchId":"batch-1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
|
||||
require.Equal(t, "batch-1", fakeService.resetPasswordCredentialBatchID)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerReturnsCredentialBatchHistory(t *testing.T) {
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
credentialBatches: []service.WorksmobileCredentialBatch{
|
||||
{BatchID: "batch-1", UserCount: 2, HasPasswords: true},
|
||||
},
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile/credential-batches", h.ListCredentialBatches)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/credential-batches", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), `"batchId":"batch-1"`)
|
||||
require.Contains(t, string(body), `"userCount":2`)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerDeletesCredentialBatchPasswords(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Delete("/tenants/:tenantId/worksmobile/credential-batches/:batchId/passwords", h.DeleteCredentialBatchPasswords)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("DELETE", "/tenants/hanmac-id/worksmobile/credential-batches/batch-1/passwords", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "batch-1", fakeService.deletedCredentialBatchID)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerDeletesPendingJobs(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{
|
||||
pendingJobsDeleteResult: service.WorksmobilePendingJobDeleteResult{DeletedCount: 3},
|
||||
}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Delete("/tenants/:tenantId/worksmobile/jobs/pending", h.DeletePendingJobs)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("DELETE", "/tenants/hanmac-id/worksmobile/jobs/pending", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "hanmac-id", fakeService.deletedPendingJobsTenantID)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), `"deletedCount":3`)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerLogsActionFailures(t *testing.T) {
|
||||
var logs bytes.Buffer
|
||||
previous := slog.Default()
|
||||
slog.SetDefault(slog.New(slog.NewJSONHandler(&logs, nil)))
|
||||
t.Cleanup(func() {
|
||||
slog.SetDefault(previous)
|
||||
})
|
||||
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
syncUserErr: errors.New("works user sync failed"),
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
require.Contains(t, logs.String(), "worksmobile admin operation failed")
|
||||
require.Contains(t, logs.String(), "sync_user")
|
||||
require.Contains(t, logs.String(), "works user sync failed")
|
||||
}
|
||||
|
||||
type fakeWorksmobileAdminService struct {
|
||||
overview service.WorksmobileTenantOverview
|
||||
credentials []service.WorksmobileInitialPasswordCredential
|
||||
syncUserErr error
|
||||
syncUserCredentialBatchID string
|
||||
syncUserInitialPassword string
|
||||
resetPasswordCredentialBatchID string
|
||||
downloadCredentialBatchID string
|
||||
deletedCredentialBatchID string
|
||||
deletedPendingJobsTenantID string
|
||||
pendingJobsDeleteResult service.WorksmobilePendingJobDeleteResult
|
||||
credentialBatches []service.WorksmobileCredentialBatch
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) GetTenantOverview(ctx context.Context, tenantID string) (service.WorksmobileTenantOverview, error) {
|
||||
return f.overview, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) GetComparison(ctx context.Context, tenantID string, includeMatched bool) (service.WorksmobileComparison, error) {
|
||||
return service.WorksmobileComparison{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueBackfillDryRun(ctx context.Context, tenantID string) (service.WorksmobileBackfillDryRun, error) {
|
||||
return service.WorksmobileBackfillDryRun{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueOrgUnitSync(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error) {
|
||||
return &domain.WorksmobileOutbox{ID: "job-orgunit", ResourceID: orgUnitID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueOrgUnitDelete(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error) {
|
||||
return &domain.WorksmobileOutbox{ID: "job-orgunit-delete", ResourceID: orgUnitID, Action: domain.WorksmobileActionDelete}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID, initialPassword string) (*domain.WorksmobileOutbox, error) {
|
||||
f.syncUserCredentialBatchID = credentialBatchID
|
||||
f.syncUserInitialPassword = initialPassword
|
||||
if f.syncUserErr != nil {
|
||||
return nil, f.syncUserErr
|
||||
}
|
||||
return &domain.WorksmobileOutbox{ID: "job-user", ResourceID: userID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueUserPasswordReset(ctx context.Context, tenantID, userID, credentialBatchID string) (*domain.WorksmobileOutbox, error) {
|
||||
f.resetPasswordCredentialBatchID = credentialBatchID
|
||||
return &domain.WorksmobileOutbox{ID: "job-user-password-reset", ResourceID: userID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) RetryJob(ctx context.Context, tenantID, jobID string) (*domain.WorksmobileOutbox, error) {
|
||||
return &domain.WorksmobileOutbox{ID: jobID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) ListInitialPasswordCredentials(ctx context.Context, tenantID, credentialBatchID string) ([]service.WorksmobileInitialPasswordCredential, error) {
|
||||
f.downloadCredentialBatchID = credentialBatchID
|
||||
return f.credentials, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) ListCredentialBatches(ctx context.Context, tenantID string) ([]service.WorksmobileCredentialBatch, error) {
|
||||
return f.credentialBatches, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) DeleteCredentialBatchPasswords(ctx context.Context, tenantID, credentialBatchID string) (service.WorksmobileCredentialBatch, error) {
|
||||
f.deletedCredentialBatchID = credentialBatchID
|
||||
return service.WorksmobileCredentialBatch{BatchID: credentialBatchID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) DeletePendingJobs(ctx context.Context, tenantID string) (service.WorksmobilePendingJobDeleteResult, error) {
|
||||
f.deletedPendingJobsTenantID = tenantID
|
||||
return f.pendingJobsDeleteResult, nil
|
||||
}
|
||||
@@ -0,0 +1,560 @@
|
||||
package handlerregression
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/handler"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
type mockSecretRepo struct {
|
||||
secrets map[string]string
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) Upsert(ctx context.Context, clientID, secret string) error {
|
||||
if m.secrets == nil {
|
||||
m.secrets = make(map[string]string)
|
||||
}
|
||||
m.secrets[clientID] = secret
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) GetByID(ctx context.Context, clientID string) (string, error) {
|
||||
return m.secrets[clientID], nil
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) Delete(ctx context.Context, clientID string) error {
|
||||
delete(m.secrets, clientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockRedisRepo struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Set(key, value string, exp time.Duration) error {
|
||||
if m.data == nil {
|
||||
m.data = make(map[string]string)
|
||||
}
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Get(key string) (string, error) {
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("not found")
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Delete(key string) error {
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) StoreVerificationCode(p, c string) error { return nil }
|
||||
func (m *mockRedisRepo) GetVerificationCode(p string) (string, error) { return "", nil }
|
||||
func (m *mockRedisRepo) DeleteVerificationCode(p string) error { return nil }
|
||||
|
||||
func httpJSONAny(r *http.Request, code int, payload any) *http.Response {
|
||||
body, _ := json.Marshal(payload)
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Request: r,
|
||||
}
|
||||
}
|
||||
|
||||
func newDevHandlerApp(h *handler.DevHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "test-user", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/dev/clients", h.CreateClient)
|
||||
app.Get("/api/v1/dev/clients/:id", h.GetClient)
|
||||
app.Put("/api/v1/dev/clients/:id", h.UpdateClient)
|
||||
app.Post("/api/v1/dev/clients/:id/secret/rotate", h.RotateClientSecret)
|
||||
return app
|
||||
}
|
||||
|
||||
func decodeClientSecret(t *testing.T, resp *http.Response) string {
|
||||
t.Helper()
|
||||
|
||||
var payload struct {
|
||||
Client struct {
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
} `json:"client"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
return payload.Client.ClientSecret
|
||||
}
|
||||
|
||||
func TestGetClient_ClientSecretFallbackPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hydraClient map[string]any
|
||||
initialRedis map[string]string
|
||||
initialSecrets map[string]string
|
||||
expectedSecret string
|
||||
expectedRedisAfter string
|
||||
expectRedisAfterSet bool
|
||||
}{
|
||||
{
|
||||
name: "uses hydra client_secret directly",
|
||||
hydraClient: map[string]any{
|
||||
"client_id": "client-direct",
|
||||
"client_name": "Direct App",
|
||||
"client_secret": "hydra-secret",
|
||||
"redirect_uris": []string{"https://direct.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"metadata": map[string]any{"status": "active"},
|
||||
},
|
||||
expectedSecret: "hydra-secret",
|
||||
},
|
||||
{
|
||||
name: "falls back to metadata client_secret",
|
||||
hydraClient: map[string]any{
|
||||
"client_id": "client-metadata",
|
||||
"client_name": "Metadata App",
|
||||
"redirect_uris": []string{"https://metadata.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"metadata": map[string]any{
|
||||
"status": "active",
|
||||
"client_secret": "metadata-secret",
|
||||
},
|
||||
},
|
||||
expectedSecret: "metadata-secret",
|
||||
},
|
||||
{
|
||||
name: "falls back to redis cache",
|
||||
hydraClient: map[string]any{
|
||||
"client_id": "client-redis",
|
||||
"client_name": "Redis App",
|
||||
"redirect_uris": []string{"https://redis.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"metadata": map[string]any{"status": "active"},
|
||||
},
|
||||
initialRedis: map[string]string{"client_secret:client-redis": "redis-secret"},
|
||||
expectedSecret: "redis-secret",
|
||||
},
|
||||
{
|
||||
name: "falls back to postgres and warms redis",
|
||||
hydraClient: map[string]any{
|
||||
"client_id": "client-postgres",
|
||||
"client_name": "Postgres App",
|
||||
"redirect_uris": []string{"https://postgres.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"metadata": map[string]any{"status": "active"},
|
||||
},
|
||||
initialSecrets: map[string]string{"client-postgres": "postgres-secret"},
|
||||
expectedSecret: "postgres-secret",
|
||||
expectedRedisAfter: "postgres-secret",
|
||||
expectRedisAfterSet: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/clients/"+tt.hydraClient["client_id"].(string) {
|
||||
return httpJSONAny(r, http.StatusOK, tt.hydraClient), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
secretRepo := &mockSecretRepo{secrets: map[string]string{}}
|
||||
maps.Copy(secretRepo.secrets, tt.initialSecrets)
|
||||
|
||||
redisRepo := &mockRedisRepo{data: map[string]string{}}
|
||||
maps.Copy(redisRepo.data, tt.initialRedis)
|
||||
|
||||
h := &handler.DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
PublicURL: "http://hydra.public",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
SecretRepo: secretRepo,
|
||||
Redis: redisRepo,
|
||||
}
|
||||
|
||||
app := newDevHandlerApp(h)
|
||||
|
||||
clientID := tt.hydraClient["client_id"].(string)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/"+clientID, nil)
|
||||
resp, err := app.Test(req, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("get request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected get 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if secret := decodeClientSecret(t, resp); secret != tt.expectedSecret {
|
||||
t.Fatalf("expected secret %q, got %q", tt.expectedSecret, secret)
|
||||
}
|
||||
|
||||
if tt.expectRedisAfterSet {
|
||||
redisSecret, err := redisRepo.Get("client_secret:" + clientID)
|
||||
if err != nil {
|
||||
t.Fatalf("expected warmed redis secret, got error: %v", err)
|
||||
}
|
||||
if redisSecret != tt.expectedRedisAfter {
|
||||
t.Fatalf("expected warmed redis secret %q, got %q", tt.expectedRedisAfter, redisSecret)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateClient_PersistsSecretForLaterDetailFetch(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
|
||||
return httpJSONAny(r, http.StatusCreated, map[string]any{
|
||||
"client_id": "client-created",
|
||||
"client_name": "Created App",
|
||||
"client_secret": "created-secret",
|
||||
"redirect_uris": []string{"https://created.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"metadata": map[string]any{"status": "active"},
|
||||
}), nil
|
||||
}
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-created" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-created",
|
||||
"client_name": "Created App",
|
||||
"redirect_uris": []string{"https://created.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"metadata": map[string]any{"status": "active"},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
secretRepo := &mockSecretRepo{secrets: make(map[string]string)}
|
||||
redisRepo := &mockRedisRepo{data: make(map[string]string)}
|
||||
|
||||
h := &handler.DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
PublicURL: "http://hydra.public",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
SecretRepo: secretRepo,
|
||||
Redis: redisRepo,
|
||||
}
|
||||
|
||||
app := newDevHandlerApp(h)
|
||||
|
||||
createBody, _ := json.Marshal(map[string]any{
|
||||
"name": "Created App",
|
||||
"type": "private",
|
||||
"redirectUris": []string{"https://created.example.com/callback"},
|
||||
"scopes": []string{"openid", "profile"},
|
||||
})
|
||||
createReq := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(createBody))
|
||||
createReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
createResp, err := app.Test(createReq, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("expected create 201, got %d", createResp.StatusCode)
|
||||
}
|
||||
|
||||
if secret := decodeClientSecret(t, createResp); secret != "created-secret" {
|
||||
t.Fatalf("expected create secret created-secret, got %q", secret)
|
||||
}
|
||||
|
||||
storedSecret, _ := secretRepo.GetByID(context.Background(), "client-created")
|
||||
if storedSecret != "created-secret" {
|
||||
t.Fatalf("expected postgres secret created-secret, got %q", storedSecret)
|
||||
}
|
||||
|
||||
redisSecret, err := redisRepo.Get("client_secret:client-created")
|
||||
if err != nil {
|
||||
t.Fatalf("expected redis secret after create, got error: %v", err)
|
||||
}
|
||||
if redisSecret != "created-secret" {
|
||||
t.Fatalf("expected redis secret created-secret, got %q", redisSecret)
|
||||
}
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-created", nil)
|
||||
getResp, err := app.Test(getReq, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("get request failed: %v", err)
|
||||
}
|
||||
if getResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected get 200, got %d", getResp.StatusCode)
|
||||
}
|
||||
|
||||
if secret := decodeClientSecret(t, getResp); secret != "created-secret" {
|
||||
t.Fatalf("expected detail secret created-secret, got %q", secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotateClientSecret_PersistsForLaterDetailFetch(t *testing.T) {
|
||||
getCount := 0
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-rotate" {
|
||||
getCount++
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-rotate",
|
||||
"client_name": "Rotate App",
|
||||
"redirect_uris": []string{"https://rotate.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"metadata": map[string]any{"status": "active"},
|
||||
}), nil
|
||||
}
|
||||
if r.Method == http.MethodPut && r.URL.Path == "/clients/client-rotate" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-rotate",
|
||||
"client_name": "Rotate App",
|
||||
"redirect_uris": []string{"https://rotate.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"metadata": map[string]any{"status": "active"},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
secretRepo := &mockSecretRepo{secrets: make(map[string]string)}
|
||||
redisRepo := &mockRedisRepo{data: make(map[string]string)}
|
||||
|
||||
h := &handler.DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
PublicURL: "http://hydra.public",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
SecretRepo: secretRepo,
|
||||
Redis: redisRepo,
|
||||
}
|
||||
|
||||
app := newDevHandlerApp(h)
|
||||
|
||||
rotateReq := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients/client-rotate/secret/rotate", nil)
|
||||
rotateResp, err := app.Test(rotateReq, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("rotate request failed: %v", err)
|
||||
}
|
||||
if rotateResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected rotate 200, got %d", rotateResp.StatusCode)
|
||||
}
|
||||
|
||||
rotatedSecret := decodeClientSecret(t, rotateResp)
|
||||
if rotatedSecret == "" {
|
||||
t.Fatalf("expected rotated secret to be present")
|
||||
}
|
||||
|
||||
storedSecret, _ := secretRepo.GetByID(context.Background(), "client-rotate")
|
||||
if storedSecret != rotatedSecret {
|
||||
t.Fatalf("expected postgres secret %q, got %q", rotatedSecret, storedSecret)
|
||||
}
|
||||
|
||||
redisSecret, err := redisRepo.Get("client_secret:client-rotate")
|
||||
if err != nil {
|
||||
t.Fatalf("expected redis secret after rotate, got error: %v", err)
|
||||
}
|
||||
if redisSecret != rotatedSecret {
|
||||
t.Fatalf("expected redis secret %q, got %q", rotatedSecret, redisSecret)
|
||||
}
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-rotate", nil)
|
||||
getResp, err := app.Test(getReq, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("get request failed: %v", err)
|
||||
}
|
||||
if getResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected get 200, got %d", getResp.StatusCode)
|
||||
}
|
||||
|
||||
if secret := decodeClientSecret(t, getResp); secret != rotatedSecret {
|
||||
t.Fatalf("expected detail secret %q, got %q", rotatedSecret, secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateClient_HeadlessLoginSecretPersistsForLaterDetailFetch(t *testing.T) {
|
||||
getCount := 0
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/clients/client-headless-login" {
|
||||
getCount++
|
||||
if getCount == 1 {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-headless-login",
|
||||
"client_name": "Headless Login Before",
|
||||
"redirect_uris": []string{"https://before.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "none",
|
||||
"metadata": map[string]any{
|
||||
"status": "active",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-headless-login",
|
||||
"client_name": "Headless Login After",
|
||||
"redirect_uris": []string{"https://headless.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "private_key_jwt",
|
||||
"jwks_uri": "https://headless.example.com/jwks.json",
|
||||
"metadata": map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"request_object_signing_alg": "RS256",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPut && r.URL.Path == "/clients/client-headless-login" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-headless-login",
|
||||
"client_name": "Headless Login After",
|
||||
"client_secret": "headless-secret",
|
||||
"redirect_uris": []string{"https://headless.example.com/callback"},
|
||||
"grant_types": []string{"authorization_code", "refresh_token"},
|
||||
"response_types": []string{"code"},
|
||||
"scope": "openid profile",
|
||||
"token_endpoint_auth_method": "private_key_jwt",
|
||||
"jwks_uri": "https://headless.example.com/jwks.json",
|
||||
"metadata": map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"request_object_signing_alg": "RS256",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
secretRepo := &mockSecretRepo{secrets: make(map[string]string)}
|
||||
redisRepo := &mockRedisRepo{data: make(map[string]string)}
|
||||
|
||||
h := &handler.DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
PublicURL: "http://hydra.public",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
SecretRepo: secretRepo,
|
||||
Redis: redisRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "test-user", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id", h.UpdateClient)
|
||||
app.Get("/api/v1/dev/clients/:id", h.GetClient)
|
||||
|
||||
updateBody, _ := json.Marshal(map[string]any{
|
||||
"name": "Headless Login After",
|
||||
"redirectUris": []string{"https://headless.example.com/callback"},
|
||||
"tokenEndpointAuthMethod": "private_key_jwt",
|
||||
"jwksUri": "https://headless.example.com/jwks.json",
|
||||
"metadata": map[string]any{
|
||||
"headless_login_enabled": true,
|
||||
"request_object_signing_alg": "RS256",
|
||||
},
|
||||
})
|
||||
updateReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-headless-login", bytes.NewReader(updateBody))
|
||||
updateReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
updateResp, err := app.Test(updateReq, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("update request failed: %v", err)
|
||||
}
|
||||
if updateResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected update 200, got %d", updateResp.StatusCode)
|
||||
}
|
||||
|
||||
storedSecret, _ := secretRepo.GetByID(context.Background(), "client-headless-login")
|
||||
if storedSecret != "headless-secret" {
|
||||
t.Fatalf("expected postgres secret headless-secret, got %q", storedSecret)
|
||||
}
|
||||
|
||||
redisSecret, err := redisRepo.Get("client_secret:client-headless-login")
|
||||
if err != nil {
|
||||
t.Fatalf("expected redis secret, got error: %v", err)
|
||||
}
|
||||
if redisSecret != "headless-secret" {
|
||||
t.Fatalf("expected redis secret headless-secret, got %q", redisSecret)
|
||||
}
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-headless-login", nil)
|
||||
getResp, err := app.Test(getReq, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("get request failed: %v", err)
|
||||
}
|
||||
if getResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected get 200, got %d", getResp.StatusCode)
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Client struct {
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
} `json:"client"`
|
||||
}
|
||||
if err := json.NewDecoder(getResp.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if payload.Client.ClientSecret != "headless-secret" {
|
||||
t.Fatalf("expected detail secret headless-secret, got %q", payload.Client.ClientSecret)
|
||||
}
|
||||
}
|
||||
305
baron-sso/backend/internal/idp/factory.go
Normal file
305
baron-sso/backend/internal/idp/factory.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Ory 계열(kratos/hydra) 공급자 문자열을 정규화하기 위한 매핑.
|
||||
var providerAliases = map[string]string{
|
||||
"ory": "ory",
|
||||
"hydra": "ory",
|
||||
"kratos": "ory",
|
||||
"ory-kratos": "ory",
|
||||
"ory_hydra": "ory",
|
||||
"ory_kratos": "ory",
|
||||
}
|
||||
|
||||
// getEnv는 환경 변수를 읽거나 대체 값을 반환하는 헬퍼 함수입니다.
|
||||
func getEnv(key, fallback string) string {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// InitializeProvider는 환경 설정을 기반으로 IDP 공급자를 생성하고 반환합니다.
|
||||
// 이것은 IdentityProvider 인터페이스의 팩토리 역할을 합니다.
|
||||
func InitializeProvider() (domain.IdentityProvider, error) {
|
||||
rawProviders := getEnv("IDP_PROVIDER", "ory")
|
||||
providers := strings.Split(rawProviders, ",")
|
||||
slog.Info("Initializing IDP chain", "providers", rawProviders)
|
||||
|
||||
var initialized []domain.IdentityProvider
|
||||
for _, p := range providers {
|
||||
providerName := strings.TrimSpace(strings.ToLower(p))
|
||||
if canonical, ok := providerAliases[providerName]; ok {
|
||||
providerName = canonical
|
||||
}
|
||||
|
||||
switch providerName {
|
||||
case "ory":
|
||||
// Kratos/Hydra 주 공급자
|
||||
oryProvider := service.NewOryProvider()
|
||||
initialized = append(initialized, oryProvider)
|
||||
|
||||
default:
|
||||
// 알 수 없는 공급자는 건너뛰고 다음 후보를 시도
|
||||
slog.Warn("Skipping unsupported IDP provider entry", "provider", providerName)
|
||||
}
|
||||
}
|
||||
if len(initialized) == 0 {
|
||||
return nil, fmt.Errorf("no valid IDP_PROVIDER entries configured from: %s", rawProviders)
|
||||
}
|
||||
|
||||
if len(initialized) == 1 {
|
||||
slog.Info("Initialized IDP provider", "provider", initialized[0].Name())
|
||||
return initialized[0], nil
|
||||
}
|
||||
|
||||
chain := newChainedProvider(initialized)
|
||||
slog.Info("Initialized IDP provider chain", "providers", chain.Name())
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
// newChainedProvider는 우선순위 순으로 IDP를 시도하는 체인을 생성합니다.
|
||||
func newChainedProvider(providers []domain.IdentityProvider) domain.IdentityProvider {
|
||||
names := make([]string, len(providers))
|
||||
for i, p := range providers {
|
||||
names[i] = p.Name()
|
||||
}
|
||||
return &chainedProvider{
|
||||
providers: providers,
|
||||
names: names,
|
||||
}
|
||||
}
|
||||
|
||||
// chainedProvider는 다중 IDP를 우선순위대로 호출하며 실패 시 폴백합니다.
|
||||
type chainedProvider struct {
|
||||
providers []domain.IdentityProvider
|
||||
names []string
|
||||
}
|
||||
|
||||
func (c *chainedProvider) Name() string {
|
||||
return strings.Join(c.names, " > ")
|
||||
}
|
||||
|
||||
func (c *chainedProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
supported := make([]string, 0)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, p := range c.providers {
|
||||
meta, err := p.GetMetadata()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch metadata from %s: %w", p.Name(), err)
|
||||
}
|
||||
for _, field := range meta.SupportedFields {
|
||||
if !seen[field] {
|
||||
seen[field] = true
|
||||
supported = append(supported, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &domain.IDPMetadata{SupportedFields: supported}, nil
|
||||
}
|
||||
|
||||
func (c *chainedProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
for _, p := range c.providers {
|
||||
id, err := p.CreateUser(user, password)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "CreateUser", "error", err)
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
return "", domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (c *chainedProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
for _, p := range c.providers {
|
||||
info, err := p.SignIn(loginID, password)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "SignIn", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
|
||||
func (c *chainedProvider) UserExists(loginID string) (bool, error) {
|
||||
var errs []error
|
||||
for _, p := range c.providers {
|
||||
exists, err := p.UserExists(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
continue
|
||||
}
|
||||
if exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("all IDP providers failed for UserExists: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.IssueSession(loginID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "IssueSession", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "IssueSession", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for IssueSession: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.InitiateLinkLogin(loginID, returnTo)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "InitiateLinkLogin", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "InitiateLinkLogin", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for InitiateLinkLogin: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.VerifyLoginCode(loginID, flowID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", "VerifyLoginCode", "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "VerifyLoginCode", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for VerifyLoginCode: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
var errs []error
|
||||
for _, p := range c.providers {
|
||||
policy, err := p.GetPasswordPolicy()
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotSupported) {
|
||||
continue
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
continue
|
||||
}
|
||||
if policy != nil {
|
||||
return policy, nil
|
||||
}
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil, domain.ErrNotSupported
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for GetPasswordPolicy: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
return c.tryProviders("InitiatePasswordReset", func(p domain.IdentityProvider) error {
|
||||
return p.InitiatePasswordReset(loginID, redirectUrl)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *chainedProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
info, err := p.VerifyPasswordResetToken(token)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP VerifyPasswordResetToken failed", "provider", p.Name(), "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", "VerifyPasswordResetToken", "provider", p.Name())
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
return nil, fmt.Errorf("no IDP providers available for VerifyPasswordResetToken")
|
||||
}
|
||||
return nil, fmt.Errorf("all IDP providers failed for VerifyPasswordResetToken: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *chainedProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
return c.tryProviders("UpdateUserPassword", func(p domain.IdentityProvider) error {
|
||||
return p.UpdateUserPassword(loginID, newPassword, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *chainedProvider) tryProviders(operation string, fn func(domain.IdentityProvider) error) error {
|
||||
var errs []error
|
||||
for idx, p := range c.providers {
|
||||
if err := fn(p); err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", p.Name(), err))
|
||||
slog.Warn("IDP provider failed", "provider", p.Name(), "operation", operation, "error", err)
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
slog.Info("IDP fallback succeeded", "operation", operation, "provider", p.Name())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
return fmt.Errorf("no IDP providers available for %s", operation)
|
||||
}
|
||||
return fmt.Errorf("all IDP providers failed for %s: %w", operation, errors.Join(errs...))
|
||||
}
|
||||
167
baron-sso/backend/internal/idp/factory_test.go
Normal file
167
baron-sso/backend/internal/idp/factory_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type stubProvider struct {
|
||||
name string
|
||||
metadata []string
|
||||
createErr error
|
||||
initiateErr error
|
||||
verifyErr error
|
||||
updateErr error
|
||||
signInErr error
|
||||
userExistsErr error
|
||||
issueErr error
|
||||
linkInitErr error
|
||||
verifyCodeErr error
|
||||
policyErr error
|
||||
initiateCalls int
|
||||
verifyCalls int
|
||||
updateCalls int
|
||||
signInCalls int
|
||||
createCalls int
|
||||
userExistsCalls int
|
||||
issueCalls int
|
||||
linkInitCalls int
|
||||
verifyCodeCalls int
|
||||
policyCalls int
|
||||
verifyResponse *domain.AuthInfo
|
||||
userExists bool
|
||||
policy *domain.PasswordPolicy
|
||||
}
|
||||
|
||||
func (s *stubProvider) Name() string { return s.name }
|
||||
|
||||
func (s *stubProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
return &domain.IDPMetadata{SupportedFields: s.metadata}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
s.createCalls++
|
||||
if s.createErr != nil {
|
||||
return "", s.createErr
|
||||
}
|
||||
return "created-id", nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
s.signInCalls++
|
||||
if s.signInErr != nil {
|
||||
return nil, s.signInErr
|
||||
}
|
||||
return &domain.AuthInfo{Subject: "subject-123"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) UserExists(loginID string) (bool, error) {
|
||||
s.userExistsCalls++
|
||||
if s.userExistsErr != nil {
|
||||
return false, s.userExistsErr
|
||||
}
|
||||
return s.userExists, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
s.issueCalls++
|
||||
if s.issueErr != nil {
|
||||
return nil, s.issueErr
|
||||
}
|
||||
return &domain.AuthInfo{Subject: "issue-subject"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
s.linkInitCalls++
|
||||
if s.linkInitErr != nil {
|
||||
return nil, s.linkInitErr
|
||||
}
|
||||
return &domain.LinkLoginInit{FlowID: "flow-123", Mode: "cookie"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
s.verifyCodeCalls++
|
||||
if s.verifyCodeErr != nil {
|
||||
return nil, s.verifyCodeErr
|
||||
}
|
||||
return &domain.AuthInfo{Subject: "verify-code-subject"}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
s.policyCalls++
|
||||
if s.policyErr != nil {
|
||||
return nil, s.policyErr
|
||||
}
|
||||
return s.policy, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) InitiatePasswordReset(loginID, redirectUrl string) error {
|
||||
s.initiateCalls++
|
||||
return s.initiateErr
|
||||
}
|
||||
|
||||
func (s *stubProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
s.verifyCalls++
|
||||
if s.verifyErr != nil {
|
||||
return nil, s.verifyErr
|
||||
}
|
||||
if s.verifyResponse != nil {
|
||||
return s.verifyResponse, nil
|
||||
}
|
||||
return &domain.AuthInfo{}, nil
|
||||
}
|
||||
|
||||
func (s *stubProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
s.updateCalls++
|
||||
return s.updateErr
|
||||
}
|
||||
|
||||
func TestChainedProviderMetadataUnion(t *testing.T) {
|
||||
p1 := &stubProvider{name: "primary", metadata: []string{"id", "email"}}
|
||||
p2 := &stubProvider{name: "backup", metadata: []string{"email", "phone_number", "grade"}}
|
||||
|
||||
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
||||
meta, err := chain.GetMetadata()
|
||||
if err != nil {
|
||||
t.Fatalf("GetMetadata returned error: %v", err)
|
||||
}
|
||||
|
||||
expected := []string{"id", "email", "phone_number", "grade"}
|
||||
if !reflect.DeepEqual(meta.SupportedFields, expected) {
|
||||
t.Fatalf("metadata mismatch: got %v, want %v", meta.SupportedFields, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainedProviderUpdateUserPasswordFallback(t *testing.T) {
|
||||
p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("boom")}
|
||||
p2 := &stubProvider{name: "backup", metadata: []string{"id"}}
|
||||
|
||||
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
||||
if err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil); err != nil {
|
||||
t.Fatalf("expected fallback to succeed, got error: %v", err)
|
||||
}
|
||||
if p1.updateCalls != 1 || p2.updateCalls != 1 {
|
||||
t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainedProviderUpdateUserPasswordAllFail(t *testing.T) {
|
||||
p1 := &stubProvider{name: "primary", metadata: []string{"id"}, updateErr: errors.New("fail1")}
|
||||
p2 := &stubProvider{name: "backup", metadata: []string{"id"}, updateErr: errors.New("fail2")}
|
||||
|
||||
chain := newChainedProvider([]domain.IdentityProvider{p1, p2})
|
||||
err := chain.UpdateUserPassword("user@example.com", "Sup3r!Pass123", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error when all providers fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "all IDP providers failed") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if p1.updateCalls != 1 || p2.updateCalls != 1 {
|
||||
t.Fatalf("unexpected call counts: p1=%d p2=%d", p1.updateCalls, p2.updateCalls)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user