첫 커밋: 로컬 프로젝트 업로드

This commit is contained in:
2026-06-10 15:51:34 +09:00
commit 6a8dbeb2e9
1211 changed files with 312864 additions and 0 deletions

View File

@@ -0,0 +1,491 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"context"
"runtime"
"strconv"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
type adminHydraClientLister interface {
ListClients(ctx context.Context, limit, offset int) ([]domain.HydraClient, error)
}
type identityCacheAdmin interface {
GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error)
FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error)
}
type AdminHandler struct {
DB *gorm.DB
Keto service.KetoService
KetoOutbox repository.KetoOutboxRepository
RPUsageQueries domain.RPUsageQueryRepository
TenantRepo repository.TenantRepository
Hydra adminHydraClientLister
AuditRepo domain.AuditRepository
UserProjectionRepo repository.UserProjectionRepository
IdentityCache identityCacheAdmin
IntegrityChecker repository.DataIntegrityChecker
}
const globalCustomClaimsSettingKey = "global_custom_claim_definitions"
type globalCustomClaimDefinition struct {
Key string `json:"key"`
Label string `json:"label"`
ValueType string `json:"valueType"`
ReadPermission string `json:"readPermission"`
WritePermission string `json:"writePermission"`
Description string `json:"description,omitempty"`
}
type globalCustomClaimDefinitionsResponse struct {
Items []globalCustomClaimDefinition `json:"items"`
}
func NewAdminHandler(keto service.KetoService, ketoOutbox repository.KetoOutboxRepository) *AdminHandler {
return &AdminHandler{
Keto: keto,
KetoOutbox: ketoOutbox,
}
}
func (h *AdminHandler) GetRPUsageDaily(c *fiber.Ctx) error {
if h == nil || h.RPUsageQueries == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "rp usage query service unavailable",
})
}
days := 14
if raw := c.Query("days"); raw != "" {
if parsed, err := strconv.Atoi(raw); err == nil {
days = parsed
}
}
period := normalizeRPUsagePeriod(c.Query("period"))
tenantID, allowed := h.authorizedRPUsageTenantID(c, strings.TrimSpace(c.Query("tenantId")))
if !allowed {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "forbidden: tenant rp usage stats permission denied",
})
}
items, err := h.RPUsageQueries.FindRPUsage(c.Context(), domain.RPUsageQuery{
Days: days,
Period: period,
TenantID: tenantID,
})
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"error": err.Error(),
})
}
return c.JSON(fiber.Map{
"items": items,
"days": days,
"period": period,
"tenantId": tenantID,
})
}
func normalizeRPUsagePeriod(period string) string {
switch strings.ToLower(strings.TrimSpace(period)) {
case "week":
return "week"
case "month":
return "month"
default:
return "day"
}
}
func (h *AdminHandler) authorizedRPUsageTenantID(c *fiber.Ctx, requestedTenantID string) (string, bool) {
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
if profile != nil && domain.NormalizeRole(profile.Role) == domain.RoleSuperAdmin {
return requestedTenantID, true
}
tenantID := requestedTenantID
if tenantID == "" && profile != nil && profile.TenantID != nil {
tenantID = strings.TrimSpace(*profile.TenantID)
}
if tenantID == "" {
return "", false
}
if h == nil || h.Keto == nil || profile == nil || strings.TrimSpace(profile.ID) == "" {
return "", false
}
allowed, err := h.Keto.CheckPermission(c.Context(), "User:"+profile.ID, "Tenant", tenantID, "view_rp_usage_stats")
if err != nil || !allowed {
return "", false
}
return tenantID, true
}
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
}
func (h *AdminHandler) GetGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
if h == nil || h.DB == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "settings store unavailable",
})
}
var setting domain.SystemSetting
if err := h.DB.WithContext(c.Context()).First(&setting, "key = ?", globalCustomClaimsSettingKey).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(globalCustomClaimDefinitionsResponse{Items: []globalCustomClaimDefinition{}})
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(globalCustomClaimDefinitionsResponse{
Items: normalizeGlobalCustomClaimDefinitions(setting.Value["items"]),
})
}
func (h *AdminHandler) UpdateGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
if h == nil || h.DB == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "settings store unavailable",
})
}
var req globalCustomClaimDefinitionsResponse
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
}
items, err := validateGlobalCustomClaimDefinitions(req.Items)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
setting := domain.SystemSetting{
Key: globalCustomClaimsSettingKey,
Value: domain.JSONMap{"items": globalCustomClaimDefinitionsToJSON(items)},
}
if err := h.DB.WithContext(c.Context()).Save(&setting).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(globalCustomClaimDefinitionsResponse{Items: items})
}
func normalizeGlobalCustomClaimDefinitions(value any) []globalCustomClaimDefinition {
rawItems, ok := value.([]any)
if !ok {
return []globalCustomClaimDefinition{}
}
items := make([]globalCustomClaimDefinition, 0, len(rawItems))
for _, item := range rawItems {
raw, ok := item.(map[string]any)
if !ok {
continue
}
def := globalCustomClaimDefinition{
Key: strings.TrimSpace(stringValue(raw["key"])),
Label: strings.TrimSpace(stringValue(raw["label"])),
ValueType: normalizeGlobalCustomClaimType(stringValue(raw["valueType"])),
ReadPermission: adminNormalizeCustomClaimPermission(stringValue(raw["readPermission"])),
WritePermission: adminNormalizeCustomClaimPermission(stringValue(raw["writePermission"])),
Description: strings.TrimSpace(stringValue(raw["description"])),
}
if def.Key != "" {
items = append(items, def)
}
}
return items
}
func validateGlobalCustomClaimDefinitions(items []globalCustomClaimDefinition) ([]globalCustomClaimDefinition, error) {
seen := map[string]struct{}{}
normalized := make([]globalCustomClaimDefinition, 0, len(items))
for _, item := range items {
key := strings.TrimSpace(item.Key)
if key == "" {
continue
}
if !isValidCustomClaimKey(key) {
return nil, fiber.NewError(fiber.StatusBadRequest, "claim key must use letters, numbers, underscore, dot, or hyphen")
}
if _, exists := seen[key]; exists {
return nil, fiber.NewError(fiber.StatusBadRequest, "duplicate claim key: "+key)
}
seen[key] = struct{}{}
normalized = append(normalized, globalCustomClaimDefinition{
Key: key,
Label: strings.TrimSpace(item.Label),
ValueType: normalizeGlobalCustomClaimType(item.ValueType),
ReadPermission: adminNormalizeCustomClaimPermission(item.ReadPermission),
WritePermission: adminNormalizeCustomClaimPermission(item.WritePermission),
Description: strings.TrimSpace(item.Description),
})
}
return normalized, nil
}
func globalCustomClaimDefinitionsToJSON(items []globalCustomClaimDefinition) []any {
values := make([]any, 0, len(items))
for _, item := range items {
values = append(values, map[string]any{
"key": item.Key,
"label": item.Label,
"valueType": item.ValueType,
"readPermission": item.ReadPermission,
"writePermission": item.WritePermission,
"description": item.Description,
})
}
return values
}
func normalizeGlobalCustomClaimType(value string) string {
switch strings.ToLower(strings.TrimSpace(value)) {
case "number", "boolean", "array", "object", "date", "datetime":
return strings.ToLower(strings.TrimSpace(value))
default:
return "text"
}
}
func adminNormalizeCustomClaimPermission(value string) string {
if strings.TrimSpace(value) == "user_and_admin" {
return "user_and_admin"
}
return "admin_only"
}
func isValidCustomClaimKey(value string) bool {
for _, r := range value {
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '_' || r == '-' || r == '.' {
continue
}
return false
}
return true
}
func stringValue(value any) string {
if text, ok := value.(string); ok {
return text
}
return ""
}
func requireSuperAdminProfile(c *fiber.Ctx) bool {
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
if profile == nil || domain.NormalizeRole(profile.Role) != domain.RoleSuperAdmin {
_ = c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: super_admin required"})
return false
}
return true
}
func (h *AdminHandler) GetUserProjectionStatus(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.UserProjectionRepo == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
}
status, err := h.UserProjectionRepo.GetStatus(c.Context())
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(status)
}
func (h *AdminHandler) GetOrySSOTSystemStatus(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.UserProjectionRepo == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
}
projectionStatus, err := h.UserProjectionRepo.GetStatus(c.Context())
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
cacheStatus := domain.IdentityCacheStatus{
Status: "unavailable",
RedisReady: false,
LastError: "identity cache service unavailable",
}
if h.IdentityCache != nil {
cacheStatus, err = h.IdentityCache.GetIdentityCacheStatus(c.Context())
if err != nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
}
}
return c.JSON(fiber.Map{
"userProjection": projectionStatus,
"identityCache": cacheStatus,
})
}
func (h *AdminHandler) FlushIdentityCache(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IdentityCache == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "identity cache service unavailable"})
}
result, err := h.IdentityCache.FlushIdentityCache(c.Context())
if err != nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(result)
}
func (h *AdminHandler) GetDataIntegrity(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IntegrityChecker == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
}
report, err := h.IntegrityChecker.CheckDataIntegrity(c.Context())
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(report)
}
func (h *AdminHandler) ListOrphanUserLoginIDs(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IntegrityChecker == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
}
items, err := h.IntegrityChecker.ListOrphanUserLoginIDs(c.Context())
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{
"items": items,
"total": len(items),
})
}
func (h *AdminHandler) DeleteOrphanUserLoginIDs(c *fiber.Ctx) error {
if !requireSuperAdminProfile(c) {
return nil
}
if h == nil || h.IntegrityChecker == nil {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
}
var req struct {
IDs []string `json:"ids"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
result, err := h.IntegrityChecker.DeleteOrphanUserLoginIDs(c.Context(), req.IDs)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(result)
}
// GetSystemStats returns runtime statistics for monitoring
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
var m runtime.MemStats
runtime.ReadMemStats(&m)
ctx := c.Context()
stats := fiber.Map{
"totalTenants": h.countTenants(ctx),
"totalUsers": h.countUsers(ctx),
"oidcClients": h.countOIDCClients(ctx),
"auditEvents24h": h.countAuditEventsSince(ctx, time.Now().UTC().Add(-24*time.Hour)),
"goroutines": runtime.NumGoroutine(),
"cpus": runtime.NumCPU(),
"memory": fiber.Map{
"alloc": m.Alloc,
"totalAlign": m.TotalAlloc,
"sys": m.Sys,
"numGC": m.NumGC,
},
"timestamp": time.Now(),
}
return c.Status(fiber.StatusOK).JSON(stats)
}
func (h *AdminHandler) countTenants(ctx context.Context) int64 {
if h == nil || h.TenantRepo == nil {
return 0
}
_, total, err := h.TenantRepo.List(ctx, 1, 0, "", "")
if err != nil {
return 0
}
return total
}
func (h *AdminHandler) countUsers(ctx context.Context) int64 {
if h == nil || h.UserProjectionRepo == nil {
return 0
}
status, err := h.UserProjectionRepo.GetStatus(ctx)
if err != nil {
return 0
}
return status.ProjectedUsers
}
func (h *AdminHandler) countOIDCClients(ctx context.Context) int64 {
if h == nil || h.Hydra == nil {
return 0
}
const pageSize = 500
var total int64
for offset := 0; ; offset += pageSize {
clients, err := h.Hydra.ListClients(ctx, pageSize, offset)
if err != nil {
return total
}
for _, client := range clients {
if isHiddenSystemClient(client) {
continue
}
total++
}
if len(clients) < pageSize {
break
}
}
return total
}
func (h *AdminHandler) countAuditEventsSince(ctx context.Context, since time.Time) int64 {
if h == nil || h.AuditRepo == nil {
return 0
}
count, err := h.AuditRepo.CountEventsSince(ctx, since)
if err == nil && count > 0 {
return count
}
logs, pageErr := h.AuditRepo.FindPage(ctx, 10000, nil, "")
if pageErr != nil {
return count
}
var fallbackCount int64
for _, log := range logs {
if !log.Timestamp.Before(since) {
fallbackCount++
}
}
return fallbackCount
}

View File

@@ -0,0 +1,343 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
type fakeRPUsageQueryRepo struct {
query domain.RPUsageQuery
items []domain.RPUsageDailyMetric
}
func (f *fakeRPUsageQueryRepo) FindRPUsage(ctx context.Context, query domain.RPUsageQuery) ([]domain.RPUsageDailyMetric, error) {
f.query = query
return f.items, nil
}
type fakeAdminKeto struct {
allowed bool
subject string
object string
relation string
}
func (f *fakeAdminKeto) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
f.subject = subject
f.object = object
f.relation = relation
return f.allowed, nil
}
func (f *fakeAdminKeto) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
return nil
}
func (f *fakeAdminKeto) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
return nil
}
func (f *fakeAdminKeto) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
return nil, nil
}
func (f *fakeAdminKeto) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
return nil, nil
}
type fakeOverviewAuditRepo struct {
mockAuditRepo
since time.Time
count int64
}
func (f *fakeOverviewAuditRepo) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
f.since = since
return f.count, nil
}
type fakeAdminUserProjectionRepo struct {
status domain.UserProjectionStatus
}
func (f *fakeAdminUserProjectionRepo) IsReady(ctx context.Context) (bool, error) {
return f.status.Ready, nil
}
func (f *fakeAdminUserProjectionRepo) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
return nil, nil
}
func (f *fakeAdminUserProjectionRepo) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
return nil, nil
}
func (f *fakeAdminUserProjectionRepo) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
return nil
}
func (f *fakeAdminUserProjectionRepo) MarkFailed(ctx context.Context, syncErr error) error {
return nil
}
func (f *fakeAdminUserProjectionRepo) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
return f.status, nil
}
type fakeIdentityCacheAdmin struct {
status domain.IdentityCacheStatus
flush domain.IdentityCacheFlushResult
err error
statusHit int
flushCalls int
}
func (f *fakeIdentityCacheAdmin) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
f.statusHit++
return f.status, f.err
}
func (f *fakeIdentityCacheAdmin) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
f.flushCalls++
return f.flush, f.err
}
func TestAdminHandler_GetRPUsageDaily(t *testing.T) {
repo := &fakeRPUsageQueryRepo{
items: []domain.RPUsageDailyMetric{
{
Date: "2026-05-06",
TenantID: "tenant-1",
TenantType: domain.TenantTypeCompany,
ClientID: "orgfront",
ClientName: "OrgFront",
LoginRequests: 12,
OtherRequests: 4,
UniqueSubjects: 8,
},
},
}
h := &AdminHandler{RPUsageQueries: repo}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/rp-usage/daily", h.GetRPUsageDaily)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/rp-usage/daily?days=7&period=week&tenantId=tenant-1", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 7, repo.query.Days)
require.Equal(t, "week", repo.query.Period)
require.Equal(t, "tenant-1", repo.query.TenantID)
var body struct {
Items []domain.RPUsageDailyMetric `json:"items"`
Days int `json:"days"`
Period string `json:"period"`
TenantID string `json:"tenantId"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 7, body.Days)
require.Equal(t, "week", body.Period)
require.Equal(t, "tenant-1", body.TenantID)
require.Len(t, body.Items, 1)
require.Equal(t, "orgfront", body.Items[0].ClientID)
require.Equal(t, uint64(12), body.Items[0].LoginRequests)
}
func TestAdminHandler_UserProjectionStatusRequiresSuperAdmin(t *testing.T) {
h := &AdminHandler{
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{Name: domain.UserProjectionNameKratos, Status: domain.UserProjectionStatusReady, Ready: true},
},
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
return c.Next()
})
app.Get("/api/v1/admin/projections/users", h.GetUserProjectionStatus)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/projections/users", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
}
func TestAdminHandler_UserProjectionStatusReturnsProjectionStateForSuperAdmin(t *testing.T) {
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
h := &AdminHandler{
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusReady,
Ready: true,
LastSyncedAt: &syncedAt,
ProjectedUsers: 152,
},
},
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/projections/users", h.GetUserProjectionStatus)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/projections/users", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body domain.UserProjectionStatus
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, domain.UserProjectionNameKratos, body.Name)
require.Equal(t, domain.UserProjectionStatusReady, body.Status)
require.True(t, body.Ready)
require.Equal(t, int64(152), body.ProjectedUsers)
}
func TestAdminHandler_GetOrySSOTSystemStatusReturnsProjectionAndIdentityCache(t *testing.T) {
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
cache := &fakeIdentityCacheAdmin{
status: domain.IdentityCacheStatus{
Status: "ready",
RedisReady: true,
ObservedCount: 151,
KeyCount: 153,
LastRefreshedAt: &syncedAt,
UpdatedAt: &syncedAt,
},
}
h := &AdminHandler{
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusReady,
Ready: true,
LastSyncedAt: &syncedAt,
ProjectedUsers: 152,
},
},
IdentityCache: cache,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/ory/ssot", h.GetOrySSOTSystemStatus)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/ory/ssot", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
UserProjection domain.UserProjectionStatus `json:"userProjection"`
IdentityCache domain.IdentityCacheStatus `json:"identityCache"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, int64(152), body.UserProjection.ProjectedUsers)
require.True(t, body.IdentityCache.RedisReady)
require.Equal(t, int64(151), body.IdentityCache.ObservedCount)
require.Equal(t, int64(153), body.IdentityCache.KeyCount)
require.Equal(t, 1, cache.statusHit)
}
func TestAdminHandler_FlushIdentityCacheRequiresSuperAdminAndFlushesCacheOnly(t *testing.T) {
cache := &fakeIdentityCacheAdmin{
flush: domain.IdentityCacheFlushResult{
Status: "success",
FlushedKeys: 7,
UpdatedAt: time.Date(2026, 5, 11, 3, 2, 0, 0, time.UTC),
},
}
h := &AdminHandler{
IdentityCache: cache,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/admin/ory/ssot/identity-cache/flush", h.FlushIdentityCache)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/ory/ssot/identity-cache/flush", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body domain.IdentityCacheFlushResult
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, int64(7), body.FlushedKeys)
require.Equal(t, 1, cache.flushCalls)
}
func TestAdminHandler_GetRPUsageDailyChecksTenantPermission(t *testing.T) {
repo := &fakeRPUsageQueryRepo{}
keto := &fakeAdminKeto{allowed: true}
h := &AdminHandler{RPUsageQueries: repo, Keto: keto}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-1",
Role: "tenant_admin",
})
return c.Next()
})
app.Get("/api/v1/admin/rp-usage/daily", h.GetRPUsageDaily)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/rp-usage/daily?tenantId=tenant-allowed", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "User:user-1", keto.subject)
require.Equal(t, "tenant-allowed", keto.object)
require.Equal(t, "view_rp_usage_stats", keto.relation)
require.Equal(t, "tenant-allowed", repo.query.TenantID)
}
func TestAdminHandler_GetSystemStatsIncludesOverviewMetrics(t *testing.T) {
auditRepo := &fakeOverviewAuditRepo{count: 22}
h := &AdminHandler{
AuditRepo: auditRepo,
UserProjectionRepo: &fakeAdminUserProjectionRepo{
status: domain.UserProjectionStatus{
Name: domain.UserProjectionNameKratos,
Status: domain.UserProjectionStatusReady,
Ready: true,
ProjectedUsers: 152,
},
},
}
app := fiber.New()
app.Get("/api/v1/admin/stats", h.GetSystemStats)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/stats", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Contains(t, body, "totalTenants")
require.Contains(t, body, "totalUsers")
require.Contains(t, body, "oidcClients")
require.Contains(t, body, "auditEvents24h")
require.Equal(t, float64(152), body["totalUsers"])
require.Equal(t, float64(22), body["auditEvents24h"])
require.Equal(t, time.UTC, auditRepo.since.Location())
}

View File

@@ -0,0 +1,196 @@
package handler
import (
"baron-sso-backend/internal/domain"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
type fakeDataIntegrityChecker struct {
calls int
listCalls int
deleteCalls int
deletedIDs []string
report domain.DataIntegrityReport
orphans []domain.OrphanUserLoginID
deleteResult domain.DeleteOrphanUserLoginIDsResult
err error
}
func (f *fakeDataIntegrityChecker) CheckDataIntegrity(ctx context.Context) (domain.DataIntegrityReport, error) {
f.calls++
return f.report, f.err
}
func (f *fakeDataIntegrityChecker) ListOrphanUserLoginIDs(ctx context.Context) ([]domain.OrphanUserLoginID, error) {
f.listCalls++
return f.orphans, f.err
}
func (f *fakeDataIntegrityChecker) DeleteOrphanUserLoginIDs(ctx context.Context, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error) {
f.deleteCalls++
f.deletedIDs = append([]string(nil), ids...)
return f.deleteResult, f.err
}
func TestAdminHandler_GetDataIntegrityRequiresSuperAdmin(t *testing.T) {
checker := &fakeDataIntegrityChecker{}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
return c.Next()
})
app.Get("/api/v1/admin/integrity", h.GetDataIntegrity)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
require.Equal(t, 0, checker.calls)
}
func TestAdminHandler_GetDataIntegrityReturnsReportForSuperAdmin(t *testing.T) {
checkedAt := time.Date(2026, 5, 14, 0, 0, 0, 0, time.UTC)
checker := &fakeDataIntegrityChecker{
report: domain.DataIntegrityReport{
Status: domain.DataIntegrityStatusFail,
CheckedAt: checkedAt,
Summary: domain.DataIntegritySummary{
TotalChecks: 1,
Failures: 1,
},
Sections: []domain.DataIntegritySection{
{
Key: "tenant_integrity",
Label: "테넌트 정합성",
Status: domain.DataIntegrityStatusFail,
Checks: []domain.DataIntegrityCheck{
{
Key: "duplicate_tenant_slugs",
Label: "중복 테넌트 slug",
Status: domain.DataIntegrityStatusFail,
Count: 1,
Severity: "error",
},
},
},
},
},
}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/integrity", h.GetDataIntegrity)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 1, checker.calls)
var body domain.DataIntegrityReport
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, domain.DataIntegrityStatusFail, body.Status)
require.Equal(t, int64(1), body.Summary.Failures)
require.Len(t, body.Sections, 1)
require.Equal(t, "tenant_integrity", body.Sections[0].Key)
}
func TestAdminHandler_ListOrphanUserLoginIDsReturnsTargetsForSuperAdmin(t *testing.T) {
checker := &fakeDataIntegrityChecker{
orphans: []domain.OrphanUserLoginID{
{
ID: "login-id-1",
UserID: "user-1",
TenantID: "tenant-1",
FieldKey: "emp_id",
LoginID: "EMP001",
Reasons: []string{"missing_tenant"},
},
},
}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Get("/api/v1/admin/integrity/orphan-user-login-ids", h.ListOrphanUserLoginIDs)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity/orphan-user-login-ids", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 1, checker.listCalls)
var body struct {
Items []domain.OrphanUserLoginID `json:"items"`
Total int `json:"total"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 1, body.Total)
require.Equal(t, "login-id-1", body.Items[0].ID)
require.Equal(t, []string{"missing_tenant"}, body.Items[0].Reasons)
}
func TestAdminHandler_DeleteOrphanUserLoginIDsRequiresSuperAdminAndDeletesSelectedTargets(t *testing.T) {
checker := &fakeDataIntegrityChecker{
deleteResult: domain.DeleteOrphanUserLoginIDsResult{
DeletedCount: 1,
Deleted: []domain.OrphanUserLoginID{
{ID: "login-id-1", LoginID: "EMP001", Reasons: []string{"missing_user"}},
},
SkippedIDs: []string{"valid-login-id"},
},
}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Delete("/api/v1/admin/integrity/orphan-user-login-ids", h.DeleteOrphanUserLoginIDs)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/admin/integrity/orphan-user-login-ids", strings.NewReader(`{"ids":["login-id-1","valid-login-id"]}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 1, checker.deleteCalls)
require.Equal(t, []string{"login-id-1", "valid-login-id"}, checker.deletedIDs)
var body domain.DeleteOrphanUserLoginIDsResult
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, int64(1), body.DeletedCount)
require.Equal(t, []string{"valid-login-id"}, body.SkippedIDs)
}
func TestAdminHandler_DeleteOrphanUserLoginIDsRejectsTenantAdmin(t *testing.T) {
checker := &fakeDataIntegrityChecker{}
h := &AdminHandler{IntegrityChecker: checker}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
return c.Next()
})
app.Delete("/api/v1/admin/integrity/orphan-user-login-ids", h.DeleteOrphanUserLoginIDs)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/admin/integrity/orphan-user-login-ids", strings.NewReader(`{"ids":["login-id-1"]}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
require.Equal(t, 0, checker.deleteCalls)
}

View File

@@ -0,0 +1,288 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/pagination"
"errors"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type ApiKeyHandler struct {
DB *gorm.DB
}
func NewApiKeyHandler(db *gorm.DB) *ApiKeyHandler {
return &ApiKeyHandler{DB: db}
}
type apiKeySummary struct {
ID string `json:"id"`
Name string `json:"name"`
ClientID string `json:"client_id"`
Scopes []string `json:"scopes"`
Status string `json:"status"`
LastUsedAt *string `json:"lastUsedAt"`
CreatedAt time.Time `json:"createdAt"`
}
type apiKeyListResponse struct {
Items []apiKeySummary `json:"items"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Cursor string `json:"cursor,omitempty"`
NextCursor string `json:"nextCursor,omitempty"`
}
func apiKeyToSummary(k domain.ApiKey) apiKeySummary {
lastUsed := ""
if k.LastUsedAt != nil {
lastUsed = k.LastUsedAt.Format(time.RFC3339)
}
return apiKeySummary{
ID: k.ID,
Name: k.Name,
ClientID: k.ClientID,
Scopes: strings.Fields(strings.ReplaceAll(k.Scopes, ",", " ")),
Status: k.Status,
LastUsedAt: &lastUsed,
CreatedAt: k.CreatedAt,
}
}
func apiKeyWithUpdatedScopes(k domain.ApiKey, scopes []string) domain.ApiKey {
k.Scopes = strings.Join(normalizeApiKeyScopes(scopes), " ")
return k
}
func apiKeyWithRotatedSecretHash(k domain.ApiKey, hashedSecret string) domain.ApiKey {
k.ClientSecretHash = hashedSecret
return k
}
func normalizeApiKeyScopes(scopes []string) []string {
seen := make(map[string]struct{}, len(scopes))
normalized := make([]string, 0, len(scopes))
for _, scope := range scopes {
scope = strings.TrimSpace(scope)
if scope == "" {
continue
}
if _, exists := seen[scope]; exists {
continue
}
seen[scope] = struct{}{}
normalized = append(normalized, scope)
}
return normalized
}
func (h *ApiKeyHandler) ListApiKeys(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0)
cursorRaw := strings.TrimSpace(c.Query("cursor"))
if limit <= 0 {
limit = 50
}
if offset < 0 {
offset = 0
}
var total int64
if err := h.DB.Model(&domain.ApiKey{}).Count(&total).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
var keys []domain.ApiKey
query := h.DB.Order("created_at desc, id desc").Limit(limit + 1)
if cursorRaw != "" {
cursor, err := pagination.Decode(cursorRaw)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid cursor")
}
query = pagination.ApplyCreatedAtIDCursor(query, cursor, "created_at", "id")
offset = 0
} else {
query = query.Offset(offset)
}
if err := query.Find(&keys).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
nextCursor := ""
hasMore := len(keys) > limit
if len(keys) > limit {
keys = keys[:limit]
}
if cursorRaw == "" && total > int64(offset+len(keys)) {
hasMore = true
}
if hasMore && len(keys) > 0 {
last := keys[len(keys)-1]
nextCursor = pagination.Encode(last.CreatedAt, last.ID)
}
items := make([]apiKeySummary, 0, len(keys))
for _, k := range keys {
items = append(items, apiKeyToSummary(k))
}
return c.JSON(apiKeyListResponse{
Items: items,
Total: total,
Limit: limit,
Offset: offset,
Cursor: cursorRaw,
NextCursor: nextCursor,
})
}
func (h *ApiKeyHandler) CreateApiKey(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
var req struct {
Name string `json:"name"`
Scopes []string `json:"scopes"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
if strings.TrimSpace(req.Name) == "" {
return errorJSON(c, fiber.StatusBadRequest, "name is required")
}
req.Scopes = normalizeApiKeyScopes(req.Scopes)
if len(req.Scopes) == 0 {
return errorJSON(c, fiber.StatusBadRequest, "at least one scope is required")
}
// Generate Client ID (16 chars hex)
clientID := GenerateSecureToken(8)
// Generate plain secret (16 chars hex)
plainSecret := GenerateSecureToken(8)
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(plainSecret), bcrypt.DefaultCost)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to hash secret")
}
apiKey := domain.ApiKey{
Name: req.Name,
ClientID: clientID,
ClientSecretHash: string(hashedSecret),
Scopes: strings.Join(req.Scopes, " "),
Status: "active",
}
if err := h.DB.Create(&apiKey).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
// Return summary + PLAIN SECRET (only this time)
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
"apiKey": apiKeyToSummary(apiKey),
"clientSecret": plainSecret, // VERY IMPORTANT: user must save this now
})
}
func (h *ApiKeyHandler) UpdateApiKey(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
id := c.Params("id")
if id == "" {
return errorJSON(c, fiber.StatusBadRequest, "id is required")
}
var req struct {
Scopes []string `json:"scopes"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
req.Scopes = normalizeApiKeyScopes(req.Scopes)
if len(req.Scopes) == 0 {
return errorJSON(c, fiber.StatusBadRequest, "at least one scope is required")
}
var apiKey domain.ApiKey
if err := h.DB.First(&apiKey, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errorJSON(c, fiber.StatusNotFound, "api key not found")
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
apiKey = apiKeyWithUpdatedScopes(apiKey, req.Scopes)
if err := h.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("scopes", apiKey.Scopes).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(apiKeyToSummary(apiKey))
}
func (h *ApiKeyHandler) RotateApiKeySecret(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
id := c.Params("id")
if id == "" {
return errorJSON(c, fiber.StatusBadRequest, "id is required")
}
var apiKey domain.ApiKey
if err := h.DB.First(&apiKey, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errorJSON(c, fiber.StatusNotFound, "api key not found")
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
plainSecret := GenerateSecureToken(8)
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(plainSecret), bcrypt.DefaultCost)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to hash secret")
}
apiKey = apiKeyWithRotatedSecretHash(apiKey, string(hashedSecret))
if err := h.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("client_secret_hash", apiKey.ClientSecretHash).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(fiber.Map{
"apiKey": apiKeyToSummary(apiKey),
"clientSecret": plainSecret,
})
}
func (h *ApiKeyHandler) DeleteApiKey(c *fiber.Ctx) error {
if h.DB == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
}
id := c.Params("id")
if id == "" {
return errorJSON(c, fiber.StatusBadRequest, "id is required")
}
if err := h.DB.Delete(&domain.ApiKey{}, "id = ?", id).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusNoContent)
}

View File

@@ -0,0 +1,133 @@
package handler
import (
"baron-sso-backend/internal/domain"
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
// Mock DB for ApiKey tests using a real GORM instance but with a hijacked connection
// or just a simple mock if we only check nil.
// For ApiKeyHandler, it uses DB for Create/List/Delete.
func TestApiKeyHandler_CreateApiKey(t *testing.T) {
app := fiber.New()
// ApiKeyHandler requires a valid DB connection to perform h.DB.Create
// Since we don't have a real DB here, we'll check if it fails gracefully
// or we can use sqlite in-memory for a more realistic test.
h := &ApiKeyHandler{DB: nil} // Testing ServiceUnavailable
app.Post("/api-keys", h.CreateApiKey)
input := map[string]any{
"name": "M2M Test",
"scopes": []string{"read", "write"},
}
body, _ := json.Marshal(input)
req := httptest.NewRequest("POST", "/api-keys", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
}
func TestApiKeyHandler_Validation(t *testing.T) {
app := fiber.New()
// Using a dummy DB pointer to pass the nil check
h := &ApiKeyHandler{DB: &gorm.DB{}}
app.Post("/api-keys", h.CreateApiKey)
// Missing name
input := map[string]any{
"scopes": []string{"read"},
}
body, _ := json.Marshal(input)
req := httptest.NewRequest("POST", "/api-keys", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}
func TestApiKeyHandler_UpdateApiKeyScopesRequiresDatabase(t *testing.T) {
app := fiber.New()
h := &ApiKeyHandler{DB: nil}
app.Patch("/api-keys/:id", h.UpdateApiKey)
body, _ := json.Marshal(map[string]any{
"scopes": []string{"org-context:read"},
})
req := httptest.NewRequest("PATCH", "/api-keys/api-key-id", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
}
func TestApiKeyHandler_RotateApiKeySecretRequiresDatabase(t *testing.T) {
app := fiber.New()
h := &ApiKeyHandler{DB: nil}
app.Post("/api-keys/:id/secret/rotate", h.RotateApiKeySecret)
req := httptest.NewRequest("POST", "/api-keys/api-key-id/secret/rotate", nil)
resp, _ := app.Test(req)
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
}
func TestApiKeyWithUpdatedScopesPreservesClientID(t *testing.T) {
key := domain.ApiKey{
ID: "api-key-id",
Name: "M2M Test",
ClientID: "client-id-stable",
ClientSecretHash: "old-secret-hash",
Scopes: "audit:read",
Status: "active",
}
updated := apiKeyWithUpdatedScopes(key, []string{"audit:read", "org-context:read"})
assert.Equal(t, "client-id-stable", updated.ClientID)
assert.Equal(t, "old-secret-hash", updated.ClientSecretHash)
assert.Equal(t, "audit:read org-context:read", updated.Scopes)
}
func TestApiKeyWithRotatedSecretHashPreservesClientIDAndScopes(t *testing.T) {
key := domain.ApiKey{
ID: "api-key-id",
Name: "M2M Test",
ClientID: "client-id-stable",
ClientSecretHash: "old-secret-hash",
Scopes: "audit:read org-context:read",
Status: "active",
}
updated := apiKeyWithRotatedSecretHash(key, "new-secret-hash")
assert.Equal(t, "client-id-stable", updated.ClientID)
assert.Equal(t, "audit:read org-context:read", updated.Scopes)
assert.Equal(t, "new-secret-hash", updated.ClientSecretHash)
}
func TestNormalizeApiKeyScopesTrimsAndDeduplicates(t *testing.T) {
scopes := normalizeApiKeyScopes([]string{
" audit:read ",
"",
"org-context:read",
"audit:read",
})
assert.Equal(t, []string{"audit:read", "org-context:read"}, scopes)
}

View File

@@ -0,0 +1,139 @@
package handler
import (
"baron-sso-backend/internal/domain"
"encoding/base64"
"errors"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
)
type AuditHandler struct {
repo domain.AuditRepository
}
func NewAuditHandler(repo domain.AuditRepository) *AuditHandler {
return &AuditHandler{repo: repo}
}
// CreateLog handles POST /api/v1/audit
func (h *AuditHandler) CreateLog(c *fiber.Ctx) error {
var req domain.AuditLog
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "Cannot parse JSON")
}
// Auto-fill metadata if missing
if req.IPAddress == "" {
req.IPAddress = c.IP()
}
if req.UserAgent == "" {
req.UserAgent = c.Get("User-Agent")
}
if req.Timestamp.IsZero() {
req.Timestamp = time.Now()
}
if req.EventID == "" {
req.EventID = ensureRequestID(c)
}
if h.repo == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
}
if err := h.repo.Create(&req); err != nil {
// Log internal error but don't expose details
return errorJSON(c, fiber.StatusInternalServerError, "Failed to save audit log")
}
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
"message": "Audit log saved",
})
}
// ListLogs handles GET /api/v1/audit
func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
limit := c.QueryInt("limit", 50)
cursorRaw := c.Query("cursor")
requestedTenantID := c.Query("tenantId")
cursor, err := parseAuditCursor(cursorRaw)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor")
}
if h.repo == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
}
// [New] Role-based Filtering
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
var filterTenantID string
if profile != nil {
if profile.Role == domain.RoleSuperAdmin {
// Super Admin can see everything or filter by a specific tenant if requested
filterTenantID = requestedTenantID
} else {
return errorJSON(c, fiber.StatusForbidden, "forbidden")
}
}
logs, err := h.repo.FindPage(c.Context(), limit+1, cursor, filterTenantID)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
}
nextCursor := ""
if len(logs) > limit {
last := logs[limit-1]
nextCursor = encodeAuditCursor(last)
logs = logs[:limit]
}
return c.JSON(fiber.Map{
"items": logs,
"limit": limit,
"cursor": cursorRaw,
"next_cursor": nextCursor,
})
}
func ensureRequestID(c *fiber.Ctx) string {
reqID := c.Get("X-Request-Id")
if reqID == "" {
reqID = uuid.New().String()
c.Set("X-Request-Id", reqID)
}
return reqID
}
func parseAuditCursor(raw string) (*domain.AuditCursor, error) {
if raw == "" {
return nil, nil
}
decoded, err := base64.RawURLEncoding.DecodeString(raw)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(decoded), "|", 2)
if len(parts) != 2 {
return nil, errors.New("invalid cursor")
}
ts, err := time.Parse(time.RFC3339Nano, parts[0])
if err != nil {
return nil, err
}
return &domain.AuditCursor{
Timestamp: ts,
EventID: parts[1],
}, nil
}
func encodeAuditCursor(log domain.AuditLog) string {
payload := log.Timestamp.UTC().Format(time.RFC3339Nano) + "|" + log.EventID
return base64.RawURLEncoding.EncodeToString([]byte(payload))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,371 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)
// --- Async Test Mocks ---
type AsyncMockIdpProvider struct {
mock.Mock
}
func (m *AsyncMockIdpProvider) Name() string { return "mock-idp" }
func (m *AsyncMockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) {
return &domain.IDPMetadata{}, nil
}
func (m *AsyncMockIdpProvider) UserExists(loginID string) (bool, error) {
args := m.Called(loginID)
return args.Bool(0), args.Error(1)
}
func (m *AsyncMockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
args := m.Called(user, password)
return args.String(0), args.Error(1)
}
func (m *AsyncMockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return &domain.PasswordPolicy{MinLength: 12}, nil
}
func (m *AsyncMockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
func (m *AsyncMockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
return nil
}
type AsyncMockUserRepo struct {
mock.Mock
createCalled chan bool
}
func (m *AsyncMockUserRepo) Create(ctx context.Context, user *domain.User) error {
// Simulate DB latency
time.Sleep(50 * time.Millisecond)
args := m.Called(ctx, user)
if m.createCalled != nil {
m.createCalled <- true
}
return args.Error(0)
}
func (m *AsyncMockUserRepo) Update(ctx context.Context, user *domain.User) error {
args := m.Called(ctx, user)
if m.createCalled != nil {
m.createCalled <- true
}
return args.Error(0)
}
func (m *AsyncMockUserRepo) Delete(ctx context.Context, id string) error { return nil }
func (m *AsyncMockUserRepo) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByID(ctx context.Context, id string) (*domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) {
return nil, 0, "", nil
}
func (m *AsyncMockUserRepo) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
return 0, nil
}
func (m *AsyncMockUserRepo) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
args := m.Called(ctx, tenantIDs)
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *AsyncMockUserRepo) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
args := m.Called(ctx, codes)
return args.Get(0).([]domain.User), args.Error(1)
}
func (m *AsyncMockUserRepo) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) DB() *gorm.DB {
return nil
}
func (m *AsyncMockUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
return nil
}
func (m *AsyncMockUserRepo) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
return false, nil
}
func (m *AsyncMockUserRepo) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
return "", nil
}
type AsyncMockRedisRepo struct {
mock.Mock
}
func (m *AsyncMockRedisRepo) Set(key string, value string, expiration time.Duration) error {
args := m.Called(key, value, expiration)
return args.Error(0)
}
func (m *AsyncMockRedisRepo) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *AsyncMockRedisRepo) Delete(key string) error {
args := m.Called(key)
return args.Error(0)
}
func (m *AsyncMockRedisRepo) StoreVerificationCode(phone, code string) error { return nil }
func (m *AsyncMockRedisRepo) GetVerificationCode(phone string) (string, error) { return "", nil }
func (m *AsyncMockRedisRepo) DeleteVerificationCode(phone string) error { return nil }
type AsyncMockTenantService struct {
mock.Mock
}
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *AsyncMockTenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
args := m.Called(ctx, emailDomain)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *AsyncMockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
return nil, 0, nil
}
func (m *AsyncMockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
return false, nil
}
func (m *AsyncMockTenantService) ApproveTenant(ctx context.Context, id string) error { return nil }
func (m *AsyncMockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) SetKetoService(keto service.KetoService) {}
func (m *AsyncMockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
func (m *AsyncMockTenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
func (m *AsyncMockTenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
return nil, nil
}
func (m *AsyncMockTenantService) DeleteTenantsBulk(ctx context.Context, ids []string) error {
return nil
}
func (m *AsyncMockTenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
if args.Get(0) != nil {
return args.Get(0).([]domain.Tenant), args.Error(1)
}
return nil, args.Error(1)
}
type AsyncMockKetoService struct {
mock.Mock
}
func (m *AsyncMockKetoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
args := m.Called(ctx, namespace, object, relation, subject)
return args.Error(0)
}
func (m *AsyncMockKetoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
return nil
}
func (m *AsyncMockKetoService) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) {
return false, nil
}
func (m *AsyncMockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
return nil, nil
}
func (m *AsyncMockKetoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
return nil, nil
}
// --- Tests ---
func TestSignup_AsyncDB_Isolation(t *testing.T) {
mockIdp := new(AsyncMockIdpProvider)
mockUserRepo := new(AsyncMockUserRepo)
mockRedis := new(AsyncMockRedisRepo)
mockTenant := new(AsyncMockTenantService)
mockKeto := new(AsyncMockKetoService)
h := &AuthHandler{
IdpProvider: mockIdp,
UserRepo: mockUserRepo,
RedisService: mockRedis,
TenantService: mockTenant,
KetoService: mockKeto,
}
app := fiber.New()
app.Post("/signup", h.Signup)
t.Run("SoT_DB_Failure_Ignored_And_Async", func(t *testing.T) {
email := "test@example.com"
phone := "010-1234-5678"
emailKey := "signup:email:" + email
phoneKey := "signup:phone:" + "01012345678"
// Redis Mocks
mockRedis.On("Get", emailKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
mockRedis.On("Get", phoneKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
mockRedis.On("Delete", emailKey).Return(nil)
mockRedis.On("Delete", phoneKey).Return(nil)
// Tenant Mocks
personalTenant := &domain.Tenant{ID: "personal-t1", Slug: "personal-test", Type: domain.TenantTypePersonal, Status: domain.TenantStatusActive}
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(nil, nil)
mockTenant.On(
"RegisterTenant",
mock.Anything,
"Personal - test@example.com",
mock.MatchedBy(func(slug string) bool { return strings.HasPrefix(slug, "personal-") }),
domain.TenantTypePersonal,
"Automatically provisioned personal tenant",
[]string(nil),
(*string)(nil),
"",
).Return(personalTenant, nil)
mockTenant.On("GetTenant", mock.Anything, "personal-t1").Return(personalTenant, nil)
// Kratos Mocks (Success)
mockIdp.On("CreateUser", mock.Anything, "Password123!").Return("new-user-uuid", nil)
// UserRepo Mocks (Async & Failure)
mockUserRepo.createCalled = make(chan bool, 1)
mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
return u.Email == email
})).Return(errors.New("db connection error"))
// Keto Mocks (Optional, since it's also async)
// We won't block on this either
body, _ := json.Marshal(domain.SignupRequest{
Email: email,
Password: "Password123!",
Name: "Test User",
Phone: phone,
TermsAccepted: true,
})
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
start := time.Now()
resp, err := app.Test(req, 5000)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
assert.Equal(t, 200, resp.StatusCode)
// Ensure API responded faster than DB latency (50ms)
assert.Less(t, int64(elapsed), int64(60*time.Millisecond), "API should return before DB timeout")
// Wait for async execution
select {
case <-mockUserRepo.createCalled:
// Pass
case <-time.After(2 * time.Second):
t.Fatal("UserRepo.Create was not called asynchronously")
}
mockRedis.AssertExpectations(t)
mockIdp.AssertExpectations(t)
mockUserRepo.AssertExpectations(t)
})
}

View File

@@ -0,0 +1,200 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
func TestRevokeLinkedRp_Success(t *testing.T) {
// Mock Hydra transport for revocation
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
// 1. Kratos whoami
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{"id": "user-123"},
}), nil
}
// 2. Hydra Revoke
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
assert.Equal(t, "user-123", r.URL.Query().Get("subject"))
assert.Equal(t, "app-1", r.URL.Query().Get("client"))
return httpResponse(r, http.StatusNoContent, ""), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
auditRepo := &mockAuditRepo{}
rpUsageSink := &mockRPUsageEventSink{}
consentRepo := &mockConsentRepo{
consents: []domain.ClientConsent{
{
ClientID: "app-1",
Subject: "user-123",
GrantedScopes: []string{"openid", "profile"},
},
},
}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
AuditRepo: auditRepo,
ConsentRepo: consentRepo,
RPUsageSink: rpUsageSink,
}
app := fiber.New()
app.Delete("/api/v1/user/rp/linked/:id", h.RevokeLinkedRp)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/rp/linked/app-1", nil)
req.Header.Set("Cookie", "valid")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 1, len(auditRepo.logs))
assert.Equal(t, "consent.revoked", auditRepo.logs[0].EventType)
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
assert.Equal(t, "success", auditRepo.logs[0].Status)
auditDetails, err := utils.ParseAuditDetails(auditRepo.logs[0].Details)
assert.NoError(t, err)
assert.Equal(t, "app-1", auditDetails["client_id"])
assert.Equal(t, 1, len(rpUsageSink.events))
assert.Equal(t, domain.RPUsageEventTypeAuthorizationRevoked, rpUsageSink.events[0].EventType)
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
assert.Equal(t, "app-1", rpUsageSink.events[0].ClientID)
remaining, err := consentRepo.Find(req.Context(), "app-1", "user-123")
assert.NoError(t, err)
assert.Nil(t, remaining)
}
func TestRevokeLinkedRp_SendsBackchannelLogoutTokenWhenConfigured(t *testing.T) {
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
var receivedBody string
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{"id": "user-123"},
}), nil
}
if r.URL.Host == "hydra.test" && r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
return httpResponse(r, http.StatusNoContent, ""), nil
}
if r.URL.Host == "hydra.test" && r.Method == http.MethodGet && r.URL.Path == "/clients/app-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "app-1",
"backchannel_logout_uri": "https://rp.example.com/backchannel-logout",
}), nil
}
if r.URL.Host == "rp.example.com" && r.Method == http.MethodPost && r.URL.Path == "/backchannel-logout" {
raw, _ := io.ReadAll(r.Body)
receivedBody = string(raw)
return httpResponse(r, http.StatusNoContent, ""), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
backchannelLogout, err := service.NewBackchannelLogoutService()
assert.NoError(t, err)
backchannelLogout.HTTPClient = client
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
BackchannelLogout: backchannelLogout,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Delete("/api/v1/user/rp/linked/:id", h.RevokeLinkedRp)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/rp/linked/app-1", nil)
req.Header.Set("Cookie", "valid")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.True(t, strings.Contains(receivedBody, "logout_token="))
values, err := url.ParseQuery(receivedBody)
assert.NoError(t, err)
assert.NotEmpty(t, values.Get("logout_token"))
assert.Len(t, auditRepo.logs, 2)
assert.Equal(t, "backchannel_logout.sent", auditRepo.logs[1].EventType)
}
func TestListRpHistory_Aggregation(t *testing.T) {
now := time.Now()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "consent.revoked", // Newest
Timestamp: now,
Details: `{"client_id":"app-1"}`,
},
{
UserID: "user-123",
EventType: "consent.granted", // Oldest
Timestamp: now.Add(-1 * time.Hour),
Details: `{"client_id":"app-1", "client_name":"App One"}`,
},
},
}
h := &AuthHandler{
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/rp/history", h.ListRpHistory)
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{"id": "user-123"},
}), nil
})
http.DefaultClient = &http.Client{Transport: transport}
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/history", nil)
req.Header.Set("Cookie", "valid")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []struct {
ClientID string `json:"client_id"`
Status string `json:"status"`
} `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
assert.Equal(t, 1, len(res.Items))
assert.Equal(t, "app-1", res.Items[0].ClientID)
// Newest event (revoked) should win
assert.Equal(t, "revoked", res.Items[0].Status)
}

View File

@@ -0,0 +1,515 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Mocks ---
type MockKratosAdminServiceForConsent struct {
mock.Mock
}
func (m *MockKratosAdminServiceForConsent) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
args := m.Called(ctx, identifier)
return args.String(0), args.Error(1)
}
func (m *MockKratosAdminServiceForConsent) GetIdentity(ctx context.Context, id string) (*service.KratosIdentity, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*service.KratosIdentity), args.Error(1)
}
func (m *MockKratosAdminServiceForConsent) ListIdentities(ctx context.Context) ([]service.KratosIdentity, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) CreateIdentity(ctx context.Context, traits map[string]any) (*service.KratosIdentity, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) DeleteIdentity(ctx context.Context, identityID string) error {
return nil
}
func (m *MockKratosAdminServiceForConsent) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
return nil
}
func (m *MockKratosAdminServiceForConsent) ListIdentitySessions(ctx context.Context, identityID string) ([]service.KratosSession, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) GetSession(ctx context.Context, sessionID string) (*service.KratosSession, error) {
return nil, nil
}
func (m *MockKratosAdminServiceForConsent) DeleteSession(ctx context.Context, sessionID string) error {
return nil
}
func (m *MockKratosAdminServiceForConsent) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
return "", nil
}
type MockTenantServiceForConsent struct {
mock.Mock
}
func (m *MockTenantServiceForConsent) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForConsent) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) GetTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
return nil, 0, nil
}
func (m *MockTenantServiceForConsent) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) ApproveTenant(ctx context.Context, id string) error {
return nil
}
func (m *MockTenantServiceForConsent) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForConsent) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForConsent) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
return false, nil
}
func (m *MockTenantServiceForConsent) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForConsent) SetKetoService(keto service.KetoService) {}
func (m *MockTenantServiceForConsent) DeleteTenantsBulk(ctx context.Context, ids []string) error {
return nil
}
// --- Test Helpers ---
func newConsentTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
return app
}
// --- Tests ---
func TestGetConsentRequest_Normal(t *testing.T) {
// Mock Hydra transport
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-123",
"requested_scope": []string{"openid", "profile"},
"skip": false,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"client_name": "Test App",
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := newConsentTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-123", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]any
json.NewDecoder(resp.Body).Decode(&body)
assert.Equal(t, "challenge-123", body["challenge"])
assert.Equal(t, false, body["skip"])
}
func TestGetConsentRequest_AddsMandatoryTenantScope(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-scope" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-tenant-scope",
"requested_scope": []string{"openid", "profile"},
"skip": false,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"client_name": "Test App",
"metadata": map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-allow"},
"structured_scopes": []map[string]any{
{"name": "openid", "mandatory": true},
{"name": "tenant", "mandatory": true, "locked": true},
{"name": "profile", "mandatory": false},
},
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
mockTenantSvc := &MockTenantServiceForConsent{}
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
// Mock profile resolution to allow tenant access
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@example.com",
},
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, "tenant-allow").Return(&domain.Tenant{
ID: "tenant-allow",
Slug: "tenant-allow",
Name: "Allowed Tenant",
}, nil)
// Mock hydration calls
mockTenantSvc.On("ListJoinedTenants", mock.Anything, mock.Anything).Return([]domain.Tenant{
{ID: "tenant-allow", Slug: "tenant-allow", Name: "Allowed Tenant"},
}, nil)
mockTenantSvc.On("ListManageableTenants", mock.Anything, mock.Anything).Return([]domain.Tenant{}, nil)
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
TenantService: mockTenantSvc,
KratosAdmin: mockKratosAdmin,
}
app := newConsentTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant-scope", nil)
req.Header.Set("X-Mock-Role", "user")
req.Header.Set("X-Tenant-ID", "tenant-allow")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]any
json.NewDecoder(resp.Body).Decode(&body)
assert.Equal(t, []any{"openid", "tenant", "profile"}, body["requested_scope"])
scopeDetails := body["scope_details"].(map[string]any)
tenantDetail := scopeDetails["tenant"].(map[string]any)
assert.Equal(t, true, tenantDetail["mandatory"])
}
func TestGetConsentRequest_Skip_AutoAccept(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
// Hydra: Get Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-skip",
"requested_scope": []string{"openid"},
"skip": true,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
},
}), nil
}
// Kratos: Get Identity
if r.URL.Path == "/admin/identities/user-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
},
}), nil
}
// Hydra: Accept Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-skip" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
consentRepo := &mockConsentRepo{}
rpUsageSink := &mockRPUsageEventSink{}
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: mockKratosAdmin,
ConsentRepo: consentRepo,
RPUsageSink: rpUsageSink,
}
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
},
}, nil)
app := newConsentTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]any
json.NewDecoder(resp.Body).Decode(&body)
assert.Equal(t, "http://rp/cb", body["redirectTo"])
assert.Equal(t, 1, len(rpUsageSink.events))
assert.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, rpUsageSink.events[0].EventType)
assert.Equal(t, "client-app", rpUsageSink.events[0].ClientID)
assert.Equal(t, "challenge-skip", rpUsageSink.events[0].CorrelationID)
assert.Equal(t, true, rpUsageSink.events[0].Payload["auto_accepted"])
}
func TestAcceptConsentRequest_Normal(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-accept" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-accept",
"requested_scope": []string{"openid", "profile"},
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"client_name": "Test App",
},
}), nil
}
if r.URL.Path == "/admin/identities/user-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-accept" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
auditRepo := &mockAuditRepo{}
consentRepo := &mockConsentRepo{}
rpUsageSink := &mockRPUsageEventSink{}
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: mockKratosAdmin,
AuditRepo: auditRepo,
ConsentRepo: consentRepo,
RPUsageSink: rpUsageSink,
}
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
},
}, nil)
app := newConsentTestApp(h)
body, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-accept",
"grant_scope": []string{"openid"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 1, len(auditRepo.logs))
assert.Equal(t, "consent.granted", auditRepo.logs[0].EventType)
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
assert.Equal(t, "success", auditRepo.logs[0].Status)
auditDetails, err := utils.ParseAuditDetails(auditRepo.logs[0].Details)
assert.NoError(t, err)
assert.Equal(t, "client-app", auditDetails["client_id"])
assert.Equal(t, "Test App", auditDetails["client_name"])
assert.Equal(t, []any{"openid"}, auditDetails["scopes"])
assert.Equal(t, 1, len(rpUsageSink.events))
assert.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, rpUsageSink.events[0].EventType)
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
assert.Equal(t, "client-app", rpUsageSink.events[0].ClientID)
assert.Equal(t, "Test App", rpUsageSink.events[0].ClientName)
assert.Equal(t, []string{"openid"}, []string(rpUsageSink.events[0].Scopes))
assert.Equal(t, "hydra_consent", rpUsageSink.events[0].Source)
}
func TestAcceptConsentRequest_EnforcesMandatoryTenantScope(t *testing.T) {
t.Setenv("APP_ENV", "dev")
var capturedGrantScopes []string
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-accept" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-tenant-accept",
"requested_scope": []string{"openid", "profile"},
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"metadata": map[string]any{
"tenant_id": "tenant-abc",
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-abc"},
"structured_scopes": []map[string]any{
{"name": "openid", "mandatory": true},
{"name": "tenant", "mandatory": true, "locked": true},
{"name": "profile", "mandatory": false},
},
},
},
}), nil
}
if r.URL.Path == "/admin/identities/user-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-accept" {
var payload map[string]any
assert.NoError(t, json.NewDecoder(r.Body).Decode(&payload))
for _, scope := range payload["grant_scope"].([]any) {
capturedGrantScopes = append(capturedGrantScopes, scope.(string))
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: mockKratosAdmin,
}
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
},
}, nil)
app := newConsentTestApp(h)
body, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-tenant-accept",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Mock-Role", "user")
req.Header.Set("X-Tenant-ID", "tenant-abc")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, []string{"openid", "tenant", "profile"}, capturedGrantScopes)
}

View File

@@ -0,0 +1,829 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestBuildOidcClaimsFromTraits_DynamicClaims(t *testing.T) {
traits := map[string]any{
"email": "user@baron.com",
"name": "홍길동",
"tenant_id": "primary-tenant-999", // Added primary tenant
"tenant-1": map[string]any{
"department": "개발팀",
"grade": "선임",
},
"tenant-2": map[string]any{
"department": "재무팀",
"grade": "팀장",
},
}
scopes := []string{"openid", "profile"}
t.Run("No tenantID", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "primary-tenant-999", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999") // Should contain primary
})
t.Run("With tenant-1", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-1")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "tenant-1", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
t.Run("With tenant-2", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-2")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "tenant-2", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
t.Run("With non-existent tenant", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-3")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "tenant-3", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
t.Run("Tenant scope includes detailed tenant metadata", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, []string{"openid", "profile", "tenant"}, "tenant-1")
assert.Equal(t, "tenant-1", claims["tenant_id"])
assert.Equal(t, "개발팀", claims["department"])
assert.Equal(t, "선임", claims["grade"])
assert.NotNil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
}
func TestRepresentativeTenantIDFromTraits(t *testing.T) {
t.Run("explicit tenant_id wins", func(t *testing.T) {
traits := map[string]any{
"tenant_id": "01970f0a-5c28-74d8-a73a-f6e9e9a7b210",
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", "isPrimary": true},
},
}
assert.Equal(t, "01970f0a-5c28-74d8-a73a-f6e9e9a7b210", representativeTenantIDFromTraits(traits))
})
t.Run("primary appointment wins when tenant_id is absent", func(t *testing.T) {
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630", "representative": true},
},
}
assert.Equal(t, "01970f0c-8c44-7069-9f20-7d28c0b8e630", representativeTenantIDFromTraits(traits))
})
t.Run("first appointment is fallback", func(t *testing.T) {
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630"},
},
}
assert.Equal(t, "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", representativeTenantIDFromTraits(traits))
})
}
func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
var capturedClaims map[string]any
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
// Hydra: Get Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-dynamic",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"metadata": map[string]any{
"tenant_id": "tenant-abc",
},
},
}), nil
}
// Kratos: Get Identity
if r.URL.Path == "/admin/identities/user-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
"name": "Test User",
"tenant-abc": map[string]any{
"department": "Innovation",
"position": "Architect",
},
},
}), nil
}
// Hydra: Accept Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
// Capture the claims sent to Hydra
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
"tenant-abc": map[string]any{
"department": "Innovation",
"position": "Architect",
},
},
}, nil)
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-dynamic",
"grant_scope": []string{"openid", "profile", "tenant"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify captured claims
assert.NotNil(t, capturedClaims)
assert.Equal(t, "user@test.com", capturedClaims["email"])
assert.Equal(t, "tenant-abc", capturedClaims["tenant_id"])
assert.Equal(t, "Innovation", capturedClaims["department"])
assert.Equal(t, "Architect", capturedClaims["position"])
}
func TestAcceptConsentRequest_UsesRepresentativeTenantIDInsteadOfClientTenantContext(t *testing.T) {
var capturedClaims map[string]any
representativeTenantID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
rpContextTenantID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-representative-tenant",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-representative",
"client": map[string]any{
"client_id": "client-app",
"metadata": map[string]any{
"tenant_id": rpContextTenantID,
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-representative").Return(&service.KratosIdentity{
ID: "user-representative",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
"additionalAppointments": []any{
map[string]any{"tenantId": representativeTenantID, "isPrimary": true},
map[string]any{"tenantId": rpContextTenantID},
},
},
}, nil)
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-representative-tenant",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
assert.Equal(t, representativeTenantID, capturedClaims["tenant_id"])
assert.Contains(t, capturedClaims["joined_tenants"], representativeTenantID)
assert.Contains(t, capturedClaims["joined_tenants"], rpContextTenantID)
assert.Nil(t, capturedClaims["tenants"])
}
func TestAcceptConsentRequest_IncludesHanmacFamilyTenantClaimDetails(t *testing.T) {
var capturedClaims map[string]any
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-hanmac-tenant-claim",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-hanmac",
"client": map[string]any{
"client_id": "hanmac-rp",
"metadata": map[string]any{
"tenant_id": deptID,
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-hanmac").Return(&service.KratosIdentity{
ID: "user-hanmac",
Traits: map[string]any{
"email": "hanmac-user@example.com",
"name": "한맥 사용자",
"additionalAppointments": []any{
map[string]any{
"tenantId": deptID,
"isPrimary": true,
"isOwner": true,
"grade": "책임",
"jobTitle": "기술기획",
"position": "팀장",
},
map[string]any{
"tenantId": secondDeptID,
"isPrimary": false,
"isOwner": false,
"grade": "선임",
"jobTitle": "품질관리",
"position": "파트원",
},
},
},
}, nil)
mockTenantSvc := new(MockTenantService)
mockTenantSvc.On("ListJoinedTenants", mock.Anything, "user-hanmac").Return([]domain.Tenant{}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
ID: deptID,
Slug: "tech-planning",
Name: "기술기획팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
ID: secondDeptID,
Slug: "quality",
Name: "품질관리팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
Name: "한맥기술",
Type: domain.TenantTypeCompany,
ParentID: &rootID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
ID: rootID,
Slug: "hanmac-family",
Name: "한맥가족",
Type: domain.TenantTypeCompanyGroup,
}, nil)
h.TenantService = mockTenantSvc
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-hanmac-tenant-claim",
"grant_scope": []string{"openid", "profile", "tenant"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
assert.Equal(t, []any{deptID}, capturedClaims["lead_tenants"])
assert.ElementsMatch(t, []any{deptID, secondDeptID}, capturedClaims["joined_tenants"])
tenants := capturedClaims["tenants"].(map[string]any)
dept := tenants[deptID].(map[string]any)
assert.Equal(t, true, dept["lead"])
assert.Equal(t, true, dept["representative"])
assert.Equal(t, "책임", dept["grade"])
assert.Equal(t, "기술기획", dept["jobTitle"])
assert.Equal(t, "팀장", dept["position"])
assert.Equal(t, companyID, dept["parentTenantId"])
assert.NotContains(t, dept, "parentTenant")
ancestors := dept["ancestors"].([]any)
assert.Len(t, ancestors, 2)
companyAncestor := ancestors[0].(map[string]any)
assert.Equal(t, companyID, companyAncestor["id"])
assert.Equal(t, "hanmac", companyAncestor["slug"])
assert.Equal(t, rootID, companyAncestor["parentTenantId"])
assert.NotContains(t, companyAncestor, "parentTenant")
rootAncestor := ancestors[1].(map[string]any)
assert.Equal(t, rootID, rootAncestor["id"])
assert.Equal(t, "hanmac-family", rootAncestor["slug"])
assert.Contains(t, rootAncestor, "parentTenantId")
assert.Nil(t, rootAncestor["parentTenantId"])
assert.NotContains(t, rootAncestor, "parentTenant")
secondDept := tenants[secondDeptID].(map[string]any)
assert.Equal(t, false, secondDept["lead"])
assert.Equal(t, false, secondDept["representative"])
assert.Equal(t, "선임", secondDept["grade"])
assert.Equal(t, "품질관리", secondDept["jobTitle"])
assert.Equal(t, "파트원", secondDept["position"])
assert.Equal(t, companyID, secondDept["parentTenantId"])
}
func TestWithHanmacFamilyTenantClaims_DefaultClaimsOnlyWithoutTenantScope(t *testing.T) {
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
mockTenantSvc := new(MockTenantService)
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
ID: deptID,
Slug: "tech-planning",
Name: "기술기획팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
ID: secondDeptID,
Slug: "quality",
Name: "품질관리팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
Name: "한맥기술",
Type: domain.TenantTypeCompany,
ParentID: &rootID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
ID: rootID,
Slug: "hanmac-family",
Name: "한맥가족",
Type: domain.TenantTypeCompanyGroup,
}, nil)
h := &AuthHandler{TenantService: mockTenantSvc}
claims := map[string]any{"tenant_id": deptID}
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{
"tenantId": deptID,
"isPrimary": true,
"isOwner": true,
"grade": "책임",
},
map[string]any{
"tenantId": secondDeptID,
"grade": "선임",
},
},
}
claims = h.withHanmacFamilyTenantClaims(context.Background(), claims, traits, []string{"openid", "profile"})
assert.Equal(t, deptID, claims["tenant_id"])
assert.ElementsMatch(t, []string{deptID, secondDeptID}, claims["joined_tenants"])
assert.NotContains(t, claims, "tenants")
assert.NotContains(t, claims, "lead_tenants")
}
func TestAcceptConsentRequest_IncludesRPProfileClaims(t *testing.T) {
var capturedClaims map[string]any
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-rp-profile",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-123",
"client": map[string]any{
"client_id": "client-app",
"metadata": map[string]any{
"customUserSchema": []map[string]any{
{
"key": "approvalLevel",
"label": "승인 등급",
"type": "text",
"claimEnabled": true,
},
{
"key": "internalMemo",
"label": "내부 메모",
"type": "text",
"claimEnabled": false,
},
},
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
"name": "Test User",
},
}, nil)
repo := new(devMockRPUserMetadataRepo)
repo.On("Get", mock.Anything, "client-app", "user-123").Return(&domain.RPUserMetadata{
ClientID: "client-app",
UserID: "user-123",
Metadata: domain.JSONMap{
"approvalLevel": "A",
"internalMemo": "관리자 전용",
},
}, nil).Once()
h.RPUserMetadataRepo = repo
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-rp-profile",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
rpProfiles, ok := capturedClaims["rp_profiles"].([]any)
assert.True(t, ok)
assert.Len(t, rpProfiles, 1)
profile := rpProfiles[0].(map[string]any)
assert.Equal(t, "client-app", profile["client_id"])
fields := profile["fields"].(map[string]any)
assert.Equal(t, "A", fields["approvalLevel"])
assert.NotContains(t, fields, "internalMemo")
repo.AssertExpectations(t)
}
func TestGetConsentRequest_Skip_DynamicClaims(t *testing.T) {
var capturedClaims map[string]any
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
// Hydra: Get Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-skip-dynamic",
"requested_scope": []string{"openid", "profile", "tenant"},
"skip": true,
"subject": "user-456",
"client": map[string]any{
"client_id": "skip-app",
"metadata": map[string]any{
"tenant_id": "tenant-xyz",
},
},
}), nil
}
// Kratos: Get Identity
if r.URL.Path == "/admin/identities/user-456" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "user-456",
"traits": map[string]any{
"email": "skip@test.com",
"tenant-xyz": map[string]any{
"department": "Security",
"position": "Officer",
},
},
}), nil
}
// Hydra: Accept Consent Request
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
// Capture the claims sent to Hydra
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-456").Return(&service.KratosIdentity{
ID: "user-456",
Traits: map[string]any{
"email": "skip@test.com",
"tenant-xyz": map[string]any{
"department": "Security",
"position": "Officer",
},
},
}, nil)
app := fiber.New()
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip-dynamic", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify captured claims
assert.NotNil(t, capturedClaims)
assert.Equal(t, "skip@test.com", capturedClaims["email"])
assert.Equal(t, "tenant-xyz", capturedClaims["tenant_id"])
assert.Equal(t, "Security", capturedClaims["department"])
assert.Equal(t, "Officer", capturedClaims["position"])
}
func TestBuildOidcClaimsFromTraits_IncludesGlobalCustomClaims(t *testing.T) {
claims := buildOidcClaimsFromTraits(map[string]any{
"email": "user@test.com",
"name": "Test User",
"global_custom_claims": map[string]any{
"contract_date": "2026-06-09",
"approved_at": "2026-06-09T09:30:00+09:00",
"email": "override@test.com",
"rp_claims": "reserved",
},
"global_custom_claim_permissions": map[string]any{
"contract_date": map[string]any{
"readPermission": "user_and_admin",
"writePermission": "admin_only",
},
},
}, []string{"openid", "profile", "email"}, "")
assert.Equal(t, "2026-06-09", claims["contract_date"])
assert.Equal(t, "2026-06-09T09:30:00+09:00", claims["approved_at"])
assert.Equal(t, "user@test.com", claims["email"])
assert.NotEqual(t, "reserved", claims["rp_claims"])
assert.NotContains(t, claims, "global_custom_claim_permissions")
}
func TestAcceptConsentRequest_AppliesConfiguredIDTokenClaims(t *testing.T) {
var capturedClaims map[string]any
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-configured-claims" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-configured-claims",
"requested_scope": []string{"openid", "profile"},
"subject": "user-789",
"client": map[string]any{
"client_id": "client-configured-claims",
"metadata": map[string]any{
"tenant_id": "tenant-claims",
"id_token_claims": []map[string]any{
{
"namespace": "top_level",
"key": "locale",
"value": "ko-KR",
"valueType": "text",
},
{
"namespace": "top_level",
"key": "email",
"value": "should-not-override@example.com",
"valueType": "text",
},
{
"namespace": "rp_claims",
"key": "tier",
"value": "2",
"valueType": "number",
},
{
"namespace": "rp_claims",
"key": "features",
"value": "[\"sso\",\"claims\"]",
"valueType": "array",
},
},
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-configured-claims" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]any
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]any); ok {
capturedClaims = session["id_token"].(map[string]any)
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-789").Return(&service.KratosIdentity{
ID: "user-789",
Traits: map[string]any{
"email": "real-user@example.com",
"name": "Configured User",
"tenant-claims": map[string]any{
"department": "Platform",
},
},
}, nil)
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]any{
"consent_challenge": "challenge-configured-claims",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
assert.Equal(t, "real-user@example.com", capturedClaims["email"])
assert.Equal(t, "ko-KR", capturedClaims["locale"])
assert.Equal(t, "tenant-claims", capturedClaims["tenant_id"])
rpClaims, ok := capturedClaims["rp_claims"].(map[string]any)
if assert.True(t, ok) {
assert.Equal(t, float64(2), rpClaims["tier"])
assert.Equal(t, []any{"sso", "claims"}, rpClaims["features"])
}
}

View File

@@ -0,0 +1,904 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/testsupport"
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// Mock services
type mockEmailService struct {
lastTo string
lastSubject string
lastBody string
}
func (m *mockEmailService) SendEmail(to, subject, body string) error {
m.lastTo = to
m.lastSubject = subject
m.lastBody = body
return nil
}
type mockSmsService struct {
lastTo string
lastContent string
}
func (m *mockSmsService) SendSms(to, content string) error {
m.lastTo = to
m.lastContent = content
return nil
}
func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit)
app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll)
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
return app
}
func newKratosWhoamiTestServer(t *testing.T, identityID string) *httptest.Server {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/sessions/whoami" {
http.NotFound(w, r)
return
}
if r.Header.Get("Cookie") == "" && r.Header.Get("X-Session-Token") == "" {
http.Error(w, "missing session", http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "session-123",
"authenticated_at": "2026-05-21T00:00:00Z",
"identity": map[string]any{
"id": identityID,
"traits": map[string]any{
"email": "user@example.com",
},
},
})
}))
origDefaultClient := http.DefaultClient
http.DefaultClient = server.Client()
t.Cleanup(func() {
http.DefaultClient = origDefaultClient
})
t.Cleanup(server.Close)
return server
}
func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
EmailService: &mockEmailService{},
SmsService: &mockSmsService{},
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/init", h.InitEnchantedLink)
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
t.Setenv("USERFRONT_URL", "http://userfront.test")
// 1. Init Enchanted Link (Email)
body, _ := json.Marshal(map[string]string{
"loginId": "user@example.com",
"method": "email",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
// Find the token key "enchanted_token:..." in mock redis
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
// 2. Verify Magic Link
verifyBody, _ := json.Marshal(map[string]any{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// 3. Poll (Success)
pollBody, _ := json.Marshal(map[string]string{"pendingRef": pendingRef})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var pollResp map[string]any
json.NewDecoder(resp.Body).Decode(&pollResp)
assert.Equal(t, "ok", pollResp["status"])
assert.Equal(t, "valid-jwt", pollResp["sessionJwt"])
}
func TestEnchantedLinkFlow_Sms_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/init", h.InitEnchantedLink)
// 1. Init Enchanted Link (SMS)
body, _ := json.Marshal(map[string]string{
"loginId": "010-1234-5678",
"method": "sms",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
json.NewDecoder(resp.Body).Decode(&initResp)
assert.NotEmpty(t, initResp["userCode"])
}
func TestVerifyMagicLink_VerifyOnlyWithoutSharedBrowserSessionApprovesOnly(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
}}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
body, _ := json.Marshal(map[string]any{
"token": "token-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Nil(t, got["sessionJwt"])
assert.Nil(t, got["token"])
}
func TestVerifyMagicLink_VerifyOnlySharedBrowserSameSubjectApprovesOnly(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-user-1")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
body, _ := json.Marshal(map[string]any{
"token": "token-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Nil(t, got["sessionJwt"])
assert.Nil(t, got["token"])
}
func TestVerifyMagicLink_VerifyOnlySharedBrowserDifferentSubjectApprovesOnly(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
body, _ := json.Marshal(map[string]any{
"token": "token-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Nil(t, got["sessionJwt"])
assert.Nil(t, got["token"])
assert.Contains(t, redis.data[prefixSession+"pending-123"], "approved")
}
func TestResolveUserfrontURL_DevLocalhostUsesConfiguredPort(t *testing.T) {
t.Setenv("APP_ENV", "dev")
t.Setenv("USERFRONT_URL", "http://localhost:5000")
h := &AuthHandler{}
app := fiber.New()
app.Get("/probe", func(c *fiber.Ctx) error {
return c.SendString(h.resolveUserfrontURL(c))
})
req := httptest.NewRequest(http.MethodGet, "http://localhost/probe", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "http://localhost:5000", string(body))
}
func TestVerifyLoginCode_VerifyOnlySharedBrowserDifferentSubjectApprovesOnly(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixLoginCode + "user@example.com": "flow-123",
prefixLoginCodePending + "user@example.com": "pending-123",
prefixLoginCodeValue + "pending-123": "569765",
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
body, _ := json.Marshal(map[string]any{
"loginId": "user@example.com",
"code": "569765",
"pendingRef": "pending-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Nil(t, got["sessionJwt"])
assert.Nil(t, got["token"])
assert.Contains(t, redis.data[prefixSession+"pending-123"], "approved")
}
func TestVerifyLoginCode_MapsSmsPhoneBeforeFlowLookup(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixLoginCode + "su-@samaneng.com": "flow-123",
prefixLoginCodePending + "su-@samaneng.com": "pending-123",
prefixLoginCodeSmsLookup + "+821041585840": "su-@samaneng.com",
prefixLoginCodeSmsTarget + "su-@samaneng.com": "+821041585840",
prefixLoginCodeValue + "pending-123": "569765",
}}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := fiber.New()
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
body, _ := json.Marshal(map[string]any{
"loginId": "01041585840",
"code": "569765",
"pendingRef": "pending-123",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "approved", got["status"])
assert.Equal(t, "pending-123", got["pendingRef"])
}
func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
body, _ := json.Marshal(map[string]string{
"pendingRef": "missing-ref",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "expired_token", got["error"])
assert.Equal(t, "expired_token", got["code"])
}
func TestPollEnchantedLink_SharedBrowserSameSubjectIssuesSession(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixSession + "pending-123": `{"status":"approved","loginId":"user@example.com"}`,
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-user-1")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{
issueSession: &domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "new-session-id"},
Subject: "kratos-user-1",
},
},
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
body, _ := json.Marshal(map[string]string{"pendingRef": "pending-123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "ok", got["status"])
assert.Equal(t, "valid-jwt", got["sessionJwt"])
}
func TestPollEnchantedLink_SharedBrowserDifferentSubjectConflicts(t *testing.T) {
redis := &mockRedisRepo{data: map[string]string{
prefixSession + "pending-123": `{"status":"approved","loginId":"user@example.com"}`,
}}
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{
issueSession: &domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "new-session-id"},
Subject: "kratos-user-1",
},
},
}
app := fiber.New()
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
body, _ := json.Marshal(map[string]string{"pendingRef": "pending-123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusConflict, resp.StatusCode)
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "session_subject_conflict", got["code"])
assert.NotContains(t, redis.data[prefixSession+"pending-123"], "valid-jwt")
}
func TestHeadlessLinkInit_HeadlessLoginClientSuccess(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
if !testsupport.PortBindingAvailable() {
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
}
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jwksBody)
}))
defer jwksServer.Close()
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
},
},
})
return
}
http.NotFound(w, r)
})
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
body, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.NotEmpty(t, got["pendingRef"])
_, hasUserCode := got["userCode"]
assert.False(t, hasUserCode)
}
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
if !testsupport.PortBindingAvailable() {
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
}
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
idp := &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
}
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
ClientName: "local-demo-rp",
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil)
auditRepo := &mockAuditRepo{}
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
SmsService: &mockSmsService{},
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
initBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
_ = json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
verifyBody, _ := json.Marshal(map[string]any{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
pollBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
"pendingRef": pendingRef,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var pollResp map[string]any
_ = json.NewDecoder(resp.Body).Decode(&pollResp)
assert.Equal(t, "http://rp/cb", pollResp["redirectTo"])
assert.Equal(t, "ok", pollResp["status"])
assert.Nil(t, pollResp["sessionJwt"])
assert.Nil(t, pollResp["token"])
assert.Empty(t, resp.Cookies())
if assert.Len(t, auditRepo.logs, 1) {
assert.Contains(t, auditRepo.logs[0].EventType, "/api/v1/auth/")
details, err := parseAuditDetails(auditRepo.logs[0].Details)
if err != nil {
t.Fatalf("failed to parse audit details: %v", err)
}
assert.Equal(t, "headless-login-client", details["client_id"])
assert.Equal(t, "local-demo-rp", details["client_name"])
assert.Equal(t, "challenge-123", details["login_challenge"])
}
}
func TestHeadlessLinkPoll_ApproverSubjectConflictBlocksMixedRP(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
if !testsupport.PortBindingAvailable() {
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
}
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
acceptCalled := false
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
ClientName: "local-demo-rp",
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
acceptCalled = true
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-target-b", nil)
auditRepo := &mockAuditRepo{}
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
},
SmsService: &mockSmsService{},
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
initBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
_ = json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
kratosPublic := newKratosWhoamiTestServer(t, "kratos-userfront-a")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
verifyBody, _ := json.Marshal(map[string]any{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=userfront-a-session")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
pollBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
"pendingRef": pendingRef,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusConflict, resp.StatusCode)
assert.False(t, acceptCalled)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "oidc_subject_conflict", got["code"])
assert.Equal(t, "redirect_to_userfront_login", got["recommendedAction"])
assert.Equal(t, "kratos-userfront-a", got["currentSubject"])
assert.Equal(t, "kratos-target-b", got["targetSubject"])
assert.Empty(t, auditRepo.logs)
}
func TestHeadlessLinkPoll_RequestCookieSubjectConflictBlocksMixedRP(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
if !testsupport.PortBindingAvailable() {
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
}
redis := &mockRedisRepo{data: make(map[string]string)}
privateKey, jwks := mustHeadlessRSAJWK(t)
jwksBody, _ := json.Marshal(jwks)
acceptCalled := false
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
Challenge: "challenge-123",
Client: domain.HydraClient{
ClientID: "headless-login-client",
TokenEndpointAuthMethod: "none",
Metadata: map[string]any{
"status": "active",
"headless_login_enabled": true,
"headless_token_endpoint_auth_method": "private_key_jwt",
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
},
},
})
return
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
acceptCalled = true
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
return
}
http.NotFound(w, r)
})
mockKratos := new(MockKratosAdminService)
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-target-b", nil)
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{
userExists: true,
initiateLinkErr: domain.ErrNotSupported,
},
SmsService: &mockSmsService{},
KratosAdmin: mockKratos,
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
},
}
app := newHeadlessLinkTestApp(h)
t.Setenv("USERFRONT_URL", "http://userfront.test")
initBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
"loginId": "010-1234-5678",
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
_ = json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
assert.NotEmpty(t, pendingRef)
var token string
for k := range redis.data {
if len(k) > 16 && k[:16] == "enchanted_token:" {
token = k[16:]
break
}
}
assert.NotEmpty(t, token)
verifyBody, _ := json.Marshal(map[string]any{
"token": token,
"verifyOnly": true,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
kratosPublic := newKratosWhoamiTestServer(t, "kratos-userfront-a")
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
pollBody, _ := json.Marshal(map[string]string{
"client_id": "headless-login-client",
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
"pendingRef": pendingRef,
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=userfront-a-session")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusConflict, resp.StatusCode)
assert.False(t, acceptCalled)
assert.Empty(t, resp.Cookies())
var got map[string]any
_ = json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, "oidc_subject_conflict", got["code"])
assert.Equal(t, "kratos-userfront-a", got["currentSubject"])
assert.Equal(t, "kratos-target-b", got["targetSubject"])
}

View File

@@ -0,0 +1,287 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
// --- Helper ---
func newLinkedRpTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Get("/api/v1/user/rp/linked", h.ListLinkedRps)
return app
}
// --- Tests ---
func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
if r.Header.Get("X-Session-Token") == "" && r.Header.Get("Cookie") == "" {
return httpResponse(r, http.StatusUnauthorized, "unauthorized"), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
},
},
}), nil
}
case "hydra.test":
if r.URL.Path == "/oauth2/auth/sessions/consent" {
return httpJSONAny(r, http.StatusOK, []map[string]any{
{
"client": map[string]any{
"client_id": "devfront",
"client_name": "DevFront",
"redirect_uris": []string{
"https://active.example.com/callback",
},
},
"grant_scope": []string{"openid", "profile"},
"handled_at": time.Now().Format(time.RFC3339),
},
{
"client": map[string]any{
"client_id": "orgfront",
"client_name": "OrgFront",
"metadata": map[string]any{
"auto_login_supported": true,
"auto_login_url": "http://localhost:5175/login",
},
"redirect_uris": []string{
"http://localhost:5175/auth/callback",
},
},
"grant_scope": []string{"openid", "profile"},
"handled_at": time.Now().Format(time.RFC3339),
},
}), nil
}
if r.URL.Path == "/admin/clients/client-audit" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-audit",
"client_name": "Audit App",
}), nil
}
if r.URL.Path == "/admin/clients/client-consent" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-consent",
"client_name": "Consent App",
}), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() {
http.DefaultClient = origDefault
}()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "consent.granted",
Timestamp: time.Now().Add(-10 * time.Hour),
Details: `{"client_id":"client-audit", "scopes":["audit_scope"]}`,
},
},
}
consentRepo := &mockConsentRepo{
consents: []domain.ClientConsent{
{
Subject: "user-123",
ClientID: "client-consent",
GrantedScopes: []string{"consent_scope"},
UpdatedAt: time.Now().Add(-2 * time.Hour),
},
},
}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
AuditRepo: auditRepo,
ConsentRepo: consentRepo,
KratosAdmin: new(MockKratosAdminService),
}
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
t.Setenv("HYDRA_PUBLIC_URL", "https://sso.example.com/oidc")
t.Setenv("DEVFRONT_URL", "http://localhost:5174")
app := newLinkedRpTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/linked", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []struct {
ID string `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
Scopes []string `json:"scopes"`
InitURL string `json:"init_url"`
AutoLoginSupported bool `json:"auto_login_supported"`
AutoLoginURL string `json:"auto_login_url"`
} `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
assert.Equal(t, 4, len(res.Items))
statusMap := make(map[string]string)
for _, item := range res.Items {
statusMap[item.ID] = item.Status
}
assert.Equal(t, "active", statusMap["devfront"])
assert.Equal(t, "active", statusMap["orgfront"])
assert.Equal(t, "inactive", statusMap["client-consent"])
assert.Equal(t, "inactive", statusMap["client-audit"])
var activeInitURL string
for _, item := range res.Items {
if item.ID == "devfront" {
activeInitURL = item.InitURL
break
}
}
parsedInitURL, err := url.Parse(activeInitURL)
assert.NoError(t, err)
assert.Equal(t, "http", parsedInitURL.Scheme)
assert.Equal(t, "localhost:5174", parsedInitURL.Host)
assert.Equal(t, "/login", parsedInitURL.Path)
assert.Equal(t, "1", parsedInitURL.Query().Get("auto"))
assert.Equal(t, "/clients", parsedInitURL.Query().Get("returnTo"))
var orgfrontItem struct {
InitURL string
AutoLoginSupported bool
AutoLoginURL string
}
for _, item := range res.Items {
if item.ID == "orgfront" {
orgfrontItem.InitURL = item.InitURL
orgfrontItem.AutoLoginSupported = item.AutoLoginSupported
orgfrontItem.AutoLoginURL = item.AutoLoginURL
break
}
}
assert.True(t, orgfrontItem.AutoLoginSupported)
assert.Equal(t, "http://localhost:5175/login?auto=1", orgfrontItem.AutoLoginURL)
assert.Equal(t, orgfrontItem.AutoLoginURL, orgfrontItem.InitURL)
}
func TestListLinkedRps_EnrichesLogoFromHydraClientWhenConsentSessionOmitsMetadata(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-123",
},
}), nil
}
case "hydra.test":
if r.URL.Path == "/oauth2/auth/sessions/consent" {
return httpJSONAny(r, http.StatusOK, []map[string]any{
{
"client": map[string]any{
"client_id": "gitea-client",
"client_name": "Gitea",
"redirect_uris": []string{
"https://gitea.example.com/callback",
},
},
"grant_scope": []string{"openid", "profile"},
"handled_at": time.Now().Format(time.RFC3339),
},
}), nil
}
if r.URL.Path == "/clients/gitea-client" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "gitea-client",
"client_name": "Gitea",
"redirect_uris": []string{
"https://gitea.example.com/callback",
},
"metadata": map[string]any{
"logo_url": "https://cdn.example.com/gitea.svg",
},
}), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() {
http.DefaultClient = origDefault
}()
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
t.Setenv("HYDRA_PUBLIC_URL", "https://sso.example.com/oidc")
app := newLinkedRpTestApp(h)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/linked", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []struct {
ID string `json:"id"`
Logo string `json:"logo"`
} `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
assert.Len(t, res.Items, 1)
assert.Equal(t, "gitea-client", res.Items[0].ID)
assert.Equal(t, "https://cdn.example.com/gitea.svg", res.Items[0].Logo)
}

View File

@@ -0,0 +1,206 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
)
func newVerifyLoginCodeTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
app.Post("/api/v1/auth/login/code/verify-short", h.VerifyLoginShortCode)
return app
}
func decodeJSONBody(t *testing.T, resp *http.Response) map[string]any {
t.Helper()
var got map[string]any
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response body: %v", err)
}
return got
}
func TestVerifyLoginCode_InvalidBody_ReturnsExplicitCode(t *testing.T) {
h := &AuthHandler{}
app := newVerifyLoginCodeTestApp(h)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewBufferString("{"))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "bad_request" {
t.Fatalf("expected code=bad_request, got %v", got["code"])
}
}
func TestVerifyLoginCode_IdpUnavailable_ReturnsExplicitCode(t *testing.T) {
h := &AuthHandler{}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"loginId": "user@example.com",
"code": "AA-111111",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("expected 503, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "service_unavailable" {
t.Fatalf("expected code=service_unavailable, got %v", got["code"])
}
}
func TestVerifyLoginCode_VerifyOnlyInvalidCode_ReturnsExplicitCode(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
redis.data[prefixLoginCode+"user@example.com"] = "flow-1"
redis.data[prefixLoginCodePending+"user@example.com"] = "pending-1"
redis.data[prefixLoginCodeValue+"pending-1"] = "AB-123"
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"loginId": "user@example.com",
"code": "ZZ-999",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "invalid_code" {
t.Fatalf("expected code=invalid_code, got %v", got["code"])
}
}
func TestVerifyLoginShortCode_MissingShortCode_ReturnsExplicitCode(t *testing.T) {
h := &AuthHandler{
RedisService: &mockRedisRepo{data: make(map[string]string)},
}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"shortCode": "",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "bad_request" {
t.Fatalf("expected code=bad_request, got %v", got["code"])
}
}
func TestVerifyLoginShortCode_InvalidOrExpired_ReturnsExplicitCode(t *testing.T) {
h := &AuthHandler{
RedisService: &mockRedisRepo{data: make(map[string]string)},
}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"shortCode": "AB-123456",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "invalid_or_expired_code" {
t.Fatalf("expected code=invalid_or_expired_code, got %v", got["code"])
}
}
func TestVerifyLoginShortCode_VerifyOnlyMissingPendingRef_ReturnsExplicitCode(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
payload, _ := json.Marshal(shortLoginCodePayload{
LoginID: "user@example.com",
Code: "AB-123",
})
redis.data[prefixLoginCodeShort+"AB-123456"] = string(payload)
h := &AuthHandler{
RedisService: redis,
}
app := newVerifyLoginCodeTestApp(h)
body, _ := json.Marshal(map[string]any{
"shortCode": "AB-123456",
"verifyOnly": true,
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
got := decodeJSONBody(t, resp)
if got["code"] != "invalid_session_reference" {
t.Fatalf("expected code=invalid_session_reference, got %v", got["code"])
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,178 @@
package handler
import (
"baron-sso-backend/internal/service"
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
)
func newOidcLoginTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/oidc/login/accept", h.AcceptOidcLoginRequest)
return app
}
func TestAcceptOidcLoginRequest_CookieOnly(t *testing.T) {
var gotSubject string
var gotChallenge string
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path != "/sessions/whoami" {
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
if r.Header.Get("X-Session-Token") != "" {
return httpResponse(r, http.StatusUnauthorized, "invalid token"), nil
}
if r.Header.Get("Cookie") == "" {
return httpResponse(r, http.StatusUnauthorized, "missing cookie"), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "kratos-123",
"traits": map[string]any{},
},
}), nil
case "hydra.test":
if r.URL.Path != "/oauth2/auth/requests/login/accept" {
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
gotChallenge = r.URL.Query().Get("login_challenge")
body, _ := io.ReadAll(r.Body)
var payload map[string]any
_ = json.Unmarshal(body, &payload)
if subject, ok := payload["subject"].(string); ok {
gotSubject = subject
}
return httpResponse(r, http.StatusOK, `{"redirect_to":"http://rp/cb"}`), nil
default:
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() {
http.DefaultClient = origDefault
}()
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := newOidcLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"login_challenge": "challenge-123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Cookie", "ory_kratos_session=abc123")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
var got map[string]string
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["redirectTo"] != "http://rp/cb" {
t.Fatalf("unexpected redirectTo: %v", got["redirectTo"])
}
if gotSubject != "kratos-123" {
t.Fatalf("unexpected subject: %v", gotSubject)
}
if gotChallenge != "challenge-123" {
t.Fatalf("unexpected login_challenge: %v", gotChallenge)
}
}
func TestAcceptOidcLoginRequest_TokenFallbackToCookie(t *testing.T) {
var gotSubject string
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path != "/sessions/whoami" {
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
if r.Header.Get("X-Session-Token") != "" {
return httpResponse(r, http.StatusUnauthorized, "invalid token"), nil
}
if r.Header.Get("Cookie") == "" {
return httpResponse(r, http.StatusUnauthorized, "missing cookie"), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "kratos-456",
"traits": map[string]any{},
},
}), nil
case "hydra.test":
if r.URL.Path != "/oauth2/auth/requests/login/accept" {
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
body, _ := io.ReadAll(r.Body)
var payload map[string]any
_ = json.Unmarshal(body, &payload)
if subject, ok := payload["subject"].(string); ok {
gotSubject = subject
}
return httpResponse(r, http.StatusOK, `{"redirect_to":"http://rp/cb"}`), nil
default:
return httpResponse(r, http.StatusNotFound, "not found"), nil
}
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() {
http.DefaultClient = origDefault
}()
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := newOidcLoginTestApp(h)
body, _ := json.Marshal(map[string]string{
"login_challenge": "challenge-456",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer invalid-token")
req.Header.Set("Cookie", "ory_kratos_session=def456")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if gotSubject != "kratos-456" {
t.Fatalf("unexpected subject: %v", gotSubject)
}
}

View File

@@ -0,0 +1,110 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
func TestHandleKratosCourierRelay_Email(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
emailSvc := &mockEmailService{}
h := &AuthHandler{
RedisService: redis,
EmailService: emailSvc,
}
app := fiber.New()
app.Post("/api/v1/auth/kratos/courier", h.HandleKratosCourierRelay)
// Simulate Kratos Courier Request for Email
reqBody := map[string]any{
"recipient": "user@example.com",
"template_type": "verification_code",
"template_data": map[string]any{
"verification_code": "123456",
},
"subject": "Verify your email",
"body": "Your code is 123456",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/kratos/courier", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestVerifySignupCode_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Post("/api/v1/auth/signup/verify", h.VerifySignupCode)
// Mock stored code in redis
// signup:email:user@test.com -> {"code":"654321", "verified":false, "expires_at":...}
state := map[string]any{
"code": "654321",
"verified": false,
"expires_at": 9999999999, // far future
}
stateJSON, _ := json.Marshal(state)
redis.data["signup:email:user@test.com"] = string(stateJSON)
// Verify Code
verifyBody := map[string]string{
"type": "email",
"target": "user@test.com",
"code": "654321",
}
body, _ := json.Marshal(verifyBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/signup/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res map[string]any
json.NewDecoder(resp.Body).Decode(&res)
assert.True(t, res["success"].(bool))
// Check redis state updated to verified
val, _ := redis.Get("signup:email:user@test.com")
var updatedState map[string]any
json.Unmarshal([]byte(val), &updatedState)
assert.True(t, updatedState["verified"].(bool))
}
func TestVerifySignupCode_Invalid(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Post("/api/v1/auth/signup/verify", h.VerifySignupCode)
stateJSON, _ := json.Marshal(map[string]any{
"code": "111111",
"expires_at": 9999999999,
})
redis.data["signup:email:user@test.com"] = string(stateJSON)
verifyBody := map[string]string{
"type": "email",
"target": "user@test.com",
"code": "222222", // wrong code
}
body, _ := json.Marshal(verifyBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/signup/verify", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}

View File

@@ -0,0 +1,244 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"maps"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
type recordingUpdateMeUserRepo struct {
MockUserRepoForHandler
updated *domain.User
loginIDs []domain.UserLoginID
}
func (r *recordingUpdateMeUserRepo) Update(ctx context.Context, user *domain.User) error {
copied := *user
r.updated = &copied
return nil
}
func (r *recordingUpdateMeUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
r.loginIDs = append([]domain.UserLoginID(nil), loginIDs...)
return nil
}
type recordingUpdateMeKratosAdmin struct {
MockKratosAdminService
updatedIdentityID string
updatedTraits map[string]any
updatedState string
storedTraits map[string]any
}
func (r *recordingUpdateMeKratosAdmin) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
r.updatedIdentityID = identityID
r.updatedTraits = maps.Clone(traits)
r.updatedState = state
if r.storedTraits != nil {
maps.Copy(r.storedTraits, traits)
}
return &service.KratosIdentity{
ID: identityID,
Traits: traits,
State: state,
}, nil
}
func TestUpdateMe_InvalidatesProfileCacheForTokenSession(t *testing.T) {
token := "token-abc"
identityID := "user-1"
traits := map[string]any{
"email": "qa@example.com",
"name": "QA User",
"phone_number": "+821012345678",
"department": "Old Dept",
"affiliationType": "employee",
"companyCode": "",
"role": domain.RoleUser,
}
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case r.URL.Host == "kratos.test" &&
r.URL.Path == "/sessions/whoami" &&
r.Method == http.MethodGet:
if r.Header.Get("X-Session-Token") != token {
return httpResponse(r, http.StatusUnauthorized, `{"error":"invalid token"}`), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": identityID,
"traits": traits,
},
}), nil
case r.URL.Host == "kratos.test" &&
r.URL.Path == "/admin/identities/"+identityID &&
r.Method == http.MethodPut:
var payload struct {
Traits map[string]any `json:"traits"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
return httpResponse(r, http.StatusBadRequest, `{"error":"invalid body"}`), nil
}
maps.Copy(traits, payload.Traits)
return httpResponse(r, http.StatusOK, `{"ok":true}`), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
setDefaultHTTPClientForTest(t, transport)
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
redis := &mockRedisRepo{data: make(map[string]string)}
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
h := &AuthHandler{
RedisService: redis,
KratosAdmin: kratosAdmin,
}
app := fiber.New()
app.Get("/api/v1/user/me", h.GetMe)
app.Put("/api/v1/user/me", h.UpdateMe)
// 1) 첫 조회로 Old Dept가 캐시에 저장됨
getReq1 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
getReq1.Header.Set("Authorization", "Bearer "+token)
getResp1, err := app.Test(getReq1, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, getResp1.StatusCode)
var profile1 map[string]any
require.NoError(t, json.NewDecoder(getResp1.Body).Decode(&profile1))
require.Equal(t, "Old Dept", profile1["department"])
// 2) 소속을 New Dept로 변경
updateBody, _ := json.Marshal(map[string]string{
"name": "QA User",
"phone": "01012345678",
"department": "New Dept",
})
updateReq := httptest.NewRequest(
http.MethodPut,
"/api/v1/user/me",
bytes.NewReader(updateBody),
)
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+token)
updateResp, err := app.Test(updateReq, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, updateResp.StatusCode)
require.Equal(t, "New Dept", traits["department"])
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
require.Equal(t, "New Dept", kratosAdmin.updatedTraits["department"])
// 3) 새로고침 재조회 시 New Dept가 보여야 함(캐시 무효화 회귀 방지)
getReq2 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
getReq2.Header.Set("Authorization", "Bearer "+token)
getResp2, err := app.Test(getReq2, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, getResp2.StatusCode)
var profile2 map[string]any
require.NoError(t, json.NewDecoder(getResp2.Body).Decode(&profile2))
require.Equal(t, "New Dept", profile2["department"])
}
func TestUpdateMe_SyncsLocalReadModelFields(t *testing.T) {
token := "token-sync"
identityID := "user-sync"
traits := map[string]any{
"email": "sync@example.com",
"name": "Old Name",
"phone_number": "+821012345678",
"department": "Old Dept",
"affiliationType": "employee",
"companyCode": "saman",
"tenant_id": "11111111-1111-1111-1111-111111111111",
"role": domain.RoleUser,
}
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case r.URL.Host == "kratos.test" &&
r.URL.Path == "/sessions/whoami" &&
r.Method == http.MethodGet:
if r.Header.Get("X-Session-Token") != token {
return httpResponse(r, http.StatusUnauthorized, `{"error":"invalid token"}`), nil
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": identityID,
"traits": traits,
},
}), nil
case r.URL.Host == "kratos.test" &&
r.URL.Path == "/admin/identities/"+identityID &&
r.Method == http.MethodPut:
var payload struct {
Traits map[string]any `json:"traits"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
return httpResponse(r, http.StatusBadRequest, `{"error":"invalid body"}`), nil
}
maps.Copy(traits, payload.Traits)
return httpResponse(r, http.StatusOK, `{"ok":true}`), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
setDefaultHTTPClientForTest(t, transport)
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
redis := &mockRedisRepo{data: map[string]string{
"verify_update_phone:" + identityID + ":+821087654321": "verified",
}}
userRepo := &recordingUpdateMeUserRepo{}
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
h := &AuthHandler{
RedisService: redis,
UserRepo: userRepo,
KratosAdmin: kratosAdmin,
}
app := fiber.New()
app.Put("/api/v1/user/me", h.UpdateMe)
updateBody, _ := json.Marshal(map[string]any{
"name": "New Name",
"phone": "01087654321",
"department": "New Dept",
})
updateReq := httptest.NewRequest(
http.MethodPut,
"/api/v1/user/me",
bytes.NewReader(updateBody),
)
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+token)
updateResp, err := app.Test(updateReq, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, updateResp.StatusCode)
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
require.Equal(t, "New Name", kratosAdmin.updatedTraits["name"])
require.Equal(t, "+821087654321", kratosAdmin.updatedTraits["phone_number"])
require.NotNil(t, userRepo.updated)
require.Equal(t, identityID, userRepo.updated.ID)
require.Equal(t, "sync@example.com", userRepo.updated.Email)
require.Equal(t, "New Name", userRepo.updated.Name)
require.Equal(t, "+821087654321", userRepo.updated.Phone)
require.Equal(t, "New Dept", userRepo.updated.Department)
require.Empty(t, userRepo.updated.CompanyCode)
require.NotNil(t, userRepo.updated.TenantID)
require.Equal(t, "11111111-1111-1111-1111-111111111111", *userRepo.updated.TenantID)
}

View File

@@ -0,0 +1,206 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
// --- Mock Redis ---
type mockRedisRepo struct {
data map[string]string
}
func (m *mockRedisRepo) Set(key, value string, ttl time.Duration) error {
if m.data == nil {
m.data = make(map[string]string)
}
m.data[key] = value
return nil
}
func (m *mockRedisRepo) Get(key string) (string, error) {
// Bypass rate limiting for tests
if strings.HasPrefix(key, "poll_meta:") {
return "", nil
}
return m.data[key], nil
}
func (m *mockRedisRepo) Delete(key string) error {
delete(m.data, key)
return nil
}
func (m *mockRedisRepo) StoreVerificationCode(phone, code string) error {
return m.Set("sms:"+phone, code, time.Minute)
}
func (m *mockRedisRepo) GetVerificationCode(phone string) (string, error) {
return m.Get("sms:" + phone)
}
func (m *mockRedisRepo) DeleteVerificationCode(phone string) error {
return m.Delete("sms:" + phone)
}
// --- Tests ---
func TestQRLoginFlow_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Post("/api/v1/auth/qr/init", h.InitQRLogin)
app.Post("/api/v1/auth/qr/poll", h.PollQRLogin)
// 1. Init QR Login
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/init", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var initResp map[string]any
json.NewDecoder(resp.Body).Decode(&initResp)
pendingRef := initResp["pendingRef"].(string)
// 2. Poll (Pending)
body, _ := json.Marshal(map[string]string{"pendingRef": pendingRef})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
// Expect authorization_pending (400)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var pollResp map[string]any
json.NewDecoder(resp.Body).Decode(&pollResp)
assert.Equal(t, "authorization_pending", pollResp["error"])
assert.Equal(t, "authorization_pending", pollResp["code"])
// 3. Mock Approval
sessionData, _ := json.Marshal(map[string]string{
"status": "success",
"jwt": "mock-session-jwt",
})
redis.data["enchanted_session:"+pendingRef] = string(sessionData)
// 4. Poll (Success)
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/poll", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var successResp map[string]any
json.NewDecoder(resp.Body).Decode(&successResp)
assert.Equal(t, "ok", successResp["status"])
assert.Equal(t, "mock-session-jwt", successResp["sessionJwt"])
}
func TestScanQRLogin_Success(t *testing.T) {
redis := &mockRedisRepo{data: make(map[string]string)}
idp := &mockIdpProvider{userExists: true}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
}
app := fiber.New()
app.Post("/api/v1/auth/qr/approve", h.ScanQRLogin)
pendingRef := "test-ref"
redis.data["enchanted_session:"+pendingRef] = `{"status":"pending"}`
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
origDefault := http.DefaultClient
http.DefaultClient = &http.Client{Transport: transport}
defer func() { http.DefaultClient = origDefault }()
body, _ := json.Marshal(map[string]string{
"pendingRef": pendingRef,
"token": "valid-token",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/approve", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestResolveConsentSubjects_TokenAndCookie(t *testing.T) {
h := &AuthHandler{}
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Header.Get("X-Session-Token") == "token-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-token",
"traits": map[string]any{
"email": "token@test.com",
},
},
}), nil
}
if r.Header.Get("Cookie") == "ory_kratos_session=cookie-123" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-cookie",
"traits": map[string]any{
"email": "cookie@test.com",
"phone": "010-1234-5678",
},
},
}), nil
}
return httpResponse(r, http.StatusUnauthorized, "unauthorized"), nil
})
origDefault := http.DefaultClient
http.DefaultClient = &http.Client{Transport: transport}
defer func() { http.DefaultClient = origDefault }()
app := fiber.New()
// Token case
app.Get("/test-token", func(c *fiber.Ctx) error {
subjects, err := h.resolveConsentSubjects(c)
assert.NoError(t, err)
assert.Contains(t, subjects, "user-token")
return c.SendStatus(200)
})
req := httptest.NewRequest("GET", "/test-token", nil)
req.Header.Set("Authorization", "Bearer token-123")
app.Test(req, -1)
// Cookie case
app.Get("/test-cookie", func(c *fiber.Ctx) error {
subjects, err := h.resolveConsentSubjects(c)
assert.NoError(t, err)
assert.Contains(t, subjects, "user-cookie")
return c.SendStatus(200)
})
req = httptest.NewRequest("GET", "/test-cookie", nil)
req.Header.Set("Cookie", "ory_kratos_session=cookie-123")
app.Test(req, -1)
}

View File

@@ -0,0 +1,105 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
func TestGetMe_IncludesSessionAuthenticatedAtFromKratosSession(t *testing.T) {
const (
token = "token-session"
identityID = "user-session"
sessionAuthenticated = "2026-03-23T15:30:00Z"
)
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "kratos.test" &&
r.URL.Path == "/sessions/whoami" &&
r.Method == http.MethodGet {
require.Equal(t, token, r.Header.Get("X-Session-Token"))
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "kratos-session-1",
"authenticated_at": sessionAuthenticated,
"identity": map[string]any{
"id": identityID,
"traits": map[string]any{
"email": "qa@example.com",
"name": "QA User",
"department": "Platform",
"affiliationType": "GENERAL",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
setDefaultHTTPClientForTest(t, transport)
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{}
app := fiber.New()
app.Get("/api/v1/user/me", h.GetMe)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var profile map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&profile))
require.Equal(t, sessionAuthenticated, profile["sessionAuthenticatedAt"])
}
func TestGetMe_IncludesSessionAuthenticatedAtForCookieSession(t *testing.T) {
const (
cookieHeader = "ory_kratos_session=session-cookie"
identityID = "user-cookie"
sessionAuthenticated = "2026-03-24T01:20:00Z"
)
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "kratos.test" &&
r.URL.Path == "/sessions/whoami" &&
r.Method == http.MethodGet {
require.Equal(t, cookieHeader, r.Header.Get("Cookie"))
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "kratos-session-cookie",
"authenticated_at": sessionAuthenticated,
"identity": map[string]any{
"id": identityID,
"traits": map[string]any{
"email": "cookie@example.com",
"name": "Cookie User",
"department": "Platform",
"affiliationType": "GENERAL",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
setDefaultHTTPClientForTest(t, transport)
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{}
app := fiber.New()
app.Get("/api/v1/user/me", h.GetMe)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
req.Header.Set("Cookie", cookieHeader)
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var profile map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&profile))
require.Equal(t, sessionAuthenticated, profile["sessionAuthenticatedAt"])
}

View File

@@ -0,0 +1,944 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestListMySessions_Success(t *testing.T) {
now := time.Date(2026, 4, 2, 1, 2, 3, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
{
ID: "other-sid",
Active: true,
AuthenticatedAt: now.Add(-2 * time.Hour),
ExpiresAt: now.Add(22 * time.Hour),
},
}, nil).Once()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "login_success",
SessionID: "other-sid",
Timestamp: now.Add(-30 * time.Minute),
IPAddress: "203.0.113.10",
UserAgent: "Mozilla/5.0",
},
},
}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
IsCurrent bool `json:"is_current"`
IsActive bool `json:"is_active"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 2) {
assert.Equal(t, "current-sid", body.Items[0].SessionID)
assert.True(t, body.Items[0].IsCurrent)
assert.Equal(t, "other-sid", body.Items[1].SessionID)
assert.True(t, body.Items[1].IsActive)
assert.Equal(t, "203.0.113.10", body.Items[1].IPAddress)
assert.Equal(t, "Mozilla/5.0", body.Items[1].UserAgent)
}
mockKratos.AssertExpectations(t)
}
func TestListMySessions_UsesConsentGrantForAppName(t *testing.T) {
now := time.Date(2026, 4, 2, 4, 40, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
{
ID: "c7c721ea-session",
Active: true,
AuthenticatedAt: now.Add(-5 * time.Minute),
ExpiresAt: now.Add(23*time.Hour + 55*time.Minute),
},
}, nil).Once()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "consent.granted",
SessionID: "c7c721ea-session",
Timestamp: now,
Details: `{"client_id":"devfront","client_name":"DevFront","session_id":"c7c721ea-session","approved_session_id":"c7c721ea-session"}`,
},
},
}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
AppName string `json:"app_name"`
ClientID string `json:"client_id"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 2) {
assert.Equal(t, "c7c721ea-session", body.Items[1].SessionID)
assert.Equal(t, "DevFront", body.Items[1].AppName)
assert.Equal(t, "devfront", body.Items[1].ClientID)
}
mockKratos.AssertExpectations(t)
}
func TestListMySessions_PreservesAppNameFromOlderConsentGrant(t *testing.T) {
now := time.Date(2026, 4, 2, 4, 40, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
{
ID: "c7c721ea-session",
Active: true,
AuthenticatedAt: now.Add(-5 * time.Minute),
ExpiresAt: now.Add(23*time.Hour + 55*time.Minute),
},
}, nil).Once()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "consent.granted",
SessionID: "c7c721ea-session",
Timestamp: now.Add(-30 * time.Second),
IPAddress: "203.0.113.10",
Details: `{"client_id":"devfront","client_name":"DevFront","session_id":"c7c721ea-session"}`,
},
{
UserID: "user-123",
EventType: "login_success",
SessionID: "c7c721ea-session",
Timestamp: now,
IPAddress: "10.0.0.12",
UserAgent: "Mozilla/5.0",
},
},
}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
AppName string `json:"app_name"`
ClientID string `json:"client_id"`
IPAddress string `json:"ip_address"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 2) {
assert.Equal(t, "c7c721ea-session", body.Items[1].SessionID)
assert.Equal(t, "DevFront", body.Items[1].AppName)
assert.Equal(t, "devfront", body.Items[1].ClientID)
assert.Equal(t, "203.0.113.10", body.Items[1].IPAddress)
}
mockKratos.AssertExpectations(t)
}
func TestListMySessions_CurrentSessionFallsBackToRequestMetadata(t *testing.T) {
now := time.Date(2026, 4, 6, 1, 2, 3, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
}, nil).Once()
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: &mockAuditRepo{},
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) Chrome/146.0.0.0 Safari/537.36")
req.Header.Set("X-Forwarded-For", "100.100.100.1, 203.0.113.25")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
IsCurrent bool `json:"is_current"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
ClientID string `json:"client_id"`
AppName string `json:"app_name"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 1) {
assert.Equal(t, "current-sid", body.Items[0].SessionID)
assert.True(t, body.Items[0].IsCurrent)
assert.Equal(t, "203.0.113.25", body.Items[0].IPAddress)
assert.Contains(t, body.Items[0].UserAgent, "Mozilla/5.0")
assert.Equal(t, "userfront", body.Items[0].ClientID)
assert.Equal(t, "UserFront", body.Items[0].AppName)
}
mockKratos.AssertExpectations(t)
}
func TestDeleteMySession_Success(t *testing.T) {
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
var hydraRevokeCalls int
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
case "hydra.test":
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
if r.URL.Query().Get("subject") != "user-123" {
t.Fatalf("unexpected revoke subject: %s", r.URL.Query().Get("subject"))
}
if r.URL.Query().Get("client") != "devfront" {
t.Fatalf("unexpected revoke client: %s", r.URL.Query().Get("client"))
}
hydraRevokeCalls++
return httpResponse(r, http.StatusNoContent, ""), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
setDefaultHTTPClientForTest(t, client.Transport)
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{ID: "target-sid", Active: true},
}, nil).Once()
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: true,
}, nil).Once()
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
UserID: "user-123",
EventType: "POST /api/v1/auth/oidc/login/accept",
SessionID: "target-sid",
Details: `{"client_id":"devfront","client_name":"Devfront"}`,
})
app := fiber.New()
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "session-test-agent")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
if assert.Len(t, auditRepo.logs, 2) {
assert.Equal(t, "session.revoked", auditRepo.logs[len(auditRepo.logs)-1].EventType)
assert.Equal(t, "user-123", auditRepo.logs[len(auditRepo.logs)-1].UserID)
assert.Equal(t, "current-sid", auditRepo.logs[len(auditRepo.logs)-1].SessionID)
assert.Contains(t, auditRepo.logs[len(auditRepo.logs)-1].Details, "target-sid")
}
assert.Equal(t, 1, hydraRevokeCalls)
mockKratos.AssertExpectations(t)
}
func TestDeleteMySession_DoesNotRevokeAllHydraSessionsWhenClientBindingMissing(t *testing.T) {
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
var hydraRevokeCalls int
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
case "hydra.test":
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
hydraRevokeCalls++
return httpResponse(r, http.StatusNoContent, ""), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
setDefaultHTTPClientForTest(t, client.Transport)
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{ID: "target-sid", Active: true},
}, nil).Once()
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: true,
}, nil).Once()
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := fiber.New()
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "session-test-agent")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 0, hydraRevokeCalls)
if assert.Len(t, auditRepo.logs, 1) {
assert.Equal(t, "session.revoked", auditRepo.logs[0].EventType)
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
assert.Contains(t, auditRepo.logs[0].Details, "target-sid")
}
mockKratos.AssertExpectations(t)
}
func TestDeleteMySession_SendsBackchannelLogoutTokenWhenClientConfigured(t *testing.T) {
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
var receivedBody string
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
case "hydra.test":
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
return httpResponse(r, http.StatusNoContent, ""), nil
}
if r.Method == http.MethodGet && r.URL.Path == "/clients/devfront" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "devfront",
"backchannel_logout_uri": "https://rp.example.com/backchannel-logout",
}), nil
}
case "rp.example.com":
if r.Method == http.MethodPost && r.URL.Path == "/backchannel-logout" {
raw, _ := io.ReadAll(r.Body)
receivedBody = string(raw)
return httpResponse(r, http.StatusNoContent, ""), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
setDefaultHTTPClientForTest(t, client.Transport)
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{ID: "target-sid", Active: true},
}, nil).Once()
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: true,
}, nil).Once()
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
backchannelLogout, err := service.NewBackchannelLogoutService()
assert.NoError(t, err)
backchannelLogout.HTTPClient = client
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
BackchannelLogout: backchannelLogout,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
UserID: "user-123",
EventType: "POST /api/v1/auth/oidc/login/accept",
SessionID: "target-sid",
Details: `{"client_id":"devfront","client_name":"Devfront"}`,
})
app := fiber.New()
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "session-test-agent")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.True(t, strings.Contains(receivedBody, "logout_token="))
values, err := url.ParseQuery(receivedBody)
assert.NoError(t, err)
assert.NotEmpty(t, values.Get("logout_token"))
foundBackchannelAudit := false
for _, log := range auditRepo.logs {
if log.EventType == "backchannel_logout.sent" {
foundBackchannelAudit = true
break
}
}
assert.True(t, foundBackchannelAudit)
mockKratos.AssertExpectations(t)
}
func TestDeleteMySession_RevokesHydraClientBoundFromPasswordLoginAudit(t *testing.T) {
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
var hydraRevokeCalls int
var revokedClient string
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch r.URL.Host {
case "kratos.test":
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
case "hydra.test":
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
revokedClient = r.URL.Query().Get("client")
hydraRevokeCalls++
return httpResponse(r, http.StatusNoContent, ""), nil
}
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
setDefaultHTTPClientForTest(t, client.Transport)
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{ID: "target-sid", Active: true},
}, nil).Once()
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: true,
}, nil).Once()
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
auditRepo := &mockAuditRepo{}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
UserID: "user-123",
EventType: "POST /api/v1/auth/password/login",
SessionID: "target-sid",
Details: `{"client_id":"adminfront","client_name":"AdminFront","session_id":"target-sid"}`,
})
app := fiber.New()
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
req.Header.Set("User-Agent", "session-test-agent")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, 1, hydraRevokeCalls)
assert.Equal(t, "adminfront", revokedClient)
mockKratos.AssertExpectations(t)
}
func TestGetHydraProfile_RejectsInactiveLinkedSession(t *testing.T) {
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Host == "hydra.test" && r.URL.Path == "/oauth2/introspect" {
body, _ := io.ReadAll(r.Body)
if string(body) != "token=opaque-token" {
t.Fatalf("unexpected introspect body: %s", string(body))
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"active": true,
"sub": "user-123",
"client_id": "devfront",
"ext": map[string]any{
"session_id": "target-sid",
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})}
mockKratos := new(MockKratosAdminService)
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
ID: "target-sid",
Active: false,
Identity: &service.KratosIdentity{
ID: "user-123",
},
}, nil).Once()
h := &AuthHandler{
KratosAdmin: mockKratos,
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
profile, err := h.getHydraProfile(context.Background(), "opaque-token")
assert.Nil(t, profile)
assert.Error(t, err)
assert.Contains(t, err.Error(), "inactive")
mockKratos.AssertExpectations(t)
}
func TestGetAuthTimeline_FillsSessionIDFromOathkeeperRaw(t *testing.T) {
now := time.Date(2026, 4, 7, 4, 39, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
h := &AuthHandler{
AuditRepo: &mockAuditRepo{},
OathkeeperRepo: &mockOathkeeperRepo{
logs: []domain.OathkeeperAccessLog{
{
Timestamp: now,
RequestID: "req-1",
Method: http.MethodGet,
Path: "/api/v1/dev/sessions",
Status: http.StatusOK,
Subject: "user-123",
ClientIP: "203.0.113.7",
UserAgent: "Mozilla/5.0",
Raw: `{"request":{"url":"https://devfront.example.com/callback?client_id=devfront"},"extra":{"session_id":"target-sid"}}`,
},
},
},
}
app := fiber.New()
app.Get("/api/v1/audit/auth/timeline", h.GetAuthTimeline)
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/auth/timeline", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
ClientID string `json:"client_id"`
AppName string `json:"app_name"`
Source string `json:"source"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 1) {
assert.Equal(t, "target-sid", body.Items[0].SessionID)
assert.Equal(t, "devfront", body.Items[0].ClientID)
assert.Equal(t, "devfront", body.Items[0].AppName)
assert.Equal(t, "oathkeeper", body.Items[0].Source)
}
}
func TestGetAuthTimeline_IncludesHeadlessPasswordLogin(t *testing.T) {
now := time.Date(2026, 4, 7, 5, 10, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
h := &AuthHandler{
AuditRepo: &mockAuditRepo{
logs: []domain.AuditLog{
{
EventID: "audit-1",
Timestamp: now,
UserID: "user-123",
SessionID: "headless-session-1",
EventType: "POST /api/v1/auth/headless/password/login",
Status: "success",
IPAddress: "203.0.113.20",
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36",
Details: `{"client_id":"headless-login-client","client_name":"Headless Login Portal","session_id":"headless-session-1","login_id":"user@example.com","login_challenge":"challenge-123"}`,
},
},
},
}
app := fiber.New()
app.Get("/api/v1/audit/auth/timeline", h.GetAuthTimeline)
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/auth/timeline", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
ClientID string `json:"client_id"`
AppName string `json:"app_name"`
AuthMethod string `json:"auth_method"`
EventType string `json:"event_type"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 1) {
assert.Equal(t, "headless-session-1", body.Items[0].SessionID)
assert.Equal(t, "headless-login-client", body.Items[0].ClientID)
assert.Equal(t, "Headless Login Portal", body.Items[0].AppName)
assert.Equal(t, "비밀번호(Email)", body.Items[0].AuthMethod)
assert.Equal(t, "POST /api/v1/auth/headless/password/login", body.Items[0].EventType)
}
}
func TestListMySessions_UsesHeadlessPasswordLoginForClientBinding(t *testing.T) {
now := time.Date(2026, 4, 7, 5, 35, 0, 0, time.UTC)
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/sessions/whoami" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"id": "current-sid",
"authenticated_at": now.Format(time.RFC3339),
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@example.com",
"name": "User",
"role": "user",
},
},
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
}))
mockKratos := new(MockKratosAdminService)
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
{
ID: "current-sid",
Active: true,
AuthenticatedAt: now,
ExpiresAt: now.Add(24 * time.Hour),
},
{
ID: "headless-session-1",
Active: true,
AuthenticatedAt: now.Add(-10 * time.Minute),
ExpiresAt: now.Add(23*time.Hour + 50*time.Minute),
},
}, nil).Once()
auditRepo := &mockAuditRepo{
logs: []domain.AuditLog{
{
UserID: "user-123",
EventType: "POST /api/v1/auth/headless/password/login",
SessionID: "headless-session-1",
Timestamp: now,
IPAddress: "203.0.113.20",
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36",
Details: `{"client_id":"headless-login-client","client_name":"Headless Login Portal","session_id":"headless-session-1"}`,
},
},
}
h := &AuthHandler{
KratosAdmin: mockKratos,
AuditRepo: auditRepo,
}
app := fiber.New()
app.Get("/api/v1/user/sessions", h.ListMySessions)
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
req.Header.Set("Cookie", "ory_kratos_session=valid")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var body struct {
Items []struct {
SessionID string `json:"session_id"`
AppName string `json:"app_name"`
ClientID string `json:"client_id"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
} `json:"items"`
}
err = json.NewDecoder(resp.Body).Decode(&body)
assert.NoError(t, err)
if assert.Len(t, body.Items, 2) {
assert.Equal(t, "headless-session-1", body.Items[1].SessionID)
assert.Equal(t, "Headless Login Portal", body.Items[1].AppName)
assert.Equal(t, "headless-login-client", body.Items[1].ClientID)
assert.Equal(t, "203.0.113.20", body.Items[1].IPAddress)
assert.Equal(t, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36", body.Items[1].UserAgent)
}
mockKratos.AssertExpectations(t)
}

View File

@@ -0,0 +1,144 @@
package handler
import (
"baron-sso-backend/internal/domain"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Local Mocks for Signup Test ---
type MockRedisForSignup struct {
mock.Mock
}
func (m *MockRedisForSignup) Set(key string, value string, ttl time.Duration) error {
return m.Called(key, value, ttl).Error(0)
}
func (m *MockRedisForSignup) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *MockRedisForSignup) Delete(key string) error {
return m.Called(key).Error(0)
}
func (m *MockRedisForSignup) StoreVerificationCode(phone, code string) error { return nil }
func (m *MockRedisForSignup) GetVerificationCode(phone string) (string, error) { return "", nil }
func (m *MockRedisForSignup) DeleteVerificationCode(phone string) error { return nil }
func (m *MockRedisForSignup) Ping(ctx context.Context) error { return nil }
type MockIdpForSignup struct {
mock.Mock
}
func (m *MockIdpForSignup) Name() string { return "mock-idp" }
func (m *MockIdpForSignup) GetMetadata() (*domain.IDPMetadata, error) {
return &domain.IDPMetadata{SupportedFields: []string{"email", "name", "phoneNumber", "grade", "department"}}, nil
}
func (m *MockIdpForSignup) CreateUser(user *domain.BrokerUser, password string) (string, error) {
args := m.Called(user, password)
return args.String(0), args.Error(1)
}
func (m *MockIdpForSignup) SignIn(loginID, password string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *MockIdpForSignup) UserExists(loginID string) (bool, error) { return false, nil }
func (m *MockIdpForSignup) IssueSession(loginID string) (*domain.AuthInfo, error) { return nil, nil }
func (m *MockIdpForSignup) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, nil
}
func (m *MockIdpForSignup) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *MockIdpForSignup) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return &domain.PasswordPolicy{MinLength: 12}, nil
}
func (m *MockIdpForSignup) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
func (m *MockIdpForSignup) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *MockIdpForSignup) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
return nil
}
func TestSignup_TenantSlugValidation(t *testing.T) {
app := fiber.New()
mockTenantSvc := new(MockTenantService)
mockRedis := new(MockRedisForSignup)
mockIdp := new(MockIdpForSignup)
h := &AuthHandler{
TenantService: mockTenantSvc,
RedisService: mockRedis,
IdpProvider: mockIdp,
}
app.Post("/signup", h.Signup)
// Prepare mock state (already verified email/phone)
verifiedState, _ := json.Marshal(map[string]any{
"verified": true,
"expires_at": time.Now().Add(time.Hour).Unix(),
})
mockRedis.On("Get", mock.Anything).Return(string(verifiedState), nil)
t.Run("Rejects legacy CompanyCode", func(t *testing.T) {
reqBody := domain.SignupRequest{
Email: "user@gmail.com",
Password: "StrongPass123!",
Name: "Test User",
Phone: "010-1234-5678",
TermsAccepted: true,
CompanyCode: "new-slug",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("Active Tenant Slug", func(t *testing.T) {
reqBody := domain.SignupRequest{
Email: "user@hanmaceng.co.kr",
Password: "StrongPass123!",
Name: "Test User",
Phone: "010-1234-5678",
TermsAccepted: true,
TenantSlug: "hanmac",
}
body, _ := json.Marshal(reqBody)
validTenant := &domain.Tenant{ID: "t1", Slug: "hanmac", Status: domain.TenantStatusActive}
mockTenantSvc.On("GetTenantByDomain", mock.Anything, "hanmaceng.co.kr").Return(&domain.Tenant{Slug: "hanmac"}, nil).Once()
mockTenantSvc.On("ProvisionTenantByDomain", mock.Anything, "hanmaceng.co.kr").Return(validTenant, nil).Maybe()
mockTenantSvc.On("GetTenantBySlug", mock.Anything, "hanmac").Return(validTenant, nil).Once()
mockTenantSvc.On("GetTenant", mock.Anything, "t1").Return(validTenant, nil).Once()
mockIdp.On("CreateUser", mock.Anything, mock.Anything).Return("user-id", nil).Once()
mockRedis.On("Delete", mock.Anything).Return(nil)
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
}

View File

@@ -0,0 +1,521 @@
package handler
import (
"baron-sso-backend/internal/middleware"
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
)
// helper to build a Fiber app with the handler route mounted.
func newTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/password/reset/complete", h.CompletePasswordReset)
return app
}
func newResetFlowTestApp(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Post("/api/v1/auth/password/reset/verify", h.ProcessPasswordResetToken)
app.Post("/api/v1/auth/password/reset/complete", h.CompletePasswordReset)
return app
}
func newResetInitAppWithErrorCodeEnricher(h *AuthHandler) *fiber.App {
app := fiber.New()
app.Use(middleware.ErrorCodeEnricher())
app.Post("/api/v1/auth/password/reset/init", h.InitiatePasswordReset)
return app
}
type testRedisRepo struct {
values map[string]string
}
func (m *testRedisRepo) Set(key string, value string, expiration time.Duration) error {
if m.values == nil {
m.values = map[string]string{}
}
m.values[key] = value
return nil
}
func (m *testRedisRepo) Get(key string) (string, error) {
if m.values == nil {
return "", nil
}
return m.values[key], nil
}
func (m *testRedisRepo) Delete(key string) error {
if m.values != nil {
delete(m.values, key)
}
return nil
}
func (m *testRedisRepo) StoreVerificationCode(phone, code string) error {
return m.Set("sms:"+phone, code, time.Minute)
}
func (m *testRedisRepo) GetVerificationCode(phone string) (string, error) {
return m.Get("sms:" + phone)
}
func (m *testRedisRepo) DeleteVerificationCode(phone string) error {
return m.Delete("sms:" + phone)
}
func TestCompletePasswordReset_MissingLoginID(t *testing.T) {
h := &AuthHandler{}
app := newTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": "Password1!",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing loginId, got %d", resp.StatusCode)
}
var got map[string]string
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "Login ID and new password are required" {
t.Fatalf("unexpected error message: %v", got["error"])
}
}
func TestCompletePasswordReset_InvalidPasswordPolicy(t *testing.T) {
h := &AuthHandler{}
app := newTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": "short", // too short + missing complexity
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete?loginId=user@example.com", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for weak password, got %d", resp.StatusCode)
}
var got map[string]string
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "비밀번호는 최소 12자 이상이어야 합니다" {
t.Fatalf("unexpected error message: %v", got["error"])
}
}
func TestCompletePasswordReset_NilIDPProvider(t *testing.T) {
h := &AuthHandler{} // IdpProvider intentionally nil to hit the configuration error branch
app := newTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": "StrongPass1!",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete?loginId=user@example.com", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("expected 500 when IDP provider is nil, got %d", resp.StatusCode)
}
var got map[string]string
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "Authentication service not configured" {
t.Fatalf("unexpected error message: %v", got["error"])
}
}
func TestCompletePasswordReset_TokenValueOverridesLoginIDQuery(t *testing.T) {
const resetToken = "tok-reset-1"
const tokenLoginID = "user@example.com"
const wrongLoginID = "wrong@example.com"
const newPassword = "StrongPass1!"
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + resetToken: tokenLoginID,
},
}
idp := &mockIdpProvider{
userExists: true,
err: nil,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
}
app := newResetFlowTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": newPassword,
})
url := fmt.Sprintf(
"/api/v1/auth/password/reset/complete?loginId=%s&token=%s",
wrongLoginID,
resetToken,
)
req := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if !idp.updateCalled {
t.Fatal("expected UpdateUserPassword to be called")
}
if idp.updatedLoginID != tokenLoginID {
t.Fatalf("expected loginId from token(%s), got %s", tokenLoginID, idp.updatedLoginID)
}
if idp.updatedPassword != newPassword {
t.Fatalf("expected newPassword propagated, got %s", idp.updatedPassword)
}
}
func TestCompletePasswordReset_InvalidTokenRejectedEvenWhenLoginIDExists(t *testing.T) {
const resetToken = "invalid-token"
redis := &testRedisRepo{
values: map[string]string{},
}
idp := &mockIdpProvider{
userExists: true,
err: nil,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
}
app := newResetFlowTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": "StrongPass1!",
})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/auth/password/reset/complete?loginId=user@example.com&token="+resetToken,
bytes.NewReader(body),
)
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401 for invalid token, got %d", resp.StatusCode)
}
if idp.updateCalled {
t.Fatal("UpdateUserPassword must not be called when token is invalid")
}
}
func TestCompletePasswordReset_DuplicateTokenSubmitIsIdempotent(t *testing.T) {
const resetToken = "dup-token"
const loginID = "user@example.com"
const newPassword = "StrongPass1!"
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + resetToken: loginID,
},
}
idp := &mockIdpProvider{
userExists: true,
err: nil,
}
h := &AuthHandler{
RedisService: redis,
IdpProvider: idp,
}
app := newResetFlowTestApp(h)
body, _ := json.Marshal(map[string]string{
"newPassword": newPassword,
})
url := fmt.Sprintf(
"/api/v1/auth/password/reset/complete?token=%s",
resetToken,
)
firstReq := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
firstReq.Header.Set("Content-Type", "application/json")
firstResp, err := app.Test(firstReq)
if err != nil {
t.Fatalf("first request failed: %v", err)
}
defer firstResp.Body.Close()
if firstResp.StatusCode != http.StatusOK {
t.Fatalf("expected first response to be 200, got %d", firstResp.StatusCode)
}
if idp.updateCallCount != 1 {
t.Fatalf("expected first request to update password once, got %d", idp.updateCallCount)
}
secondReq := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
secondReq.Header.Set("Content-Type", "application/json")
secondResp, err := app.Test(secondReq)
if err != nil {
t.Fatalf("second request failed: %v", err)
}
defer secondResp.Body.Close()
if secondResp.StatusCode != http.StatusOK {
t.Fatalf("expected duplicate response to be 200, got %d", secondResp.StatusCode)
}
if idp.updateCallCount != 1 {
t.Fatalf("expected duplicate request not to update password again, got %d", idp.updateCallCount)
}
}
func TestProcessPasswordResetToken_EncodesLoginIDInRedirect(t *testing.T) {
const token = "tok-enc"
const loginID = "user+alias@example.com"
t.Setenv("USERFRONT_URL", "https://sss.hmac.kr")
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + token: loginID,
},
}
h := &AuthHandler{
RedisService: redis,
}
app := newResetFlowTestApp(h)
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/auth/password/reset/verify?token="+token,
nil,
)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusFound {
t.Fatalf("expected 302, got %d", resp.StatusCode)
}
location := resp.Header.Get("Location")
if location == "" {
t.Fatal("missing redirect location")
}
redirectReq := httptest.NewRequest(http.MethodGet, location, nil)
gotLoginID := redirectReq.URL.Query().Get("loginId")
if gotLoginID != loginID {
t.Fatalf("expected encoded loginId round-trip=%s, got %s (location=%s)", loginID, gotLoginID, location)
}
}
func TestPasswordResetVerifyAlias_AcceptsShortVePath(t *testing.T) {
const token = "tok-ve"
const loginID = "user@example.com"
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + token: loginID,
},
}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Get("/api/v1/auth/password/reset/ve", h.VerifyPasswordResetPage)
app.Post("/api/v1/auth/password/reset/ve", h.ProcessPasswordResetToken)
getReq := httptest.NewRequest(
http.MethodGet,
"/api/v1/auth/password/reset/ve?token="+token,
nil,
)
getResp, err := app.Test(getReq)
if err != nil {
t.Fatalf("get request failed: %v", err)
}
defer getResp.Body.Close()
if getResp.StatusCode != http.StatusOK {
t.Fatalf("expected alias GET to return 200, got %d", getResp.StatusCode)
}
postReq := httptest.NewRequest(
http.MethodPost,
"/api/v1/auth/password/reset/ve?token="+token,
nil,
)
postResp, err := app.Test(postReq)
if err != nil {
t.Fatalf("post request failed: %v", err)
}
defer postResp.Body.Close()
if postResp.StatusCode != http.StatusFound {
t.Fatalf("expected alias POST to return 302, got %d", postResp.StatusCode)
}
}
func TestPasswordResetVerifyPathToken_AcceptsShortVPath(t *testing.T) {
const token = "tok-path"
const loginID = "user@example.com"
redis := &testRedisRepo{
values: map[string]string{
prefixPwdResetToken + token: loginID,
},
}
h := &AuthHandler{
RedisService: redis,
}
app := fiber.New()
app.Get("/api/v1/auth/password/reset/v/:token", h.VerifyPasswordResetPage)
app.Post("/api/v1/auth/password/reset/v/:token", h.ProcessPasswordResetToken)
getReq := httptest.NewRequest(
http.MethodGet,
"/api/v1/auth/password/reset/v/"+token,
nil,
)
getResp, err := app.Test(getReq)
if err != nil {
t.Fatalf("get request failed: %v", err)
}
defer getResp.Body.Close()
if getResp.StatusCode != http.StatusOK {
t.Fatalf("expected path-token GET to return 200, got %d", getResp.StatusCode)
}
postReq := httptest.NewRequest(
http.MethodPost,
"/api/v1/auth/password/reset/v/"+token,
nil,
)
postResp, err := app.Test(postReq)
if err != nil {
t.Fatalf("post request failed: %v", err)
}
defer postResp.Body.Close()
if postResp.StatusCode != http.StatusFound {
t.Fatalf("expected path-token POST to return 302, got %d", postResp.StatusCode)
}
}
func TestPasswordResetInit_LegacyErrorResponseHasCodeViaMiddleware(t *testing.T) {
h := &AuthHandler{}
app := newResetInitAppWithErrorCodeEnricher(h)
body, _ := json.Marshal(map[string]string{
"loginId": "",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
var got map[string]any
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got["error"] != "Login ID is required" {
t.Fatalf("unexpected error message: %v", got["error"])
}
if got["code"] != "bad_request" {
t.Fatalf("expected code=bad_request, got %v", got["code"])
}
}
func TestInitiatePasswordReset_SmsContainsVerifyLink(t *testing.T) {
t.Setenv("USERFRONT_URL", "https://sss.hmac.kr")
redis := &testRedisRepo{values: map[string]string{}}
smsSvc := &mockSmsService{}
h := &AuthHandler{
RedisService: redis,
IdpProvider: &mockIdpProvider{},
SmsService: smsSvc,
}
app := fiber.New()
app.Post("/api/v1/auth/password/reset/init", h.InitiatePasswordReset)
body, _ := json.Marshal(map[string]string{
"loginId": "01012345678",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if !strings.Contains(smsSvc.lastContent, "/api/v1/auth/password/reset/v/") {
t.Fatalf("expected SMS to contain short path verify link, got %q", smsSvc.lastContent)
}
if strings.Contains(smsSvc.lastContent, "/reset-password?token=") {
t.Fatalf("expected direct reset-password link to be removed, got %q", smsSvc.lastContent)
}
}

View File

@@ -0,0 +1,537 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/response"
"baron-sso-backend/internal/service"
"encoding/json"
"errors"
"sort"
"strings"
"github.com/gofiber/fiber/v2"
)
const (
clientTenantAccessRestrictedKey = "tenant_access_restricted"
clientAllowedTenantsKey = "allowed_tenants"
)
func normalizeClientTenantAccessMetadata(metadata map[string]any) (map[string]any, error) {
if metadata == nil {
metadata = map[string]any{}
}
restricted := readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey)
allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])
ownerTenantID := normalizeMetadataString(metadata["tenant_id"])
if len(allowedTenants) > 0 {
restricted = true
}
if !restricted {
delete(metadata, clientAllowedTenantsKey)
metadata[clientTenantAccessRestrictedKey] = false
return metadata, nil
}
if ownerTenantID != "" {
allowedTenants = append(allowedTenants, ownerTenantID)
}
allowedTenants = uniqueSortedStrings(allowedTenants)
if len(allowedTenants) == 0 {
return nil, errors.New("allowed_tenants is required when tenant_access_restricted is enabled")
}
metadata[clientTenantAccessRestrictedKey] = true
metadata[clientAllowedTenantsKey] = allowedTenants
return metadata, nil
}
func clientTenantAccessRestricted(metadata map[string]any) bool {
if metadata == nil {
return false
}
if readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey) {
return true
}
return len(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])) > 0
}
func clientAllowedTenants(metadata map[string]any) []string {
if metadata == nil {
return nil
}
if !clientTenantAccessRestricted(metadata) {
return nil
}
return uniqueSortedStrings(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey]))
}
func normalizeMetadataStringSlice(raw any) []string {
switch value := raw.(type) {
case []string:
return uniqueSortedStrings(value)
case []any:
items := make([]string, 0, len(value))
for _, item := range value {
if s, ok := item.(string); ok {
items = append(items, s)
}
}
return uniqueSortedStrings(items)
default:
return nil
}
}
func normalizeMetadataString(raw any) string {
s, ok := raw.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func uniqueSortedStrings(values []string) []string {
if len(values) == 0 {
return nil
}
seen := make(map[string]struct{}, len(values))
out := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
if _, ok := seen[trimmed]; ok {
continue
}
seen[trimmed] = struct{}{}
out = append(out, trimmed)
}
sort.Strings(out)
return out
}
func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
if !clientTenantAccessRestricted(client.Metadata) {
return true
}
allowed := clientAllowedTenants(client.Metadata)
if len(allowed) == 0 {
return false
}
keys := manageableTenantKeysFromProfile(profile)
if len(keys) == 0 {
return false
}
for _, tenantID := range allowed {
if _, ok := keys[strings.ToLower(strings.TrimSpace(tenantID))]; ok {
return true
}
}
return false
}
func clientTenantAccessAllowedForSubtree(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse, client domain.HydraClient) bool {
if clientTenantAccessAllowed(profile, client) {
return true
}
if tenantSvc == nil || profile == nil {
return false
}
allowedTenants := make([]domain.Tenant, 0)
for _, identifier := range clientAllowedTenants(client.Metadata) {
if tenant, ok := resolveTenantAccessTenant(c, tenantSvc, domain.Tenant{ID: identifier, Slug: identifier}); ok {
allowedTenants = append(allowedTenants, tenant)
}
}
if len(allowedTenants) == 0 {
return false
}
for _, candidate := range tenantAccessProfileTenants(profile) {
resolvedCandidate, ok := resolveTenantAccessTenant(c, tenantSvc, candidate)
if !ok {
continue
}
for _, allowed := range allowedTenants {
if tenantMatchesOrDescendsFrom(c, tenantSvc, resolvedCandidate, allowed) {
return true
}
}
}
return false
}
func tenantAccessProfileTenants(profile *domain.UserProfileResponse) []domain.Tenant {
if profile == nil {
return nil
}
seen := make(map[string]struct{})
tenants := make([]domain.Tenant, 0, len(profile.ManageableTenants)+len(profile.JoinedTenants)+2)
add := func(tenant domain.Tenant) {
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Name))
if key == "" {
return
}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
tenants = append(tenants, tenant)
}
if profile.Tenant != nil {
add(*profile.Tenant)
}
if profile.TenantID != nil {
add(domain.Tenant{ID: strings.TrimSpace(*profile.TenantID)})
}
for _, tenant := range profile.ManageableTenants {
add(tenant)
}
for _, tenant := range profile.JoinedTenants {
add(tenant)
}
return tenants
}
func resolveTenantAccessTenant(c *fiber.Ctx, tenantSvc service.TenantService, tenant domain.Tenant) (domain.Tenant, bool) {
if tenantSvc == nil {
return tenant, firstNonEmptyString(tenant.ID, tenant.Slug) != ""
}
if strings.TrimSpace(tenant.ID) != "" {
if resolved, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(tenant.ID)); err == nil && resolved != nil {
return *resolved, true
}
}
if strings.TrimSpace(tenant.Slug) != "" {
if resolved, err := tenantSvc.GetTenantBySlug(c.Context(), strings.TrimSpace(tenant.Slug)); err == nil && resolved != nil {
return *resolved, true
}
}
return tenant, firstNonEmptyString(tenant.ID, tenant.Slug) != ""
}
func tenantMatchesOrDescendsFrom(c *fiber.Ctx, tenantSvc service.TenantService, tenant domain.Tenant, ancestor domain.Tenant) bool {
if tenantAccessTenantMatches(tenant, ancestor) {
return true
}
if tenantSvc == nil {
return false
}
visited := make(map[string]struct{})
current := tenant
for current.ParentID != nil && strings.TrimSpace(*current.ParentID) != "" {
parentID := strings.TrimSpace(*current.ParentID)
if _, ok := visited[parentID]; ok {
return false
}
visited[parentID] = struct{}{}
parent, err := tenantSvc.GetTenant(c.Context(), parentID)
if err != nil || parent == nil {
return false
}
if tenantAccessTenantMatches(*parent, ancestor) {
return true
}
current = *parent
}
return false
}
func tenantAccessTenantMatches(left, right domain.Tenant) bool {
leftID := strings.ToLower(strings.TrimSpace(left.ID))
rightID := strings.ToLower(strings.TrimSpace(right.ID))
if leftID != "" && rightID != "" && leftID == rightID {
return true
}
leftSlug := strings.ToLower(strings.TrimSpace(left.Slug))
rightSlug := strings.ToLower(strings.TrimSpace(right.Slug))
return leftSlug != "" && rightSlug != "" && leftSlug == rightSlug
}
type tenantAccessDeniedDetails struct {
Account tenantAccessDeniedAccount `json:"account"`
CurrentTenant tenantAccessDeniedTenant `json:"current_tenant"`
AffiliatedTenants []tenantAccessDeniedTenant `json:"affiliated_tenants,omitempty"`
AllowedTenants []tenantAccessDeniedTenant `json:"allowed_tenants,omitempty"`
}
type tenantAccessDeniedAccount struct {
Email string `json:"email,omitempty"`
}
type tenantAccessDeniedTenant struct {
ID string `json:"id,omitempty"`
Slug string `json:"slug,omitempty"`
Name string `json:"name,omitempty"`
Identifier string `json:"identifier,omitempty"`
}
func tenantNotAllowedError(c *fiber.Ctx, details tenantAccessDeniedDetails) error {
return response.ErrorWithDetails(
c,
fiber.StatusForbidden,
"tenant_not_allowed",
"허용되지 않은 테넌트입니다.",
details,
)
}
func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
if profile == nil {
return false
}
return clientTenantAccessAllowed(profile, client)
}
func enforceClientTenantAccess(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) bool {
if !clientTenantAccessRestricted(client.Metadata) {
return false
}
details := buildTenantAccessDeniedDetails(c, tenantSvc, client, profile)
if resolveErr != nil || profile == nil {
_ = tenantNotAllowedError(c, details)
return true
}
if !clientTenantAccessAllowedForSubtree(c, tenantSvc, profile, client) {
_ = tenantNotAllowedError(c, details)
return true
}
return false
}
func buildTenantAccessDeniedDetails(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse) tenantAccessDeniedDetails {
details := tenantAccessDeniedDetails{
Account: tenantAccessDeniedAccount{Email: strings.TrimSpace(profileEmail(profile))},
CurrentTenant: resolveCurrentTenantDetails(c, tenantSvc, profile),
AffiliatedTenants: resolveAffiliatedTenantDetails(c, tenantSvc, profile),
}
for _, identifier := range clientAllowedTenants(client.Metadata) {
details.AllowedTenants = append(details.AllowedTenants, resolveAllowedTenantDetails(c, tenantSvc, identifier))
}
return details
}
func resolveAffiliatedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) []tenantAccessDeniedTenant {
if profile == nil {
return nil
}
seen := make(map[string]struct{})
out := make([]tenantAccessDeniedTenant, 0, len(profile.JoinedTenants)+1)
appendTenant := func(tenant tenantAccessDeniedTenant) {
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Identifier, tenant.Name))
if key == "" {
return
}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
out = append(out, tenant)
}
appendTenant(resolveCurrentTenantDetails(c, tenantSvc, profile))
for _, joined := range profile.JoinedTenants {
appendTenant(tenantAccessDeniedTenant{
ID: strings.TrimSpace(joined.ID),
Slug: strings.TrimSpace(joined.Slug),
Name: strings.TrimSpace(joined.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(joined.Slug), strings.TrimSpace(joined.ID)),
})
}
return out
}
func resolveCurrentTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) tenantAccessDeniedTenant {
if profile == nil {
return tenantAccessDeniedTenant{}
}
if profile.Tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(profile.Tenant.ID),
Slug: strings.TrimSpace(profile.Tenant.Slug),
Name: strings.TrimSpace(profile.Tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(profile.Tenant.Slug), strings.TrimSpace(profile.Tenant.ID)),
}
}
if tenantSvc != nil {
if profile.TenantID != nil && strings.TrimSpace(*profile.TenantID) != "" {
if tenant, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(*profile.TenantID)); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
}
}
}
}
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(pointerValue(profile.TenantID)),
Identifier: strings.TrimSpace(pointerValue(profile.TenantID)),
}
}
func resolveAllowedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, identifier string) tenantAccessDeniedTenant {
identifier = strings.TrimSpace(identifier)
if identifier == "" {
return tenantAccessDeniedTenant{}
}
if tenantSvc != nil {
if tenant, err := tenantSvc.GetTenant(c.Context(), identifier); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
}
}
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), identifier); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
}
}
}
return tenantAccessDeniedTenant{Identifier: identifier}
}
func profileEmail(profile *domain.UserProfileResponse) string {
if profile == nil {
return ""
}
return profile.Email
}
func pointerValue(value *string) string {
if value == nil {
return ""
}
return *value
}
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
type clientStructuredScope struct {
Name string `json:"name"`
Mandatory bool `json:"mandatory"`
Locked bool `json:"locked"`
}
func mergeRequestedScopesWithClientRequirements(client domain.HydraClient, requested []string) []string {
combined := make([]string, 0, len(requested)+2)
combined = append(combined, requested...)
combined = append(combined, requiredClientScopes(client)...)
return normalizeScopesInConsentOrder(combined)
}
func normalizeScopesInConsentOrder(scopes []string) []string {
combined := make([]string, 0, len(scopes))
combined = append(combined, scopes...)
seen := make(map[string]struct{}, len(combined))
out := make([]string, 0, len(combined))
appendIfPresent := func(scope string) {
scope = strings.TrimSpace(scope)
if scope == "" {
return
}
if _, ok := seen[scope]; ok {
return
}
for _, candidate := range combined {
if strings.TrimSpace(candidate) != scope {
continue
}
seen[scope] = struct{}{}
out = append(out, scope)
return
}
}
appendIfPresent("openid")
appendIfPresent("tenant")
for _, scope := range combined {
scope = strings.TrimSpace(scope)
if scope == "" {
continue
}
if _, ok := seen[scope]; ok {
continue
}
seen[scope] = struct{}{}
out = append(out, scope)
}
return out
}
func requiredClientScopes(client domain.HydraClient) []string {
required := make([]string, 0, 4)
if clientTenantAccessRestricted(client.Metadata) {
required = append(required, "tenant")
}
if client.Metadata == nil {
return normalizeScopesInConsentOrder(required)
}
rawStructuredScopes, ok := client.Metadata["structured_scopes"]
if !ok || rawStructuredScopes == nil {
return normalizeScopesInConsentOrder(required)
}
rawBytes, err := json.Marshal(rawStructuredScopes)
if err != nil {
return normalizeScopesInConsentOrder(required)
}
var scopes []clientStructuredScope
if err := json.Unmarshal(rawBytes, &scopes); err != nil {
return normalizeScopesInConsentOrder(required)
}
for _, scope := range scopes {
name := strings.TrimSpace(scope.Name)
if name == "" {
continue
}
if scope.Mandatory || scope.Locked {
required = append(required, name)
}
}
return normalizeScopesInConsentOrder(required)
}

View File

@@ -0,0 +1,458 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) {
var captured domain.HydraClient
ownerTenantID := "tenant-owner"
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.NoError(t, json.Unmarshal(body, &captured))
return httpJSONAny(r, http.StatusCreated, map[string]any{
"client_id": captured.ClientID,
"client_name": captured.ClientName,
"redirect_uris": captured.RedirectURIs,
"grant_types": captured.GrantTypes,
"response_types": captured.ResponseTypes,
"scope": captured.Scope,
"token_endpoint_auth_method": captured.TokenEndpointAuthMethod,
"metadata": captured.Metadata,
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
Keto: new(devMockKetoService),
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-1",
Role: domain.RoleSuperAdmin,
TenantID: &ownerTenantID,
})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"id": "client-tenant",
"name": "Tenant Client",
"type": "pkce",
"redirectUris": []string{"https://rp.example.com/cb"},
"metadata": map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-b", "tenant-a", "tenant-b"},
},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
assert.True(t, clientTenantAccessRestricted(captured.Metadata))
assert.Equal(t, []string{"tenant-a", "tenant-b", "tenant-owner"}, clientAllowedTenants(captured.Metadata))
}
func TestCreateClient_RejectsTenantAccessWithoutAllowedTenants(t *testing.T) {
hydraCalled := false
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
hydraCalled = true
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
Keto: new(devMockKetoService),
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "user-1", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"id": "client-tenant",
"name": "Tenant Client",
"type": "pkce",
"redirectUris": []string{"https://rp.example.com/cb"},
"metadata": map[string]any{
"tenant_access_restricted": true,
},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.False(t, hydraCalled)
}
func TestMergeRequestedScopesWithClientRequirements_AddsTenantScope(t *testing.T) {
client := domain.HydraClient{
Metadata: map[string]any{
"tenant_access_restricted": true,
"structured_scopes": []map[string]any{
{"name": "openid", "mandatory": true},
{"name": "tenant", "mandatory": true, "locked": true},
{"name": "profile", "mandatory": false},
},
},
}
merged := mergeRequestedScopesWithClientRequirements(client, []string{"openid", "profile"})
assert.Equal(t, []string{"openid", "tenant", "profile"}, merged)
}
func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant":
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-tenant",
"requested_scope": []string{"openid", "profile"},
"skip": false,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-tenant",
"metadata": map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-b"},
},
},
}), nil
case r.URL.Host == "kratos.test" && r.URL.Path == "/sessions/whoami":
return httpJSONAny(r, http.StatusOK, map[string]any{
"identity": map[string]any{
"id": "user-123",
"traits": map[string]any{
"email": "user@test.com",
"tenant_id": "tenant-a",
},
},
}), nil
default:
return httpJSONAny(r, http.StatusNotFound, nil), nil
}
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
}
app := fiber.New()
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
t.Setenv("APP_ENV", "dev")
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant", nil)
req.Header.Set("X-Mock-Role", "user")
req.Header.Set("X-Tenant-ID", "tenant-a")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.NotEmpty(t, account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.NotEmpty(t, currentTenant["identifier"])
}
func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) {
acceptCalled := false
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
switch {
case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-profile-missing":
return httpJSONAny(r, http.StatusOK, map[string]any{
"challenge": "challenge-profile-missing",
"requested_scope": []string{"openid", "profile"},
"skip": false,
"subject": "user-123",
"client": map[string]any{
"client_id": "client-tenant",
"metadata": map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-b"},
},
},
}), nil
case r.URL.Path == "/oauth2/auth/requests/consent/accept":
acceptCalled = true
return httpJSONAny(r, http.StatusOK, map[string]any{
"redirect_to": "http://rp/cb",
}), nil
default:
return httpJSONAny(r, http.StatusNotFound, nil), nil
}
})
client := &http.Client{Transport: transport}
origDefault := http.DefaultClient
http.DefaultClient = client
defer func() { http.DefaultClient = origDefault }()
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: func() service.KratosAdminService {
mockKratos := new(MockKratosAdminService)
mockKratos.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]any{
"email": "user@test.com",
"tenant_id": "tenant-a",
"companyCode": "tenant-a",
},
}, nil).Once()
return mockKratos
}(),
TenantService: func() service.TenantService {
tenantSvc := new(MockTenantService)
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
ID: "tenant-a",
Slug: "tenant-a",
Name: "Tenant A",
}, nil)
tenantSvc.On("GetTenant", mock.Anything, "tenant-c").Return(&domain.Tenant{
ID: "tenant-c",
Slug: "tenant-c",
Name: "Tenant C",
}, nil)
tenantSvc.On("ListJoinedTenants", mock.Anything, "user-123").Return([]domain.Tenant{
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
}, nil).Once()
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError)
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
ID: "tenant-b-id",
Slug: "tenant-b",
Name: "Tenant B",
}, nil)
return tenantSvc
}(),
ConsentRepo: &mockConsentRepo{
consents: []domain.ClientConsent{
{
ClientID: "client-tenant",
Subject: "user-123",
GrantedScopes: []string{"openid", "profile"},
},
},
},
}
app := fiber.New()
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-profile-missing", nil)
req.Header.Set("Cookie", "ory_kratos_session=invalid-session")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assert.False(t, acceptCalled)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "user@test.com", account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant A", currentTenant["name"])
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, affiliatedTenants, 2)
}
func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
app := fiber.New()
app.Get("/deny", func(c *fiber.Ctx) error {
tenantID := "tenant-a"
profile := &domain.UserProfileResponse{
ID: "user-123",
Role: domain.RoleUser,
Email: "user@test.com",
TenantID: &tenantID,
CompanyCode: "tenant-a",
JoinedTenants: []domain.Tenant{
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
},
}
client := domain.HydraClient{
ClientID: "client-tenant",
Metadata: map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"tenant-b"},
},
}
tenantSvc := new(MockTenantService)
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
ID: "tenant-a",
Slug: "tenant-a",
Name: "Tenant A",
}, nil)
tenantSvc.On("GetTenant", mock.Anything, "tenant-c").Return(&domain.Tenant{
ID: "tenant-c",
Slug: "tenant-c",
Name: "Tenant C",
}, nil)
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError)
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
ID: "tenant-b-id",
Slug: "tenant-b",
Name: "Tenant B",
}, nil)
enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
return nil
})
req := httptest.NewRequest(http.MethodGet, "/deny", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "user@test.com", account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant A", currentTenant["name"])
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, affiliatedTenants, 2)
allowedTenants, ok := details["allowed_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, allowedTenants, 1)
allowedTenant, ok := allowedTenants[0].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant B", allowedTenant["name"])
}
func TestAcceptOidcLoginRequest_AllowsRestrictedClientForHanmacFamilyDescendant(t *testing.T) {
app := fiber.New()
app.Get("/allow-descendant", func(c *fiber.Ctx) error {
hanmacFamilyID := "hanmac-family-id"
samanID := "saman-id"
profile := &domain.UserProfileResponse{
ID: "user-123",
Role: domain.RoleUser,
Email: "user@samaneng.com",
TenantID: &samanID,
Tenant: &domain.Tenant{
ID: samanID,
Slug: "saman",
Name: "삼안",
ParentID: &hanmacFamilyID,
},
JoinedTenants: []domain.Tenant{
{
ID: samanID,
Slug: "saman",
Name: "삼안",
ParentID: &hanmacFamilyID,
},
},
}
client := domain.HydraClient{
ClientID: "orgfront",
Metadata: map[string]any{
"tenant_access_restricted": true,
"allowed_tenants": []string{"hanmac-family"},
},
}
tenantSvc := new(MockTenantService)
tenantSvc.On("GetTenant", mock.Anything, "hanmac-family").Return(nil, assert.AnError).Maybe()
tenantSvc.On("GetTenantBySlug", mock.Anything, "hanmac-family").Return(&domain.Tenant{
ID: hanmacFamilyID,
Slug: "hanmac-family",
Name: "한맥가족",
}, nil).Maybe()
tenantSvc.On("GetTenant", mock.Anything, samanID).Return(&domain.Tenant{
ID: samanID,
Slug: "saman",
Name: "삼안",
ParentID: &hanmacFamilyID,
}, nil).Maybe()
tenantSvc.On("GetTenant", mock.Anything, hanmacFamilyID).Return(&domain.Tenant{
ID: hanmacFamilyID,
Slug: "hanmac-family",
Name: "한맥가족",
}, nil).Maybe()
blocked := enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
assert.False(t, blocked)
return c.SendStatus(http.StatusNoContent)
})
req := httptest.NewRequest(http.MethodGet, "/allow-descendant", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
}

View File

@@ -0,0 +1,303 @@
package handler
import (
"baron-sso-backend/internal/domain"
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"slices"
"time"
)
// --- Mock IDP Provider ---
type mockIdpProvider struct {
userExists bool
name string
signInInfo *domain.AuthInfo
issueSession *domain.AuthInfo
verifyCodeInfo *domain.AuthInfo
err error
initiateLinkErr error
updateCalled bool
updateCallCount int
updatedLoginID string
updatedPassword string
}
func (m *mockIdpProvider) Name() string {
if m.name != "" {
return m.name
}
return "mock-idp"
}
func (m *mockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) { return nil, m.err }
func (m *mockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
return "mock-user-id", m.err
}
func (m *mockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
return m.signInInfo, m.err
}
func (m *mockIdpProvider) UserExists(loginID string) (bool, error) { return m.userExists, m.err }
func (m *mockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
if m.issueSession != nil {
return m.issueSession, m.err
}
return &domain.AuthInfo{
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "valid-sid"},
}, m.err
}
func (m *mockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
if m.initiateLinkErr != nil {
return nil, m.initiateLinkErr
}
return &domain.LinkLoginInit{FlowID: "mock-flow-id", Mode: "code"}, m.err
}
func (m *mockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return m.verifyCodeInfo, m.err
}
func (m *mockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) { return nil, m.err }
func (m *mockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return m.err }
func (m *mockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, m.err
}
func (m *mockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
m.updateCalled = true
m.updateCallCount++
m.updatedLoginID = loginID
m.updatedPassword = newPassword
return m.err
}
// --- Mock Audit Repository ---
type mockAuditRepo struct {
logs []domain.AuditLog
}
func (m *mockAuditRepo) Create(log *domain.AuditLog) error {
m.logs = append(m.logs, *log)
return nil
}
func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
return m.logs, nil
}
func (m *mockAuditRepo) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
var results []domain.AuditLog
for _, log := range m.logs {
if log.UserID == userID {
if slices.Contains(eventTypes, log.EventType) {
results = append(results, log)
}
}
}
return results, nil
}
func (m *mockAuditRepo) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
return 0, nil
}
func (m *mockAuditRepo) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
return 0, nil
}
func (m *mockAuditRepo) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
return 0, nil
}
func (m *mockAuditRepo) Ping(ctx context.Context) error { return nil }
type mockRPUsageEventSink struct {
events []domain.RPUsageEvent
err error
}
func (m *mockRPUsageEventSink) EmitRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
if m.err != nil {
return m.err
}
m.events = append(m.events, event)
return nil
}
type mockOathkeeperRepo struct {
logs []domain.OathkeeperAccessLog
}
func (m *mockOathkeeperRepo) FindPageBySubject(ctx context.Context, subject string, limit int, cursor *domain.AuditCursor) ([]domain.OathkeeperAccessLog, error) {
if subject == "" {
return m.logs, nil
}
results := make([]domain.OathkeeperAccessLog, 0, len(m.logs))
for _, log := range m.logs {
if log.Subject == subject {
results = append(results, log)
}
}
return results, nil
}
func (m *mockOathkeeperRepo) Ping(ctx context.Context) error { return nil }
// --- Mock Consent Repository ---
type mockConsentRepo struct {
consents []domain.ClientConsent
}
func (m *mockConsentRepo) Upsert(ctx context.Context, consent *domain.ClientConsent) error {
m.consents = append(m.consents, *consent)
return nil
}
func (m *mockConsentRepo) ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error) {
var results []domain.ClientConsent
for _, c := range m.consents {
if c.Subject == subject {
results = append(results, c)
}
}
return results, nil
}
func (m *mockConsentRepo) ListSubjectsByClient(ctx context.Context, clientID string) ([]string, error) {
seen := map[string]struct{}{}
subjects := make([]string, 0, len(m.consents))
for _, consent := range m.consents {
if consent.ClientID != clientID {
continue
}
if _, ok := seen[consent.Subject]; ok {
continue
}
seen[consent.Subject] = struct{}{}
subjects = append(subjects, consent.Subject)
}
return subjects, nil
}
func (m *mockConsentRepo) Find(ctx context.Context, clientID, subject string) (*domain.ClientConsent, error) {
for _, consent := range m.consents {
if consent.ClientID == clientID && consent.Subject == subject {
found := consent
return &found, nil
}
}
return nil, nil
}
func (m *mockConsentRepo) Delete(ctx context.Context, subject, clientID string) error {
filtered := m.consents[:0]
for _, consent := range m.consents {
if consent.Subject == subject && (clientID == "" || consent.ClientID == clientID) {
continue
}
filtered = append(filtered, consent)
}
m.consents = filtered
return nil
}
func (m *mockConsentRepo) DeleteByClient(ctx context.Context, clientID string) error {
filtered := m.consents[:0]
for _, consent := range m.consents {
if consent.ClientID != clientID {
filtered = append(filtered, consent)
}
}
m.consents = filtered
return nil
}
func (m *mockConsentRepo) List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
results := make([]domain.ClientConsentWithTenantInfo, 0, len(m.consents))
for _, consent := range m.consents {
if consent.ClientID == clientID {
results = append(results, domain.ClientConsentWithTenantInfo{ClientConsent: consent})
}
}
return results, int64(len(results)), nil
}
func (m *mockConsentRepo) ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
results := make([]domain.ClientConsentWithTenantInfo, 0, len(m.consents))
for _, consent := range m.consents {
if consent.ClientID == clientID {
results = append(results, domain.ClientConsentWithTenantInfo{
ClientConsent: consent,
TenantID: tenantID,
})
}
}
return results, int64(len(results)), nil
}
// --- Mock Secret Repository ---
type mockSecretRepo struct {
secrets map[string]string
}
func (m *mockSecretRepo) Upsert(ctx context.Context, clientID, secret string) error {
if m.secrets == nil {
m.secrets = make(map[string]string)
}
m.secrets[clientID] = secret
return nil
}
func (m *mockSecretRepo) GetByID(ctx context.Context, clientID string) (string, error) {
return m.secrets[clientID], nil
}
func (m *mockSecretRepo) Delete(ctx context.Context, clientID string) error {
delete(m.secrets, clientID)
return nil
}
// --- HTTP Mock Helpers ---
type roundTripFunc func(req *http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func setDefaultHTTPClientForTest(t interface{ Cleanup(func()) }, transport http.RoundTripper) {
origDefault := http.DefaultClient
http.DefaultClient = &http.Client{Transport: transport}
t.Cleanup(func() {
http.DefaultClient = origDefault
})
}
func httpResponse(r *http.Request, code int, body string) *http.Response {
return &http.Response{
StatusCode: code,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewBufferString(body)),
Request: r,
}
}
func httpJSONAny(r *http.Request, code int, data any) *http.Response {
body, _ := json.Marshal(data)
return &http.Response{
StatusCode: code,
Header: http.Header{
"Content-Type": []string{"application/json"},
},
Body: io.NopCloser(bytes.NewBuffer(body)),
Request: r,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,242 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestDevHandler_Isolation(t *testing.T) {
createHandler := func(mockKeto *devMockKetoService) *DevHandler {
return &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodGet && r.URL.Path == "/clients" {
return httpJSONAny(r, http.StatusOK, []map[string]any{
{
"client_id": "client-tenant-a",
"client_name": "App Tenant A",
"token_endpoint_auth_method": "none", // PKCE
"metadata": map[string]any{"tenant_id": "tenant-a"},
},
{
"client_id": "client-tenant-b",
"client_name": "App Tenant B",
"token_endpoint_auth_method": "none", // PKCE
"metadata": map[string]any{"tenant_id": "tenant-b"},
},
}), nil
}
if (r.Method == http.MethodGet || r.Method == http.MethodPut) && strings.HasPrefix(r.URL.Path, "/clients/") {
id := strings.TrimPrefix(r.URL.Path, "/clients/")
tenantID := "tenant-a"
if id == "client-tenant-b" {
tenantID = "tenant-b"
}
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": id,
"client_name": "App " + id,
"token_endpoint_auth_method": "none",
"metadata": map[string]any{"tenant_id": tenantID},
}), nil
}
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
return httpJSONAny(r, http.StatusCreated, body), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
}),
},
},
Keto: mockKeto,
}
}
t.Run("Local bypass should be removed", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
app.Get("/api/v1/dev/clients", h.ListClients)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
req.Header.Set("Origin", "http://localhost:5174")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("ListClients should show all for SuperAdmin", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "super-user",
Role: domain.RoleSuperAdmin,
})
return c.Next()
})
app.Get("/api/v1/dev/clients", h.ListClients)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []clientSummary `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
// Should see both clients
assert.Equal(t, 2, len(res.Items))
})
t.Run("ListClients should filter by permit for non-SuperAdmin", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
tenantA := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-a",
Role: domain.RoleUser,
TenantID: &tenantA,
})
return c.Next()
})
app.Get("/api/v1/dev/clients", h.ListClients)
// Explicit permission for private client check bypass
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "System", "global", "manage_all").Return(true, nil).Maybe()
// Mock permit for the specific client
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(true, nil).Maybe()
// Deny for other clients
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-b", "view").Return(false, nil).Maybe()
mockKeto.On("ListRelations", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]service.RelationTuple{}, nil).Maybe()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []clientSummary `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
// Should only see client-tenant-a (tenant permit)
assert.Equal(t, 1, len(res.Items))
assert.Equal(t, "client-tenant-a", res.Items[0].ID)
})
t.Run("Tenant member should see empty list from DevFront clients if no relation", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
tenantA := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-member",
Role: domain.RoleUser,
TenantID: &tenantA,
})
return c.Next()
})
app.Get("/api/v1/dev/clients", h.ListClients)
// Deny all by default
mockKeto.On("CheckPermission", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(false, nil).Maybe()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var res struct {
Items []clientSummary `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
// Empty list because we didn't mock any specific 'view' permissions for this user
assert.Equal(t, 0, len(res.Items))
})
t.Run("GetClient should enforce isolation for non-SuperAdmin", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
tenantA := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-a",
Role: domain.RoleUser,
TenantID: &tenantA,
})
return c.Next()
})
app.Get("/api/v1/dev/clients/:id", h.GetClient)
// Case 1: Same tenant BUT no permit (Normal users need permit now)
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(false, nil).Once()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-a", nil)
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
// Case 2: Same tenant WITH permit
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(true, nil).Maybe()
mockKeto.On("ListRelations", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]service.RelationTuple{}, nil).Maybe()
req = httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-a", nil)
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Case 3: Different tenant
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-b", "view").Return(false, nil).Maybe()
req = httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-b", nil)
resp, _ = app.Test(req, -1)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
})
t.Run("CreateClient should record user_id and tenant_id", func(t *testing.T) {
mockKeto := new(devMockKetoService)
h := createHandler(mockKeto)
app := fiber.New()
tenantA := "tenant-a"
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
ID: "user-a",
Role: domain.RoleSuperAdmin, // Bypass for creation permission
TenantID: &tenantA,
})
return c.Next()
})
app.Post("/api/v1/dev/clients", h.CreateClient)
body, _ := json.Marshal(map[string]any{
"client_name": "New App",
"type": "pkce",
"redirectUris": []string{"http://localhost/cb"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Tenant-ID", "tenant-a")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
var res clientDetailResponse
json.NewDecoder(resp.Body).Decode(&res)
assert.Equal(t, "tenant-a", res.Client.Metadata["tenant_id"])
assert.Equal(t, "user-a", res.Client.Metadata["user_id"])
})
}

View File

@@ -0,0 +1,278 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type devMockRPUserMetadataRepo struct {
mock.Mock
}
func (m *devMockRPUserMetadataRepo) Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error) {
args := m.Called(ctx, clientID, userID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.RPUserMetadata), args.Error(1)
}
func (m *devMockRPUserMetadataRepo) Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error {
args := m.Called(ctx, metadata)
return args.Error(0)
}
func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"tenant_id": "tenant-1",
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "approvalLevel",
"valueType": "text",
"value": "A",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
repo.On("Upsert", mock.Anything, mock.MatchedBy(func(row *domain.RPUserMetadata) bool {
return row.ClientID == "client-1" &&
row.UserID == "user-1" &&
row.Metadata["approvalLevel"] == "A" &&
row.Metadata["approvalLevel_permissions"].(map[string]any)["readPermission"] == "admin_only" &&
row.Metadata["approvalLevel_permissions"].(map[string]any)["writePermission"] == "user_and_admin"
})).Return(nil).Once()
repo.On("Get", mock.Anything, "client-1", "user-1").Return(&domain.RPUserMetadata{
ClientID: "client-1",
UserID: "user-1",
Metadata: domain.JSONMap{"approvalLevel": "A"},
}, nil).Once()
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
app.Get("/api/v1/dev/clients/:id/users/:userId/metadata", h.GetRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{
"approvalLevel": "A",
"approvalLevel_permissions": map[string]any{
"writePermission": "user_and_admin",
},
},
})
putReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
putReq.Header.Set("Content-Type", "application/json")
putResp, _ := app.Test(putReq, -1)
assert.Equal(t, http.StatusOK, putResp.StatusCode)
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-1/users/user-1/metadata", nil)
getResp, _ := app.Test(getReq, -1)
assert.Equal(t, http.StatusOK, getResp.StatusCode)
var got map[string]any
assert.NoError(t, json.NewDecoder(getResp.Body).Decode(&got))
assert.Equal(t, "client-1", got["clientId"])
assert.Equal(t, "user-1", got["userId"])
assert.Equal(t, "A", got["metadata"].(map[string]any)["approvalLevel"])
repo.AssertExpectations(t)
}
func TestDevHandler_RPUserMetadataMirrorsToKratosTraits(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"tenant_id": "tenant-1",
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "approvalLevel",
"valueType": "text",
"value": "A",
"readPermission": "user_and_admin",
"writePermission": "admin_only",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
repo.On("Upsert", mock.Anything, mock.AnythingOfType("*domain.RPUserMetadata")).Return(nil).Once()
kratos := new(MockKratosAdmin)
kratos.On("GetIdentity", mock.Anything, "user-1").Return(&service.KratosIdentity{
ID: "user-1",
State: "active",
Traits: map[string]any{
"email": "user@example.com",
"name": "User One",
},
}, nil).Once()
var capturedTraits map[string]any
kratos.On("UpdateIdentity", mock.Anything, "user-1", mock.Anything, "active").Run(func(args mock.Arguments) {
capturedTraits = args.Get(2).(map[string]any)
}).Return(&service.KratosIdentity{ID: "user-1", State: "active", Traits: map[string]any{}}, nil).Once()
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
KratosAdmin: kratos,
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"approvalLevel": "B"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
require.Equal(t, http.StatusOK, resp.StatusCode)
rpClaims := capturedTraits["rp_custom_claims"].(map[string]any)
clientClaims := rpClaims["client-1"].(domain.JSONMap)
require.Equal(t, "B", clientClaims["approvalLevel"])
require.Equal(t, map[string]any{
"readPermission": "user_and_admin",
"writePermission": "admin_only",
}, clientClaims["approvalLevel_permissions"])
repo.AssertExpectations(t)
kratos.AssertExpectations(t)
}
func TestDevHandler_RPUserMetadataRejectsUndefinedClaimKey(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "contract_date",
"valueType": "date",
"value": "2026-06-09",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"unknown_claim": "A"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
}
func TestDevHandler_RPUserMetadataRejectsInvalidTypedClaimValue(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]any{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]any{
"id_token_claims": []map[string]any{
{
"namespace": "rp_claims",
"key": "contract_date",
"valueType": "date",
"value": "2026-06-09",
},
},
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"contract_date": "2026/06/09"},
})
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
package handler
import (
"baron-sso-backend/internal/response"
"github.com/gofiber/fiber/v2"
)
// errorJSON은 기존 error 필드를 유지하면서 기계 판독용 code를 명시적으로 추가합니다.
func errorJSON(c *fiber.Ctx, status int, message string) error {
return response.Error(c, status, response.StatusCode(status), message)
}
// errorJSONCode는 상태코드 기반 매핑만으로 부족한 경우 명시 코드를 강제할 때 사용합니다.
func errorJSONCode(c *fiber.Ctx, status int, code, message string) error {
return response.Error(c, status, code, message)
}

View File

@@ -0,0 +1,161 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"errors"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
// FederationHandler handles API requests for IdP federation.
type FederationHandler struct {
fedSvc *service.FederationService
repo repository.FederationRepository // For IdP Config CRUD
db *gorm.DB // For tenant existence checks, etc. in CRUD
}
// NewFederationHandler creates a new FederationHandler.
func NewFederationHandler(fedSvc *service.FederationService, repo repository.FederationRepository, db *gorm.DB) *FederationHandler {
return &FederationHandler{
fedSvc: fedSvc,
repo: repo,
db: db,
}
}
// InitiateOIDCLogin handles the start of the OIDC login flow.
// It expects `provider_id` and `login_challenge` as query parameters.
func (h *FederationHandler) InitiateOIDCLogin(c *fiber.Ctx) error {
providerID := c.Query("provider_id")
loginChallenge := c.Query("login_challenge")
if providerID == "" || loginChallenge == "" {
return errorJSON(c, fiber.StatusBadRequest, "provider_id and login_challenge are required")
}
redirectURL, err := h.fedSvc.InitiateOIDCLogin(c.Context(), providerID, loginChallenge)
if err != nil {
// Log the error properly in a real application
return errorJSON(c, fiber.StatusInternalServerError, "failed to initiate OIDC login")
}
return c.Redirect(redirectURL, fiber.StatusFound)
}
// HandleOIDCCallback handles the OIDC callback from the IdP.
func (h *FederationHandler) HandleOIDCCallback(c *fiber.Ctx) error {
code := c.Query("code")
state := c.Query("state")
if code == "" || state == "" {
return errorJSON(c, fiber.StatusBadRequest, "code and state are required")
}
redirectURL, err := h.fedSvc.HandleOIDCCallback(c.Context(), code, state)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to handle OIDC callback")
}
return c.Redirect(redirectURL, fiber.StatusFound)
}
// --- New Client-based IdP Config Methods ---
// ListIdpConfigsForClient handles listing all IdP configurations for a client.
func (h *FederationHandler) ListIdpConfigsForClient(c *fiber.Ctx) error {
clientID := c.Params("clientId")
if clientID == "" {
return errorJSON(c, fiber.StatusBadRequest, "clientId is required")
}
var configs []domain.IdentityProviderConfig
if err := h.db.Where("client_id = ?", clientID).Find(&configs).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(configs)
}
// CreateIdpConfigForClient handles the creation of a new IdP configuration for a client.
func (h *FederationHandler) CreateIdpConfigForClient(c *fiber.Ctx) error {
clientID := c.Params("clientId")
if clientID == "" {
return errorJSON(c, fiber.StatusBadRequest, "clientId is required in path")
}
var req domain.IdentityProviderConfig
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
// Assign clientID from path parameter
req.ClientID = clientID
// Basic validation
if req.DisplayName == "" || req.ProviderType == "" {
return errorJSON(c, fiber.StatusBadRequest, "display_name and provider_type are required")
}
// TODO: Optionally, validate if the clientID exists in Hydra
// Create in DB
if err := h.db.Create(&req).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.Status(fiber.StatusCreated).JSON(req)
}
// --- Deprecated Tenant-based IdP Config Methods ---
// ListIdpConfigsForTenant handles listing all IdP configurations for a tenant.
func (h *FederationHandler) ListIdpConfigsForTenant(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
if tenantID == "" {
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
}
// This is a temporary solution. We should create a proper method in the repository.
var configs []domain.IdentityProviderConfig
// Note: This now queries client_id, which is incorrect for tenants.
// This method is deprecated.
if err := h.db.Where("tenant_id = ?", tenantID).Find(&configs).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(configs)
}
// CreateIdpConfig handles the creation of a new IdP configuration.
func (h *FederationHandler) CreateIdpConfig(c *fiber.Ctx) error {
var req domain.IdentityProviderConfig
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
// Basic validation - This is the old validation logic
if req.ClientID == "" || req.DisplayName == "" || req.ProviderType == "" {
return errorJSON(c, fiber.StatusBadRequest, "client_id, display_name, and provider_type are required")
}
// This check is now incorrect and deprecated.
var tenant domain.Tenant
if err := h.db.First(&tenant, "id = ?", req.ClientID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errorJSON(c, fiber.StatusBadRequest, "tenant not found")
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
// Create in DB
if err := h.db.Create(&req).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.Status(fiber.StatusCreated).JSON(req)
}
// TODO: Re-implement Update, Delete handlers for IdP Configs for Clients

View File

@@ -0,0 +1,243 @@
package handler
import (
"baron-sso-backend/internal/domain"
"context"
"fmt"
"slices"
"strings"
)
const hanmacFamilyTenantSlug = "hanmac-family"
type hanmacEmailScope struct {
TenantIDs map[string]bool
Slugs map[string]bool
IDList []string
SlugList []string
}
type hanmacEmailEvaluation struct {
Email string
OriginalEmail string
SuggestedEmail string
Status string
Warnings []string
Message string
Blocking bool
LocalPart string
}
func (h *UserHandler) evaluateHanmacImportEmail(ctx context.Context, item bulkUserItem, scope *hanmacEmailScope, usedLocalParts map[string]bool) hanmacEmailEvaluation {
originalEmail := strings.TrimSpace(item.Email)
name := strings.TrimSpace(item.Name)
evaluation := hanmacEmailEvaluation{
Email: originalEmail,
OriginalEmail: originalEmail,
Status: "valid",
}
localPart, domainPart, err := domain.SplitEmailDomain(originalEmail)
if err != nil {
evaluation.Status = "blockingError"
evaluation.Message = "invalid email format"
evaluation.Blocking = true
return evaluation
}
base, needsReview, _ := domain.BuildKoreanNameEmailBase(name)
if needsReview {
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
evaluation.Status = "needsReview"
}
if localPart == "" {
if base == "" {
evaluation.Status = "blockingError"
evaluation.Message = "이름으로 이메일 ID를 제안할 수 없습니다."
evaluation.Blocking = true
return evaluation
}
nextLocalPart := nextAvailableHanmacLocalPart(base, usedLocalParts)
evaluation.Email = nextLocalPart + "@" + domainPart
evaluation.SuggestedEmail = evaluation.Email
evaluation.LocalPart = nextLocalPart
evaluation.Status = "suggested"
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "suggested")
return evaluation
}
evaluation.LocalPart = localPart
if usedLocalParts[localPart] {
evaluation.Status = "blockingError"
evaluation.Message = "한맥가족 내에서 이미 사용 중인 이메일 ID입니다."
evaluation.Blocking = true
return evaluation
}
if base != "" && !domain.MatchesSuggestedNameRule(localPart, base) {
evaluation.Status = "ruleMismatch"
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "ruleMismatch")
}
if evaluation.Status == "needsReview" && len(evaluation.Warnings) == 0 {
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
}
_ = scope
return evaluation
}
func (h *UserHandler) ensureHanmacCreateEmailAllowed(ctx context.Context, email string, tenantSlug string, tenantID string) error {
scope, err := h.resolveHanmacEmailScope(ctx)
if err != nil || scope == nil || !scope.ContainsTenant(tenantID, tenantSlug) {
return nil
}
localPart, err := domain.ExtractNormalizedEmailLocalPart(email)
if err != nil {
return err
}
usedLocalParts, err := h.loadHanmacLocalParts(ctx, scope)
if err != nil {
return err
}
if usedLocalParts[localPart] {
return fmt.Errorf("한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
}
return nil
}
func (h *UserHandler) resolveHanmacEmailScope(ctx context.Context) (*hanmacEmailScope, error) {
if h.TenantService == nil {
return nil, nil
}
tenants, _, err := h.TenantService.ListTenants(ctx, 10000, 0, "", "")
if err != nil {
return nil, err
}
var rootID string
for _, tenant := range tenants {
if strings.EqualFold(strings.TrimSpace(tenant.Slug), hanmacFamilyTenantSlug) {
rootID = tenant.ID
break
}
}
if rootID == "" {
return nil, nil
}
tenantByID := make(map[string]domain.Tenant, len(tenants))
for _, tenant := range tenants {
tenantByID[tenant.ID] = tenant
}
scope := &hanmacEmailScope{
TenantIDs: make(map[string]bool),
Slugs: make(map[string]bool),
}
for _, tenant := range tenants {
if isTenantDescendantOf(tenant, rootID, tenantByID) {
scope.TenantIDs[tenant.ID] = true
scope.Slugs[strings.ToLower(strings.TrimSpace(tenant.Slug))] = true
scope.IDList = append(scope.IDList, tenant.ID)
scope.SlugList = append(scope.SlugList, tenant.Slug)
}
}
return scope, nil
}
func (h *UserHandler) loadHanmacLocalParts(ctx context.Context, scope *hanmacEmailScope) (map[string]bool, error) {
used := make(map[string]bool)
if h.UserRepo == nil || scope == nil {
return used, nil
}
if len(scope.IDList) > 0 {
users, err := h.UserRepo.FindByTenantIDs(ctx, scope.IDList)
if err != nil {
return nil, err
}
addUserEmailLocalParts(used, users)
}
if len(scope.SlugList) > 0 {
users, err := h.UserRepo.FindByCompanyCodes(ctx, scope.SlugList)
if err != nil {
return nil, err
}
addUserEmailLocalParts(used, users)
}
return used, nil
}
func (s *hanmacEmailScope) ContainsTenant(tenantID string, slug string) bool {
if s == nil {
return false
}
if tenantID != "" && s.TenantIDs[tenantID] {
return true
}
return s.Slugs[strings.ToLower(strings.TrimSpace(slug))]
}
func isTenantDescendantOf(tenant domain.Tenant, rootID string, tenantByID map[string]domain.Tenant) bool {
if tenant.ID == rootID {
return true
}
visited := make(map[string]bool)
parentID := ""
if tenant.ParentID != nil {
parentID = *tenant.ParentID
}
for parentID != "" {
if parentID == rootID {
return true
}
if visited[parentID] {
return false
}
visited[parentID] = true
parent, ok := tenantByID[parentID]
if !ok || parent.ParentID == nil {
return false
}
parentID = *parent.ParentID
}
return false
}
func addUserEmailLocalParts(target map[string]bool, users []domain.User) {
for _, user := range users {
localPart, err := domain.ExtractNormalizedEmailLocalPart(user.Email)
if err == nil && localPart != "" {
target[localPart] = true
}
}
}
func nextAvailableHanmacLocalPart(base string, usedLocalParts map[string]bool) string {
base = strings.ToLower(strings.TrimSpace(base))
if base == "" {
return ""
}
if !usedLocalParts[base] {
return base
}
for index := 1; ; index++ {
candidate := fmt.Sprintf("%s%d", base, index)
if !usedLocalParts[candidate] {
return candidate
}
}
}
func appendUniqueString(values []string, value string) []string {
if slices.Contains(values, value) {
return values
}
return append(values, value)
}

View File

@@ -0,0 +1,98 @@
package handler
import (
"baron-sso-backend/internal/domain"
"crypto/rand"
"encoding/binary"
"testing"
"unicode"
)
// 정책을 받아 필수 요구사항을 모두 포함하는 비밀번호를 생성한다.
func generatePasswordFromPolicy(policy *domain.PasswordPolicy) string {
minLen := policy.MinLength
if minLen < 8 {
minLen = 12 // 안전한 기본값
}
pwd := make([]rune, 0, minLen)
if policy.Lowercase {
pwd = append(pwd, 'a')
}
if policy.Uppercase {
pwd = append(pwd, 'B')
}
if policy.Number {
pwd = append(pwd, '3')
}
if policy.NonAlphanumeric {
pwd = append(pwd, '!')
}
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
for len(pwd) < minLen {
pwd = append(pwd, rune(charset[randomInt(len(charset))]))
}
// 섞어서 예측 가능성을 낮춘다.
for i := range pwd {
j := randomInt(len(pwd))
pwd[i], pwd[j] = pwd[j], pwd[i]
}
return string(pwd)
}
func randomInt(n int) int {
if n <= 0 {
return 0
}
var b [8]byte
if _, err := rand.Read(b[:]); err != nil {
return 0
}
return int(binary.BigEndian.Uint64(b[:]) % uint64(n))
}
func TestGeneratePasswordUsesNonAlphanumericRequirement(t *testing.T) {
policy := &domain.PasswordPolicy{
MinLength: 8,
Lowercase: true,
Uppercase: true,
Number: true,
NonAlphanumeric: true,
}
pwd := generatePasswordFromPolicy(policy)
if len(pwd) < policy.MinLength {
t.Fatalf("비밀번호 길이가 정책 최소 길이 미만: got %d, want >= %d", len(pwd), policy.MinLength)
}
var hasLower, hasUpper, hasNumber, hasSymbol bool
for _, r := range pwd {
switch {
case unicode.IsLower(r):
hasLower = true
case unicode.IsUpper(r):
hasUpper = true
case unicode.IsNumber(r):
hasNumber = true
case !unicode.IsLetter(r) && !unicode.IsNumber(r):
hasSymbol = true
}
}
if policy.Lowercase && !hasLower {
t.Fatalf("소문자 요구사항 미충족: %q", pwd)
}
if policy.Uppercase && !hasUpper {
t.Fatalf("대문자 요구사항 미충족: %q", pwd)
}
if policy.Number && !hasNumber {
t.Fatalf("숫자 요구사항 미충족: %q", pwd)
}
if policy.NonAlphanumeric && !hasSymbol {
t.Fatalf("비영문자 요구사항 미충족: %q", pwd)
}
}

View File

@@ -0,0 +1,114 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"log/slog"
"github.com/gofiber/fiber/v2"
)
type RelyingPartyHandler struct {
Service service.RelyingPartyService
KratosAdmin service.KratosAdminService
}
func NewRelyingPartyHandler(s service.RelyingPartyService, kratos service.KratosAdminService) *RelyingPartyHandler {
return &RelyingPartyHandler{Service: s, KratosAdmin: kratos}
}
func (h *RelyingPartyHandler) Create(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
if tenantID == "" {
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
}
var req domain.HydraClient
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
rp, err := h.Service.Create(c.Context(), tenantID, req)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.Status(fiber.StatusCreated).JSON(rp)
}
func (h *RelyingPartyHandler) ListAll(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: user profile not found in context")
}
var rps []domain.RelyingParty
var err error
role := domain.NormalizeRole(profile.Role)
if role == domain.RoleSuperAdmin {
rps, err = h.Service.ListAll(c.Context())
} else if role == "tenant_admin" && profile.TenantID != nil {
rps, err = h.Service.List(c.Context(), *profile.TenantID)
} else {
slog.Warn("Forbidden access to all applications", "userID", profile.ID, "role", role)
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient role to list all applications")
}
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(rps)
}
func (h *RelyingPartyHandler) List(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
if tenantID == "" {
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
}
rps, err := h.Service.List(c.Context(), tenantID)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(rps)
}
func (h *RelyingPartyHandler) Get(c *fiber.Ctx) error {
id := c.Params("id")
rp, hydraClient, err := h.Service.Get(c.Context(), id)
if err != nil {
return errorJSON(c, fiber.StatusNotFound, "relying party not found")
}
return c.JSON(fiber.Map{
"relyingParty": rp,
"oauth2Config": hydraClient,
})
}
func (h *RelyingPartyHandler) Update(c *fiber.Ctx) error {
id := c.Params("id")
var req domain.HydraClient
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
rp, err := h.Service.Update(c.Context(), id, req)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(rp)
}
func (h *RelyingPartyHandler) Delete(c *fiber.Ctx) error {
id := c.Params("id")
if err := h.Service.Delete(c.Context(), id); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusNoContent)
}

View File

@@ -0,0 +1,255 @@
package handler
import (
"html"
"os"
"strings"
"github.com/gofiber/fiber/v2"
)
type RPManifestHandler struct{}
const rpObjectLookupMermaid = `flowchart TD
A[RP request] --> B{obj_id supplied?}
B -->|yes| C[Normalize object type and obj_id]
B -->|no| D{Route has client_id?}
D -->|yes| E[obj_id = RelyingParty:<client_id>]
D -->|no| F{Route has tenant_id?}
F -->|yes| G[obj_id = Tenant:<tenant_id>]
F -->|no| H[Reject: explicit obj_id required]
C --> I[Check Keto relation]
E --> I
G --> I
I --> J{allowed?}
J -->|yes| K[Inject trusted Baron headers]
J -->|no| L[Reject request]
K --> M[Write audit with obj_id, relation, client_id, X-Request-Id]`
const rpExternalKeyMermaid = `flowchart TD
A[User authenticates through Baron SSO] --> B[Baron resolves internal identity]
B --> C[Baron derives or loads Baron-issued alias]
C --> D[Baron injects X-Baron-External-Key]
D --> E[Baron injects X-Baron-Subject]
E --> I[RP receives trusted headers from Baron gateway]
I --> F[RP upserts local user with provider + X-Baron-External-Key]
F --> G[RP stores the full external key as opaque value]
G --> H[RP never parses or stores raw kratos_identity_id]`
func NewRPManifestHandler() *RPManifestHandler {
return &RPManifestHandler{}
}
func (h *RPManifestHandler) GetJSON(c *fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
return c.JSON(buildRPManifest(c))
}
func (h *RPManifestHandler) GetSchema(c *fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
return c.JSON(rpManifestSchema())
}
func (h *RPManifestHandler) GetHTML(c *fiber.Ctx) error {
manifest := buildRPManifest(c)
issuer, _ := manifest["issuer"].(string)
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
c.Type("html", "utf-8")
return c.SendString(`<!doctype html>
<html lang="ko">
<head>
<meta charset="utf-8">
<title>Baron RP IAM Manifest</title>
<style>
body { font-family: system-ui, sans-serif; margin: 2rem; line-height: 1.6; max-width: 920px; }
code, pre { background: #f5f5f5; border-radius: 4px; padding: .1rem .3rem; }
pre { padding: 1rem; overflow: auto; }
table { border-collapse: collapse; width: 100%; }
th, td { border: 1px solid #ddd; padding: .5rem; text-align: left; }
</style>
</head>
<body>
<h1>Baron RP IAM Manifest</h1>
<p>외부 RP가 Baron SSO/Ory Stack/Keto 기반 공용 IAM을 연동하기 위한 공개 규격입니다.</p>
<ul>
<li>Machine-readable manifest: <a href="/.well-known/baron-rp-manifest.json">/.well-known/baron-rp-manifest.json</a></li>
<li>JSON schema: <a href="/.well-known/baron-rp-manifest.schema.json">/.well-known/baron-rp-manifest.schema.json</a></li>
</ul>
<h2>Issuer</h2>
<pre>` + html.EscapeString(issuer) + `</pre>
<h2>Identity Contract</h2>
<table>
<tr><th>용도</th><th>Header</th><th>정책</th></tr>
<tr><td>Keto subject</td><td><code>X-Baron-Subject</code></td><td><code>User:&lt;baron_identity_id&gt;</code> 전체 문자열을 opaque subject로 취급합니다.</td></tr>
<tr><td>RP upsert key</td><td><code>X-Baron-External-Key</code></td><td>Baron-issued alias입니다. RP가 만들거나 제출하지 않고, Baron이 주입한 전체 문자열을 local user external key로 저장합니다.</td></tr>
<tr><td>RP client</td><td><code>X-Baron-Client-ID</code></td><td>현재 접근 중인 RP client id입니다.</td></tr>
</table>
<h2>External Key Flow</h2>
<p><code>X-Baron-External-Key</code>는 RP 입력값이 아니라 Baron이 인증된 subject에서 발급/조회해 주입하는 opaque alias입니다. RP upserts local user from the Baron-issued alias.</p>
<pre>` + "```mermaid\n" + html.EscapeString(rpExternalKeyMermaid) + "\n```" + `</pre>
<h2>Object Lookup</h2>
<pre>check(User:abc, viewers, RelyingParty:&lt;client_id&gt;)
check(User:abc, members, Tenant:&lt;tenant_id&gt;)
check(User:abc, viewers, Resource:&lt;resource_type&gt;:&lt;resource_id&gt;)</pre>
<h2>audit_contract</h2>
<p>권한과 설정을 변경하는 command는 sync audit write에 실패하면 요청도 실패해야 합니다. Read audit은 allowlist된 조회에 한해 best effort로 취급합니다.</p>
<pre>{
"mutating_command_mode": "fail_closed_sync",
"missing_audit_sink_behavior": "reject_mutation",
"correlation_header": "X-Request-Id"
}</pre>
<h2>Object Lookup Flow</h2>
<pre>` + "```mermaid\n" + html.EscapeString(rpObjectLookupMermaid) + "\n```" + `</pre>
</body>
</html>`)
}
func buildRPManifest(c *fiber.Ctx) map[string]any {
issuer := resolvePublicRequestBaseURL(c, os.Getenv("BACKEND_PUBLIC_URL"))
if issuer == "" {
issuer = strings.TrimRight(os.Getenv("USERFRONT_URL"), "/")
}
if issuer == "" {
issuer = "https://sso.hmac.kr"
}
issuer = strings.TrimRight(issuer, "/")
return map[string]any{
"version": "2026-05-11",
"issuer": issuer,
"oidc": map[string]any{
"discovery_url": issuer + "/.well-known/openid-configuration",
"jwks_url": issuer + "/.well-known/jwks.json",
"supported_flows": []string{"authorization_code_pkce"},
"required_scopes": []string{"openid", "profile", "email"},
},
"iam": map[string]any{
"authorization_engine": "ory-keto",
"subject_format": "User:<baron_identity_id>",
"target_object_patterns": []string{
"RelyingParty:<client_id>",
"Tenant:<tenant_id>",
"Resource:<resource_type>:<resource_id>",
},
"supported_relations": []string{
"admins",
"users",
"viewers",
"operators",
"members",
"owners",
"editors",
},
},
"identity_contract": map[string]any{
"subject_header": "X-Baron-Subject",
"external_key_header": "X-Baron-External-Key",
"external_key_is_opaque": true,
"external_key_issuer": "baron",
"external_key_delivery": "baron_injected_header",
"external_key_lifecycle": "issued_or_loaded_after_successful_authentication_before_rp_request",
"rp_supplied_external_key_allowed": false,
"rp_user_upsert_source": "rp_must_upsert_from_header_value",
"raw_kratos_identity_id_exposed": false,
"rp_user_upsert_key": "provider + external_key",
"email_is_stable_primary_key": false,
"initial_external_key_expression": "X-Baron-External-Key",
"fallback_to_subject_allowed": false,
},
"trusted_headers": map[string]any{
"subject": "X-Baron-Subject",
"external_key": "X-Baron-External-Key",
"email": "X-Baron-Email",
"tenant": "X-Baron-Tenant",
"relations": "X-Baron-Relations",
"client_id": "X-Baron-Client-ID",
},
"object_lookup": map[string]any{
"rp_level": map[string]any{
"object": "RelyingParty:<client_id>",
"relations": []string{"viewers", "users", "operators", "admins"},
"example": "check(User:abc, viewers, RelyingParty:mh-dashboard)",
},
"tenant_level": map[string]any{
"object": "Tenant:<tenant_id>",
"relations": []string{"members", "admins", "owners"},
"example": "check(User:abc, members, Tenant:9caf62e1-297d-4e8f-870b-61780998bbe)",
},
"resource_level": map[string]any{
"object": "Resource:<resource_type>:<resource_id>",
"relations": []string{"viewers", "editors", "owners"},
"example": "check(User:abc, viewers, Resource:dashboard:mh-monthly-2026-05)",
},
"recommended_order": []string{
"authenticated",
"rp_level",
"tenant_or_resource_level",
"trusted_header_injection",
},
},
"object_lookup_flow": map[string]any{
"format": "mermaid",
"mermaid": rpObjectLookupMermaid,
},
"external_key_flow": map[string]any{
"format": "mermaid",
"mermaid": rpExternalKeyMermaid,
},
"audit_contract": map[string]any{
"mutating_command_mode": "fail_closed_sync",
"missing_audit_sink_behavior": "reject_mutation",
"read_audit_mode": "best_effort_allowlisted",
"correlation_header": "X-Request-Id",
"rp_business_audit_required": true,
"baron_gateway_audit_required": true,
"required_detail_fields": []string{
"obj_id",
"relation",
"client_id",
"subject",
"decision",
},
"guarantee_scope": "Baron-mediated IAM mutations fail closed on audit write failure; RP-owned business events must be emitted by the RP with the same correlation header.",
},
"security_requirements": map[string]any{
"strip_external_identity_headers": true,
"backend_direct_exposure_allowed": false,
"static_snapshot_requires_auth": true,
"email_as_primary_key_allowed": false,
},
}
}
func rpManifestSchema() map[string]any {
return map[string]any{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Baron RP IAM Manifest",
"type": "object",
"required": []string{
"version",
"issuer",
"oidc",
"iam",
"trusted_headers",
"identity_contract",
"object_lookup",
"object_lookup_flow",
"external_key_flow",
"audit_contract",
"security_requirements",
},
"properties": map[string]any{
"version": map[string]any{"type": "string"},
"issuer": map[string]any{"type": "string", "format": "uri"},
"oidc": map[string]any{"type": "object"},
"iam": map[string]any{"type": "object"},
"trusted_headers": map[string]any{"type": "object"},
"identity_contract": map[string]any{"type": "object"},
"object_lookup": map[string]any{"type": "object"},
"object_lookup_flow": map[string]any{"type": "object"},
"external_key_flow": map[string]any{"type": "object"},
"audit_contract": map[string]any{"type": "object"},
"security_requirements": map[string]any{"type": "object"},
},
}
}

View File

@@ -0,0 +1,125 @@
package handler
import (
"encoding/json"
"io"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
func TestRPManifestJSONIncludesIAMAndExternalKeyContract(t *testing.T) {
t.Setenv("BACKEND_PUBLIC_URL", "")
app := fiber.New()
h := NewRPManifestHandler()
app.Get("/.well-known/baron-rp-manifest.json", h.GetJSON)
req := httptest.NewRequest("GET", "/.well-known/baron-rp-manifest.json", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "sso.hmac.kr")
resp, err := app.Test(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Contains(t, resp.Header.Get("Content-Type"), "application/json")
var body map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, "https://sso.hmac.kr", body["issuer"])
oidc := body["oidc"].(map[string]any)
require.Equal(t, "https://sso.hmac.kr/.well-known/openid-configuration", oidc["discovery_url"])
require.Equal(t, "https://sso.hmac.kr/.well-known/jwks.json", oidc["jwks_url"])
iam := body["iam"].(map[string]any)
require.Equal(t, "ory-keto", iam["authorization_engine"])
require.Equal(t, "User:<baron_identity_id>", iam["subject_format"])
require.Contains(t, iam["target_object_patterns"].([]any), "RelyingParty:<client_id>")
require.Contains(t, iam["target_object_patterns"].([]any), "Tenant:<tenant_id>")
require.Contains(t, iam["target_object_patterns"].([]any), "Resource:<resource_type>:<resource_id>")
identity := body["identity_contract"].(map[string]any)
require.Equal(t, "X-Baron-External-Key", identity["external_key_header"])
require.Equal(t, true, identity["external_key_is_opaque"])
require.Equal(t, false, identity["raw_kratos_identity_id_exposed"])
require.Equal(t, "baron", identity["external_key_issuer"])
require.Equal(t, "baron_injected_header", identity["external_key_delivery"])
require.Equal(t, false, identity["rp_supplied_external_key_allowed"])
require.Equal(t, "rp_must_upsert_from_header_value", identity["rp_user_upsert_source"])
headers := body["trusted_headers"].(map[string]any)
require.Equal(t, "X-Baron-Subject", headers["subject"])
require.Equal(t, "X-Baron-External-Key", headers["external_key"])
require.Equal(t, "X-Baron-Client-ID", headers["client_id"])
security := body["security_requirements"].(map[string]any)
require.Equal(t, true, security["strip_external_identity_headers"])
require.Equal(t, false, security["backend_direct_exposure_allowed"])
audit := body["audit_contract"].(map[string]any)
require.Equal(t, "fail_closed_sync", audit["mutating_command_mode"])
require.Equal(t, "reject_mutation", audit["missing_audit_sink_behavior"])
require.Equal(t, "X-Request-Id", audit["correlation_header"])
require.Contains(t, audit["required_detail_fields"].([]any), "obj_id")
require.Contains(t, audit["required_detail_fields"].([]any), "client_id")
flow := body["object_lookup_flow"].(map[string]any)
require.Contains(t, flow["mermaid"].(string), "flowchart TD")
require.Contains(t, flow["mermaid"].(string), "obj_id")
aliasFlow := body["external_key_flow"].(map[string]any)
require.Contains(t, aliasFlow["mermaid"].(string), "Baron resolves internal identity")
require.Contains(t, aliasFlow["mermaid"].(string), "Baron injects X-Baron-External-Key")
require.Contains(t, aliasFlow["mermaid"].(string), "RP upserts local user")
require.NotContains(t, aliasFlow["mermaid"].(string), "RP creates external key")
}
func TestRPManifestSchemaRequiresLookupAndIdentityContracts(t *testing.T) {
app := fiber.New()
h := NewRPManifestHandler()
app.Get("/.well-known/baron-rp-manifest.schema.json", h.GetSchema)
resp, err := app.Test(httptest.NewRequest("GET", "/.well-known/baron-rp-manifest.schema.json", nil))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, fiber.StatusOK, resp.StatusCode)
var body map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
required := body["required"].([]any)
require.Contains(t, required, "iam")
require.Contains(t, required, "trusted_headers")
require.Contains(t, required, "identity_contract")
require.Contains(t, required, "object_lookup")
require.Contains(t, required, "audit_contract")
require.Contains(t, required, "object_lookup_flow")
require.Contains(t, required, "external_key_flow")
}
func TestRPManifestHTMLLinksMachineReadableManifest(t *testing.T) {
app := fiber.New()
h := NewRPManifestHandler()
app.Get("/.well-known/baron-rp-manifest", h.GetHTML)
resp, err := app.Test(httptest.NewRequest("GET", "/.well-known/baron-rp-manifest", nil))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Contains(t, resp.Header.Get("Content-Type"), "text/html")
raw, err := io.ReadAll(resp.Body)
require.NoError(t, err)
text := string(raw)
require.Contains(t, text, "/.well-known/baron-rp-manifest.json")
require.Contains(t, text, "X-Baron-External-Key")
require.Contains(t, text, "RelyingParty:&lt;client_id&gt;")
require.Contains(t, text, "```mermaid")
require.Contains(t, text, "audit_contract")
require.Contains(t, text, "Baron-issued alias")
require.Contains(t, text, "RP upserts local user")
}

View File

@@ -0,0 +1,154 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"context"
"fmt"
"log/slog"
"maps"
"strings"
)
const tenantAccessCleanupClientPageSize = 500
func cleanupDeletedTenantReferences(ctx context.Context, hydra *service.HydraAdminService, consentRepo repository.ClientConsentRepository, ketoOutbox repository.KetoOutboxRepository, deletedTenantIDs []string) error {
if hydra == nil {
return nil
}
deletedTenantSet := make(map[string]struct{}, len(deletedTenantIDs))
for _, tenantID := range deletedTenantIDs {
tenantID = strings.TrimSpace(tenantID)
if tenantID == "" {
continue
}
deletedTenantSet[tenantID] = struct{}{}
}
if len(deletedTenantSet) == 0 {
return nil
}
for offset := 0; ; offset += tenantAccessCleanupClientPageSize {
clients, err := hydra.ListClients(ctx, tenantAccessCleanupClientPageSize, offset)
if err != nil {
return fmt.Errorf("failed to list hydra clients for tenant cleanup: %w", err)
}
for _, client := range clients {
beforeMetadata := maps.Clone(client.Metadata)
updatedMetadata, changed, removedOwnerTenantID := pruneDeletedTenantReferences(beforeMetadata, deletedTenantSet)
if !changed {
continue
}
updatedClient := client
updatedClient.Metadata = updatedMetadata
if _, err := hydra.UpdateClient(ctx, client.ClientID, updatedClient); err != nil {
return fmt.Errorf("failed to update hydra client %s during tenant cleanup: %w", client.ClientID, err)
}
if removedOwnerTenantID != "" {
if err := enqueueDeletedTenantRelyingPartyParentCleanup(ctx, ketoOutbox, client.ClientID, removedOwnerTenantID); err != nil {
return fmt.Errorf("failed to cleanup RP parent relation for client %s during tenant cleanup: %w", client.ClientID, err)
}
}
if tenantAccessPolicyChanged(beforeMetadata, updatedMetadata) {
if err := revokeClientConsentsForPolicyChange(ctx, hydra, consentRepo, client.ClientID); err != nil {
return fmt.Errorf("failed to revoke consent sessions for client %s during tenant cleanup: %w", client.ClientID, err)
}
}
}
if len(clients) < tenantAccessCleanupClientPageSize {
return nil
}
}
}
func pruneDeletedTenantReferences(metadata map[string]any, deletedTenantSet map[string]struct{}) (map[string]any, bool, string) {
if len(deletedTenantSet) == 0 {
return metadata, false, ""
}
ownerTenantID := normalizeMetadataString(metadata["tenant_id"])
_, ownerDeleted := deletedTenantSet[ownerTenantID]
allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])
filtered := make([]string, 0, len(allowedTenants))
for _, tenantID := range allowedTenants {
if _, ok := deletedTenantSet[tenantID]; ok {
continue
}
filtered = append(filtered, tenantID)
}
allowedChanged := len(filtered) != len(allowedTenants)
if !ownerDeleted && !allowedChanged {
return metadata, false, ""
}
updated := maps.Clone(metadata)
if ownerDeleted {
delete(updated, "tenant_id")
}
if len(filtered) == 0 {
delete(updated, clientAllowedTenantsKey)
updated[clientTenantAccessRestrictedKey] = false
return updated, true, ownerTenantID
}
updated[clientAllowedTenantsKey] = uniqueSortedStrings(filtered)
updated[clientTenantAccessRestrictedKey] = true
return updated, true, ownerTenantID
}
func enqueueDeletedTenantRelyingPartyParentCleanup(ctx context.Context, ketoOutbox repository.KetoOutboxRepository, clientID, tenantID string) error {
if ketoOutbox == nil {
return nil
}
clientID = strings.TrimSpace(clientID)
tenantID = strings.TrimSpace(tenantID)
if clientID == "" || tenantID == "" {
return nil
}
return ketoOutbox.Create(ctx, &domain.KetoOutbox{
Namespace: "RelyingParty",
Object: clientID,
Relation: "parents",
Subject: "Tenant:" + tenantID,
Action: domain.KetoOutboxActionDelete,
})
}
func revokeClientConsentsForPolicyChange(ctx context.Context, hydra *service.HydraAdminService, consentRepo repository.ClientConsentRepository, clientID string) error {
if consentRepo == nil || hydra == nil {
return nil
}
subjects, err := consentRepo.ListSubjectsByClient(ctx, clientID)
if err != nil {
return err
}
for _, subject := range subjects {
subject = strings.TrimSpace(subject)
if subject == "" {
continue
}
if err := hydra.RevokeConsentSessions(ctx, subject, clientID); err != nil {
return err
}
}
return consentRepo.DeleteByClient(ctx, clientID)
}
func logTenantCleanupFailure(err error, deletedTenantIDs []string) {
if err == nil {
return
}
slog.Error("Failed to cleanup RP tenant restrictions after tenant deletion", "tenant_ids", deletedTenantIDs, "error", err)
}

View File

@@ -0,0 +1,178 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/repository"
"baron-sso-backend/internal/service"
"context"
"encoding/json"
"net/http"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func TestPruneDeletedTenantReferences_PreservesOtherAllowedTenants(t *testing.T) {
metadata := map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []string{"keep-tenant", "deleted-tenant"},
"tenant_id": "deleted-tenant",
}
updated, changed, removedOwnerTenantID := pruneDeletedTenantReferences(metadata, map[string]struct{}{
"deleted-tenant": {},
})
require.True(t, changed)
assert.Equal(t, "deleted-tenant", removedOwnerTenantID)
assert.Equal(t, true, updated[clientTenantAccessRestrictedKey])
assert.Equal(t, []string{"keep-tenant"}, updated[clientAllowedTenantsKey])
_, exists := updated["tenant_id"]
assert.False(t, exists)
}
func TestPruneDeletedTenantReferences_DisablesRestrictionWhenLastTenantRemoved(t *testing.T) {
metadata := map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []string{"deleted-tenant"},
"tenant_id": "deleted-tenant",
}
updated, changed, removedOwnerTenantID := pruneDeletedTenantReferences(metadata, map[string]struct{}{
"deleted-tenant": {},
})
require.True(t, changed)
assert.Equal(t, "deleted-tenant", removedOwnerTenantID)
assert.Equal(t, false, updated[clientTenantAccessRestrictedKey])
_, exists := updated[clientAllowedTenantsKey]
assert.False(t, exists)
_, exists = updated["tenant_id"]
assert.False(t, exists)
}
func TestCleanupDeletedTenantReferences_PrunesClientsAndRevokesConsents(t *testing.T) {
var (
mu sync.Mutex
page0Called bool
updated = map[string]map[string]any{}
revokes []string
)
transport := roundTripFunc(func(req *http.Request) (*http.Response, error) {
mu.Lock()
defer mu.Unlock()
switch {
case req.Method == http.MethodGet && req.URL.Path == "/clients":
switch req.URL.Query().Get("offset") {
case "":
page0Called = true
return httpJSONAny(req, http.StatusOK, []domain.HydraClient{
{
ClientID: "client-keep",
Metadata: map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []string{"keep-tenant", "deleted-tenant"},
"tenant_id": "deleted-tenant",
},
},
{
ClientID: "client-drop",
Metadata: map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []string{"deleted-tenant"},
"tenant_id": "deleted-tenant",
},
},
}), nil
default:
return httpResponse(req, http.StatusBadRequest, "unexpected offset"), nil
}
case req.Method == http.MethodPut && strings.HasPrefix(req.URL.Path, "/clients/"):
var client domain.HydraClient
require.NoError(t, json.NewDecoder(req.Body).Decode(&client))
updated[client.ClientID] = client.Metadata
return httpJSONAny(req, http.StatusOK, client), nil
case req.Method == http.MethodDelete && req.URL.Path == "/oauth2/auth/sessions/consent":
revokes = append(revokes, req.URL.Query().Get("subject")+"|"+req.URL.Query().Get("client"))
return httpResponse(req, http.StatusNoContent, ""), nil
default:
return httpResponse(req, http.StatusNotFound, "unexpected request"), nil
}
})
hydra := &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
}
consentRepo := &mockConsentRepo{
consents: []domain.ClientConsent{
{ClientID: "client-keep", Subject: "user-a"},
{ClientID: "client-drop", Subject: "user-b"},
},
}
outbox := &tenantCleanupMockKetoOutboxRepository{}
err := cleanupDeletedTenantReferences(context.Background(), hydra, consentRepo, outbox, []string{"deleted-tenant"})
require.NoError(t, err)
assert.True(t, page0Called)
assert.Equal(t, map[string]any{
clientTenantAccessRestrictedKey: true,
clientAllowedTenantsKey: []any{"keep-tenant"},
}, updated["client-keep"])
assert.Equal(t, map[string]any{
clientTenantAccessRestrictedKey: false,
}, updated["client-drop"])
assert.ElementsMatch(t, []string{"user-a|client-keep", "user-b|client-drop"}, revokes)
assert.Empty(t, consentRepo.consents)
require.Len(t, outbox.entries, 2)
assert.ElementsMatch(t, []string{"client-keep", "client-drop"}, []string{outbox.entries[0].Object, outbox.entries[1].Object})
for _, entry := range outbox.entries {
assert.Equal(t, "RelyingParty", entry.Namespace)
assert.Equal(t, "parents", entry.Relation)
assert.Equal(t, "Tenant:deleted-tenant", entry.Subject)
assert.Equal(t, domain.KetoOutboxActionDelete, entry.Action)
}
}
type tenantCleanupMockKetoOutboxRepository struct {
entries []domain.KetoOutbox
}
var _ repository.KetoOutboxRepository = (*tenantCleanupMockKetoOutboxRepository)(nil)
func (m *tenantCleanupMockKetoOutboxRepository) Create(ctx context.Context, entry *domain.KetoOutbox) error {
if entry == nil {
return nil
}
m.entries = append(m.entries, *entry)
return nil
}
func (m *tenantCleanupMockKetoOutboxRepository) CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error {
return m.Create(context.Background(), entry)
}
func (m *tenantCleanupMockKetoOutboxRepository) FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error) {
return nil, nil
}
func (m *tenantCleanupMockKetoOutboxRepository) ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error) {
return nil, nil
}
func (m *tenantCleanupMockKetoOutboxRepository) UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error {
return nil
}
func (m *tenantCleanupMockKetoOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
return nil
}

View File

@@ -0,0 +1,155 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"errors"
"fmt"
"strings"
"github.com/google/uuid"
)
func representativeTenantIDFromTraits(traits map[string]any) string {
if value := tenantClaimString(traits, "tenant_id"); value != "" {
return value
}
if value := tenantClaimString(traits, "primaryTenantId"); value != "" {
return value
}
if metadata, ok := traits["metadata"].(map[string]any); ok {
if value := tenantClaimString(metadata, "primaryTenantId"); value != "" {
return value
}
}
appointments := tenantAssignmentAppointmentsFromTraits(traits)
for _, appointment := range appointments {
if tenantAssignmentBool(appointment, "isPrimary", "primary", "representative", "isRepresentative") {
if value := tenantAssignmentTenantID(appointment); value != "" {
return value
}
}
}
for _, appointment := range appointments {
if value := tenantAssignmentTenantID(appointment); value != "" {
return value
}
}
for _, tenantID := range tenantNamespaceIDsFromTraits(traits) {
return tenantID
}
return ""
}
func joinedTenantIDsFromTraits(traits map[string]any, representativeTenantID string) []string {
values := make([]string, 0)
if representativeTenantID != "" {
values = append(values, representativeTenantID)
}
if value := tenantClaimString(traits, "tenant_id"); value != "" {
values = append(values, value)
}
for _, appointment := range tenantAssignmentAppointmentsFromTraits(traits) {
if value := tenantAssignmentTenantID(appointment); value != "" {
values = append(values, value)
}
}
values = append(values, tenantNamespaceIDsFromTraits(traits)...)
return uniqueSortedStrings(values)
}
func tenantAssignmentAppointmentsFromTraits(traits map[string]any) []map[string]any {
raw := rawAdditionalAppointments(traits)
switch values := raw.(type) {
case []any:
appointments := make([]map[string]any, 0, len(values))
for _, item := range values {
if appointment, ok := item.(map[string]any); ok {
appointments = append(appointments, appointment)
}
}
return appointments
case []map[string]any:
return values
default:
return nil
}
}
func tenantAssignmentTenantID(appointment map[string]any) string {
for _, key := range []string{"tenantId", "tenant_id"} {
if value := tenantClaimString(appointment, key); value != "" {
return value
}
}
return ""
}
func tenantAssignmentBool(values map[string]any, keys ...string) bool {
for _, key := range keys {
raw, ok := values[key]
if !ok || raw == nil {
continue
}
switch value := raw.(type) {
case bool:
if value {
return true
}
case string:
normalized := strings.ToLower(strings.TrimSpace(value))
if normalized == "true" || normalized == "1" || normalized == "yes" {
return true
}
}
}
return false
}
func tenantNamespaceIDsFromTraits(traits map[string]any) []string {
if traits == nil {
return nil
}
ids := make([]string, 0)
for key, value := range traits {
if key == "" || key == "metadata" {
continue
}
switch value.(type) {
case map[string]any:
ids = append(ids, key)
}
}
return uniqueSortedStrings(ids)
}
func createPersonalTenantForUser(ctx context.Context, tenantService service.TenantService, email string) (*domain.Tenant, error) {
if tenantService == nil {
return nil, errors.New("tenant service unavailable")
}
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
if normalizedEmail == "" {
normalizedEmail = "user"
}
slug := "personal-" + strings.ReplaceAll(uuid.NewString(), "-", "")
tenant, err := tenantService.RegisterTenant(
ctx,
fmt.Sprintf("Personal - %s", normalizedEmail),
slug,
domain.TenantTypePersonal,
"Automatically provisioned personal tenant",
nil,
nil,
"",
)
if err != nil {
return nil, err
}
if tenant == nil {
return nil, errors.New("personal tenant not created")
}
return tenant, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,159 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/testsupport"
"bytes"
"context"
"encoding/json"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/testcontainers/testcontainers-go"
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
gorm_postgres "gorm.io/driver/postgres"
"gorm.io/gorm"
)
func newTenantHandlerSeedDeleteDB(t *testing.T) *gorm.DB {
t.Helper()
if !testsupport.DockerAvailable() {
t.Skip("Docker provider is unavailable in this environment")
}
ctx := context.Background()
postgresContainer, err := postgres_module.Run(ctx,
"postgres:16-alpine",
postgres_module.WithDatabase("testdb"),
postgres_module.WithUsername("user"),
postgres_module.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(30*time.Second)),
)
if err != nil {
t.Fatalf("failed to start postgres container: %v", err)
}
t.Cleanup(func() {
if err := postgresContainer.Terminate(ctx); err != nil {
log.Printf("failed to terminate postgres container: %v", err)
}
})
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("failed to get postgres connection string: %v", err)
}
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open postgres connection: %v", err)
}
if err := db.AutoMigrate(&domain.Tenant{}); err != nil {
t.Fatalf("failed to migrate tenants: %v", err)
}
return db
}
func setSeedTenantCSVForDeleteGuard(t *testing.T, slug string) {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "seed-tenant.csv")
csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" +
"Protected,COMPANY_GROUP,," + slug + ",Protected seed,\n"
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
t.Fatalf("failed to write seed csv: %v", err)
}
t.Setenv("SEED_TENANT_CSV_PATH", path)
}
func TestTenantHandlerDeleteTenantRejectsSeedTenant(t *testing.T) {
setSeedTenantCSVForDeleteGuard(t, "protected-root")
db := newTenantHandlerSeedDeleteDB(t)
tenant := domain.Tenant{
ID: "00000000-0000-0000-0000-000000000001",
Name: "Protected",
Slug: "protected-root",
Type: domain.TenantTypeCompanyGroup,
Status: domain.TenantStatusActive,
}
if err := db.Create(&tenant).Error; err != nil {
t.Fatalf("failed to create tenant: %v", err)
}
app := fiber.New()
app.Delete("/tenants/:id", (&TenantHandler{DB: db}).DeleteTenant)
req := httptest.NewRequest(http.MethodDelete, "/tenants/"+tenant.ID, nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
if resp.StatusCode != http.StatusConflict {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
}
var count int64
if err := db.Model(&domain.Tenant{}).Where("id = ?", tenant.ID).Count(&count).Error; err != nil {
t.Fatalf("count tenant: %v", err)
}
if count != 1 {
t.Fatalf("seed tenant count = %d, want 1", count)
}
}
func TestTenantHandlerDeleteTenantsBulkRejectsSeedTenant(t *testing.T) {
setSeedTenantCSVForDeleteGuard(t, "protected-root")
db := newTenantHandlerSeedDeleteDB(t)
seed := domain.Tenant{
ID: "00000000-0000-0000-0000-000000000011",
Name: "Protected",
Slug: "protected-root",
Type: domain.TenantTypeCompanyGroup,
Status: domain.TenantStatusActive,
}
normal := domain.Tenant{
ID: "00000000-0000-0000-0000-000000000012",
Name: "Normal",
Slug: "normal",
Type: domain.TenantTypeCompany,
Status: domain.TenantStatusActive,
}
if err := db.Create(&seed).Error; err != nil {
t.Fatalf("failed to create seed tenant: %v", err)
}
if err := db.Create(&normal).Error; err != nil {
t.Fatalf("failed to create normal tenant: %v", err)
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Delete("/tenants/bulk", (&TenantHandler{DB: db}).DeleteTenantsBulk)
body, _ := json.Marshal(map[string][]string{"ids": {seed.ID, normal.ID}})
req := httptest.NewRequest(http.MethodDelete, "/tenants/bulk", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
if resp.StatusCode != http.StatusConflict {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
}
var count int64
if err := db.Model(&domain.Tenant{}).Where("id IN ?", []string{seed.ID, normal.ID}).Count(&count).Error; err != nil {
t.Fatalf("count tenants: %v", err)
}
if count != 2 {
t.Fatalf("remaining tenant count = %d, want 2", count)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,133 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"github.com/gofiber/fiber/v2"
)
type UserGroupHandler struct {
Service service.UserGroupService
}
func NewUserGroupHandler(s service.UserGroupService) *UserGroupHandler {
return &UserGroupHandler{Service: s}
}
func (h *UserGroupHandler) List(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
groups, err := h.Service.List(c.Context(), tenantID)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(groups)
}
func (h *UserGroupHandler) Create(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
var req domain.GroupCreateRequest
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid body"})
}
group, err := h.Service.Create(c.Context(), tenantID, req.ParentID, req.Name, req.Description, req.UnitType)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.Status(fiber.StatusCreated).JSON(group)
}
func (h *UserGroupHandler) Get(c *fiber.Ctx) error {
id := c.Params("id")
group, err := h.Service.Get(c.Context(), id)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to get group: "+err.Error())
}
return c.JSON(group)
}
func (h *UserGroupHandler) Update(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
groupID := c.Params("id")
var req domain.GroupCreateRequest // Using create request for update fields
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid body"})
}
group, err := h.Service.Update(c.Context(), tenantID, groupID, req.Name, req.Description, req.UnitType, req.ParentID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(group)
}
func (h *UserGroupHandler) Delete(c *fiber.Ctx) error {
tenantID := c.Params("tenantId")
groupID := c.Params("id")
if err := h.Service.Delete(c.Context(), tenantID, groupID); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.SendStatus(fiber.StatusNoContent)
}
func (h *UserGroupHandler) AddMember(c *fiber.Ctx) error {
groupID := c.Params("id")
var req struct {
UserID string `json:"userId"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "userId is required")
}
if err := h.Service.AddMember(c.Context(), groupID, req.UserID); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusOK)
}
func (h *UserGroupHandler) RemoveMember(c *fiber.Ctx) error {
groupID := c.Params("id")
userID := c.Params("userId")
if err := h.Service.RemoveMember(c.Context(), groupID, userID); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusNoContent)
}
func (h *UserGroupHandler) AssignRole(c *fiber.Ctx) error {
groupID := c.Params("id")
var req struct {
TenantID string `json:"tenantId"`
Relation string `json:"relation"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid body")
}
if err := h.Service.AssignRoleToTenant(c.Context(), groupID, req.TenantID, req.Relation); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusOK)
}
func (h *UserGroupHandler) ListRoles(c *fiber.Ctx) error {
groupID := c.Params("id")
roles, err := h.Service.ListRoles(c.Context(), groupID)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(roles)
}
func (h *UserGroupHandler) RemoveRole(c *fiber.Ctx) error {
groupID := c.Params("id")
tenantID := c.Params("tenantId")
relation := c.Params("relation")
if err := h.Service.RemoveRoleFromTenant(c.Context(), groupID, tenantID, relation); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.SendStatus(fiber.StatusNoContent)
}

View File

@@ -0,0 +1,155 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Mocks ---
type MockUserGroupService struct {
mock.Mock
}
func (m *MockUserGroupService) Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error) {
args := m.Called(ctx, tenantID, parentID, name, description, unitType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupService) Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error) {
args := m.Called(ctx, tenantID, groupID, name, description, unitType, parentID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupService) Delete(ctx context.Context, tenantID, groupID string) error {
return m.Called(ctx, tenantID, groupID).Error(0)
}
func (m *MockUserGroupService) Get(ctx context.Context, id string) (*domain.UserGroup, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupService) List(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
args := m.Called(ctx, tenantID)
return args.Get(0).([]domain.UserGroup), args.Error(1)
}
func (m *MockUserGroupService) SetWorksmobileSyncer(syncer service.WorksmobileSyncer) {}
func (m *MockUserGroupService) AddMember(ctx context.Context, groupID, userID string) error {
return m.Called(ctx, groupID, userID).Error(0)
}
func (m *MockUserGroupService) RemoveMember(ctx context.Context, groupID, userID string) error {
return m.Called(ctx, groupID, userID).Error(0)
}
func (m *MockUserGroupService) ListRoles(ctx context.Context, groupID string) ([]domain.GroupRole, error) {
args := m.Called(ctx, groupID)
return args.Get(0).([]domain.GroupRole), args.Error(1)
}
func (m *MockUserGroupService) AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error {
return m.Called(ctx, groupID, tenantID, relation).Error(0)
}
func (m *MockUserGroupService) RemoveRoleFromTenant(ctx context.Context, groupID, tenantID, relation string) error {
return m.Called(ctx, groupID, tenantID, relation).Error(0)
}
// --- Tests ---
func TestUserGroupHandler_List(t *testing.T) {
mockSvc := new(MockUserGroupService)
h := NewUserGroupHandler(mockSvc)
app := fiber.New()
app.Get("/tenants/:tenantId/user-groups", h.List)
tenantID := "t1"
groups := []domain.UserGroup{{ID: "g1", Name: "Group 1"}}
mockSvc.On("List", mock.Anything, tenantID).Return(groups, nil)
req := httptest.NewRequest("GET", "/tenants/t1/user-groups", nil)
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var result []domain.UserGroup
json.NewDecoder(resp.Body).Decode(&result)
assert.Len(t, result, 1)
assert.Equal(t, "Group 1", result[0].Name)
}
func TestUserGroupHandler_Create(t *testing.T) {
mockSvc := new(MockUserGroupService)
h := NewUserGroupHandler(mockSvc)
app := fiber.New()
app.Post("/tenants/:tenantId/user-groups", h.Create)
body, _ := json.Marshal(map[string]string{"name": "New Group"})
mockSvc.On("Create", mock.Anything, "t1", mock.Anything, "New Group", mock.Anything, mock.Anything).Return(&domain.UserGroup{ID: "g1", Name: "New Group"}, nil)
req := httptest.NewRequest("POST", "/tenants/t1/user-groups", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
}
func TestUserGroupHandler_AddMember(t *testing.T) {
mockSvc := new(MockUserGroupService)
h := NewUserGroupHandler(mockSvc)
app := fiber.New()
app.Post("/user-groups/:id/members", h.AddMember)
groupID := "g1"
userID := "u1"
body, _ := json.Marshal(map[string]string{"userId": userID})
mockSvc.On("AddMember", mock.Anything, groupID, userID).Return(nil)
req := httptest.NewRequest("POST", "/user-groups/g1/members", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestUserGroupHandler_AssignRole(t *testing.T) {
mockSvc := new(MockUserGroupService)
h := NewUserGroupHandler(mockSvc)
app := fiber.New()
app.Post("/user-groups/:id/roles", h.AssignRole)
groupID := "g1"
targetTenantID := "t2"
relation := "manage"
body, _ := json.Marshal(map[string]string{"tenantId": targetTenantID, "relation": relation})
mockSvc.On("AssignRoleToTenant", mock.Anything, groupID, targetTenantID, relation).Return(nil)
req := httptest.NewRequest("POST", "/user-groups/g1/roles", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
package handler
import (
"baron-sso-backend/internal/domain"
"bytes"
"encoding/json"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestUserHandler_BulkCreateUsers_UUIDImportPolicy(t *testing.T) {
tests := []struct {
name string
field string
}{
{name: "id 필드 차단", field: "id"},
{name: "uuid 필드 차단", field: "uuid"},
{name: "userId 필드 차단", field: "userId"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockOry := new(MockOryProvider)
h := &UserHandler{
KratosAdmin: mockKratos,
OryProvider: mockOry,
}
app.Post("/users/bulk", h.BulkCreateUsers)
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
payload := map[string]any{
"users": []map[string]any{
{
"email": "uuid-import@test.com",
"name": "UUID Import User",
tt.field: "550e8400-e29b-41d4-a716-446655440000",
"metadata": map[string]any{},
},
},
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest("POST", "/users/bulk", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
var result struct {
Results []bulkUserResult `json:"results"`
}
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
assert.Len(t, result.Results, 1)
assert.False(t, result.Results[0].Success)
assert.Contains(t, result.Results[0].Message, "사용자 UUID 가져오기는 지원하지 않습니다")
mockOry.AssertExpectations(t)
mockKratos.AssertNotCalled(t, "FindIdentityIDByIdentifier", mock.Anything, mock.Anything)
})
}
}

View File

@@ -0,0 +1,16 @@
package handler
import (
"baron-sso-backend/internal/repository"
"context"
"log/slog"
)
func markUserProjectionFailed(ctx context.Context, repo repository.UserProjectionRepository, syncErr error) {
if repo == nil || syncErr == nil {
return
}
if err := repo.MarkFailed(ctx, syncErr); err != nil {
slog.Error("Failed to mark user projection as failed", "syncError", syncErr, "error", err)
}
}

View File

@@ -0,0 +1,217 @@
package handler
import (
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/csv"
"errors"
"log/slog"
"strings"
"github.com/gofiber/fiber/v2"
)
type WorksmobileHandler struct {
Service service.WorksmobileAdminService
}
func NewWorksmobileHandler(svc service.WorksmobileAdminService) *WorksmobileHandler {
return &WorksmobileHandler{Service: svc}
}
func (h *WorksmobileHandler) GetOverview(c *fiber.Ctx) error {
overview, err := h.Service.GetTenantOverview(c.Context(), strings.TrimSpace(c.Params("tenantId")))
if err != nil {
return worksmobileGuardError(c, err, "get_overview")
}
if !worksmobileOverviewAllowed(overview) {
return errorJSON(c, fiber.StatusNotFound, "worksmobile is only available for hanmac-family root tenant")
}
return c.JSON(overview)
}
func (h *WorksmobileHandler) GetComparison(c *fiber.Ctx) error {
includeMatched := strings.EqualFold(strings.TrimSpace(c.Query("includeMatched")), "true")
comparison, err := h.Service.GetComparison(c.Context(), strings.TrimSpace(c.Params("tenantId")), includeMatched)
if err != nil {
return worksmobileGuardError(c, err, "get_comparison")
}
return c.JSON(comparison)
}
func (h *WorksmobileHandler) OAuthCallback(c *fiber.Ctx) error {
return c.Type("html").SendString("<!doctype html><html><body>Worksmobile OAuth callback reachable</body></html>")
}
func (h *WorksmobileHandler) BackfillDryRun(c *fiber.Ctx) error {
result, err := h.Service.EnqueueBackfillDryRun(c.Context(), strings.TrimSpace(c.Params("tenantId")))
if err != nil {
return worksmobileGuardError(c, err, "backfill_dry_run")
}
return c.JSON(result)
}
func (h *WorksmobileHandler) SyncOrgUnit(c *fiber.Ctx) error {
orgUnitID := strings.TrimSpace(c.Params("orgUnitId"))
job, err := h.Service.EnqueueOrgUnitSync(c.Context(), strings.TrimSpace(c.Params("tenantId")), orgUnitID)
if err != nil {
return worksmobileGuardError(c, err, "sync_orgunit", "org_unit_id", orgUnitID)
}
return c.Status(fiber.StatusAccepted).JSON(job)
}
func (h *WorksmobileHandler) DeleteOrgUnit(c *fiber.Ctx) error {
orgUnitID := strings.TrimSpace(c.Params("orgUnitId"))
job, err := h.Service.EnqueueOrgUnitDelete(c.Context(), strings.TrimSpace(c.Params("tenantId")), orgUnitID)
if err != nil {
return worksmobileGuardError(c, err, "delete_orgunit", "org_unit_id", orgUnitID)
}
return c.Status(fiber.StatusAccepted).JSON(job)
}
func (h *WorksmobileHandler) SyncUser(c *fiber.Ctx) error {
userID := strings.TrimSpace(c.Params("userId"))
credentialRequest, err := parseWorksmobileCredentialRequest(c)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
job, err := h.Service.EnqueueUserSync(
c.Context(),
strings.TrimSpace(c.Params("tenantId")),
userID,
credentialRequest.CredentialBatchID,
credentialRequest.InitialPassword,
)
if err != nil {
return worksmobileGuardError(c, err, "sync_user", "user_id", userID)
}
return c.Status(fiber.StatusAccepted).JSON(job)
}
func (h *WorksmobileHandler) ResetUserPassword(c *fiber.Ctx) error {
userID := strings.TrimSpace(c.Params("userId"))
credentialBatchID, err := parseWorksmobileCredentialBatchID(c)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
job, err := h.Service.EnqueueUserPasswordReset(c.Context(), strings.TrimSpace(c.Params("tenantId")), userID, credentialBatchID)
if err != nil {
return worksmobileGuardError(c, err, "reset_user_password", "user_id", userID)
}
return c.Status(fiber.StatusAccepted).JSON(job)
}
func (h *WorksmobileHandler) RetryJob(c *fiber.Ctx) error {
jobID := strings.TrimSpace(c.Params("jobId"))
job, err := h.Service.RetryJob(c.Context(), strings.TrimSpace(c.Params("tenantId")), jobID)
if err != nil {
return worksmobileGuardError(c, err, "retry_job", "job_id", jobID)
}
return c.JSON(job)
}
func (h *WorksmobileHandler) DeletePendingJobs(c *fiber.Ctx) error {
result, err := h.Service.DeletePendingJobs(c.Context(), strings.TrimSpace(c.Params("tenantId")))
if err != nil {
return worksmobileGuardError(c, err, "delete_pending_jobs")
}
return c.JSON(result)
}
func (h *WorksmobileHandler) DownloadInitialPasswordsCSV(c *fiber.Ctx) error {
credentials, err := h.Service.ListInitialPasswordCredentials(c.Context(), strings.TrimSpace(c.Params("tenantId")), strings.TrimSpace(c.Query("batchId")))
if err != nil {
return worksmobileGuardError(c, err, "download_initial_passwords")
}
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
if err := writer.Write([]string{"email", "name", "primaryLeafOrgName", "initialPassword", "status", "lastError"}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
for _, credential := range credentials {
if err := writer.Write([]string{credential.Email, credential.Name, credential.PrimaryLeafOrgName, credential.InitialPassword, credential.Status, credential.LastError}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
}
writer.Flush()
if err := writer.Error(); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
c.Set(fiber.HeaderContentType, "text/csv; charset=utf-8")
c.Set(fiber.HeaderContentDisposition, `attachment; filename="worksmobile_initial_passwords.csv"`)
return c.Send(buf.Bytes())
}
func (h *WorksmobileHandler) ListCredentialBatches(c *fiber.Ctx) error {
batches, err := h.Service.ListCredentialBatches(c.Context(), strings.TrimSpace(c.Params("tenantId")))
if err != nil {
return worksmobileGuardError(c, err, "list_credential_batches")
}
return c.JSON(batches)
}
func (h *WorksmobileHandler) DeleteCredentialBatchPasswords(c *fiber.Ctx) error {
batchID := strings.TrimSpace(c.Params("batchId"))
batch, err := h.Service.DeleteCredentialBatchPasswords(c.Context(), strings.TrimSpace(c.Params("tenantId")), batchID)
if err != nil {
return worksmobileGuardError(c, err, "delete_credential_batch_passwords", "batch_id", batchID)
}
return c.JSON(batch)
}
type worksmobileCredentialBatchRequest struct {
CredentialBatchID string `json:"credentialBatchId"`
InitialPassword string `json:"initialPassword"`
}
func parseWorksmobileCredentialBatchID(c *fiber.Ctx) (string, error) {
req, err := parseWorksmobileCredentialRequest(c)
return req.CredentialBatchID, err
}
func parseWorksmobileCredentialRequest(c *fiber.Ctx) (worksmobileCredentialBatchRequest, error) {
batchID := strings.TrimSpace(c.Query("credentialBatchId"))
req := worksmobileCredentialBatchRequest{CredentialBatchID: batchID}
if len(bytes.TrimSpace(c.Body())) == 0 {
return req, nil
}
if err := c.BodyParser(&req); err != nil {
return worksmobileCredentialBatchRequest{}, err
}
req.InitialPassword = strings.TrimSpace(req.InitialPassword)
if bodyBatchID := strings.TrimSpace(req.CredentialBatchID); bodyBatchID != "" {
req.CredentialBatchID = bodyBatchID
return req, nil
}
req.CredentialBatchID = batchID
return req, nil
}
func worksmobileOverviewAllowed(overview service.WorksmobileTenantOverview) bool {
return overview.Tenant.Slug == service.HanmacFamilyTenantSlug && overview.Tenant.ParentID == nil
}
func worksmobileGuardError(c *fiber.Ctx, err error, operation string, attrs ...any) error {
if err == nil {
return nil
}
logAttrs := []any{
"operation", operation,
"tenant_id", strings.TrimSpace(c.Params("tenantId")),
"path", c.Path(),
"error", err,
}
logAttrs = append(logAttrs, attrs...)
if errors.Is(err, context.Canceled) {
slog.Warn("worksmobile admin operation failed", logAttrs...)
return errorJSON(c, fiber.StatusRequestTimeout, err.Error())
}
slog.Error("worksmobile admin operation failed", logAttrs...)
if strings.Contains(err.Error(), "hanmac-family root") {
return errorJSON(c, fiber.StatusNotFound, err.Error())
}
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}

View File

@@ -0,0 +1,267 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"errors"
"io"
"log/slog"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
func TestWorksmobileHandlerRejectsNonHanmacTenant(t *testing.T) {
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
overview: service.WorksmobileTenantOverview{
Tenant: domain.Tenant{ID: "tenant-1", Slug: "other"},
},
})
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile", h.GetOverview)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/tenant-1/worksmobile", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func TestWorksmobileHandlerReturnsOverviewForHanmacTenant(t *testing.T) {
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
overview: service.WorksmobileTenantOverview{
Tenant: domain.Tenant{ID: "hanmac-id", Slug: "hanmac-family"},
Config: service.WorksmobileConfigSummary{
Enabled: true,
},
},
})
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile", h.GetOverview)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func TestWorksmobileHandlerDownloadsInitialPasswordCSV(t *testing.T) {
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
credentials: []service.WorksmobileInitialPasswordCredential{
{
Email: "user@hanmaceng.co.kr",
Name: "홍길동",
PrimaryLeafOrgName: "인재성장",
InitialPassword: "Aa1!Aa1!Aa1!Aa1!",
Status: "processed",
},
},
})
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile/initial-passwords.csv", h.DownloadInitialPasswordsCSV)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/initial-passwords.csv", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Contains(t, resp.Header.Get("Content-Disposition"), "worksmobile_initial_passwords.csv")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "email,name,primaryLeafOrgName,initialPassword,status,lastError")
require.Contains(t, string(body), "user@hanmaceng.co.kr,홍길동,인재성장,Aa1!Aa1!Aa1!Aa1!,processed,")
}
func TestWorksmobileHandlerPassesInitialPasswordBatchID(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{
credentials: []service.WorksmobileInitialPasswordCredential{
{Email: "batch-user@hanmaceng.co.kr", InitialPassword: "BatchPass1!", Status: "pending"},
},
}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile/initial-passwords.csv", h.DownloadInitialPasswordsCSV)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/initial-passwords.csv?batchId=batch-1", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.downloadCredentialBatchID)
}
func TestWorksmobileHandlerPassesSyncUserCredentialBatchID(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", strings.NewReader(`{"credentialBatchId":"batch-1","initialPassword":"InputPass1!"}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.syncUserCredentialBatchID)
require.Equal(t, "InputPass1!", fakeService.syncUserInitialPassword)
}
func TestWorksmobileHandlerPassesPasswordResetCredentialBatchID(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Post("/tenants/:tenantId/worksmobile/users/:userId/password/reset", h.ResetUserPassword)
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/password/reset", strings.NewReader(`{"credentialBatchId":"batch-1"}`))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.resetPasswordCredentialBatchID)
}
func TestWorksmobileHandlerReturnsCredentialBatchHistory(t *testing.T) {
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
credentialBatches: []service.WorksmobileCredentialBatch{
{BatchID: "batch-1", UserCount: 2, HasPasswords: true},
},
})
app := fiber.New()
app.Get("/tenants/:tenantId/worksmobile/credential-batches", h.ListCredentialBatches)
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/credential-batches", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), `"batchId":"batch-1"`)
require.Contains(t, string(body), `"userCount":2`)
}
func TestWorksmobileHandlerDeletesCredentialBatchPasswords(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Delete("/tenants/:tenantId/worksmobile/credential-batches/:batchId/passwords", h.DeleteCredentialBatchPasswords)
resp, err := app.Test(httptest.NewRequest("DELETE", "/tenants/hanmac-id/worksmobile/credential-batches/batch-1/passwords", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "batch-1", fakeService.deletedCredentialBatchID)
}
func TestWorksmobileHandlerDeletesPendingJobs(t *testing.T) {
fakeService := &fakeWorksmobileAdminService{
pendingJobsDeleteResult: service.WorksmobilePendingJobDeleteResult{DeletedCount: 3},
}
h := NewWorksmobileHandler(fakeService)
app := fiber.New()
app.Delete("/tenants/:tenantId/worksmobile/jobs/pending", h.DeletePendingJobs)
resp, err := app.Test(httptest.NewRequest("DELETE", "/tenants/hanmac-id/worksmobile/jobs/pending", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "hanmac-id", fakeService.deletedPendingJobsTenantID)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), `"deletedCount":3`)
}
func TestWorksmobileHandlerLogsActionFailures(t *testing.T) {
var logs bytes.Buffer
previous := slog.Default()
slog.SetDefault(slog.New(slog.NewJSONHandler(&logs, nil)))
t.Cleanup(func() {
slog.SetDefault(previous)
})
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
syncUserErr: errors.New("works user sync failed"),
})
app := fiber.New()
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
resp, err := app.Test(httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, logs.String(), "worksmobile admin operation failed")
require.Contains(t, logs.String(), "sync_user")
require.Contains(t, logs.String(), "works user sync failed")
}
type fakeWorksmobileAdminService struct {
overview service.WorksmobileTenantOverview
credentials []service.WorksmobileInitialPasswordCredential
syncUserErr error
syncUserCredentialBatchID string
syncUserInitialPassword string
resetPasswordCredentialBatchID string
downloadCredentialBatchID string
deletedCredentialBatchID string
deletedPendingJobsTenantID string
pendingJobsDeleteResult service.WorksmobilePendingJobDeleteResult
credentialBatches []service.WorksmobileCredentialBatch
}
func (f *fakeWorksmobileAdminService) GetTenantOverview(ctx context.Context, tenantID string) (service.WorksmobileTenantOverview, error) {
return f.overview, nil
}
func (f *fakeWorksmobileAdminService) GetComparison(ctx context.Context, tenantID string, includeMatched bool) (service.WorksmobileComparison, error) {
return service.WorksmobileComparison{}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueBackfillDryRun(ctx context.Context, tenantID string) (service.WorksmobileBackfillDryRun, error) {
return service.WorksmobileBackfillDryRun{}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueOrgUnitSync(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error) {
return &domain.WorksmobileOutbox{ID: "job-orgunit", ResourceID: orgUnitID}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueOrgUnitDelete(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error) {
return &domain.WorksmobileOutbox{ID: "job-orgunit-delete", ResourceID: orgUnitID, Action: domain.WorksmobileActionDelete}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID, initialPassword string) (*domain.WorksmobileOutbox, error) {
f.syncUserCredentialBatchID = credentialBatchID
f.syncUserInitialPassword = initialPassword
if f.syncUserErr != nil {
return nil, f.syncUserErr
}
return &domain.WorksmobileOutbox{ID: "job-user", ResourceID: userID}, nil
}
func (f *fakeWorksmobileAdminService) EnqueueUserPasswordReset(ctx context.Context, tenantID, userID, credentialBatchID string) (*domain.WorksmobileOutbox, error) {
f.resetPasswordCredentialBatchID = credentialBatchID
return &domain.WorksmobileOutbox{ID: "job-user-password-reset", ResourceID: userID}, nil
}
func (f *fakeWorksmobileAdminService) RetryJob(ctx context.Context, tenantID, jobID string) (*domain.WorksmobileOutbox, error) {
return &domain.WorksmobileOutbox{ID: jobID}, nil
}
func (f *fakeWorksmobileAdminService) ListInitialPasswordCredentials(ctx context.Context, tenantID, credentialBatchID string) ([]service.WorksmobileInitialPasswordCredential, error) {
f.downloadCredentialBatchID = credentialBatchID
return f.credentials, nil
}
func (f *fakeWorksmobileAdminService) ListCredentialBatches(ctx context.Context, tenantID string) ([]service.WorksmobileCredentialBatch, error) {
return f.credentialBatches, nil
}
func (f *fakeWorksmobileAdminService) DeleteCredentialBatchPasswords(ctx context.Context, tenantID, credentialBatchID string) (service.WorksmobileCredentialBatch, error) {
f.deletedCredentialBatchID = credentialBatchID
return service.WorksmobileCredentialBatch{BatchID: credentialBatchID}, nil
}
func (f *fakeWorksmobileAdminService) DeletePendingJobs(ctx context.Context, tenantID string) (service.WorksmobilePendingJobDeleteResult, error) {
f.deletedPendingJobsTenantID = tenantID
return f.pendingJobsDeleteResult, nil
}