첫 커밋: 로컬 프로젝트 업로드
This commit is contained in:
491
baron-sso/backend/internal/handler/admin_handler.go
Normal file
491
baron-sso/backend/internal/handler/admin_handler.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type adminHydraClientLister interface {
|
||||
ListClients(ctx context.Context, limit, offset int) ([]domain.HydraClient, error)
|
||||
}
|
||||
|
||||
type identityCacheAdmin interface {
|
||||
GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error)
|
||||
FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error)
|
||||
}
|
||||
|
||||
type AdminHandler struct {
|
||||
DB *gorm.DB
|
||||
Keto service.KetoService
|
||||
KetoOutbox repository.KetoOutboxRepository
|
||||
RPUsageQueries domain.RPUsageQueryRepository
|
||||
TenantRepo repository.TenantRepository
|
||||
Hydra adminHydraClientLister
|
||||
AuditRepo domain.AuditRepository
|
||||
UserProjectionRepo repository.UserProjectionRepository
|
||||
IdentityCache identityCacheAdmin
|
||||
IntegrityChecker repository.DataIntegrityChecker
|
||||
}
|
||||
|
||||
const globalCustomClaimsSettingKey = "global_custom_claim_definitions"
|
||||
|
||||
type globalCustomClaimDefinition struct {
|
||||
Key string `json:"key"`
|
||||
Label string `json:"label"`
|
||||
ValueType string `json:"valueType"`
|
||||
ReadPermission string `json:"readPermission"`
|
||||
WritePermission string `json:"writePermission"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type globalCustomClaimDefinitionsResponse struct {
|
||||
Items []globalCustomClaimDefinition `json:"items"`
|
||||
}
|
||||
|
||||
func NewAdminHandler(keto service.KetoService, ketoOutbox repository.KetoOutboxRepository) *AdminHandler {
|
||||
return &AdminHandler{
|
||||
Keto: keto,
|
||||
KetoOutbox: ketoOutbox,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetRPUsageDaily(c *fiber.Ctx) error {
|
||||
if h == nil || h.RPUsageQueries == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
|
||||
"error": "rp usage query service unavailable",
|
||||
})
|
||||
}
|
||||
days := 14
|
||||
if raw := c.Query("days"); raw != "" {
|
||||
if parsed, err := strconv.Atoi(raw); err == nil {
|
||||
days = parsed
|
||||
}
|
||||
}
|
||||
period := normalizeRPUsagePeriod(c.Query("period"))
|
||||
tenantID, allowed := h.authorizedRPUsageTenantID(c, strings.TrimSpace(c.Query("tenantId")))
|
||||
if !allowed {
|
||||
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
||||
"error": "forbidden: tenant rp usage stats permission denied",
|
||||
})
|
||||
}
|
||||
items, err := h.RPUsageQueries.FindRPUsage(c.Context(), domain.RPUsageQuery{
|
||||
Days: days,
|
||||
Period: period,
|
||||
TenantID: tenantID,
|
||||
})
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
return c.JSON(fiber.Map{
|
||||
"items": items,
|
||||
"days": days,
|
||||
"period": period,
|
||||
"tenantId": tenantID,
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeRPUsagePeriod(period string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(period)) {
|
||||
case "week":
|
||||
return "week"
|
||||
case "month":
|
||||
return "month"
|
||||
default:
|
||||
return "day"
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AdminHandler) authorizedRPUsageTenantID(c *fiber.Ctx, requestedTenantID string) (string, bool) {
|
||||
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
if profile != nil && domain.NormalizeRole(profile.Role) == domain.RoleSuperAdmin {
|
||||
return requestedTenantID, true
|
||||
}
|
||||
tenantID := requestedTenantID
|
||||
if tenantID == "" && profile != nil && profile.TenantID != nil {
|
||||
tenantID = strings.TrimSpace(*profile.TenantID)
|
||||
}
|
||||
if tenantID == "" {
|
||||
return "", false
|
||||
}
|
||||
if h == nil || h.Keto == nil || profile == nil || strings.TrimSpace(profile.ID) == "" {
|
||||
return "", false
|
||||
}
|
||||
allowed, err := h.Keto.CheckPermission(c.Context(), "User:"+profile.ID, "Tenant", tenantID, "view_rp_usage_stats")
|
||||
if err != nil || !allowed {
|
||||
return "", false
|
||||
}
|
||||
return tenantID, true
|
||||
}
|
||||
|
||||
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
|
||||
if h == nil || h.DB == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
|
||||
"error": "settings store unavailable",
|
||||
})
|
||||
}
|
||||
|
||||
var setting domain.SystemSetting
|
||||
if err := h.DB.WithContext(c.Context()).First(&setting, "key = ?", globalCustomClaimsSettingKey).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(globalCustomClaimDefinitionsResponse{Items: []globalCustomClaimDefinition{}})
|
||||
}
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(globalCustomClaimDefinitionsResponse{
|
||||
Items: normalizeGlobalCustomClaimDefinitions(setting.Value["items"]),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) UpdateGlobalCustomClaimDefinitions(c *fiber.Ctx) error {
|
||||
if h == nil || h.DB == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
|
||||
"error": "settings store unavailable",
|
||||
})
|
||||
}
|
||||
|
||||
var req globalCustomClaimDefinitionsResponse
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
|
||||
}
|
||||
items, err := validateGlobalCustomClaimDefinitions(req.Items)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
setting := domain.SystemSetting{
|
||||
Key: globalCustomClaimsSettingKey,
|
||||
Value: domain.JSONMap{"items": globalCustomClaimDefinitionsToJSON(items)},
|
||||
}
|
||||
if err := h.DB.WithContext(c.Context()).Save(&setting).Error; err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(globalCustomClaimDefinitionsResponse{Items: items})
|
||||
}
|
||||
|
||||
func normalizeGlobalCustomClaimDefinitions(value any) []globalCustomClaimDefinition {
|
||||
rawItems, ok := value.([]any)
|
||||
if !ok {
|
||||
return []globalCustomClaimDefinition{}
|
||||
}
|
||||
items := make([]globalCustomClaimDefinition, 0, len(rawItems))
|
||||
for _, item := range rawItems {
|
||||
raw, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
def := globalCustomClaimDefinition{
|
||||
Key: strings.TrimSpace(stringValue(raw["key"])),
|
||||
Label: strings.TrimSpace(stringValue(raw["label"])),
|
||||
ValueType: normalizeGlobalCustomClaimType(stringValue(raw["valueType"])),
|
||||
ReadPermission: adminNormalizeCustomClaimPermission(stringValue(raw["readPermission"])),
|
||||
WritePermission: adminNormalizeCustomClaimPermission(stringValue(raw["writePermission"])),
|
||||
Description: strings.TrimSpace(stringValue(raw["description"])),
|
||||
}
|
||||
if def.Key != "" {
|
||||
items = append(items, def)
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func validateGlobalCustomClaimDefinitions(items []globalCustomClaimDefinition) ([]globalCustomClaimDefinition, error) {
|
||||
seen := map[string]struct{}{}
|
||||
normalized := make([]globalCustomClaimDefinition, 0, len(items))
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if !isValidCustomClaimKey(key) {
|
||||
return nil, fiber.NewError(fiber.StatusBadRequest, "claim key must use letters, numbers, underscore, dot, or hyphen")
|
||||
}
|
||||
if _, exists := seen[key]; exists {
|
||||
return nil, fiber.NewError(fiber.StatusBadRequest, "duplicate claim key: "+key)
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
normalized = append(normalized, globalCustomClaimDefinition{
|
||||
Key: key,
|
||||
Label: strings.TrimSpace(item.Label),
|
||||
ValueType: normalizeGlobalCustomClaimType(item.ValueType),
|
||||
ReadPermission: adminNormalizeCustomClaimPermission(item.ReadPermission),
|
||||
WritePermission: adminNormalizeCustomClaimPermission(item.WritePermission),
|
||||
Description: strings.TrimSpace(item.Description),
|
||||
})
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func globalCustomClaimDefinitionsToJSON(items []globalCustomClaimDefinition) []any {
|
||||
values := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
values = append(values, map[string]any{
|
||||
"key": item.Key,
|
||||
"label": item.Label,
|
||||
"valueType": item.ValueType,
|
||||
"readPermission": item.ReadPermission,
|
||||
"writePermission": item.WritePermission,
|
||||
"description": item.Description,
|
||||
})
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func normalizeGlobalCustomClaimType(value string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "number", "boolean", "array", "object", "date", "datetime":
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
default:
|
||||
return "text"
|
||||
}
|
||||
}
|
||||
|
||||
func adminNormalizeCustomClaimPermission(value string) string {
|
||||
if strings.TrimSpace(value) == "user_and_admin" {
|
||||
return "user_and_admin"
|
||||
}
|
||||
return "admin_only"
|
||||
}
|
||||
|
||||
func isValidCustomClaimKey(value string) bool {
|
||||
for _, r := range value {
|
||||
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '_' || r == '-' || r == '.' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringValue(value any) string {
|
||||
if text, ok := value.(string); ok {
|
||||
return text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func requireSuperAdminProfile(c *fiber.Ctx) bool {
|
||||
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
if profile == nil || domain.NormalizeRole(profile.Role) != domain.RoleSuperAdmin {
|
||||
_ = c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: super_admin required"})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetUserProjectionStatus(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.UserProjectionRepo == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
|
||||
}
|
||||
status, err := h.UserProjectionRepo.GetStatus(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(status)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetOrySSOTSystemStatus(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.UserProjectionRepo == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "user projection service unavailable"})
|
||||
}
|
||||
projectionStatus, err := h.UserProjectionRepo.GetStatus(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
cacheStatus := domain.IdentityCacheStatus{
|
||||
Status: "unavailable",
|
||||
RedisReady: false,
|
||||
LastError: "identity cache service unavailable",
|
||||
}
|
||||
if h.IdentityCache != nil {
|
||||
cacheStatus, err = h.IdentityCache.GetIdentityCacheStatus(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"userProjection": projectionStatus,
|
||||
"identityCache": cacheStatus,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) FlushIdentityCache(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.IdentityCache == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "identity cache service unavailable"})
|
||||
}
|
||||
result, err := h.IdentityCache.FlushIdentityCache(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) GetDataIntegrity(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.IntegrityChecker == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
|
||||
}
|
||||
report, err := h.IntegrityChecker.CheckDataIntegrity(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(report)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) ListOrphanUserLoginIDs(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.IntegrityChecker == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
|
||||
}
|
||||
items, err := h.IntegrityChecker.ListOrphanUserLoginIDs(c.Context())
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(fiber.Map{
|
||||
"items": items,
|
||||
"total": len(items),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) DeleteOrphanUserLoginIDs(c *fiber.Ctx) error {
|
||||
if !requireSuperAdminProfile(c) {
|
||||
return nil
|
||||
}
|
||||
if h == nil || h.IntegrityChecker == nil {
|
||||
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "data integrity checker unavailable"})
|
||||
}
|
||||
var req struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
result, err := h.IntegrityChecker.DeleteOrphanUserLoginIDs(c.Context(), req.IDs)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
// GetSystemStats returns runtime statistics for monitoring
|
||||
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
ctx := c.Context()
|
||||
|
||||
stats := fiber.Map{
|
||||
"totalTenants": h.countTenants(ctx),
|
||||
"totalUsers": h.countUsers(ctx),
|
||||
"oidcClients": h.countOIDCClients(ctx),
|
||||
"auditEvents24h": h.countAuditEventsSince(ctx, time.Now().UTC().Add(-24*time.Hour)),
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"cpus": runtime.NumCPU(),
|
||||
"memory": fiber.Map{
|
||||
"alloc": m.Alloc,
|
||||
"totalAlign": m.TotalAlloc,
|
||||
"sys": m.Sys,
|
||||
"numGC": m.NumGC,
|
||||
},
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(stats)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) countTenants(ctx context.Context) int64 {
|
||||
if h == nil || h.TenantRepo == nil {
|
||||
return 0
|
||||
}
|
||||
_, total, err := h.TenantRepo.List(ctx, 1, 0, "", "")
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (h *AdminHandler) countUsers(ctx context.Context) int64 {
|
||||
if h == nil || h.UserProjectionRepo == nil {
|
||||
return 0
|
||||
}
|
||||
status, err := h.UserProjectionRepo.GetStatus(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return status.ProjectedUsers
|
||||
}
|
||||
|
||||
func (h *AdminHandler) countOIDCClients(ctx context.Context) int64 {
|
||||
if h == nil || h.Hydra == nil {
|
||||
return 0
|
||||
}
|
||||
const pageSize = 500
|
||||
var total int64
|
||||
for offset := 0; ; offset += pageSize {
|
||||
clients, err := h.Hydra.ListClients(ctx, pageSize, offset)
|
||||
if err != nil {
|
||||
return total
|
||||
}
|
||||
for _, client := range clients {
|
||||
if isHiddenSystemClient(client) {
|
||||
continue
|
||||
}
|
||||
total++
|
||||
}
|
||||
if len(clients) < pageSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (h *AdminHandler) countAuditEventsSince(ctx context.Context, since time.Time) int64 {
|
||||
if h == nil || h.AuditRepo == nil {
|
||||
return 0
|
||||
}
|
||||
count, err := h.AuditRepo.CountEventsSince(ctx, since)
|
||||
if err == nil && count > 0 {
|
||||
return count
|
||||
}
|
||||
logs, pageErr := h.AuditRepo.FindPage(ctx, 10000, nil, "")
|
||||
if pageErr != nil {
|
||||
return count
|
||||
}
|
||||
var fallbackCount int64
|
||||
for _, log := range logs {
|
||||
if !log.Timestamp.Before(since) {
|
||||
fallbackCount++
|
||||
}
|
||||
}
|
||||
return fallbackCount
|
||||
}
|
||||
343
baron-sso/backend/internal/handler/admin_handler_test.go
Normal file
343
baron-sso/backend/internal/handler/admin_handler_test.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeRPUsageQueryRepo struct {
|
||||
query domain.RPUsageQuery
|
||||
items []domain.RPUsageDailyMetric
|
||||
}
|
||||
|
||||
func (f *fakeRPUsageQueryRepo) FindRPUsage(ctx context.Context, query domain.RPUsageQuery) ([]domain.RPUsageDailyMetric, error) {
|
||||
f.query = query
|
||||
return f.items, nil
|
||||
}
|
||||
|
||||
type fakeAdminKeto struct {
|
||||
allowed bool
|
||||
subject string
|
||||
object string
|
||||
relation string
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) CheckPermission(ctx context.Context, subject, namespace, object, relation string) (bool, error) {
|
||||
f.subject = subject
|
||||
f.object = object
|
||||
f.relation = relation
|
||||
return f.allowed, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminKeto) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type fakeOverviewAuditRepo struct {
|
||||
mockAuditRepo
|
||||
since time.Time
|
||||
count int64
|
||||
}
|
||||
|
||||
func (f *fakeOverviewAuditRepo) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
f.since = since
|
||||
return f.count, nil
|
||||
}
|
||||
|
||||
type fakeAdminUserProjectionRepo struct {
|
||||
status domain.UserProjectionStatus
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) IsReady(ctx context.Context) (bool, error) {
|
||||
return f.status.Ready, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) CountTenantMembers(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) CountTenantMembersRecursive(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) ReplaceAllFromKratos(ctx context.Context, users []domain.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) MarkFailed(ctx context.Context, syncErr error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAdminUserProjectionRepo) GetStatus(ctx context.Context) (domain.UserProjectionStatus, error) {
|
||||
return f.status, nil
|
||||
}
|
||||
|
||||
type fakeIdentityCacheAdmin struct {
|
||||
status domain.IdentityCacheStatus
|
||||
flush domain.IdentityCacheFlushResult
|
||||
err error
|
||||
statusHit int
|
||||
flushCalls int
|
||||
}
|
||||
|
||||
func (f *fakeIdentityCacheAdmin) GetIdentityCacheStatus(ctx context.Context) (domain.IdentityCacheStatus, error) {
|
||||
f.statusHit++
|
||||
return f.status, f.err
|
||||
}
|
||||
|
||||
func (f *fakeIdentityCacheAdmin) FlushIdentityCache(ctx context.Context) (domain.IdentityCacheFlushResult, error) {
|
||||
f.flushCalls++
|
||||
return f.flush, f.err
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetRPUsageDaily(t *testing.T) {
|
||||
repo := &fakeRPUsageQueryRepo{
|
||||
items: []domain.RPUsageDailyMetric{
|
||||
{
|
||||
Date: "2026-05-06",
|
||||
TenantID: "tenant-1",
|
||||
TenantType: domain.TenantTypeCompany,
|
||||
ClientID: "orgfront",
|
||||
ClientName: "OrgFront",
|
||||
LoginRequests: 12,
|
||||
OtherRequests: 4,
|
||||
UniqueSubjects: 8,
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{RPUsageQueries: repo}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/rp-usage/daily", h.GetRPUsageDaily)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/rp-usage/daily?days=7&period=week&tenantId=tenant-1", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 7, repo.query.Days)
|
||||
require.Equal(t, "week", repo.query.Period)
|
||||
require.Equal(t, "tenant-1", repo.query.TenantID)
|
||||
|
||||
var body struct {
|
||||
Items []domain.RPUsageDailyMetric `json:"items"`
|
||||
Days int `json:"days"`
|
||||
Period string `json:"period"`
|
||||
TenantID string `json:"tenantId"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, 7, body.Days)
|
||||
require.Equal(t, "week", body.Period)
|
||||
require.Equal(t, "tenant-1", body.TenantID)
|
||||
require.Len(t, body.Items, 1)
|
||||
require.Equal(t, "orgfront", body.Items[0].ClientID)
|
||||
require.Equal(t, uint64(12), body.Items[0].LoginRequests)
|
||||
}
|
||||
|
||||
func TestAdminHandler_UserProjectionStatusRequiresSuperAdmin(t *testing.T) {
|
||||
h := &AdminHandler{
|
||||
UserProjectionRepo: &fakeAdminUserProjectionRepo{
|
||||
status: domain.UserProjectionStatus{Name: domain.UserProjectionNameKratos, Status: domain.UserProjectionStatusReady, Ready: true},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/projections/users", h.GetUserProjectionStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/projections/users", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestAdminHandler_UserProjectionStatusReturnsProjectionStateForSuperAdmin(t *testing.T) {
|
||||
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
|
||||
h := &AdminHandler{
|
||||
UserProjectionRepo: &fakeAdminUserProjectionRepo{
|
||||
status: domain.UserProjectionStatus{
|
||||
Name: domain.UserProjectionNameKratos,
|
||||
Status: domain.UserProjectionStatusReady,
|
||||
Ready: true,
|
||||
LastSyncedAt: &syncedAt,
|
||||
ProjectedUsers: 152,
|
||||
},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/projections/users", h.GetUserProjectionStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/projections/users", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body domain.UserProjectionStatus
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, domain.UserProjectionNameKratos, body.Name)
|
||||
require.Equal(t, domain.UserProjectionStatusReady, body.Status)
|
||||
require.True(t, body.Ready)
|
||||
require.Equal(t, int64(152), body.ProjectedUsers)
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetOrySSOTSystemStatusReturnsProjectionAndIdentityCache(t *testing.T) {
|
||||
syncedAt := time.Date(2026, 5, 11, 3, 0, 0, 0, time.UTC)
|
||||
cache := &fakeIdentityCacheAdmin{
|
||||
status: domain.IdentityCacheStatus{
|
||||
Status: "ready",
|
||||
RedisReady: true,
|
||||
ObservedCount: 151,
|
||||
KeyCount: 153,
|
||||
LastRefreshedAt: &syncedAt,
|
||||
UpdatedAt: &syncedAt,
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{
|
||||
UserProjectionRepo: &fakeAdminUserProjectionRepo{
|
||||
status: domain.UserProjectionStatus{
|
||||
Name: domain.UserProjectionNameKratos,
|
||||
Status: domain.UserProjectionStatusReady,
|
||||
Ready: true,
|
||||
LastSyncedAt: &syncedAt,
|
||||
ProjectedUsers: 152,
|
||||
},
|
||||
},
|
||||
IdentityCache: cache,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/ory/ssot", h.GetOrySSOTSystemStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/ory/ssot", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
UserProjection domain.UserProjectionStatus `json:"userProjection"`
|
||||
IdentityCache domain.IdentityCacheStatus `json:"identityCache"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, int64(152), body.UserProjection.ProjectedUsers)
|
||||
require.True(t, body.IdentityCache.RedisReady)
|
||||
require.Equal(t, int64(151), body.IdentityCache.ObservedCount)
|
||||
require.Equal(t, int64(153), body.IdentityCache.KeyCount)
|
||||
require.Equal(t, 1, cache.statusHit)
|
||||
}
|
||||
|
||||
func TestAdminHandler_FlushIdentityCacheRequiresSuperAdminAndFlushesCacheOnly(t *testing.T) {
|
||||
cache := &fakeIdentityCacheAdmin{
|
||||
flush: domain.IdentityCacheFlushResult{
|
||||
Status: "success",
|
||||
FlushedKeys: 7,
|
||||
UpdatedAt: time.Date(2026, 5, 11, 3, 2, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{
|
||||
IdentityCache: cache,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/admin/ory/ssot/identity-cache/flush", h.FlushIdentityCache)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/ory/ssot/identity-cache/flush", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body domain.IdentityCacheFlushResult
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, int64(7), body.FlushedKeys)
|
||||
require.Equal(t, 1, cache.flushCalls)
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetRPUsageDailyChecksTenantPermission(t *testing.T) {
|
||||
repo := &fakeRPUsageQueryRepo{}
|
||||
keto := &fakeAdminKeto{allowed: true}
|
||||
h := &AdminHandler{RPUsageQueries: repo, Keto: keto}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-1",
|
||||
Role: "tenant_admin",
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/rp-usage/daily", h.GetRPUsageDaily)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/rp-usage/daily?tenantId=tenant-allowed", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "User:user-1", keto.subject)
|
||||
require.Equal(t, "tenant-allowed", keto.object)
|
||||
require.Equal(t, "view_rp_usage_stats", keto.relation)
|
||||
require.Equal(t, "tenant-allowed", repo.query.TenantID)
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetSystemStatsIncludesOverviewMetrics(t *testing.T) {
|
||||
auditRepo := &fakeOverviewAuditRepo{count: 22}
|
||||
h := &AdminHandler{
|
||||
AuditRepo: auditRepo,
|
||||
UserProjectionRepo: &fakeAdminUserProjectionRepo{
|
||||
status: domain.UserProjectionStatus{
|
||||
Name: domain.UserProjectionNameKratos,
|
||||
Status: domain.UserProjectionStatusReady,
|
||||
Ready: true,
|
||||
ProjectedUsers: 152,
|
||||
},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/admin/stats", h.GetSystemStats)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/stats", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Contains(t, body, "totalTenants")
|
||||
require.Contains(t, body, "totalUsers")
|
||||
require.Contains(t, body, "oidcClients")
|
||||
require.Contains(t, body, "auditEvents24h")
|
||||
require.Equal(t, float64(152), body["totalUsers"])
|
||||
require.Equal(t, float64(22), body["auditEvents24h"])
|
||||
require.Equal(t, time.UTC, auditRepo.since.Location())
|
||||
}
|
||||
196
baron-sso/backend/internal/handler/admin_integrity_test.go
Normal file
196
baron-sso/backend/internal/handler/admin_integrity_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeDataIntegrityChecker struct {
|
||||
calls int
|
||||
listCalls int
|
||||
deleteCalls int
|
||||
deletedIDs []string
|
||||
report domain.DataIntegrityReport
|
||||
orphans []domain.OrphanUserLoginID
|
||||
deleteResult domain.DeleteOrphanUserLoginIDsResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeDataIntegrityChecker) CheckDataIntegrity(ctx context.Context) (domain.DataIntegrityReport, error) {
|
||||
f.calls++
|
||||
return f.report, f.err
|
||||
}
|
||||
|
||||
func (f *fakeDataIntegrityChecker) ListOrphanUserLoginIDs(ctx context.Context) ([]domain.OrphanUserLoginID, error) {
|
||||
f.listCalls++
|
||||
return f.orphans, f.err
|
||||
}
|
||||
|
||||
func (f *fakeDataIntegrityChecker) DeleteOrphanUserLoginIDs(ctx context.Context, ids []string) (domain.DeleteOrphanUserLoginIDsResult, error) {
|
||||
f.deleteCalls++
|
||||
f.deletedIDs = append([]string(nil), ids...)
|
||||
return f.deleteResult, f.err
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetDataIntegrityRequiresSuperAdmin(t *testing.T) {
|
||||
checker := &fakeDataIntegrityChecker{}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/integrity", h.GetDataIntegrity)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
require.Equal(t, 0, checker.calls)
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetDataIntegrityReturnsReportForSuperAdmin(t *testing.T) {
|
||||
checkedAt := time.Date(2026, 5, 14, 0, 0, 0, 0, time.UTC)
|
||||
checker := &fakeDataIntegrityChecker{
|
||||
report: domain.DataIntegrityReport{
|
||||
Status: domain.DataIntegrityStatusFail,
|
||||
CheckedAt: checkedAt,
|
||||
Summary: domain.DataIntegritySummary{
|
||||
TotalChecks: 1,
|
||||
Failures: 1,
|
||||
},
|
||||
Sections: []domain.DataIntegritySection{
|
||||
{
|
||||
Key: "tenant_integrity",
|
||||
Label: "테넌트 정합성",
|
||||
Status: domain.DataIntegrityStatusFail,
|
||||
Checks: []domain.DataIntegrityCheck{
|
||||
{
|
||||
Key: "duplicate_tenant_slugs",
|
||||
Label: "중복 테넌트 slug",
|
||||
Status: domain.DataIntegrityStatusFail,
|
||||
Count: 1,
|
||||
Severity: "error",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/integrity", h.GetDataIntegrity)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 1, checker.calls)
|
||||
|
||||
var body domain.DataIntegrityReport
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, domain.DataIntegrityStatusFail, body.Status)
|
||||
require.Equal(t, int64(1), body.Summary.Failures)
|
||||
require.Len(t, body.Sections, 1)
|
||||
require.Equal(t, "tenant_integrity", body.Sections[0].Key)
|
||||
}
|
||||
|
||||
func TestAdminHandler_ListOrphanUserLoginIDsReturnsTargetsForSuperAdmin(t *testing.T) {
|
||||
checker := &fakeDataIntegrityChecker{
|
||||
orphans: []domain.OrphanUserLoginID{
|
||||
{
|
||||
ID: "login-id-1",
|
||||
UserID: "user-1",
|
||||
TenantID: "tenant-1",
|
||||
FieldKey: "emp_id",
|
||||
LoginID: "EMP001",
|
||||
Reasons: []string{"missing_tenant"},
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/admin/integrity/orphan-user-login-ids", h.ListOrphanUserLoginIDs)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/integrity/orphan-user-login-ids", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 1, checker.listCalls)
|
||||
|
||||
var body struct {
|
||||
Items []domain.OrphanUserLoginID `json:"items"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, 1, body.Total)
|
||||
require.Equal(t, "login-id-1", body.Items[0].ID)
|
||||
require.Equal(t, []string{"missing_tenant"}, body.Items[0].Reasons)
|
||||
}
|
||||
|
||||
func TestAdminHandler_DeleteOrphanUserLoginIDsRequiresSuperAdminAndDeletesSelectedTargets(t *testing.T) {
|
||||
checker := &fakeDataIntegrityChecker{
|
||||
deleteResult: domain.DeleteOrphanUserLoginIDsResult{
|
||||
DeletedCount: 1,
|
||||
Deleted: []domain.OrphanUserLoginID{
|
||||
{ID: "login-id-1", LoginID: "EMP001", Reasons: []string{"missing_user"}},
|
||||
},
|
||||
SkippedIDs: []string{"valid-login-id"},
|
||||
},
|
||||
}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "super", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Delete("/api/v1/admin/integrity/orphan-user-login-ids", h.DeleteOrphanUserLoginIDs)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/admin/integrity/orphan-user-login-ids", strings.NewReader(`{"ids":["login-id-1","valid-login-id"]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 1, checker.deleteCalls)
|
||||
require.Equal(t, []string{"login-id-1", "valid-login-id"}, checker.deletedIDs)
|
||||
|
||||
var body domain.DeleteOrphanUserLoginIDsResult
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, int64(1), body.DeletedCount)
|
||||
require.Equal(t, []string{"valid-login-id"}, body.SkippedIDs)
|
||||
}
|
||||
|
||||
func TestAdminHandler_DeleteOrphanUserLoginIDsRejectsTenantAdmin(t *testing.T) {
|
||||
checker := &fakeDataIntegrityChecker{}
|
||||
h := &AdminHandler{IntegrityChecker: checker}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "tenant-admin", Role: "tenant_admin"})
|
||||
return c.Next()
|
||||
})
|
||||
app.Delete("/api/v1/admin/integrity/orphan-user-login-ids", h.DeleteOrphanUserLoginIDs)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/admin/integrity/orphan-user-login-ids", strings.NewReader(`{"ids":["login-id-1"]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
require.Equal(t, 0, checker.deleteCalls)
|
||||
}
|
||||
288
baron-sso/backend/internal/handler/api_key_handler.go
Normal file
288
baron-sso/backend/internal/handler/api_key_handler.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/pagination"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyHandler struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyHandler(db *gorm.DB) *ApiKeyHandler {
|
||||
return &ApiKeyHandler{DB: db}
|
||||
}
|
||||
|
||||
type apiKeySummary struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Status string `json:"status"`
|
||||
LastUsedAt *string `json:"lastUsedAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type apiKeyListResponse struct {
|
||||
Items []apiKeySummary `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Cursor string `json:"cursor,omitempty"`
|
||||
NextCursor string `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
func apiKeyToSummary(k domain.ApiKey) apiKeySummary {
|
||||
lastUsed := ""
|
||||
if k.LastUsedAt != nil {
|
||||
lastUsed = k.LastUsedAt.Format(time.RFC3339)
|
||||
}
|
||||
return apiKeySummary{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
ClientID: k.ClientID,
|
||||
Scopes: strings.Fields(strings.ReplaceAll(k.Scopes, ",", " ")),
|
||||
Status: k.Status,
|
||||
LastUsedAt: &lastUsed,
|
||||
CreatedAt: k.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func apiKeyWithUpdatedScopes(k domain.ApiKey, scopes []string) domain.ApiKey {
|
||||
k.Scopes = strings.Join(normalizeApiKeyScopes(scopes), " ")
|
||||
return k
|
||||
}
|
||||
|
||||
func apiKeyWithRotatedSecretHash(k domain.ApiKey, hashedSecret string) domain.ApiKey {
|
||||
k.ClientSecretHash = hashedSecret
|
||||
return k
|
||||
}
|
||||
|
||||
func normalizeApiKeyScopes(scopes []string) []string {
|
||||
seen := make(map[string]struct{}, len(scopes))
|
||||
normalized := make([]string, 0, len(scopes))
|
||||
for _, scope := range scopes {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[scope]; exists {
|
||||
continue
|
||||
}
|
||||
seen[scope] = struct{}{}
|
||||
normalized = append(normalized, scope)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) ListApiKeys(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
limit := c.QueryInt("limit", 50)
|
||||
offset := c.QueryInt("offset", 0)
|
||||
cursorRaw := strings.TrimSpace(c.Query("cursor"))
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := h.DB.Model(&domain.ApiKey{}).Count(&total).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
var keys []domain.ApiKey
|
||||
query := h.DB.Order("created_at desc, id desc").Limit(limit + 1)
|
||||
if cursorRaw != "" {
|
||||
cursor, err := pagination.Decode(cursorRaw)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid cursor")
|
||||
}
|
||||
query = pagination.ApplyCreatedAtIDCursor(query, cursor, "created_at", "id")
|
||||
offset = 0
|
||||
} else {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
if err := query.Find(&keys).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
hasMore := len(keys) > limit
|
||||
if len(keys) > limit {
|
||||
keys = keys[:limit]
|
||||
}
|
||||
if cursorRaw == "" && total > int64(offset+len(keys)) {
|
||||
hasMore = true
|
||||
}
|
||||
if hasMore && len(keys) > 0 {
|
||||
last := keys[len(keys)-1]
|
||||
nextCursor = pagination.Encode(last.CreatedAt, last.ID)
|
||||
}
|
||||
|
||||
items := make([]apiKeySummary, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
items = append(items, apiKeyToSummary(k))
|
||||
}
|
||||
|
||||
return c.JSON(apiKeyListResponse{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
Cursor: cursorRaw,
|
||||
NextCursor: nextCursor,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) CreateApiKey(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "name is required")
|
||||
}
|
||||
req.Scopes = normalizeApiKeyScopes(req.Scopes)
|
||||
if len(req.Scopes) == 0 {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "at least one scope is required")
|
||||
}
|
||||
|
||||
// Generate Client ID (16 chars hex)
|
||||
clientID := GenerateSecureToken(8)
|
||||
|
||||
// Generate plain secret (16 chars hex)
|
||||
plainSecret := GenerateSecureToken(8)
|
||||
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(plainSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to hash secret")
|
||||
}
|
||||
|
||||
apiKey := domain.ApiKey{
|
||||
Name: req.Name,
|
||||
ClientID: clientID,
|
||||
ClientSecretHash: string(hashedSecret),
|
||||
Scopes: strings.Join(req.Scopes, " "),
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
if err := h.DB.Create(&apiKey).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
// Return summary + PLAIN SECRET (only this time)
|
||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
|
||||
"apiKey": apiKeyToSummary(apiKey),
|
||||
"clientSecret": plainSecret, // VERY IMPORTANT: user must save this now
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) UpdateApiKey(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
id := c.Params("id")
|
||||
if id == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
req.Scopes = normalizeApiKeyScopes(req.Scopes)
|
||||
if len(req.Scopes) == 0 {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "at least one scope is required")
|
||||
}
|
||||
|
||||
var apiKey domain.ApiKey
|
||||
if err := h.DB.First(&apiKey, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errorJSON(c, fiber.StatusNotFound, "api key not found")
|
||||
}
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
apiKey = apiKeyWithUpdatedScopes(apiKey, req.Scopes)
|
||||
if err := h.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("scopes", apiKey.Scopes).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(apiKeyToSummary(apiKey))
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) RotateApiKeySecret(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
id := c.Params("id")
|
||||
if id == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
var apiKey domain.ApiKey
|
||||
if err := h.DB.First(&apiKey, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errorJSON(c, fiber.StatusNotFound, "api key not found")
|
||||
}
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
plainSecret := GenerateSecureToken(8)
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(plainSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to hash secret")
|
||||
}
|
||||
|
||||
apiKey = apiKeyWithRotatedSecretHash(apiKey, string(hashedSecret))
|
||||
if err := h.DB.Model(&domain.ApiKey{}).Where("id = ?", id).Update("client_secret_hash", apiKey.ClientSecretHash).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"apiKey": apiKeyToSummary(apiKey),
|
||||
"clientSecret": plainSecret,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) DeleteApiKey(c *fiber.Ctx) error {
|
||||
if h.DB == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "database not available")
|
||||
}
|
||||
|
||||
id := c.Params("id")
|
||||
if id == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
if err := h.DB.Delete(&domain.ApiKey{}, "id = ?", id).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
133
baron-sso/backend/internal/handler/api_key_handler_test.go
Normal file
133
baron-sso/backend/internal/handler/api_key_handler_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Mock DB for ApiKey tests using a real GORM instance but with a hijacked connection
|
||||
// or just a simple mock if we only check nil.
|
||||
// For ApiKeyHandler, it uses DB for Create/List/Delete.
|
||||
|
||||
func TestApiKeyHandler_CreateApiKey(t *testing.T) {
|
||||
app := fiber.New()
|
||||
// ApiKeyHandler requires a valid DB connection to perform h.DB.Create
|
||||
// Since we don't have a real DB here, we'll check if it fails gracefully
|
||||
// or we can use sqlite in-memory for a more realistic test.
|
||||
h := &ApiKeyHandler{DB: nil} // Testing ServiceUnavailable
|
||||
|
||||
app.Post("/api-keys", h.CreateApiKey)
|
||||
|
||||
input := map[string]any{
|
||||
"name": "M2M Test",
|
||||
"scopes": []string{"read", "write"},
|
||||
}
|
||||
body, _ := json.Marshal(input)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api-keys", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestApiKeyHandler_Validation(t *testing.T) {
|
||||
app := fiber.New()
|
||||
// Using a dummy DB pointer to pass the nil check
|
||||
h := &ApiKeyHandler{DB: &gorm.DB{}}
|
||||
|
||||
app.Post("/api-keys", h.CreateApiKey)
|
||||
|
||||
// Missing name
|
||||
input := map[string]any{
|
||||
"scopes": []string{"read"},
|
||||
}
|
||||
body, _ := json.Marshal(input)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api-keys", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestApiKeyHandler_UpdateApiKeyScopesRequiresDatabase(t *testing.T) {
|
||||
app := fiber.New()
|
||||
h := &ApiKeyHandler{DB: nil}
|
||||
|
||||
app.Patch("/api-keys/:id", h.UpdateApiKey)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"scopes": []string{"org-context:read"},
|
||||
})
|
||||
req := httptest.NewRequest("PATCH", "/api-keys/api-key-id", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestApiKeyHandler_RotateApiKeySecretRequiresDatabase(t *testing.T) {
|
||||
app := fiber.New()
|
||||
h := &ApiKeyHandler{DB: nil}
|
||||
|
||||
app.Post("/api-keys/:id/secret/rotate", h.RotateApiKeySecret)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api-keys/api-key-id/secret/rotate", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestApiKeyWithUpdatedScopesPreservesClientID(t *testing.T) {
|
||||
key := domain.ApiKey{
|
||||
ID: "api-key-id",
|
||||
Name: "M2M Test",
|
||||
ClientID: "client-id-stable",
|
||||
ClientSecretHash: "old-secret-hash",
|
||||
Scopes: "audit:read",
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
updated := apiKeyWithUpdatedScopes(key, []string{"audit:read", "org-context:read"})
|
||||
|
||||
assert.Equal(t, "client-id-stable", updated.ClientID)
|
||||
assert.Equal(t, "old-secret-hash", updated.ClientSecretHash)
|
||||
assert.Equal(t, "audit:read org-context:read", updated.Scopes)
|
||||
}
|
||||
|
||||
func TestApiKeyWithRotatedSecretHashPreservesClientIDAndScopes(t *testing.T) {
|
||||
key := domain.ApiKey{
|
||||
ID: "api-key-id",
|
||||
Name: "M2M Test",
|
||||
ClientID: "client-id-stable",
|
||||
ClientSecretHash: "old-secret-hash",
|
||||
Scopes: "audit:read org-context:read",
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
updated := apiKeyWithRotatedSecretHash(key, "new-secret-hash")
|
||||
|
||||
assert.Equal(t, "client-id-stable", updated.ClientID)
|
||||
assert.Equal(t, "audit:read org-context:read", updated.Scopes)
|
||||
assert.Equal(t, "new-secret-hash", updated.ClientSecretHash)
|
||||
}
|
||||
|
||||
func TestNormalizeApiKeyScopesTrimsAndDeduplicates(t *testing.T) {
|
||||
scopes := normalizeApiKeyScopes([]string{
|
||||
" audit:read ",
|
||||
"",
|
||||
"org-context:read",
|
||||
"audit:read",
|
||||
})
|
||||
|
||||
assert.Equal(t, []string{"audit:read", "org-context:read"}, scopes)
|
||||
}
|
||||
139
baron-sso/backend/internal/handler/audit_handler.go
Normal file
139
baron-sso/backend/internal/handler/audit_handler.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditHandler struct {
|
||||
repo domain.AuditRepository
|
||||
}
|
||||
|
||||
func NewAuditHandler(repo domain.AuditRepository) *AuditHandler {
|
||||
return &AuditHandler{repo: repo}
|
||||
}
|
||||
|
||||
// CreateLog handles POST /api/v1/audit
|
||||
func (h *AuditHandler) CreateLog(c *fiber.Ctx) error {
|
||||
var req domain.AuditLog
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "Cannot parse JSON")
|
||||
}
|
||||
|
||||
// Auto-fill metadata if missing
|
||||
if req.IPAddress == "" {
|
||||
req.IPAddress = c.IP()
|
||||
}
|
||||
if req.UserAgent == "" {
|
||||
req.UserAgent = c.Get("User-Agent")
|
||||
}
|
||||
if req.Timestamp.IsZero() {
|
||||
req.Timestamp = time.Now()
|
||||
}
|
||||
if req.EventID == "" {
|
||||
req.EventID = ensureRequestID(c)
|
||||
}
|
||||
|
||||
if h.repo == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
|
||||
}
|
||||
|
||||
if err := h.repo.Create(&req); err != nil {
|
||||
// Log internal error but don't expose details
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to save audit log")
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
|
||||
"message": "Audit log saved",
|
||||
})
|
||||
}
|
||||
|
||||
// ListLogs handles GET /api/v1/audit
|
||||
func (h *AuditHandler) ListLogs(c *fiber.Ctx) error {
|
||||
limit := c.QueryInt("limit", 50)
|
||||
cursorRaw := c.Query("cursor")
|
||||
requestedTenantID := c.Query("tenantId")
|
||||
|
||||
cursor, err := parseAuditCursor(cursorRaw)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "Invalid cursor")
|
||||
}
|
||||
|
||||
if h.repo == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "Audit service unavailable")
|
||||
}
|
||||
|
||||
// [New] Role-based Filtering
|
||||
profile, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
var filterTenantID string
|
||||
|
||||
if profile != nil {
|
||||
if profile.Role == domain.RoleSuperAdmin {
|
||||
// Super Admin can see everything or filter by a specific tenant if requested
|
||||
filterTenantID = requestedTenantID
|
||||
} else {
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden")
|
||||
}
|
||||
}
|
||||
|
||||
logs, err := h.repo.FindPage(c.Context(), limit+1, cursor, filterTenantID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Failed to retrieve audit logs")
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(logs) > limit {
|
||||
last := logs[limit-1]
|
||||
nextCursor = encodeAuditCursor(last)
|
||||
logs = logs[:limit]
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"items": logs,
|
||||
"limit": limit,
|
||||
"cursor": cursorRaw,
|
||||
"next_cursor": nextCursor,
|
||||
})
|
||||
}
|
||||
|
||||
func ensureRequestID(c *fiber.Ctx) string {
|
||||
reqID := c.Get("X-Request-Id")
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
c.Set("X-Request-Id", reqID)
|
||||
}
|
||||
return reqID
|
||||
}
|
||||
|
||||
func parseAuditCursor(raw string) (*domain.AuditCursor, error) {
|
||||
if raw == "" {
|
||||
return nil, nil
|
||||
}
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := strings.SplitN(string(decoded), "|", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid cursor")
|
||||
}
|
||||
ts, err := time.Parse(time.RFC3339Nano, parts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &domain.AuditCursor{
|
||||
Timestamp: ts,
|
||||
EventID: parts[1],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeAuditCursor(log domain.AuditLog) string {
|
||||
payload := log.Timestamp.UTC().Format(time.RFC3339Nano) + "|" + log.EventID
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||
}
|
||||
9066
baron-sso/backend/internal/handler/auth_handler.go
Normal file
9066
baron-sso/backend/internal/handler/auth_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
371
baron-sso/backend/internal/handler/auth_handler_async_test.go
Normal file
371
baron-sso/backend/internal/handler/auth_handler_async_test.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// --- Async Test Mocks ---
|
||||
|
||||
type AsyncMockIdpProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) Name() string { return "mock-idp" }
|
||||
func (m *AsyncMockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
return &domain.IDPMetadata{}, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) UserExists(loginID string) (bool, error) {
|
||||
args := m.Called(loginID)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
args := m.Called(user, password)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return &domain.PasswordPolicy{MinLength: 12}, nil
|
||||
}
|
||||
func (m *AsyncMockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
|
||||
func (m *AsyncMockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type AsyncMockUserRepo struct {
|
||||
mock.Mock
|
||||
createCalled chan bool
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) Create(ctx context.Context, user *domain.User) error {
|
||||
// Simulate DB latency
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
args := m.Called(ctx, user)
|
||||
if m.createCalled != nil {
|
||||
m.createCalled <- true
|
||||
}
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) Update(ctx context.Context, user *domain.User) error {
|
||||
args := m.Called(ctx, user)
|
||||
if m.createCalled != nil {
|
||||
m.createCalled <- true
|
||||
}
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) Delete(ctx context.Context, id string) error { return nil }
|
||||
func (m *AsyncMockUserRepo) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindByID(ctx context.Context, id string) (*domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) List(ctx context.Context, offset, limit int, search string, tenantIDs []string, cursor string) ([]domain.User, int64, string, error) {
|
||||
return nil, 0, "", nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) CountByTenant(ctx context.Context, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, tenantIDs)
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) CountByTenantIDs(ctx context.Context, tenantIDs []string) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindByCompanyCodes(ctx context.Context, codes []string) ([]domain.User, error) {
|
||||
args := m.Called(ctx, codes)
|
||||
return args.Get(0).([]domain.User), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) CountByCompanyCodes(ctx context.Context, codes []string) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) DB() *gorm.DB {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) GetUserLoginIDs(ctx context.Context, userID string) ([]domain.UserLoginID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) IsLoginIDTaken(ctx context.Context, loginID string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockUserRepo) FindTenantIDByLoginID(ctx context.Context, loginID string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type AsyncMockRedisRepo struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AsyncMockRedisRepo) Set(key string, value string, expiration time.Duration) error {
|
||||
args := m.Called(key, value, expiration)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AsyncMockRedisRepo) Get(key string) (string, error) {
|
||||
args := m.Called(key)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockRedisRepo) Delete(key string) error {
|
||||
args := m.Called(key)
|
||||
return args.Error(0)
|
||||
}
|
||||
func (m *AsyncMockRedisRepo) StoreVerificationCode(phone, code string) error { return nil }
|
||||
func (m *AsyncMockRedisRepo) GetVerificationCode(phone string) (string, error) { return "", nil }
|
||||
func (m *AsyncMockRedisRepo) DeleteVerificationCode(phone string) error { return nil }
|
||||
|
||||
type AsyncMockTenantService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, emailDomain)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *AsyncMockTenantService) ApproveTenant(ctx context.Context, id string) error { return nil }
|
||||
func (m *AsyncMockTenantService) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *AsyncMockTenantService) SetKetoService(keto service.KetoService) {}
|
||||
func (m *AsyncMockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) DeleteTenantsBulk(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockTenantService) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
args := m.Called(ctx, userID)
|
||||
if args.Get(0) != nil {
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
type AsyncMockKetoService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
args := m.Called(ctx, namespace, object, relation, subject)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *AsyncMockKetoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestSignup_AsyncDB_Isolation(t *testing.T) {
|
||||
mockIdp := new(AsyncMockIdpProvider)
|
||||
mockUserRepo := new(AsyncMockUserRepo)
|
||||
mockRedis := new(AsyncMockRedisRepo)
|
||||
mockTenant := new(AsyncMockTenantService)
|
||||
mockKeto := new(AsyncMockKetoService)
|
||||
|
||||
h := &AuthHandler{
|
||||
IdpProvider: mockIdp,
|
||||
UserRepo: mockUserRepo,
|
||||
RedisService: mockRedis,
|
||||
TenantService: mockTenant,
|
||||
KetoService: mockKeto,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/signup", h.Signup)
|
||||
|
||||
t.Run("SoT_DB_Failure_Ignored_And_Async", func(t *testing.T) {
|
||||
email := "test@example.com"
|
||||
phone := "010-1234-5678"
|
||||
emailKey := "signup:email:" + email
|
||||
phoneKey := "signup:phone:" + "01012345678"
|
||||
|
||||
// Redis Mocks
|
||||
mockRedis.On("Get", emailKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
|
||||
mockRedis.On("Get", phoneKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
|
||||
mockRedis.On("Delete", emailKey).Return(nil)
|
||||
mockRedis.On("Delete", phoneKey).Return(nil)
|
||||
|
||||
// Tenant Mocks
|
||||
personalTenant := &domain.Tenant{ID: "personal-t1", Slug: "personal-test", Type: domain.TenantTypePersonal, Status: domain.TenantStatusActive}
|
||||
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(nil, nil)
|
||||
mockTenant.On(
|
||||
"RegisterTenant",
|
||||
mock.Anything,
|
||||
"Personal - test@example.com",
|
||||
mock.MatchedBy(func(slug string) bool { return strings.HasPrefix(slug, "personal-") }),
|
||||
domain.TenantTypePersonal,
|
||||
"Automatically provisioned personal tenant",
|
||||
[]string(nil),
|
||||
(*string)(nil),
|
||||
"",
|
||||
).Return(personalTenant, nil)
|
||||
mockTenant.On("GetTenant", mock.Anything, "personal-t1").Return(personalTenant, nil)
|
||||
|
||||
// Kratos Mocks (Success)
|
||||
mockIdp.On("CreateUser", mock.Anything, "Password123!").Return("new-user-uuid", nil)
|
||||
|
||||
// UserRepo Mocks (Async & Failure)
|
||||
mockUserRepo.createCalled = make(chan bool, 1)
|
||||
mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
|
||||
return u.Email == email
|
||||
})).Return(errors.New("db connection error"))
|
||||
|
||||
// Keto Mocks (Optional, since it's also async)
|
||||
// We won't block on this either
|
||||
|
||||
body, _ := json.Marshal(domain.SignupRequest{
|
||||
Email: email,
|
||||
Password: "Password123!",
|
||||
Name: "Test User",
|
||||
Phone: phone,
|
||||
TermsAccepted: true,
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
start := time.Now()
|
||||
resp, err := app.Test(req, 5000)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
// Ensure API responded faster than DB latency (50ms)
|
||||
assert.Less(t, int64(elapsed), int64(60*time.Millisecond), "API should return before DB timeout")
|
||||
|
||||
// Wait for async execution
|
||||
select {
|
||||
case <-mockUserRepo.createCalled:
|
||||
// Pass
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("UserRepo.Create was not called asynchronously")
|
||||
}
|
||||
|
||||
mockRedis.AssertExpectations(t)
|
||||
mockIdp.AssertExpectations(t)
|
||||
mockUserRepo.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
200
baron-sso/backend/internal/handler/auth_handler_client_test.go
Normal file
200
baron-sso/backend/internal/handler/auth_handler_client_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRevokeLinkedRp_Success(t *testing.T) {
|
||||
// Mock Hydra transport for revocation
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// 1. Kratos whoami
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{"id": "user-123"},
|
||||
}), nil
|
||||
}
|
||||
// 2. Hydra Revoke
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
assert.Equal(t, "user-123", r.URL.Query().Get("subject"))
|
||||
assert.Equal(t, "app-1", r.URL.Query().Get("client"))
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
rpUsageSink := &mockRPUsageEventSink{}
|
||||
consentRepo := &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{
|
||||
ClientID: "app-1",
|
||||
Subject: "user-123",
|
||||
GrantedScopes: []string{"openid", "profile"},
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
AuditRepo: auditRepo,
|
||||
ConsentRepo: consentRepo,
|
||||
RPUsageSink: rpUsageSink,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/rp/linked/:id", h.RevokeLinkedRp)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/rp/linked/app-1", nil)
|
||||
req.Header.Set("Cookie", "valid")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, 1, len(auditRepo.logs))
|
||||
assert.Equal(t, "consent.revoked", auditRepo.logs[0].EventType)
|
||||
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
|
||||
assert.Equal(t, "success", auditRepo.logs[0].Status)
|
||||
auditDetails, err := utils.ParseAuditDetails(auditRepo.logs[0].Details)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "app-1", auditDetails["client_id"])
|
||||
assert.Equal(t, 1, len(rpUsageSink.events))
|
||||
assert.Equal(t, domain.RPUsageEventTypeAuthorizationRevoked, rpUsageSink.events[0].EventType)
|
||||
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
|
||||
assert.Equal(t, "app-1", rpUsageSink.events[0].ClientID)
|
||||
remaining, err := consentRepo.Find(req.Context(), "app-1", "user-123")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, remaining)
|
||||
}
|
||||
|
||||
func TestRevokeLinkedRp_SendsBackchannelLogoutTokenWhenConfigured(t *testing.T) {
|
||||
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
|
||||
|
||||
var receivedBody string
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{"id": "user-123"},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Host == "hydra.test" && r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
if r.URL.Host == "hydra.test" && r.Method == http.MethodGet && r.URL.Path == "/clients/app-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "app-1",
|
||||
"backchannel_logout_uri": "https://rp.example.com/backchannel-logout",
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Host == "rp.example.com" && r.Method == http.MethodPost && r.URL.Path == "/backchannel-logout" {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
receivedBody = string(raw)
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
backchannelLogout, err := service.NewBackchannelLogoutService()
|
||||
assert.NoError(t, err)
|
||||
backchannelLogout.HTTPClient = client
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
BackchannelLogout: backchannelLogout,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/rp/linked/:id", h.RevokeLinkedRp)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/rp/linked/app-1", nil)
|
||||
req.Header.Set("Cookie", "valid")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.True(t, strings.Contains(receivedBody, "logout_token="))
|
||||
|
||||
values, err := url.ParseQuery(receivedBody)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, values.Get("logout_token"))
|
||||
|
||||
assert.Len(t, auditRepo.logs, 2)
|
||||
assert.Equal(t, "backchannel_logout.sent", auditRepo.logs[1].EventType)
|
||||
}
|
||||
|
||||
func TestListRpHistory_Aggregation(t *testing.T) {
|
||||
now := time.Now()
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.revoked", // Newest
|
||||
Timestamp: now,
|
||||
Details: `{"client_id":"app-1"}`,
|
||||
},
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.granted", // Oldest
|
||||
Timestamp: now.Add(-1 * time.Hour),
|
||||
Details: `{"client_id":"app-1", "client_name":"App One"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/rp/history", h.ListRpHistory)
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{"id": "user-123"},
|
||||
}), nil
|
||||
})
|
||||
http.DefaultClient = &http.Client{Transport: transport}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/history", nil)
|
||||
req.Header.Set("Cookie", "valid")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res struct {
|
||||
Items []struct {
|
||||
ClientID string `json:"client_id"`
|
||||
Status string `json:"status"`
|
||||
} `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Equal(t, 1, len(res.Items))
|
||||
assert.Equal(t, "app-1", res.Items[0].ClientID)
|
||||
// Newest event (revoked) should win
|
||||
assert.Equal(t, "revoked", res.Items[0].Status)
|
||||
}
|
||||
515
baron-sso/backend/internal/handler/auth_handler_consent_test.go
Normal file
515
baron-sso/backend/internal/handler/auth_handler_consent_test.go
Normal file
@@ -0,0 +1,515 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// --- Mocks ---
|
||||
|
||||
type MockKratosAdminServiceForConsent struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) FindIdentityIDByIdentifier(ctx context.Context, identifier string) (string, error) {
|
||||
args := m.Called(ctx, identifier)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) GetIdentity(ctx context.Context, id string) (*service.KratosIdentity, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*service.KratosIdentity), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) ListIdentities(ctx context.Context) ([]service.KratosIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) CreateIdentity(ctx context.Context, traits map[string]any) (*service.KratosIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) DeleteIdentity(ctx context.Context, identityID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) UpdateIdentityPassword(ctx context.Context, identityID, newPassword string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) ListIdentitySessions(ctx context.Context, identityID string) ([]service.KratosSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) GetSession(ctx context.Context, sessionID string) (*service.KratosSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockKratosAdminServiceForConsent) CreateUser(ctx context.Context, user *domain.BrokerUser, password string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type MockTenantServiceForConsent struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) GetTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ListTenants(ctx context.Context, limit, offset int, parentID string, search string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ApproveTenant(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
args := m.Called(ctx, userID)
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ListJoinedTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
args := m.Called(ctx, userID)
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForConsent) SetKetoService(keto service.KetoService) {}
|
||||
|
||||
func (m *MockTenantServiceForConsent) DeleteTenantsBulk(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Test Helpers ---
|
||||
|
||||
func newConsentTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
return app
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestGetConsentRequest_Normal(t *testing.T) {
|
||||
// Mock Hydra transport
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-123",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"skip": false,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"client_name": "Test App",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
app := newConsentTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-123", nil)
|
||||
resp, err := app.Test(req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&body)
|
||||
|
||||
assert.Equal(t, "challenge-123", body["challenge"])
|
||||
assert.Equal(t, false, body["skip"])
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_AddsMandatoryTenantScope(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-scope" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-tenant-scope",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"skip": false,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"client_name": "Test App",
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-allow"},
|
||||
"structured_scopes": []map[string]any{
|
||||
{"name": "openid", "mandatory": true},
|
||||
{"name": "tenant", "mandatory": true, "locked": true},
|
||||
{"name": "profile", "mandatory": false},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
mockTenantSvc := &MockTenantServiceForConsent{}
|
||||
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
|
||||
|
||||
// Mock profile resolution to allow tenant access
|
||||
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, "tenant-allow").Return(&domain.Tenant{
|
||||
ID: "tenant-allow",
|
||||
Slug: "tenant-allow",
|
||||
Name: "Allowed Tenant",
|
||||
}, nil)
|
||||
|
||||
// Mock hydration calls
|
||||
mockTenantSvc.On("ListJoinedTenants", mock.Anything, mock.Anything).Return([]domain.Tenant{
|
||||
{ID: "tenant-allow", Slug: "tenant-allow", Name: "Allowed Tenant"},
|
||||
}, nil)
|
||||
mockTenantSvc.On("ListManageableTenants", mock.Anything, mock.Anything).Return([]domain.Tenant{}, nil)
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
TenantService: mockTenantSvc,
|
||||
KratosAdmin: mockKratosAdmin,
|
||||
}
|
||||
app := newConsentTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant-scope", nil)
|
||||
req.Header.Set("X-Mock-Role", "user")
|
||||
req.Header.Set("X-Tenant-ID", "tenant-allow")
|
||||
resp, err := app.Test(req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&body)
|
||||
|
||||
assert.Equal(t, []any{"openid", "tenant", "profile"}, body["requested_scope"])
|
||||
scopeDetails := body["scope_details"].(map[string]any)
|
||||
tenantDetail := scopeDetails["tenant"].(map[string]any)
|
||||
assert.Equal(t, true, tenantDetail["mandatory"])
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_Skip_AutoAccept(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// Hydra: Get Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-skip",
|
||||
"requested_scope": []string{"openid"},
|
||||
"skip": true,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Kratos: Get Identity
|
||||
if r.URL.Path == "/admin/identities/user-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Hydra: Accept Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-skip" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
consentRepo := &mockConsentRepo{}
|
||||
rpUsageSink := &mockRPUsageEventSink{}
|
||||
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: mockKratosAdmin,
|
||||
ConsentRepo: consentRepo,
|
||||
RPUsageSink: rpUsageSink,
|
||||
}
|
||||
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := newConsentTestApp(h)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip", nil)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.Equal(t, "http://rp/cb", body["redirectTo"])
|
||||
assert.Equal(t, 1, len(rpUsageSink.events))
|
||||
assert.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, rpUsageSink.events[0].EventType)
|
||||
assert.Equal(t, "client-app", rpUsageSink.events[0].ClientID)
|
||||
assert.Equal(t, "challenge-skip", rpUsageSink.events[0].CorrelationID)
|
||||
assert.Equal(t, true, rpUsageSink.events[0].Payload["auto_accepted"])
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_Normal(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-accept" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-accept",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"client_name": "Test App",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/identities/user-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-accept" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
consentRepo := &mockConsentRepo{}
|
||||
rpUsageSink := &mockRPUsageEventSink{}
|
||||
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: mockKratosAdmin,
|
||||
AuditRepo: auditRepo,
|
||||
ConsentRepo: consentRepo,
|
||||
RPUsageSink: rpUsageSink,
|
||||
}
|
||||
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := newConsentTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-accept",
|
||||
"grant_scope": []string{"openid"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, 1, len(auditRepo.logs))
|
||||
assert.Equal(t, "consent.granted", auditRepo.logs[0].EventType)
|
||||
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
|
||||
assert.Equal(t, "success", auditRepo.logs[0].Status)
|
||||
auditDetails, err := utils.ParseAuditDetails(auditRepo.logs[0].Details)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "client-app", auditDetails["client_id"])
|
||||
assert.Equal(t, "Test App", auditDetails["client_name"])
|
||||
assert.Equal(t, []any{"openid"}, auditDetails["scopes"])
|
||||
assert.Equal(t, 1, len(rpUsageSink.events))
|
||||
assert.Equal(t, domain.RPUsageEventTypeAuthorizationGranted, rpUsageSink.events[0].EventType)
|
||||
assert.Equal(t, "user-123", rpUsageSink.events[0].Subject)
|
||||
assert.Equal(t, "client-app", rpUsageSink.events[0].ClientID)
|
||||
assert.Equal(t, "Test App", rpUsageSink.events[0].ClientName)
|
||||
assert.Equal(t, []string{"openid"}, []string(rpUsageSink.events[0].Scopes))
|
||||
assert.Equal(t, "hydra_consent", rpUsageSink.events[0].Source)
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_EnforcesMandatoryTenantScope(t *testing.T) {
|
||||
t.Setenv("APP_ENV", "dev")
|
||||
|
||||
var capturedGrantScopes []string
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-accept" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-tenant-accept",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-abc",
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-abc"},
|
||||
"structured_scopes": []map[string]any{
|
||||
{"name": "openid", "mandatory": true},
|
||||
{"name": "tenant", "mandatory": true, "locked": true},
|
||||
{"name": "profile", "mandatory": false},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/identities/user-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-tenant-accept" {
|
||||
var payload map[string]any
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&payload))
|
||||
for _, scope := range payload["grant_scope"].([]any) {
|
||||
capturedGrantScopes = append(capturedGrantScopes, scope.(string))
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
mockKratosAdmin := &MockKratosAdminServiceForConsent{}
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: mockKratosAdmin,
|
||||
}
|
||||
mockKratosAdmin.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := newConsentTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-tenant-accept",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Mock-Role", "user")
|
||||
req.Header.Set("X-Tenant-ID", "tenant-abc")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, []string{"openid", "tenant", "profile"}, capturedGrantScopes)
|
||||
}
|
||||
@@ -0,0 +1,829 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestBuildOidcClaimsFromTraits_DynamicClaims(t *testing.T) {
|
||||
traits := map[string]any{
|
||||
"email": "user@baron.com",
|
||||
"name": "홍길동",
|
||||
"tenant_id": "primary-tenant-999", // Added primary tenant
|
||||
"tenant-1": map[string]any{
|
||||
"department": "개발팀",
|
||||
"grade": "선임",
|
||||
},
|
||||
"tenant-2": map[string]any{
|
||||
"department": "재무팀",
|
||||
"grade": "팀장",
|
||||
},
|
||||
}
|
||||
scopes := []string{"openid", "profile"}
|
||||
|
||||
t.Run("No tenantID", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, scopes, "")
|
||||
assert.Equal(t, "user@baron.com", claims["email"])
|
||||
assert.Equal(t, "홍길동", claims["name"])
|
||||
assert.Equal(t, "primary-tenant-999", claims["tenant_id"])
|
||||
assert.Nil(t, claims["department"])
|
||||
assert.Nil(t, claims["grade"])
|
||||
|
||||
assert.Nil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-1")
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-2")
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999") // Should contain primary
|
||||
})
|
||||
|
||||
t.Run("With tenant-1", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-1")
|
||||
assert.Equal(t, "user@baron.com", claims["email"])
|
||||
assert.Equal(t, "홍길동", claims["name"])
|
||||
assert.Equal(t, "tenant-1", claims["tenant_id"])
|
||||
assert.Nil(t, claims["department"])
|
||||
assert.Nil(t, claims["grade"])
|
||||
|
||||
assert.Nil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-1")
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-2")
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
|
||||
})
|
||||
|
||||
t.Run("With tenant-2", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-2")
|
||||
assert.Equal(t, "user@baron.com", claims["email"])
|
||||
assert.Equal(t, "홍길동", claims["name"])
|
||||
assert.Equal(t, "tenant-2", claims["tenant_id"])
|
||||
assert.Nil(t, claims["department"])
|
||||
assert.Nil(t, claims["grade"])
|
||||
|
||||
assert.Nil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
|
||||
})
|
||||
|
||||
t.Run("With non-existent tenant", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-3")
|
||||
assert.Equal(t, "user@baron.com", claims["email"])
|
||||
assert.Equal(t, "홍길동", claims["name"])
|
||||
assert.Equal(t, "tenant-3", claims["tenant_id"])
|
||||
assert.Nil(t, claims["department"])
|
||||
assert.Nil(t, claims["grade"])
|
||||
|
||||
assert.Nil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-1")
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
|
||||
})
|
||||
|
||||
t.Run("Tenant scope includes detailed tenant metadata", func(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(traits, []string{"openid", "profile", "tenant"}, "tenant-1")
|
||||
assert.Equal(t, "tenant-1", claims["tenant_id"])
|
||||
assert.Equal(t, "개발팀", claims["department"])
|
||||
assert.Equal(t, "선임", claims["grade"])
|
||||
assert.NotNil(t, claims["tenants"])
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-1")
|
||||
assert.Contains(t, claims["joined_tenants"], "tenant-2")
|
||||
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepresentativeTenantIDFromTraits(t *testing.T) {
|
||||
t.Run("explicit tenant_id wins", func(t *testing.T) {
|
||||
traits := map[string]any{
|
||||
"tenant_id": "01970f0a-5c28-74d8-a73a-f6e9e9a7b210",
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", "isPrimary": true},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "01970f0a-5c28-74d8-a73a-f6e9e9a7b210", representativeTenantIDFromTraits(traits))
|
||||
})
|
||||
|
||||
t.Run("primary appointment wins when tenant_id is absent", func(t *testing.T) {
|
||||
traits := map[string]any{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
|
||||
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630", "representative": true},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "01970f0c-8c44-7069-9f20-7d28c0b8e630", representativeTenantIDFromTraits(traits))
|
||||
})
|
||||
|
||||
t.Run("first appointment is fallback", func(t *testing.T) {
|
||||
traits := map[string]any{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
|
||||
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630"},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", representativeTenantIDFromTraits(traits))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// Hydra: Get Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-dynamic",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-abc",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Kratos: Get Identity
|
||||
if r.URL.Path == "/admin/identities/user-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
"tenant-abc": map[string]any{
|
||||
"department": "Innovation",
|
||||
"position": "Architect",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Hydra: Accept Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
|
||||
// Capture the claims sent to Hydra
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
"tenant-abc": map[string]any{
|
||||
"department": "Innovation",
|
||||
"position": "Architect",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-dynamic",
|
||||
"grant_scope": []string{"openid", "profile", "tenant"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify captured claims
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, "user@test.com", capturedClaims["email"])
|
||||
assert.Equal(t, "tenant-abc", capturedClaims["tenant_id"])
|
||||
assert.Equal(t, "Innovation", capturedClaims["department"])
|
||||
assert.Equal(t, "Architect", capturedClaims["position"])
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_UsesRepresentativeTenantIDInsteadOfClientTenantContext(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
representativeTenantID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
|
||||
rpContextTenantID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-representative-tenant",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"subject": "user-representative",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": rpContextTenantID,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-representative").Return(&service.KratosIdentity{
|
||||
ID: "user-representative",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{"tenantId": representativeTenantID, "isPrimary": true},
|
||||
map[string]any{"tenantId": rpContextTenantID},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-representative-tenant",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, representativeTenantID, capturedClaims["tenant_id"])
|
||||
assert.Contains(t, capturedClaims["joined_tenants"], representativeTenantID)
|
||||
assert.Contains(t, capturedClaims["joined_tenants"], rpContextTenantID)
|
||||
assert.Nil(t, capturedClaims["tenants"])
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_IncludesHanmacFamilyTenantClaimDetails(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
|
||||
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
|
||||
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
|
||||
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-hanmac-tenant-claim",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"subject": "user-hanmac",
|
||||
"client": map[string]any{
|
||||
"client_id": "hanmac-rp",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": deptID,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-hanmac").Return(&service.KratosIdentity{
|
||||
ID: "user-hanmac",
|
||||
Traits: map[string]any{
|
||||
"email": "hanmac-user@example.com",
|
||||
"name": "한맥 사용자",
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": deptID,
|
||||
"isPrimary": true,
|
||||
"isOwner": true,
|
||||
"grade": "책임",
|
||||
"jobTitle": "기술기획",
|
||||
"position": "팀장",
|
||||
},
|
||||
map[string]any{
|
||||
"tenantId": secondDeptID,
|
||||
"isPrimary": false,
|
||||
"isOwner": false,
|
||||
"grade": "선임",
|
||||
"jobTitle": "품질관리",
|
||||
"position": "파트원",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mockTenantSvc := new(MockTenantService)
|
||||
mockTenantSvc.On("ListJoinedTenants", mock.Anything, "user-hanmac").Return([]domain.Tenant{}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
|
||||
ID: deptID,
|
||||
Slug: "tech-planning",
|
||||
Name: "기술기획팀",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &companyID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
|
||||
ID: secondDeptID,
|
||||
Slug: "quality",
|
||||
Name: "품질관리팀",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &companyID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
Name: "한맥기술",
|
||||
Type: domain.TenantTypeCompany,
|
||||
ParentID: &rootID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
|
||||
ID: rootID,
|
||||
Slug: "hanmac-family",
|
||||
Name: "한맥가족",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
}, nil)
|
||||
h.TenantService = mockTenantSvc
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-hanmac-tenant-claim",
|
||||
"grant_scope": []string{"openid", "profile", "tenant"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, []any{deptID}, capturedClaims["lead_tenants"])
|
||||
assert.ElementsMatch(t, []any{deptID, secondDeptID}, capturedClaims["joined_tenants"])
|
||||
tenants := capturedClaims["tenants"].(map[string]any)
|
||||
dept := tenants[deptID].(map[string]any)
|
||||
assert.Equal(t, true, dept["lead"])
|
||||
assert.Equal(t, true, dept["representative"])
|
||||
assert.Equal(t, "책임", dept["grade"])
|
||||
assert.Equal(t, "기술기획", dept["jobTitle"])
|
||||
assert.Equal(t, "팀장", dept["position"])
|
||||
assert.Equal(t, companyID, dept["parentTenantId"])
|
||||
assert.NotContains(t, dept, "parentTenant")
|
||||
|
||||
ancestors := dept["ancestors"].([]any)
|
||||
assert.Len(t, ancestors, 2)
|
||||
companyAncestor := ancestors[0].(map[string]any)
|
||||
assert.Equal(t, companyID, companyAncestor["id"])
|
||||
assert.Equal(t, "hanmac", companyAncestor["slug"])
|
||||
assert.Equal(t, rootID, companyAncestor["parentTenantId"])
|
||||
assert.NotContains(t, companyAncestor, "parentTenant")
|
||||
rootAncestor := ancestors[1].(map[string]any)
|
||||
assert.Equal(t, rootID, rootAncestor["id"])
|
||||
assert.Equal(t, "hanmac-family", rootAncestor["slug"])
|
||||
assert.Contains(t, rootAncestor, "parentTenantId")
|
||||
assert.Nil(t, rootAncestor["parentTenantId"])
|
||||
assert.NotContains(t, rootAncestor, "parentTenant")
|
||||
|
||||
secondDept := tenants[secondDeptID].(map[string]any)
|
||||
assert.Equal(t, false, secondDept["lead"])
|
||||
assert.Equal(t, false, secondDept["representative"])
|
||||
assert.Equal(t, "선임", secondDept["grade"])
|
||||
assert.Equal(t, "품질관리", secondDept["jobTitle"])
|
||||
assert.Equal(t, "파트원", secondDept["position"])
|
||||
assert.Equal(t, companyID, secondDept["parentTenantId"])
|
||||
}
|
||||
|
||||
func TestWithHanmacFamilyTenantClaims_DefaultClaimsOnlyWithoutTenantScope(t *testing.T) {
|
||||
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
|
||||
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
|
||||
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
|
||||
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
|
||||
|
||||
mockTenantSvc := new(MockTenantService)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
|
||||
ID: deptID,
|
||||
Slug: "tech-planning",
|
||||
Name: "기술기획팀",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &companyID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
|
||||
ID: secondDeptID,
|
||||
Slug: "quality",
|
||||
Name: "품질관리팀",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &companyID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
Name: "한맥기술",
|
||||
Type: domain.TenantTypeCompany,
|
||||
ParentID: &rootID,
|
||||
}, nil)
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
|
||||
ID: rootID,
|
||||
Slug: "hanmac-family",
|
||||
Name: "한맥가족",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
}, nil)
|
||||
|
||||
h := &AuthHandler{TenantService: mockTenantSvc}
|
||||
claims := map[string]any{"tenant_id": deptID}
|
||||
traits := map[string]any{
|
||||
"additionalAppointments": []any{
|
||||
map[string]any{
|
||||
"tenantId": deptID,
|
||||
"isPrimary": true,
|
||||
"isOwner": true,
|
||||
"grade": "책임",
|
||||
},
|
||||
map[string]any{
|
||||
"tenantId": secondDeptID,
|
||||
"grade": "선임",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claims = h.withHanmacFamilyTenantClaims(context.Background(), claims, traits, []string{"openid", "profile"})
|
||||
|
||||
assert.Equal(t, deptID, claims["tenant_id"])
|
||||
assert.ElementsMatch(t, []string{deptID, secondDeptID}, claims["joined_tenants"])
|
||||
assert.NotContains(t, claims, "tenants")
|
||||
assert.NotContains(t, claims, "lead_tenants")
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_IncludesRPProfileClaims(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-rp-profile",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]any{
|
||||
"customUserSchema": []map[string]any{
|
||||
{
|
||||
"key": "approvalLevel",
|
||||
"label": "승인 등급",
|
||||
"type": "text",
|
||||
"claimEnabled": true,
|
||||
},
|
||||
{
|
||||
"key": "internalMemo",
|
||||
"label": "내부 메모",
|
||||
"type": "text",
|
||||
"claimEnabled": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
}, nil)
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
repo.On("Get", mock.Anything, "client-app", "user-123").Return(&domain.RPUserMetadata{
|
||||
ClientID: "client-app",
|
||||
UserID: "user-123",
|
||||
Metadata: domain.JSONMap{
|
||||
"approvalLevel": "A",
|
||||
"internalMemo": "관리자 전용",
|
||||
},
|
||||
}, nil).Once()
|
||||
h.RPUserMetadataRepo = repo
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-rp-profile",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
rpProfiles, ok := capturedClaims["rp_profiles"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, rpProfiles, 1)
|
||||
profile := rpProfiles[0].(map[string]any)
|
||||
assert.Equal(t, "client-app", profile["client_id"])
|
||||
fields := profile["fields"].(map[string]any)
|
||||
assert.Equal(t, "A", fields["approvalLevel"])
|
||||
assert.NotContains(t, fields, "internalMemo")
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_Skip_DynamicClaims(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// Hydra: Get Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-skip-dynamic",
|
||||
"requested_scope": []string{"openid", "profile", "tenant"},
|
||||
"skip": true,
|
||||
"subject": "user-456",
|
||||
"client": map[string]any{
|
||||
"client_id": "skip-app",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-xyz",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Kratos: Get Identity
|
||||
if r.URL.Path == "/admin/identities/user-456" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "user-456",
|
||||
"traits": map[string]any{
|
||||
"email": "skip@test.com",
|
||||
"tenant-xyz": map[string]any{
|
||||
"department": "Security",
|
||||
"position": "Officer",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
// Hydra: Accept Consent Request
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
|
||||
// Capture the claims sent to Hydra
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-456").Return(&service.KratosIdentity{
|
||||
ID: "user-456",
|
||||
Traits: map[string]any{
|
||||
"email": "skip@test.com",
|
||||
"tenant-xyz": map[string]any{
|
||||
"department": "Security",
|
||||
"position": "Officer",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-skip-dynamic", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify captured claims
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, "skip@test.com", capturedClaims["email"])
|
||||
assert.Equal(t, "tenant-xyz", capturedClaims["tenant_id"])
|
||||
assert.Equal(t, "Security", capturedClaims["department"])
|
||||
assert.Equal(t, "Officer", capturedClaims["position"])
|
||||
}
|
||||
|
||||
func TestBuildOidcClaimsFromTraits_IncludesGlobalCustomClaims(t *testing.T) {
|
||||
claims := buildOidcClaimsFromTraits(map[string]any{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
"global_custom_claims": map[string]any{
|
||||
"contract_date": "2026-06-09",
|
||||
"approved_at": "2026-06-09T09:30:00+09:00",
|
||||
"email": "override@test.com",
|
||||
"rp_claims": "reserved",
|
||||
},
|
||||
"global_custom_claim_permissions": map[string]any{
|
||||
"contract_date": map[string]any{
|
||||
"readPermission": "user_and_admin",
|
||||
"writePermission": "admin_only",
|
||||
},
|
||||
},
|
||||
}, []string{"openid", "profile", "email"}, "")
|
||||
|
||||
assert.Equal(t, "2026-06-09", claims["contract_date"])
|
||||
assert.Equal(t, "2026-06-09T09:30:00+09:00", claims["approved_at"])
|
||||
assert.Equal(t, "user@test.com", claims["email"])
|
||||
assert.NotEqual(t, "reserved", claims["rp_claims"])
|
||||
assert.NotContains(t, claims, "global_custom_claim_permissions")
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_AppliesConfiguredIDTokenClaims(t *testing.T) {
|
||||
var capturedClaims map[string]any
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-configured-claims" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-configured-claims",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"subject": "user-789",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-configured-claims",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-claims",
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "top_level",
|
||||
"key": "locale",
|
||||
"value": "ko-KR",
|
||||
"valueType": "text",
|
||||
},
|
||||
{
|
||||
"namespace": "top_level",
|
||||
"key": "email",
|
||||
"value": "should-not-override@example.com",
|
||||
"valueType": "text",
|
||||
},
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "tier",
|
||||
"value": "2",
|
||||
"valueType": "number",
|
||||
},
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "features",
|
||||
"value": "[\"sso\",\"claims\"]",
|
||||
"valueType": "array",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-configured-claims" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]any
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]any); ok {
|
||||
capturedClaims = session["id_token"].(map[string]any)
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-789").Return(&service.KratosIdentity{
|
||||
ID: "user-789",
|
||||
Traits: map[string]any{
|
||||
"email": "real-user@example.com",
|
||||
"name": "Configured User",
|
||||
"tenant-claims": map[string]any{
|
||||
"department": "Platform",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]any{
|
||||
"consent_challenge": "challenge-configured-claims",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
assert.Equal(t, "real-user@example.com", capturedClaims["email"])
|
||||
assert.Equal(t, "ko-KR", capturedClaims["locale"])
|
||||
assert.Equal(t, "tenant-claims", capturedClaims["tenant_id"])
|
||||
|
||||
rpClaims, ok := capturedClaims["rp_claims"].(map[string]any)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, float64(2), rpClaims["tier"])
|
||||
assert.Equal(t, []any{"sso", "claims"}, rpClaims["features"])
|
||||
}
|
||||
}
|
||||
904
baron-sso/backend/internal/handler/auth_handler_link_test.go
Normal file
904
baron-sso/backend/internal/handler/auth_handler_link_test.go
Normal file
@@ -0,0 +1,904 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/testsupport"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Mock services
|
||||
type mockEmailService struct {
|
||||
lastTo string
|
||||
lastSubject string
|
||||
lastBody string
|
||||
}
|
||||
|
||||
func (m *mockEmailService) SendEmail(to, subject, body string) error {
|
||||
m.lastTo = to
|
||||
m.lastSubject = subject
|
||||
m.lastBody = body
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockSmsService struct {
|
||||
lastTo string
|
||||
lastContent string
|
||||
}
|
||||
|
||||
func (m *mockSmsService) SendSms(to, content string) error {
|
||||
m.lastTo = to
|
||||
m.lastContent = content
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHeadlessLinkTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/headless/link/init", h.HeadlessLinkInit)
|
||||
app.Post("/api/v1/auth/headless/link/poll", h.HeadlessLinkPoll)
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
return app
|
||||
}
|
||||
|
||||
func newKratosWhoamiTestServer(t *testing.T, identityID string) *httptest.Server {
|
||||
t.Helper()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/sessions/whoami" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Cookie") == "" && r.Header.Get("X-Session-Token") == "" {
|
||||
http.Error(w, "missing session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "session-123",
|
||||
"authenticated_at": "2026-05-21T00:00:00Z",
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
origDefaultClient := http.DefaultClient
|
||||
http.DefaultClient = server.Client()
|
||||
t.Cleanup(func() {
|
||||
http.DefaultClient = origDefaultClient
|
||||
})
|
||||
t.Cleanup(server.Close)
|
||||
return server
|
||||
}
|
||||
|
||||
func TestEnchantedLinkFlow_Email_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
// Force "Not Supported" for InitiateLinkLogin only to trigger custom Enchanted Link logic
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
EmailService: &mockEmailService{},
|
||||
SmsService: &mockSmsService{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/init", h.InitEnchantedLink)
|
||||
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
// 1. Init Enchanted Link (Email)
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"loginId": "user@example.com",
|
||||
"method": "email",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
// Find the token key "enchanted_token:..." in mock redis
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 2. Verify Magic Link
|
||||
verifyBody, _ := json.Marshal(map[string]any{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// 3. Poll (Success)
|
||||
pollBody, _ := json.Marshal(map[string]string{"pendingRef": pendingRef})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var pollResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&pollResp)
|
||||
assert.Equal(t, "ok", pollResp["status"])
|
||||
assert.Equal(t, "valid-jwt", pollResp["sessionJwt"])
|
||||
}
|
||||
|
||||
func TestEnchantedLinkFlow_Sms_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
SmsService: &mockSmsService{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/init", h.InitEnchantedLink)
|
||||
|
||||
// 1. Init Enchanted Link (SMS)
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"loginId": "010-1234-5678",
|
||||
"method": "sms",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
assert.NotEmpty(t, initResp["userCode"])
|
||||
}
|
||||
|
||||
func TestVerifyMagicLink_VerifyOnlyWithoutSharedBrowserSessionApprovesOnly(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
|
||||
}}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"token": "token-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Nil(t, got["sessionJwt"])
|
||||
assert.Nil(t, got["token"])
|
||||
}
|
||||
|
||||
func TestVerifyMagicLink_VerifyOnlySharedBrowserSameSubjectApprovesOnly(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-user-1")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"token": "token-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Nil(t, got["sessionJwt"])
|
||||
assert.Nil(t, got["token"])
|
||||
}
|
||||
|
||||
func TestVerifyMagicLink_VerifyOnlySharedBrowserDifferentSubjectApprovesOnly(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixToken + "token-123": `{"pendingRef":"pending-123","loginId":"user@example.com"}`,
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/magic-link/verify", h.VerifyMagicLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"token": "token-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Nil(t, got["sessionJwt"])
|
||||
assert.Nil(t, got["token"])
|
||||
assert.Contains(t, redis.data[prefixSession+"pending-123"], "approved")
|
||||
}
|
||||
|
||||
func TestResolveUserfrontURL_DevLocalhostUsesConfiguredPort(t *testing.T) {
|
||||
t.Setenv("APP_ENV", "dev")
|
||||
t.Setenv("USERFRONT_URL", "http://localhost:5000")
|
||||
|
||||
h := &AuthHandler{}
|
||||
app := fiber.New()
|
||||
app.Get("/probe", func(c *fiber.Ctx) error {
|
||||
return c.SendString(h.resolveUserfrontURL(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost/probe", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "http://localhost:5000", string(body))
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_VerifyOnlySharedBrowserDifferentSubjectApprovesOnly(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixLoginCode + "user@example.com": "flow-123",
|
||||
prefixLoginCodePending + "user@example.com": "pending-123",
|
||||
prefixLoginCodeValue + "pending-123": "569765",
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"loginId": "user@example.com",
|
||||
"code": "569765",
|
||||
"pendingRef": "pending-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Nil(t, got["sessionJwt"])
|
||||
assert.Nil(t, got["token"])
|
||||
assert.Contains(t, redis.data[prefixSession+"pending-123"], "approved")
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_MapsSmsPhoneBeforeFlowLookup(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixLoginCode + "su-@samaneng.com": "flow-123",
|
||||
prefixLoginCodePending + "su-@samaneng.com": "pending-123",
|
||||
prefixLoginCodeSmsLookup + "+821041585840": "su-@samaneng.com",
|
||||
prefixLoginCodeSmsTarget + "su-@samaneng.com": "+821041585840",
|
||||
prefixLoginCodeValue + "pending-123": "569765",
|
||||
}}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"loginId": "01041585840",
|
||||
"code": "569765",
|
||||
"pendingRef": "pending-123",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "approved", got["status"])
|
||||
assert.Equal(t, "pending-123", got["pendingRef"])
|
||||
}
|
||||
|
||||
func TestPollEnchantedLink_ExpiredToken_ReturnsCode(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"pendingRef": "missing-ref",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var got map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "expired_token", got["error"])
|
||||
assert.Equal(t, "expired_token", got["code"])
|
||||
}
|
||||
|
||||
func TestPollEnchantedLink_SharedBrowserSameSubjectIssuesSession(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixSession + "pending-123": `{"status":"approved","loginId":"user@example.com"}`,
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-user-1")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{
|
||||
issueSession: &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "new-session-id"},
|
||||
Subject: "kratos-user-1",
|
||||
},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"pendingRef": "pending-123"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "ok", got["status"])
|
||||
assert.Equal(t, "valid-jwt", got["sessionJwt"])
|
||||
}
|
||||
|
||||
func TestPollEnchantedLink_SharedBrowserDifferentSubjectConflicts(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
prefixSession + "pending-123": `{"status":"approved","loginId":"user@example.com"}`,
|
||||
}}
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-other-user")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{
|
||||
issueSession: &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "new-session-id"},
|
||||
Subject: "kratos-user-1",
|
||||
},
|
||||
},
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/enchanted-link/poll", h.PollEnchantedLink)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"pendingRef": "pending-123"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/enchanted-link/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=shared-browser-session")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, resp.StatusCode)
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "session_subject_conflict", got["code"])
|
||||
assert.NotContains(t, redis.data[prefixSession+"pending-123"], "valid-jwt")
|
||||
}
|
||||
|
||||
func TestHeadlessLinkInit_HeadlessLoginClientSuccess(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(jwksBody)
|
||||
}))
|
||||
defer jwksServer.Close()
|
||||
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet {
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": jwksServer.URL + "/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
SmsService: &mockSmsService{},
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.NotEmpty(t, got["pendingRef"])
|
||||
_, hasUserCode := got["userCode"]
|
||||
assert.False(t, hasUserCode)
|
||||
}
|
||||
|
||||
func TestHeadlessLinkPoll_AfterApprovalReturnsRedirect(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
}
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
ClientName: "local-demo-rp",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-identity-id", nil)
|
||||
auditRepo := &mockAuditRepo{}
|
||||
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
|
||||
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
SmsService: &mockSmsService{},
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
initBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
verifyBody, _ := json.Marshal(map[string]any{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
pollBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
|
||||
"pendingRef": pendingRef,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var pollResp map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&pollResp)
|
||||
assert.Equal(t, "http://rp/cb", pollResp["redirectTo"])
|
||||
assert.Equal(t, "ok", pollResp["status"])
|
||||
assert.Nil(t, pollResp["sessionJwt"])
|
||||
assert.Nil(t, pollResp["token"])
|
||||
assert.Empty(t, resp.Cookies())
|
||||
if assert.Len(t, auditRepo.logs, 1) {
|
||||
assert.Contains(t, auditRepo.logs[0].EventType, "/api/v1/auth/")
|
||||
details, err := parseAuditDetails(auditRepo.logs[0].Details)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse audit details: %v", err)
|
||||
}
|
||||
assert.Equal(t, "headless-login-client", details["client_id"])
|
||||
assert.Equal(t, "local-demo-rp", details["client_name"])
|
||||
assert.Equal(t, "challenge-123", details["login_challenge"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessLinkPoll_ApproverSubjectConflictBlocksMixedRP(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
acceptCalled := false
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
ClientName: "local-demo-rp",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
acceptCalled = true
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-target-b", nil)
|
||||
auditRepo := &mockAuditRepo{}
|
||||
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
|
||||
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
},
|
||||
SmsService: &mockSmsService{},
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
initBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-userfront-a")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
verifyBody, _ := json.Marshal(map[string]any{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=userfront-a-session")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
pollBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
|
||||
"pendingRef": pendingRef,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, resp.StatusCode)
|
||||
assert.False(t, acceptCalled)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "oidc_subject_conflict", got["code"])
|
||||
assert.Equal(t, "redirect_to_userfront_login", got["recommendedAction"])
|
||||
assert.Equal(t, "kratos-userfront-a", got["currentSubject"])
|
||||
assert.Equal(t, "kratos-target-b", got["targetSubject"])
|
||||
assert.Empty(t, auditRepo.logs)
|
||||
}
|
||||
|
||||
func TestHeadlessLinkPoll_RequestCookieSubjectConflictBlocksMixedRP(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
if !testsupport.PortBindingAvailable() {
|
||||
t.Skip("skipping headless link tests because this environment cannot bind local TCP listeners")
|
||||
}
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
privateKey, jwks := mustHeadlessRSAJWK(t)
|
||||
jwksBody, _ := json.Marshal(jwks)
|
||||
acceptCalled := false
|
||||
|
||||
hydraHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login") && r.Method == http.MethodGet:
|
||||
_ = json.NewEncoder(w).Encode(domain.HydraLoginRequest{
|
||||
Challenge: "challenge-123",
|
||||
Client: domain.HydraClient{
|
||||
ClientID: "headless-login-client",
|
||||
TokenEndpointAuthMethod: "none",
|
||||
Metadata: map[string]any{
|
||||
"status": "active",
|
||||
"headless_login_enabled": true,
|
||||
"headless_token_endpoint_auth_method": "private_key_jwt",
|
||||
"headless_jwks_uri": "https://rp.example.com/.well-known/jwks.json",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case strings.Contains(r.URL.Path, "/oauth2/auth/requests/login/accept") && r.Method == http.MethodPut:
|
||||
acceptCalled = true
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"redirect_to": "http://rp/cb"})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("FindIdentityIDByIdentifier", mock.Anything, "+821012345678").Return("kratos-target-b", nil)
|
||||
headlessClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "rp.example.com" && r.URL.Path == "/.well-known/jwks.json" {
|
||||
return httpResponse(r, http.StatusOK, string(jwksBody)), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{
|
||||
userExists: true,
|
||||
initiateLinkErr: domain.ErrNotSupported,
|
||||
},
|
||||
SmsService: &mockSmsService{},
|
||||
KratosAdmin: mockKratos,
|
||||
HeadlessJWKS: service.NewHeadlessJWKSCacheService(nil, headlessClient),
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: mockHydraTransport(hydraHandler)},
|
||||
},
|
||||
}
|
||||
|
||||
app := newHeadlessLinkTestApp(h)
|
||||
t.Setenv("USERFRONT_URL", "http://userfront.test")
|
||||
|
||||
initBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/init"),
|
||||
"loginId": "010-1234-5678",
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/init", bytes.NewReader(initBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
assert.NotEmpty(t, pendingRef)
|
||||
|
||||
var token string
|
||||
for k := range redis.data {
|
||||
if len(k) > 16 && k[:16] == "enchanted_token:" {
|
||||
token = k[16:]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
verifyBody, _ := json.Marshal(map[string]any{
|
||||
"token": token,
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/magic-link/verify", bytes.NewReader(verifyBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
kratosPublic := newKratosWhoamiTestServer(t, "kratos-userfront-a")
|
||||
t.Setenv("KRATOS_PUBLIC_URL", kratosPublic.URL)
|
||||
|
||||
pollBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": "headless-login-client",
|
||||
"client_assertion": mustHeadlessClientAssertion(t, privateKey, "headless-login-client", "http://example.com/api/v1/auth/headless/link/poll"),
|
||||
"pendingRef": pendingRef,
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/headless/link/poll", bytes.NewReader(pollBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=userfront-a-session")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, resp.StatusCode)
|
||||
assert.False(t, acceptCalled)
|
||||
assert.Empty(t, resp.Cookies())
|
||||
var got map[string]any
|
||||
_ = json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, "oidc_subject_conflict", got["code"])
|
||||
assert.Equal(t, "kratos-userfront-a", got["currentSubject"])
|
||||
assert.Equal(t, "kratos-target-b", got["targetSubject"])
|
||||
}
|
||||
287
baron-sso/backend/internal/handler/auth_handler_linked_test.go
Normal file
287
baron-sso/backend/internal/handler/auth_handler_linked_test.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// --- Helper ---
|
||||
|
||||
func newLinkedRpTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/rp/linked", h.ListLinkedRps)
|
||||
return app
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
if r.Header.Get("X-Session-Token") == "" && r.Header.Get("Cookie") == "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "unauthorized"), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
return httpJSONAny(r, http.StatusOK, []map[string]any{
|
||||
{
|
||||
"client": map[string]any{
|
||||
"client_id": "devfront",
|
||||
"client_name": "DevFront",
|
||||
"redirect_uris": []string{
|
||||
"https://active.example.com/callback",
|
||||
},
|
||||
},
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
"handled_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
"client": map[string]any{
|
||||
"client_id": "orgfront",
|
||||
"client_name": "OrgFront",
|
||||
"metadata": map[string]any{
|
||||
"auto_login_supported": true,
|
||||
"auto_login_url": "http://localhost:5175/login",
|
||||
},
|
||||
"redirect_uris": []string{
|
||||
"http://localhost:5175/auth/callback",
|
||||
},
|
||||
},
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
"handled_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/clients/client-audit" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-audit",
|
||||
"client_name": "Audit App",
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/clients/client-consent" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-consent",
|
||||
"client_name": "Consent App",
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() {
|
||||
http.DefaultClient = origDefault
|
||||
}()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.granted",
|
||||
Timestamp: time.Now().Add(-10 * time.Hour),
|
||||
Details: `{"client_id":"client-audit", "scopes":["audit_scope"]}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
consentRepo := &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{
|
||||
Subject: "user-123",
|
||||
ClientID: "client-consent",
|
||||
GrantedScopes: []string{"consent_scope"},
|
||||
UpdatedAt: time.Now().Add(-2 * time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
AuditRepo: auditRepo,
|
||||
ConsentRepo: consentRepo,
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
|
||||
t.Setenv("HYDRA_PUBLIC_URL", "https://sso.example.com/oidc")
|
||||
t.Setenv("DEVFRONT_URL", "http://localhost:5174")
|
||||
|
||||
app := newLinkedRpTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/linked", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Scopes []string `json:"scopes"`
|
||||
InitURL string `json:"init_url"`
|
||||
AutoLoginSupported bool `json:"auto_login_supported"`
|
||||
AutoLoginURL string `json:"auto_login_url"`
|
||||
} `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Equal(t, 4, len(res.Items))
|
||||
|
||||
statusMap := make(map[string]string)
|
||||
for _, item := range res.Items {
|
||||
statusMap[item.ID] = item.Status
|
||||
}
|
||||
|
||||
assert.Equal(t, "active", statusMap["devfront"])
|
||||
assert.Equal(t, "active", statusMap["orgfront"])
|
||||
assert.Equal(t, "inactive", statusMap["client-consent"])
|
||||
assert.Equal(t, "inactive", statusMap["client-audit"])
|
||||
|
||||
var activeInitURL string
|
||||
for _, item := range res.Items {
|
||||
if item.ID == "devfront" {
|
||||
activeInitURL = item.InitURL
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
parsedInitURL, err := url.Parse(activeInitURL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "http", parsedInitURL.Scheme)
|
||||
assert.Equal(t, "localhost:5174", parsedInitURL.Host)
|
||||
assert.Equal(t, "/login", parsedInitURL.Path)
|
||||
assert.Equal(t, "1", parsedInitURL.Query().Get("auto"))
|
||||
assert.Equal(t, "/clients", parsedInitURL.Query().Get("returnTo"))
|
||||
|
||||
var orgfrontItem struct {
|
||||
InitURL string
|
||||
AutoLoginSupported bool
|
||||
AutoLoginURL string
|
||||
}
|
||||
for _, item := range res.Items {
|
||||
if item.ID == "orgfront" {
|
||||
orgfrontItem.InitURL = item.InitURL
|
||||
orgfrontItem.AutoLoginSupported = item.AutoLoginSupported
|
||||
orgfrontItem.AutoLoginURL = item.AutoLoginURL
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, orgfrontItem.AutoLoginSupported)
|
||||
assert.Equal(t, "http://localhost:5175/login?auto=1", orgfrontItem.AutoLoginURL)
|
||||
assert.Equal(t, orgfrontItem.AutoLoginURL, orgfrontItem.InitURL)
|
||||
}
|
||||
|
||||
func TestListLinkedRps_EnrichesLogoFromHydraClientWhenConsentSessionOmitsMetadata(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
return httpJSONAny(r, http.StatusOK, []map[string]any{
|
||||
{
|
||||
"client": map[string]any{
|
||||
"client_id": "gitea-client",
|
||||
"client_name": "Gitea",
|
||||
"redirect_uris": []string{
|
||||
"https://gitea.example.com/callback",
|
||||
},
|
||||
},
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
"handled_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/clients/gitea-client" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "gitea-client",
|
||||
"client_name": "Gitea",
|
||||
"redirect_uris": []string{
|
||||
"https://gitea.example.com/callback",
|
||||
},
|
||||
"metadata": map[string]any{
|
||||
"logo_url": "https://cdn.example.com/gitea.svg",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() {
|
||||
http.DefaultClient = origDefault
|
||||
}()
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
|
||||
t.Setenv("HYDRA_PUBLIC_URL", "https://sso.example.com/oidc")
|
||||
|
||||
app := newLinkedRpTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/rp/linked", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
Logo string `json:"logo"`
|
||||
} `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Len(t, res.Items, 1)
|
||||
assert.Equal(t, "gitea-client", res.Items[0].ID)
|
||||
assert.Equal(t, "https://cdn.example.com/gitea.svg", res.Items[0].Logo)
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func newVerifyLoginCodeTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/login/code/verify", h.VerifyLoginCode)
|
||||
app.Post("/api/v1/auth/login/code/verify-short", h.VerifyLoginShortCode)
|
||||
return app
|
||||
}
|
||||
|
||||
func decodeJSONBody(t *testing.T, resp *http.Response) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
var got map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response body: %v", err)
|
||||
}
|
||||
return got
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_InvalidBody_ReturnsExplicitCode(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewBufferString("{"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "bad_request" {
|
||||
t.Fatalf("expected code=bad_request, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_IdpUnavailable_ReturnsExplicitCode(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"loginId": "user@example.com",
|
||||
"code": "AA-111111",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "service_unavailable" {
|
||||
t.Fatalf("expected code=service_unavailable, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginCode_VerifyOnlyInvalidCode_ReturnsExplicitCode(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
redis.data[prefixLoginCode+"user@example.com"] = "flow-1"
|
||||
redis.data[prefixLoginCodePending+"user@example.com"] = "pending-1"
|
||||
redis.data[prefixLoginCodeValue+"pending-1"] = "AB-123"
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"loginId": "user@example.com",
|
||||
"code": "ZZ-999",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "invalid_code" {
|
||||
t.Fatalf("expected code=invalid_code, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginShortCode_MissingShortCode_ReturnsExplicitCode(t *testing.T) {
|
||||
h := &AuthHandler{
|
||||
RedisService: &mockRedisRepo{data: make(map[string]string)},
|
||||
}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"shortCode": "",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "bad_request" {
|
||||
t.Fatalf("expected code=bad_request, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginShortCode_InvalidOrExpired_ReturnsExplicitCode(t *testing.T) {
|
||||
h := &AuthHandler{
|
||||
RedisService: &mockRedisRepo{data: make(map[string]string)},
|
||||
}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"shortCode": "AB-123456",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "invalid_or_expired_code" {
|
||||
t.Fatalf("expected code=invalid_or_expired_code, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyLoginShortCode_VerifyOnlyMissingPendingRef_ReturnsExplicitCode(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
payload, _ := json.Marshal(shortLoginCodePayload{
|
||||
LoginID: "user@example.com",
|
||||
Code: "AB-123",
|
||||
})
|
||||
redis.data[prefixLoginCodeShort+"AB-123456"] = string(payload)
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := newVerifyLoginCodeTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"shortCode": "AB-123456",
|
||||
"verifyOnly": true,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/code/verify-short", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
got := decodeJSONBody(t, resp)
|
||||
if got["code"] != "invalid_session_reference" {
|
||||
t.Fatalf("expected code=invalid_session_reference, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
2398
baron-sso/backend/internal/handler/auth_handler_login_test.go
Normal file
2398
baron-sso/backend/internal/handler/auth_handler_login_test.go
Normal file
File diff suppressed because it is too large
Load Diff
178
baron-sso/backend/internal/handler/auth_handler_oidc_test.go
Normal file
178
baron-sso/backend/internal/handler/auth_handler_oidc_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func newOidcLoginTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/oidc/login/accept", h.AcceptOidcLoginRequest)
|
||||
return app
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_CookieOnly(t *testing.T) {
|
||||
var gotSubject string
|
||||
var gotChallenge string
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path != "/sessions/whoami" {
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
if r.Header.Get("X-Session-Token") != "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "invalid token"), nil
|
||||
}
|
||||
if r.Header.Get("Cookie") == "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "missing cookie"), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "kratos-123",
|
||||
"traits": map[string]any{},
|
||||
},
|
||||
}), nil
|
||||
case "hydra.test":
|
||||
if r.URL.Path != "/oauth2/auth/requests/login/accept" {
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
gotChallenge = r.URL.Query().Get("login_challenge")
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var payload map[string]any
|
||||
_ = json.Unmarshal(body, &payload)
|
||||
if subject, ok := payload["subject"].(string); ok {
|
||||
gotSubject = subject
|
||||
}
|
||||
return httpResponse(r, http.StatusOK, `{"redirect_to":"http://rp/cb"}`), nil
|
||||
default:
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
})
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() {
|
||||
http.DefaultClient = origDefault
|
||||
}()
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
app := newOidcLoginTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"login_challenge": "challenge-123",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=abc123")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
var got map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["redirectTo"] != "http://rp/cb" {
|
||||
t.Fatalf("unexpected redirectTo: %v", got["redirectTo"])
|
||||
}
|
||||
if gotSubject != "kratos-123" {
|
||||
t.Fatalf("unexpected subject: %v", gotSubject)
|
||||
}
|
||||
if gotChallenge != "challenge-123" {
|
||||
t.Fatalf("unexpected login_challenge: %v", gotChallenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_TokenFallbackToCookie(t *testing.T) {
|
||||
var gotSubject string
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path != "/sessions/whoami" {
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
if r.Header.Get("X-Session-Token") != "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "invalid token"), nil
|
||||
}
|
||||
if r.Header.Get("Cookie") == "" {
|
||||
return httpResponse(r, http.StatusUnauthorized, "missing cookie"), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "kratos-456",
|
||||
"traits": map[string]any{},
|
||||
},
|
||||
}), nil
|
||||
case "hydra.test":
|
||||
if r.URL.Path != "/oauth2/auth/requests/login/accept" {
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var payload map[string]any
|
||||
_ = json.Unmarshal(body, &payload)
|
||||
if subject, ok := payload["subject"].(string); ok {
|
||||
gotSubject = subject
|
||||
}
|
||||
return httpResponse(r, http.StatusOK, `{"redirect_to":"http://rp/cb"}`), nil
|
||||
default:
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}
|
||||
})
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() {
|
||||
http.DefaultClient = origDefault
|
||||
}()
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
app := newOidcLoginTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"login_challenge": "challenge-456",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
req.Header.Set("Cookie", "ory_kratos_session=def456")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if gotSubject != "kratos-456" {
|
||||
t.Fatalf("unexpected subject: %v", gotSubject)
|
||||
}
|
||||
}
|
||||
110
baron-sso/backend/internal/handler/auth_handler_otp_test.go
Normal file
110
baron-sso/backend/internal/handler/auth_handler_otp_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHandleKratosCourierRelay_Email(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
emailSvc := &mockEmailService{}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
EmailService: emailSvc,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/kratos/courier", h.HandleKratosCourierRelay)
|
||||
|
||||
// Simulate Kratos Courier Request for Email
|
||||
reqBody := map[string]any{
|
||||
"recipient": "user@example.com",
|
||||
"template_type": "verification_code",
|
||||
"template_data": map[string]any{
|
||||
"verification_code": "123456",
|
||||
},
|
||||
"subject": "Verify your email",
|
||||
"body": "Your code is 123456",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/kratos/courier", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestVerifySignupCode_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/signup/verify", h.VerifySignupCode)
|
||||
|
||||
// Mock stored code in redis
|
||||
// signup:email:user@test.com -> {"code":"654321", "verified":false, "expires_at":...}
|
||||
state := map[string]any{
|
||||
"code": "654321",
|
||||
"verified": false,
|
||||
"expires_at": 9999999999, // far future
|
||||
}
|
||||
stateJSON, _ := json.Marshal(state)
|
||||
redis.data["signup:email:user@test.com"] = string(stateJSON)
|
||||
|
||||
// Verify Code
|
||||
verifyBody := map[string]string{
|
||||
"type": "email",
|
||||
"target": "user@test.com",
|
||||
"code": "654321",
|
||||
}
|
||||
body, _ := json.Marshal(verifyBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/signup/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
assert.True(t, res["success"].(bool))
|
||||
|
||||
// Check redis state updated to verified
|
||||
val, _ := redis.Get("signup:email:user@test.com")
|
||||
var updatedState map[string]any
|
||||
json.Unmarshal([]byte(val), &updatedState)
|
||||
assert.True(t, updatedState["verified"].(bool))
|
||||
}
|
||||
|
||||
func TestVerifySignupCode_Invalid(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/signup/verify", h.VerifySignupCode)
|
||||
|
||||
stateJSON, _ := json.Marshal(map[string]any{
|
||||
"code": "111111",
|
||||
"expires_at": 9999999999,
|
||||
})
|
||||
redis.data["signup:email:user@test.com"] = string(stateJSON)
|
||||
|
||||
verifyBody := map[string]string{
|
||||
"type": "email",
|
||||
"target": "user@test.com",
|
||||
"code": "222222", // wrong code
|
||||
}
|
||||
body, _ := json.Marshal(verifyBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/signup/verify", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type recordingUpdateMeUserRepo struct {
|
||||
MockUserRepoForHandler
|
||||
updated *domain.User
|
||||
loginIDs []domain.UserLoginID
|
||||
}
|
||||
|
||||
func (r *recordingUpdateMeUserRepo) Update(ctx context.Context, user *domain.User) error {
|
||||
copied := *user
|
||||
r.updated = &copied
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingUpdateMeUserRepo) UpdateUserLoginIDs(ctx context.Context, userID string, loginIDs []domain.UserLoginID) error {
|
||||
r.loginIDs = append([]domain.UserLoginID(nil), loginIDs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
type recordingUpdateMeKratosAdmin struct {
|
||||
MockKratosAdminService
|
||||
updatedIdentityID string
|
||||
updatedTraits map[string]any
|
||||
updatedState string
|
||||
storedTraits map[string]any
|
||||
}
|
||||
|
||||
func (r *recordingUpdateMeKratosAdmin) UpdateIdentity(ctx context.Context, identityID string, traits map[string]any, state string) (*service.KratosIdentity, error) {
|
||||
r.updatedIdentityID = identityID
|
||||
r.updatedTraits = maps.Clone(traits)
|
||||
r.updatedState = state
|
||||
if r.storedTraits != nil {
|
||||
maps.Copy(r.storedTraits, traits)
|
||||
}
|
||||
return &service.KratosIdentity{
|
||||
ID: identityID,
|
||||
Traits: traits,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestUpdateMe_InvalidatesProfileCacheForTokenSession(t *testing.T) {
|
||||
token := "token-abc"
|
||||
identityID := "user-1"
|
||||
traits := map[string]any{
|
||||
"email": "qa@example.com",
|
||||
"name": "QA User",
|
||||
"phone_number": "+821012345678",
|
||||
"department": "Old Dept",
|
||||
"affiliationType": "employee",
|
||||
"companyCode": "",
|
||||
"role": domain.RoleUser,
|
||||
}
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/sessions/whoami" &&
|
||||
r.Method == http.MethodGet:
|
||||
if r.Header.Get("X-Session-Token") != token {
|
||||
return httpResponse(r, http.StatusUnauthorized, `{"error":"invalid token"}`), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": traits,
|
||||
},
|
||||
}), nil
|
||||
|
||||
case r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/admin/identities/"+identityID &&
|
||||
r.Method == http.MethodPut:
|
||||
var payload struct {
|
||||
Traits map[string]any `json:"traits"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
return httpResponse(r, http.StatusBadRequest, `{"error":"invalid body"}`), nil
|
||||
}
|
||||
maps.Copy(traits, payload.Traits)
|
||||
return httpResponse(r, http.StatusOK, `{"ok":true}`), nil
|
||||
}
|
||||
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
setDefaultHTTPClientForTest(t, transport)
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
|
||||
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
KratosAdmin: kratosAdmin,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/me", h.GetMe)
|
||||
app.Put("/api/v1/user/me", h.UpdateMe)
|
||||
|
||||
// 1) 첫 조회로 Old Dept가 캐시에 저장됨
|
||||
getReq1 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
|
||||
getReq1.Header.Set("Authorization", "Bearer "+token)
|
||||
getResp1, err := app.Test(getReq1, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, getResp1.StatusCode)
|
||||
var profile1 map[string]any
|
||||
require.NoError(t, json.NewDecoder(getResp1.Body).Decode(&profile1))
|
||||
require.Equal(t, "Old Dept", profile1["department"])
|
||||
|
||||
// 2) 소속을 New Dept로 변경
|
||||
updateBody, _ := json.Marshal(map[string]string{
|
||||
"name": "QA User",
|
||||
"phone": "01012345678",
|
||||
"department": "New Dept",
|
||||
})
|
||||
updateReq := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/v1/user/me",
|
||||
bytes.NewReader(updateBody),
|
||||
)
|
||||
updateReq.Header.Set("Content-Type", "application/json")
|
||||
updateReq.Header.Set("Authorization", "Bearer "+token)
|
||||
updateResp, err := app.Test(updateReq, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, updateResp.StatusCode)
|
||||
require.Equal(t, "New Dept", traits["department"])
|
||||
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
|
||||
require.Equal(t, "New Dept", kratosAdmin.updatedTraits["department"])
|
||||
|
||||
// 3) 새로고침 재조회 시 New Dept가 보여야 함(캐시 무효화 회귀 방지)
|
||||
getReq2 := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
|
||||
getReq2.Header.Set("Authorization", "Bearer "+token)
|
||||
getResp2, err := app.Test(getReq2, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, getResp2.StatusCode)
|
||||
var profile2 map[string]any
|
||||
require.NoError(t, json.NewDecoder(getResp2.Body).Decode(&profile2))
|
||||
require.Equal(t, "New Dept", profile2["department"])
|
||||
}
|
||||
|
||||
func TestUpdateMe_SyncsLocalReadModelFields(t *testing.T) {
|
||||
token := "token-sync"
|
||||
identityID := "user-sync"
|
||||
traits := map[string]any{
|
||||
"email": "sync@example.com",
|
||||
"name": "Old Name",
|
||||
"phone_number": "+821012345678",
|
||||
"department": "Old Dept",
|
||||
"affiliationType": "employee",
|
||||
"companyCode": "saman",
|
||||
"tenant_id": "11111111-1111-1111-1111-111111111111",
|
||||
"role": domain.RoleUser,
|
||||
}
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/sessions/whoami" &&
|
||||
r.Method == http.MethodGet:
|
||||
if r.Header.Get("X-Session-Token") != token {
|
||||
return httpResponse(r, http.StatusUnauthorized, `{"error":"invalid token"}`), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": traits,
|
||||
},
|
||||
}), nil
|
||||
|
||||
case r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/admin/identities/"+identityID &&
|
||||
r.Method == http.MethodPut:
|
||||
var payload struct {
|
||||
Traits map[string]any `json:"traits"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
return httpResponse(r, http.StatusBadRequest, `{"error":"invalid body"}`), nil
|
||||
}
|
||||
maps.Copy(traits, payload.Traits)
|
||||
return httpResponse(r, http.StatusOK, `{"ok":true}`), nil
|
||||
}
|
||||
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
setDefaultHTTPClientForTest(t, transport)
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("KRATOS_ADMIN_URL", "http://kratos.test")
|
||||
|
||||
redis := &mockRedisRepo{data: map[string]string{
|
||||
"verify_update_phone:" + identityID + ":+821087654321": "verified",
|
||||
}}
|
||||
userRepo := &recordingUpdateMeUserRepo{}
|
||||
kratosAdmin := &recordingUpdateMeKratosAdmin{storedTraits: traits}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
UserRepo: userRepo,
|
||||
KratosAdmin: kratosAdmin,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Put("/api/v1/user/me", h.UpdateMe)
|
||||
|
||||
updateBody, _ := json.Marshal(map[string]any{
|
||||
"name": "New Name",
|
||||
"phone": "01087654321",
|
||||
"department": "New Dept",
|
||||
})
|
||||
updateReq := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/v1/user/me",
|
||||
bytes.NewReader(updateBody),
|
||||
)
|
||||
updateReq.Header.Set("Content-Type", "application/json")
|
||||
updateReq.Header.Set("Authorization", "Bearer "+token)
|
||||
updateResp, err := app.Test(updateReq, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, updateResp.StatusCode)
|
||||
require.Equal(t, identityID, kratosAdmin.updatedIdentityID)
|
||||
require.Equal(t, "New Name", kratosAdmin.updatedTraits["name"])
|
||||
require.Equal(t, "+821087654321", kratosAdmin.updatedTraits["phone_number"])
|
||||
|
||||
require.NotNil(t, userRepo.updated)
|
||||
require.Equal(t, identityID, userRepo.updated.ID)
|
||||
require.Equal(t, "sync@example.com", userRepo.updated.Email)
|
||||
require.Equal(t, "New Name", userRepo.updated.Name)
|
||||
require.Equal(t, "+821087654321", userRepo.updated.Phone)
|
||||
require.Equal(t, "New Dept", userRepo.updated.Department)
|
||||
require.Empty(t, userRepo.updated.CompanyCode)
|
||||
require.NotNil(t, userRepo.updated.TenantID)
|
||||
require.Equal(t, "11111111-1111-1111-1111-111111111111", *userRepo.updated.TenantID)
|
||||
}
|
||||
206
baron-sso/backend/internal/handler/auth_handler_qr_test.go
Normal file
206
baron-sso/backend/internal/handler/auth_handler_qr_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// --- Mock Redis ---
|
||||
|
||||
type mockRedisRepo struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Set(key, value string, ttl time.Duration) error {
|
||||
if m.data == nil {
|
||||
m.data = make(map[string]string)
|
||||
}
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Get(key string) (string, error) {
|
||||
// Bypass rate limiting for tests
|
||||
if strings.HasPrefix(key, "poll_meta:") {
|
||||
return "", nil
|
||||
}
|
||||
return m.data[key], nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) Delete(key string) error {
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) StoreVerificationCode(phone, code string) error {
|
||||
return m.Set("sms:"+phone, code, time.Minute)
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) GetVerificationCode(phone string) (string, error) {
|
||||
return m.Get("sms:" + phone)
|
||||
}
|
||||
|
||||
func (m *mockRedisRepo) DeleteVerificationCode(phone string) error {
|
||||
return m.Delete("sms:" + phone)
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestQRLoginFlow_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/qr/init", h.InitQRLogin)
|
||||
app.Post("/api/v1/auth/qr/poll", h.PollQRLogin)
|
||||
|
||||
// 1. Init QR Login
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/init", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var initResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&initResp)
|
||||
pendingRef := initResp["pendingRef"].(string)
|
||||
|
||||
// 2. Poll (Pending)
|
||||
body, _ := json.Marshal(map[string]string{"pendingRef": pendingRef})
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
// Expect authorization_pending (400)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
var pollResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&pollResp)
|
||||
assert.Equal(t, "authorization_pending", pollResp["error"])
|
||||
assert.Equal(t, "authorization_pending", pollResp["code"])
|
||||
|
||||
// 3. Mock Approval
|
||||
sessionData, _ := json.Marshal(map[string]string{
|
||||
"status": "success",
|
||||
"jwt": "mock-session-jwt",
|
||||
})
|
||||
redis.data["enchanted_session:"+pendingRef] = string(sessionData)
|
||||
|
||||
// 4. Poll (Success)
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/poll", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var successResp map[string]any
|
||||
json.NewDecoder(resp.Body).Decode(&successResp)
|
||||
assert.Equal(t, "ok", successResp["status"])
|
||||
assert.Equal(t, "mock-session-jwt", successResp["sessionJwt"])
|
||||
}
|
||||
|
||||
func TestScanQRLogin_Success(t *testing.T) {
|
||||
redis := &mockRedisRepo{data: make(map[string]string)}
|
||||
idp := &mockIdpProvider{userExists: true}
|
||||
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/qr/approve", h.ScanQRLogin)
|
||||
|
||||
pendingRef := "test-ref"
|
||||
redis.data["enchanted_session:"+pendingRef] = `{"status":"pending"}`
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = &http.Client{Transport: transport}
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"pendingRef": pendingRef,
|
||||
"token": "valid-token",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/qr/approve", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestResolveConsentSubjects_TokenAndCookie(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Header.Get("X-Session-Token") == "token-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-token",
|
||||
"traits": map[string]any{
|
||||
"email": "token@test.com",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.Header.Get("Cookie") == "ory_kratos_session=cookie-123" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-cookie",
|
||||
"traits": map[string]any{
|
||||
"email": "cookie@test.com",
|
||||
"phone": "010-1234-5678",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusUnauthorized, "unauthorized"), nil
|
||||
})
|
||||
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = &http.Client{Transport: transport}
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
// Token case
|
||||
app.Get("/test-token", func(c *fiber.Ctx) error {
|
||||
subjects, err := h.resolveConsentSubjects(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, subjects, "user-token")
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
req := httptest.NewRequest("GET", "/test-token", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-123")
|
||||
app.Test(req, -1)
|
||||
|
||||
// Cookie case
|
||||
app.Get("/test-cookie", func(c *fiber.Ctx) error {
|
||||
subjects, err := h.resolveConsentSubjects(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, subjects, "user-cookie")
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
req = httptest.NewRequest("GET", "/test-cookie", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=cookie-123")
|
||||
app.Test(req, -1)
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetMe_IncludesSessionAuthenticatedAtFromKratosSession(t *testing.T) {
|
||||
const (
|
||||
token = "token-session"
|
||||
identityID = "user-session"
|
||||
sessionAuthenticated = "2026-03-23T15:30:00Z"
|
||||
)
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/sessions/whoami" &&
|
||||
r.Method == http.MethodGet {
|
||||
require.Equal(t, token, r.Header.Get("X-Session-Token"))
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "kratos-session-1",
|
||||
"authenticated_at": sessionAuthenticated,
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": map[string]any{
|
||||
"email": "qa@example.com",
|
||||
"name": "QA User",
|
||||
"department": "Platform",
|
||||
"affiliationType": "GENERAL",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
setDefaultHTTPClientForTest(t, transport)
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/me", h.GetMe)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err := app.Test(req, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var profile map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&profile))
|
||||
require.Equal(t, sessionAuthenticated, profile["sessionAuthenticatedAt"])
|
||||
}
|
||||
|
||||
func TestGetMe_IncludesSessionAuthenticatedAtForCookieSession(t *testing.T) {
|
||||
const (
|
||||
cookieHeader = "ory_kratos_session=session-cookie"
|
||||
identityID = "user-cookie"
|
||||
sessionAuthenticated = "2026-03-24T01:20:00Z"
|
||||
)
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "kratos.test" &&
|
||||
r.URL.Path == "/sessions/whoami" &&
|
||||
r.Method == http.MethodGet {
|
||||
require.Equal(t, cookieHeader, r.Header.Get("Cookie"))
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "kratos-session-cookie",
|
||||
"authenticated_at": sessionAuthenticated,
|
||||
"identity": map[string]any{
|
||||
"id": identityID,
|
||||
"traits": map[string]any{
|
||||
"email": "cookie@example.com",
|
||||
"name": "Cookie User",
|
||||
"department": "Platform",
|
||||
"affiliationType": "GENERAL",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
setDefaultHTTPClientForTest(t, transport)
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{}
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/me", h.GetMe)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/me", nil)
|
||||
req.Header.Set("Cookie", cookieHeader)
|
||||
resp, err := app.Test(req, -1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var profile map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&profile))
|
||||
require.Equal(t, sessionAuthenticated, profile["sessionAuthenticatedAt"])
|
||||
}
|
||||
944
baron-sso/backend/internal/handler/auth_handler_sessions_test.go
Normal file
944
baron-sso/backend/internal/handler/auth_handler_sessions_test.go
Normal file
@@ -0,0 +1,944 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestListMySessions_Success(t *testing.T) {
|
||||
now := time.Date(2026, 4, 2, 1, 2, 3, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "other-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now.Add(-2 * time.Hour),
|
||||
ExpiresAt: now.Add(22 * time.Hour),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "login_success",
|
||||
SessionID: "other-sid",
|
||||
Timestamp: now.Add(-30 * time.Minute),
|
||||
IPAddress: "203.0.113.10",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
IsCurrent bool `json:"is_current"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 2) {
|
||||
assert.Equal(t, "current-sid", body.Items[0].SessionID)
|
||||
assert.True(t, body.Items[0].IsCurrent)
|
||||
assert.Equal(t, "other-sid", body.Items[1].SessionID)
|
||||
assert.True(t, body.Items[1].IsActive)
|
||||
assert.Equal(t, "203.0.113.10", body.Items[1].IPAddress)
|
||||
assert.Equal(t, "Mozilla/5.0", body.Items[1].UserAgent)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestListMySessions_UsesConsentGrantForAppName(t *testing.T) {
|
||||
now := time.Date(2026, 4, 2, 4, 40, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "c7c721ea-session",
|
||||
Active: true,
|
||||
AuthenticatedAt: now.Add(-5 * time.Minute),
|
||||
ExpiresAt: now.Add(23*time.Hour + 55*time.Minute),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.granted",
|
||||
SessionID: "c7c721ea-session",
|
||||
Timestamp: now,
|
||||
Details: `{"client_id":"devfront","client_name":"DevFront","session_id":"c7c721ea-session","approved_session_id":"c7c721ea-session"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
AppName string `json:"app_name"`
|
||||
ClientID string `json:"client_id"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 2) {
|
||||
assert.Equal(t, "c7c721ea-session", body.Items[1].SessionID)
|
||||
assert.Equal(t, "DevFront", body.Items[1].AppName)
|
||||
assert.Equal(t, "devfront", body.Items[1].ClientID)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestListMySessions_PreservesAppNameFromOlderConsentGrant(t *testing.T) {
|
||||
now := time.Date(2026, 4, 2, 4, 40, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "c7c721ea-session",
|
||||
Active: true,
|
||||
AuthenticatedAt: now.Add(-5 * time.Minute),
|
||||
ExpiresAt: now.Add(23*time.Hour + 55*time.Minute),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "consent.granted",
|
||||
SessionID: "c7c721ea-session",
|
||||
Timestamp: now.Add(-30 * time.Second),
|
||||
IPAddress: "203.0.113.10",
|
||||
Details: `{"client_id":"devfront","client_name":"DevFront","session_id":"c7c721ea-session"}`,
|
||||
},
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "login_success",
|
||||
SessionID: "c7c721ea-session",
|
||||
Timestamp: now,
|
||||
IPAddress: "10.0.0.12",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
AppName string `json:"app_name"`
|
||||
ClientID string `json:"client_id"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 2) {
|
||||
assert.Equal(t, "c7c721ea-session", body.Items[1].SessionID)
|
||||
assert.Equal(t, "DevFront", body.Items[1].AppName)
|
||||
assert.Equal(t, "devfront", body.Items[1].ClientID)
|
||||
assert.Equal(t, "203.0.113.10", body.Items[1].IPAddress)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestListMySessions_CurrentSessionFallsBackToRequestMetadata(t *testing.T) {
|
||||
now := time.Date(2026, 4, 6, 1, 2, 3, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: &mockAuditRepo{},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) Chrome/146.0.0.0 Safari/537.36")
|
||||
req.Header.Set("X-Forwarded-For", "100.100.100.1, 203.0.113.25")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
IsCurrent bool `json:"is_current"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
ClientID string `json:"client_id"`
|
||||
AppName string `json:"app_name"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 1) {
|
||||
assert.Equal(t, "current-sid", body.Items[0].SessionID)
|
||||
assert.True(t, body.Items[0].IsCurrent)
|
||||
assert.Equal(t, "203.0.113.25", body.Items[0].IPAddress)
|
||||
assert.Contains(t, body.Items[0].UserAgent, "Mozilla/5.0")
|
||||
assert.Equal(t, "userfront", body.Items[0].ClientID)
|
||||
assert.Equal(t, "UserFront", body.Items[0].AppName)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDeleteMySession_Success(t *testing.T) {
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
var hydraRevokeCalls int
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
if r.URL.Query().Get("subject") != "user-123" {
|
||||
t.Fatalf("unexpected revoke subject: %s", r.URL.Query().Get("subject"))
|
||||
}
|
||||
if r.URL.Query().Get("client") != "devfront" {
|
||||
t.Fatalf("unexpected revoke client: %s", r.URL.Query().Get("client"))
|
||||
}
|
||||
hydraRevokeCalls++
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
setDefaultHTTPClientForTest(t, client.Transport)
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{ID: "target-sid", Active: true},
|
||||
}, nil).Once()
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: true,
|
||||
}, nil).Once()
|
||||
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
|
||||
UserID: "user-123",
|
||||
EventType: "POST /api/v1/auth/oidc/login/accept",
|
||||
SessionID: "target-sid",
|
||||
Details: `{"client_id":"devfront","client_name":"Devfront"}`,
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "session-test-agent")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
if assert.Len(t, auditRepo.logs, 2) {
|
||||
assert.Equal(t, "session.revoked", auditRepo.logs[len(auditRepo.logs)-1].EventType)
|
||||
assert.Equal(t, "user-123", auditRepo.logs[len(auditRepo.logs)-1].UserID)
|
||||
assert.Equal(t, "current-sid", auditRepo.logs[len(auditRepo.logs)-1].SessionID)
|
||||
assert.Contains(t, auditRepo.logs[len(auditRepo.logs)-1].Details, "target-sid")
|
||||
}
|
||||
assert.Equal(t, 1, hydraRevokeCalls)
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDeleteMySession_DoesNotRevokeAllHydraSessionsWhenClientBindingMissing(t *testing.T) {
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
var hydraRevokeCalls int
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
hydraRevokeCalls++
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
setDefaultHTTPClientForTest(t, client.Transport)
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{ID: "target-sid", Active: true},
|
||||
}, nil).Once()
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: true,
|
||||
}, nil).Once()
|
||||
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "session-test-agent")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, 0, hydraRevokeCalls)
|
||||
if assert.Len(t, auditRepo.logs, 1) {
|
||||
assert.Equal(t, "session.revoked", auditRepo.logs[0].EventType)
|
||||
assert.Equal(t, "user-123", auditRepo.logs[0].UserID)
|
||||
assert.Contains(t, auditRepo.logs[0].Details, "target-sid")
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDeleteMySession_SendsBackchannelLogoutTokenWhenClientConfigured(t *testing.T) {
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
t.Setenv("BACKCHANNEL_LOGOUT_ISSUER", "https://sso.example.com/oidc")
|
||||
|
||||
var receivedBody string
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/clients/devfront" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "devfront",
|
||||
"backchannel_logout_uri": "https://rp.example.com/backchannel-logout",
|
||||
}), nil
|
||||
}
|
||||
case "rp.example.com":
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/backchannel-logout" {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
receivedBody = string(raw)
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
setDefaultHTTPClientForTest(t, client.Transport)
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{ID: "target-sid", Active: true},
|
||||
}, nil).Once()
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: true,
|
||||
}, nil).Once()
|
||||
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
|
||||
|
||||
backchannelLogout, err := service.NewBackchannelLogoutService()
|
||||
assert.NoError(t, err)
|
||||
backchannelLogout.HTTPClient = client
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
BackchannelLogout: backchannelLogout,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
|
||||
UserID: "user-123",
|
||||
EventType: "POST /api/v1/auth/oidc/login/accept",
|
||||
SessionID: "target-sid",
|
||||
Details: `{"client_id":"devfront","client_name":"Devfront"}`,
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "session-test-agent")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.True(t, strings.Contains(receivedBody, "logout_token="))
|
||||
|
||||
values, err := url.ParseQuery(receivedBody)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, values.Get("logout_token"))
|
||||
|
||||
foundBackchannelAudit := false
|
||||
for _, log := range auditRepo.logs {
|
||||
if log.EventType == "backchannel_logout.sent" {
|
||||
foundBackchannelAudit = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundBackchannelAudit)
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDeleteMySession_RevokesHydraClientBoundFromPasswordLoginAudit(t *testing.T) {
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
var hydraRevokeCalls int
|
||||
var revokedClient string
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch r.URL.Host {
|
||||
case "kratos.test":
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
case "hydra.test":
|
||||
if r.Method == http.MethodDelete && r.URL.Path == "/oauth2/auth/sessions/consent" {
|
||||
revokedClient = r.URL.Query().Get("client")
|
||||
hydraRevokeCalls++
|
||||
return httpResponse(r, http.StatusNoContent, ""), nil
|
||||
}
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
setDefaultHTTPClientForTest(t, client.Transport)
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{ID: "target-sid", Active: true},
|
||||
}, nil).Once()
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: true,
|
||||
}, nil).Once()
|
||||
mockKratos.On("DeleteSession", mock.Anything, "target-sid").Return(nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{}
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
auditRepo.logs = append(auditRepo.logs, domain.AuditLog{
|
||||
UserID: "user-123",
|
||||
EventType: "POST /api/v1/auth/password/login",
|
||||
SessionID: "target-sid",
|
||||
Details: `{"client_id":"adminfront","client_name":"AdminFront","session_id":"target-sid"}`,
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/api/v1/user/sessions/:id", h.DeleteMySession)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/sessions/target-sid", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
req.Header.Set("User-Agent", "session-test-agent")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, 1, hydraRevokeCalls)
|
||||
assert.Equal(t, "adminfront", revokedClient)
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetHydraProfile_RejectsInactiveLinkedSession(t *testing.T) {
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Host == "hydra.test" && r.URL.Path == "/oauth2/introspect" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
if string(body) != "token=opaque-token" {
|
||||
t.Fatalf("unexpected introspect body: %s", string(body))
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"active": true,
|
||||
"sub": "user-123",
|
||||
"client_id": "devfront",
|
||||
"ext": map[string]any{
|
||||
"session_id": "target-sid",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})}
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("GetSession", mock.Anything, "target-sid").Return(&service.KratosSession{
|
||||
ID: "target-sid",
|
||||
Active: false,
|
||||
Identity: &service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
|
||||
profile, err := h.getHydraProfile(context.Background(), "opaque-token")
|
||||
assert.Nil(t, profile)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "inactive")
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetAuthTimeline_FillsSessionIDFromOathkeeperRaw(t *testing.T) {
|
||||
now := time.Date(2026, 4, 7, 4, 39, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
h := &AuthHandler{
|
||||
AuditRepo: &mockAuditRepo{},
|
||||
OathkeeperRepo: &mockOathkeeperRepo{
|
||||
logs: []domain.OathkeeperAccessLog{
|
||||
{
|
||||
Timestamp: now,
|
||||
RequestID: "req-1",
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/v1/dev/sessions",
|
||||
Status: http.StatusOK,
|
||||
Subject: "user-123",
|
||||
ClientIP: "203.0.113.7",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
Raw: `{"request":{"url":"https://devfront.example.com/callback?client_id=devfront"},"extra":{"session_id":"target-sid"}}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/audit/auth/timeline", h.GetAuthTimeline)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/auth/timeline", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ClientID string `json:"client_id"`
|
||||
AppName string `json:"app_name"`
|
||||
Source string `json:"source"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 1) {
|
||||
assert.Equal(t, "target-sid", body.Items[0].SessionID)
|
||||
assert.Equal(t, "devfront", body.Items[0].ClientID)
|
||||
assert.Equal(t, "devfront", body.Items[0].AppName)
|
||||
assert.Equal(t, "oathkeeper", body.Items[0].Source)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAuthTimeline_IncludesHeadlessPasswordLogin(t *testing.T) {
|
||||
now := time.Date(2026, 4, 7, 5, 10, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
h := &AuthHandler{
|
||||
AuditRepo: &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
EventID: "audit-1",
|
||||
Timestamp: now,
|
||||
UserID: "user-123",
|
||||
SessionID: "headless-session-1",
|
||||
EventType: "POST /api/v1/auth/headless/password/login",
|
||||
Status: "success",
|
||||
IPAddress: "203.0.113.20",
|
||||
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36",
|
||||
Details: `{"client_id":"headless-login-client","client_name":"Headless Login Portal","session_id":"headless-session-1","login_id":"user@example.com","login_challenge":"challenge-123"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/audit/auth/timeline", h.GetAuthTimeline)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/audit/auth/timeline", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ClientID string `json:"client_id"`
|
||||
AppName string `json:"app_name"`
|
||||
AuthMethod string `json:"auth_method"`
|
||||
EventType string `json:"event_type"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 1) {
|
||||
assert.Equal(t, "headless-session-1", body.Items[0].SessionID)
|
||||
assert.Equal(t, "headless-login-client", body.Items[0].ClientID)
|
||||
assert.Equal(t, "Headless Login Portal", body.Items[0].AppName)
|
||||
assert.Equal(t, "비밀번호(Email)", body.Items[0].AuthMethod)
|
||||
assert.Equal(t, "POST /api/v1/auth/headless/password/login", body.Items[0].EventType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListMySessions_UsesHeadlessPasswordLoginForClientBinding(t *testing.T) {
|
||||
now := time.Date(2026, 4, 7, 5, 35, 0, 0, time.UTC)
|
||||
setDefaultHTTPClientForTest(t, roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/sessions/whoami" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"id": "current-sid",
|
||||
"authenticated_at": now.Format(time.RFC3339),
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
}))
|
||||
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("ListIdentitySessions", mock.Anything, "user-123").Return([]service.KratosSession{
|
||||
{
|
||||
ID: "current-sid",
|
||||
Active: true,
|
||||
AuthenticatedAt: now,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: "headless-session-1",
|
||||
Active: true,
|
||||
AuthenticatedAt: now.Add(-10 * time.Minute),
|
||||
ExpiresAt: now.Add(23*time.Hour + 50*time.Minute),
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
auditRepo := &mockAuditRepo{
|
||||
logs: []domain.AuditLog{
|
||||
{
|
||||
UserID: "user-123",
|
||||
EventType: "POST /api/v1/auth/headless/password/login",
|
||||
SessionID: "headless-session-1",
|
||||
Timestamp: now,
|
||||
IPAddress: "203.0.113.20",
|
||||
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36",
|
||||
Details: `{"client_id":"headless-login-client","client_name":"Headless Login Portal","session_id":"headless-session-1"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := &AuthHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
AuditRepo: auditRepo,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/user/sessions", h.ListMySessions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/user/sessions", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=valid")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var body struct {
|
||||
Items []struct {
|
||||
SessionID string `json:"session_id"`
|
||||
AppName string `json:"app_name"`
|
||||
ClientID string `json:"client_id"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
} `json:"items"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, body.Items, 2) {
|
||||
assert.Equal(t, "headless-session-1", body.Items[1].SessionID)
|
||||
assert.Equal(t, "Headless Login Portal", body.Items[1].AppName)
|
||||
assert.Equal(t, "headless-login-client", body.Items[1].ClientID)
|
||||
assert.Equal(t, "203.0.113.20", body.Items[1].IPAddress)
|
||||
assert.Equal(t, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/146.0.0.0 Safari/537.36", body.Items[1].UserAgent)
|
||||
}
|
||||
|
||||
mockKratos.AssertExpectations(t)
|
||||
}
|
||||
144
baron-sso/backend/internal/handler/auth_handler_signup_test.go
Normal file
144
baron-sso/backend/internal/handler/auth_handler_signup_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// --- Local Mocks for Signup Test ---
|
||||
|
||||
type MockRedisForSignup struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRedisForSignup) Set(key string, value string, ttl time.Duration) error {
|
||||
return m.Called(key, value, ttl).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRedisForSignup) Get(key string) (string, error) {
|
||||
args := m.Called(key)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedisForSignup) Delete(key string) error {
|
||||
return m.Called(key).Error(0)
|
||||
}
|
||||
func (m *MockRedisForSignup) StoreVerificationCode(phone, code string) error { return nil }
|
||||
func (m *MockRedisForSignup) GetVerificationCode(phone string) (string, error) { return "", nil }
|
||||
func (m *MockRedisForSignup) DeleteVerificationCode(phone string) error { return nil }
|
||||
func (m *MockRedisForSignup) Ping(ctx context.Context) error { return nil }
|
||||
|
||||
type MockIdpForSignup struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) Name() string { return "mock-idp" }
|
||||
func (m *MockIdpForSignup) GetMetadata() (*domain.IDPMetadata, error) {
|
||||
return &domain.IDPMetadata{SupportedFields: []string{"email", "name", "phoneNumber", "grade", "department"}}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
args := m.Called(user, password)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockIdpForSignup) UserExists(loginID string) (bool, error) { return false, nil }
|
||||
func (m *MockIdpForSignup) IssueSession(loginID string) (*domain.AuthInfo, error) { return nil, nil }
|
||||
func (m *MockIdpForSignup) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
|
||||
return &domain.PasswordPolicy{MinLength: 12}, nil
|
||||
}
|
||||
func (m *MockIdpForSignup) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
|
||||
func (m *MockIdpForSignup) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockIdpForSignup) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSignup_TenantSlugValidation(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockTenantSvc := new(MockTenantService)
|
||||
mockRedis := new(MockRedisForSignup)
|
||||
mockIdp := new(MockIdpForSignup)
|
||||
|
||||
h := &AuthHandler{
|
||||
TenantService: mockTenantSvc,
|
||||
RedisService: mockRedis,
|
||||
IdpProvider: mockIdp,
|
||||
}
|
||||
|
||||
app.Post("/signup", h.Signup)
|
||||
|
||||
// Prepare mock state (already verified email/phone)
|
||||
verifiedState, _ := json.Marshal(map[string]any{
|
||||
"verified": true,
|
||||
"expires_at": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
mockRedis.On("Get", mock.Anything).Return(string(verifiedState), nil)
|
||||
|
||||
t.Run("Rejects legacy CompanyCode", func(t *testing.T) {
|
||||
reqBody := domain.SignupRequest{
|
||||
Email: "user@gmail.com",
|
||||
Password: "StrongPass123!",
|
||||
Name: "Test User",
|
||||
Phone: "010-1234-5678",
|
||||
TermsAccepted: true,
|
||||
CompanyCode: "new-slug",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Active Tenant Slug", func(t *testing.T) {
|
||||
reqBody := domain.SignupRequest{
|
||||
Email: "user@hanmaceng.co.kr",
|
||||
Password: "StrongPass123!",
|
||||
Name: "Test User",
|
||||
Phone: "010-1234-5678",
|
||||
TermsAccepted: true,
|
||||
TenantSlug: "hanmac",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
validTenant := &domain.Tenant{ID: "t1", Slug: "hanmac", Status: domain.TenantStatusActive}
|
||||
mockTenantSvc.On("GetTenantByDomain", mock.Anything, "hanmaceng.co.kr").Return(&domain.Tenant{Slug: "hanmac"}, nil).Once()
|
||||
mockTenantSvc.On("ProvisionTenantByDomain", mock.Anything, "hanmaceng.co.kr").Return(validTenant, nil).Maybe()
|
||||
mockTenantSvc.On("GetTenantBySlug", mock.Anything, "hanmac").Return(validTenant, nil).Once()
|
||||
mockTenantSvc.On("GetTenant", mock.Anything, "t1").Return(validTenant, nil).Once()
|
||||
mockIdp.On("CreateUser", mock.Anything, mock.Anything).Return("user-id", nil).Once()
|
||||
mockRedis.On("Delete", mock.Anything).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
521
baron-sso/backend/internal/handler/auth_handler_test.go
Normal file
521
baron-sso/backend/internal/handler/auth_handler_test.go
Normal file
@@ -0,0 +1,521 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/middleware"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// helper to build a Fiber app with the handler route mounted.
|
||||
func newTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/password/reset/complete", h.CompletePasswordReset)
|
||||
return app
|
||||
}
|
||||
|
||||
func newResetFlowTestApp(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/password/reset/verify", h.ProcessPasswordResetToken)
|
||||
app.Post("/api/v1/auth/password/reset/complete", h.CompletePasswordReset)
|
||||
return app
|
||||
}
|
||||
|
||||
func newResetInitAppWithErrorCodeEnricher(h *AuthHandler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Use(middleware.ErrorCodeEnricher())
|
||||
app.Post("/api/v1/auth/password/reset/init", h.InitiatePasswordReset)
|
||||
return app
|
||||
}
|
||||
|
||||
type testRedisRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) Set(key string, value string, expiration time.Duration) error {
|
||||
if m.values == nil {
|
||||
m.values = map[string]string{}
|
||||
}
|
||||
m.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) Get(key string) (string, error) {
|
||||
if m.values == nil {
|
||||
return "", nil
|
||||
}
|
||||
return m.values[key], nil
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) Delete(key string) error {
|
||||
if m.values != nil {
|
||||
delete(m.values, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) StoreVerificationCode(phone, code string) error {
|
||||
return m.Set("sms:"+phone, code, time.Minute)
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) GetVerificationCode(phone string) (string, error) {
|
||||
return m.Get("sms:" + phone)
|
||||
}
|
||||
|
||||
func (m *testRedisRepo) DeleteVerificationCode(phone string) error {
|
||||
return m.Delete("sms:" + phone)
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_MissingLoginID(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": "Password1!",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for missing loginId, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "Login ID and new password are required" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_InvalidPasswordPolicy(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": "short", // too short + missing complexity
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete?loginId=user@example.com", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for weak password, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "비밀번호는 최소 12자 이상이어야 합니다" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_NilIDPProvider(t *testing.T) {
|
||||
h := &AuthHandler{} // IdpProvider intentionally nil to hit the configuration error branch
|
||||
app := newTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": "StrongPass1!",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/complete?loginId=user@example.com", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500 when IDP provider is nil, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "Authentication service not configured" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_TokenValueOverridesLoginIDQuery(t *testing.T) {
|
||||
const resetToken = "tok-reset-1"
|
||||
const tokenLoginID = "user@example.com"
|
||||
const wrongLoginID = "wrong@example.com"
|
||||
const newPassword = "StrongPass1!"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + resetToken: tokenLoginID,
|
||||
},
|
||||
}
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
err: nil,
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
}
|
||||
app := newResetFlowTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": newPassword,
|
||||
})
|
||||
url := fmt.Sprintf(
|
||||
"/api/v1/auth/password/reset/complete?loginId=%s&token=%s",
|
||||
wrongLoginID,
|
||||
resetToken,
|
||||
)
|
||||
req := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if !idp.updateCalled {
|
||||
t.Fatal("expected UpdateUserPassword to be called")
|
||||
}
|
||||
if idp.updatedLoginID != tokenLoginID {
|
||||
t.Fatalf("expected loginId from token(%s), got %s", tokenLoginID, idp.updatedLoginID)
|
||||
}
|
||||
if idp.updatedPassword != newPassword {
|
||||
t.Fatalf("expected newPassword propagated, got %s", idp.updatedPassword)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_InvalidTokenRejectedEvenWhenLoginIDExists(t *testing.T) {
|
||||
const resetToken = "invalid-token"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{},
|
||||
}
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
err: nil,
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
}
|
||||
app := newResetFlowTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": "StrongPass1!",
|
||||
})
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/auth/password/reset/complete?loginId=user@example.com&token="+resetToken,
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401 for invalid token, got %d", resp.StatusCode)
|
||||
}
|
||||
if idp.updateCalled {
|
||||
t.Fatal("UpdateUserPassword must not be called when token is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePasswordReset_DuplicateTokenSubmitIsIdempotent(t *testing.T) {
|
||||
const resetToken = "dup-token"
|
||||
const loginID = "user@example.com"
|
||||
const newPassword = "StrongPass1!"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + resetToken: loginID,
|
||||
},
|
||||
}
|
||||
idp := &mockIdpProvider{
|
||||
userExists: true,
|
||||
err: nil,
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: idp,
|
||||
}
|
||||
app := newResetFlowTestApp(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"newPassword": newPassword,
|
||||
})
|
||||
url := fmt.Sprintf(
|
||||
"/api/v1/auth/password/reset/complete?token=%s",
|
||||
resetToken,
|
||||
)
|
||||
|
||||
firstReq := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
firstReq.Header.Set("Content-Type", "application/json")
|
||||
firstResp, err := app.Test(firstReq)
|
||||
if err != nil {
|
||||
t.Fatalf("first request failed: %v", err)
|
||||
}
|
||||
defer firstResp.Body.Close()
|
||||
|
||||
if firstResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected first response to be 200, got %d", firstResp.StatusCode)
|
||||
}
|
||||
if idp.updateCallCount != 1 {
|
||||
t.Fatalf("expected first request to update password once, got %d", idp.updateCallCount)
|
||||
}
|
||||
|
||||
secondReq := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
secondReq.Header.Set("Content-Type", "application/json")
|
||||
secondResp, err := app.Test(secondReq)
|
||||
if err != nil {
|
||||
t.Fatalf("second request failed: %v", err)
|
||||
}
|
||||
defer secondResp.Body.Close()
|
||||
|
||||
if secondResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected duplicate response to be 200, got %d", secondResp.StatusCode)
|
||||
}
|
||||
if idp.updateCallCount != 1 {
|
||||
t.Fatalf("expected duplicate request not to update password again, got %d", idp.updateCallCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessPasswordResetToken_EncodesLoginIDInRedirect(t *testing.T) {
|
||||
const token = "tok-enc"
|
||||
const loginID = "user+alias@example.com"
|
||||
|
||||
t.Setenv("USERFRONT_URL", "https://sss.hmac.kr")
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + token: loginID,
|
||||
},
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
app := newResetFlowTestApp(h)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/auth/password/reset/verify?token="+token,
|
||||
nil,
|
||||
)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected 302, got %d", resp.StatusCode)
|
||||
}
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
t.Fatal("missing redirect location")
|
||||
}
|
||||
redirectReq := httptest.NewRequest(http.MethodGet, location, nil)
|
||||
gotLoginID := redirectReq.URL.Query().Get("loginId")
|
||||
if gotLoginID != loginID {
|
||||
t.Fatalf("expected encoded loginId round-trip=%s, got %s (location=%s)", loginID, gotLoginID, location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetVerifyAlias_AcceptsShortVePath(t *testing.T) {
|
||||
const token = "tok-ve"
|
||||
const loginID = "user@example.com"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + token: loginID,
|
||||
},
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/password/reset/ve", h.VerifyPasswordResetPage)
|
||||
app.Post("/api/v1/auth/password/reset/ve", h.ProcessPasswordResetToken)
|
||||
|
||||
getReq := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/api/v1/auth/password/reset/ve?token="+token,
|
||||
nil,
|
||||
)
|
||||
getResp, err := app.Test(getReq)
|
||||
if err != nil {
|
||||
t.Fatalf("get request failed: %v", err)
|
||||
}
|
||||
defer getResp.Body.Close()
|
||||
|
||||
if getResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected alias GET to return 200, got %d", getResp.StatusCode)
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/auth/password/reset/ve?token="+token,
|
||||
nil,
|
||||
)
|
||||
postResp, err := app.Test(postReq)
|
||||
if err != nil {
|
||||
t.Fatalf("post request failed: %v", err)
|
||||
}
|
||||
defer postResp.Body.Close()
|
||||
|
||||
if postResp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected alias POST to return 302, got %d", postResp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetVerifyPathToken_AcceptsShortVPath(t *testing.T) {
|
||||
const token = "tok-path"
|
||||
const loginID = "user@example.com"
|
||||
|
||||
redis := &testRedisRepo{
|
||||
values: map[string]string{
|
||||
prefixPwdResetToken + token: loginID,
|
||||
},
|
||||
}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/password/reset/v/:token", h.VerifyPasswordResetPage)
|
||||
app.Post("/api/v1/auth/password/reset/v/:token", h.ProcessPasswordResetToken)
|
||||
|
||||
getReq := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/api/v1/auth/password/reset/v/"+token,
|
||||
nil,
|
||||
)
|
||||
getResp, err := app.Test(getReq)
|
||||
if err != nil {
|
||||
t.Fatalf("get request failed: %v", err)
|
||||
}
|
||||
defer getResp.Body.Close()
|
||||
|
||||
if getResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected path-token GET to return 200, got %d", getResp.StatusCode)
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/auth/password/reset/v/"+token,
|
||||
nil,
|
||||
)
|
||||
postResp, err := app.Test(postReq)
|
||||
if err != nil {
|
||||
t.Fatalf("post request failed: %v", err)
|
||||
}
|
||||
defer postResp.Body.Close()
|
||||
|
||||
if postResp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected path-token POST to return 302, got %d", postResp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetInit_LegacyErrorResponseHasCodeViaMiddleware(t *testing.T) {
|
||||
h := &AuthHandler{}
|
||||
app := newResetInitAppWithErrorCodeEnricher(h)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"loginId": "",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got["error"] != "Login ID is required" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
if got["code"] != "bad_request" {
|
||||
t.Fatalf("expected code=bad_request, got %v", got["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitiatePasswordReset_SmsContainsVerifyLink(t *testing.T) {
|
||||
t.Setenv("USERFRONT_URL", "https://sss.hmac.kr")
|
||||
|
||||
redis := &testRedisRepo{values: map[string]string{}}
|
||||
smsSvc := &mockSmsService{}
|
||||
h := &AuthHandler{
|
||||
RedisService: redis,
|
||||
IdpProvider: &mockIdpProvider{},
|
||||
SmsService: smsSvc,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/password/reset/init", h.InitiatePasswordReset)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"loginId": "01012345678",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/password/reset/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if !strings.Contains(smsSvc.lastContent, "/api/v1/auth/password/reset/v/") {
|
||||
t.Fatalf("expected SMS to contain short path verify link, got %q", smsSvc.lastContent)
|
||||
}
|
||||
if strings.Contains(smsSvc.lastContent, "/reset-password?token=") {
|
||||
t.Fatalf("expected direct reset-password link to be removed, got %q", smsSvc.lastContent)
|
||||
}
|
||||
}
|
||||
537
baron-sso/backend/internal/handler/client_tenant_access.go
Normal file
537
baron-sso/backend/internal/handler/client_tenant_access.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/response"
|
||||
"baron-sso-backend/internal/service"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
clientTenantAccessRestrictedKey = "tenant_access_restricted"
|
||||
clientAllowedTenantsKey = "allowed_tenants"
|
||||
)
|
||||
|
||||
func normalizeClientTenantAccessMetadata(metadata map[string]any) (map[string]any, error) {
|
||||
if metadata == nil {
|
||||
metadata = map[string]any{}
|
||||
}
|
||||
|
||||
restricted := readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey)
|
||||
allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])
|
||||
ownerTenantID := normalizeMetadataString(metadata["tenant_id"])
|
||||
|
||||
if len(allowedTenants) > 0 {
|
||||
restricted = true
|
||||
}
|
||||
|
||||
if !restricted {
|
||||
delete(metadata, clientAllowedTenantsKey)
|
||||
metadata[clientTenantAccessRestrictedKey] = false
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
if ownerTenantID != "" {
|
||||
allowedTenants = append(allowedTenants, ownerTenantID)
|
||||
}
|
||||
allowedTenants = uniqueSortedStrings(allowedTenants)
|
||||
if len(allowedTenants) == 0 {
|
||||
return nil, errors.New("allowed_tenants is required when tenant_access_restricted is enabled")
|
||||
}
|
||||
|
||||
metadata[clientTenantAccessRestrictedKey] = true
|
||||
metadata[clientAllowedTenantsKey] = allowedTenants
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func clientTenantAccessRestricted(metadata map[string]any) bool {
|
||||
if metadata == nil {
|
||||
return false
|
||||
}
|
||||
if readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey) {
|
||||
return true
|
||||
}
|
||||
return len(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])) > 0
|
||||
}
|
||||
|
||||
func clientAllowedTenants(metadata map[string]any) []string {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
if !clientTenantAccessRestricted(metadata) {
|
||||
return nil
|
||||
}
|
||||
return uniqueSortedStrings(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey]))
|
||||
}
|
||||
|
||||
func normalizeMetadataStringSlice(raw any) []string {
|
||||
switch value := raw.(type) {
|
||||
case []string:
|
||||
return uniqueSortedStrings(value)
|
||||
case []any:
|
||||
items := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
if s, ok := item.(string); ok {
|
||||
items = append(items, s)
|
||||
}
|
||||
}
|
||||
return uniqueSortedStrings(items)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMetadataString(raw any) string {
|
||||
s, ok := raw.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func uniqueSortedStrings(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
||||
if !clientTenantAccessRestricted(client.Metadata) {
|
||||
return true
|
||||
}
|
||||
allowed := clientAllowedTenants(client.Metadata)
|
||||
if len(allowed) == 0 {
|
||||
return false
|
||||
}
|
||||
keys := manageableTenantKeysFromProfile(profile)
|
||||
if len(keys) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, tenantID := range allowed {
|
||||
if _, ok := keys[strings.ToLower(strings.TrimSpace(tenantID))]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func clientTenantAccessAllowedForSubtree(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
||||
if clientTenantAccessAllowed(profile, client) {
|
||||
return true
|
||||
}
|
||||
if tenantSvc == nil || profile == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
allowedTenants := make([]domain.Tenant, 0)
|
||||
for _, identifier := range clientAllowedTenants(client.Metadata) {
|
||||
if tenant, ok := resolveTenantAccessTenant(c, tenantSvc, domain.Tenant{ID: identifier, Slug: identifier}); ok {
|
||||
allowedTenants = append(allowedTenants, tenant)
|
||||
}
|
||||
}
|
||||
if len(allowedTenants) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, candidate := range tenantAccessProfileTenants(profile) {
|
||||
resolvedCandidate, ok := resolveTenantAccessTenant(c, tenantSvc, candidate)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, allowed := range allowedTenants {
|
||||
if tenantMatchesOrDescendsFrom(c, tenantSvc, resolvedCandidate, allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tenantAccessProfileTenants(profile *domain.UserProfileResponse) []domain.Tenant {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
tenants := make([]domain.Tenant, 0, len(profile.ManageableTenants)+len(profile.JoinedTenants)+2)
|
||||
add := func(tenant domain.Tenant) {
|
||||
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Name))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
tenants = append(tenants, tenant)
|
||||
}
|
||||
|
||||
if profile.Tenant != nil {
|
||||
add(*profile.Tenant)
|
||||
}
|
||||
if profile.TenantID != nil {
|
||||
add(domain.Tenant{ID: strings.TrimSpace(*profile.TenantID)})
|
||||
}
|
||||
for _, tenant := range profile.ManageableTenants {
|
||||
add(tenant)
|
||||
}
|
||||
for _, tenant := range profile.JoinedTenants {
|
||||
add(tenant)
|
||||
}
|
||||
return tenants
|
||||
}
|
||||
|
||||
func resolveTenantAccessTenant(c *fiber.Ctx, tenantSvc service.TenantService, tenant domain.Tenant) (domain.Tenant, bool) {
|
||||
if tenantSvc == nil {
|
||||
return tenant, firstNonEmptyString(tenant.ID, tenant.Slug) != ""
|
||||
}
|
||||
if strings.TrimSpace(tenant.ID) != "" {
|
||||
if resolved, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(tenant.ID)); err == nil && resolved != nil {
|
||||
return *resolved, true
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tenant.Slug) != "" {
|
||||
if resolved, err := tenantSvc.GetTenantBySlug(c.Context(), strings.TrimSpace(tenant.Slug)); err == nil && resolved != nil {
|
||||
return *resolved, true
|
||||
}
|
||||
}
|
||||
return tenant, firstNonEmptyString(tenant.ID, tenant.Slug) != ""
|
||||
}
|
||||
|
||||
func tenantMatchesOrDescendsFrom(c *fiber.Ctx, tenantSvc service.TenantService, tenant domain.Tenant, ancestor domain.Tenant) bool {
|
||||
if tenantAccessTenantMatches(tenant, ancestor) {
|
||||
return true
|
||||
}
|
||||
if tenantSvc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
visited := make(map[string]struct{})
|
||||
current := tenant
|
||||
for current.ParentID != nil && strings.TrimSpace(*current.ParentID) != "" {
|
||||
parentID := strings.TrimSpace(*current.ParentID)
|
||||
if _, ok := visited[parentID]; ok {
|
||||
return false
|
||||
}
|
||||
visited[parentID] = struct{}{}
|
||||
|
||||
parent, err := tenantSvc.GetTenant(c.Context(), parentID)
|
||||
if err != nil || parent == nil {
|
||||
return false
|
||||
}
|
||||
if tenantAccessTenantMatches(*parent, ancestor) {
|
||||
return true
|
||||
}
|
||||
current = *parent
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tenantAccessTenantMatches(left, right domain.Tenant) bool {
|
||||
leftID := strings.ToLower(strings.TrimSpace(left.ID))
|
||||
rightID := strings.ToLower(strings.TrimSpace(right.ID))
|
||||
if leftID != "" && rightID != "" && leftID == rightID {
|
||||
return true
|
||||
}
|
||||
|
||||
leftSlug := strings.ToLower(strings.TrimSpace(left.Slug))
|
||||
rightSlug := strings.ToLower(strings.TrimSpace(right.Slug))
|
||||
return leftSlug != "" && rightSlug != "" && leftSlug == rightSlug
|
||||
}
|
||||
|
||||
type tenantAccessDeniedDetails struct {
|
||||
Account tenantAccessDeniedAccount `json:"account"`
|
||||
CurrentTenant tenantAccessDeniedTenant `json:"current_tenant"`
|
||||
AffiliatedTenants []tenantAccessDeniedTenant `json:"affiliated_tenants,omitempty"`
|
||||
AllowedTenants []tenantAccessDeniedTenant `json:"allowed_tenants,omitempty"`
|
||||
}
|
||||
|
||||
type tenantAccessDeniedAccount struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
type tenantAccessDeniedTenant struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Slug string `json:"slug,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Identifier string `json:"identifier,omitempty"`
|
||||
}
|
||||
|
||||
func tenantNotAllowedError(c *fiber.Ctx, details tenantAccessDeniedDetails) error {
|
||||
return response.ErrorWithDetails(
|
||||
c,
|
||||
fiber.StatusForbidden,
|
||||
"tenant_not_allowed",
|
||||
"허용되지 않은 테넌트입니다.",
|
||||
details,
|
||||
)
|
||||
}
|
||||
|
||||
func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
||||
if profile == nil {
|
||||
return false
|
||||
}
|
||||
return clientTenantAccessAllowed(profile, client)
|
||||
}
|
||||
|
||||
func enforceClientTenantAccess(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) bool {
|
||||
if !clientTenantAccessRestricted(client.Metadata) {
|
||||
return false
|
||||
}
|
||||
details := buildTenantAccessDeniedDetails(c, tenantSvc, client, profile)
|
||||
if resolveErr != nil || profile == nil {
|
||||
_ = tenantNotAllowedError(c, details)
|
||||
return true
|
||||
}
|
||||
if !clientTenantAccessAllowedForSubtree(c, tenantSvc, profile, client) {
|
||||
_ = tenantNotAllowedError(c, details)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildTenantAccessDeniedDetails(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse) tenantAccessDeniedDetails {
|
||||
details := tenantAccessDeniedDetails{
|
||||
Account: tenantAccessDeniedAccount{Email: strings.TrimSpace(profileEmail(profile))},
|
||||
CurrentTenant: resolveCurrentTenantDetails(c, tenantSvc, profile),
|
||||
AffiliatedTenants: resolveAffiliatedTenantDetails(c, tenantSvc, profile),
|
||||
}
|
||||
|
||||
for _, identifier := range clientAllowedTenants(client.Metadata) {
|
||||
details.AllowedTenants = append(details.AllowedTenants, resolveAllowedTenantDetails(c, tenantSvc, identifier))
|
||||
}
|
||||
|
||||
return details
|
||||
}
|
||||
|
||||
func resolveAffiliatedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) []tenantAccessDeniedTenant {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]tenantAccessDeniedTenant, 0, len(profile.JoinedTenants)+1)
|
||||
appendTenant := func(tenant tenantAccessDeniedTenant) {
|
||||
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Identifier, tenant.Name))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, tenant)
|
||||
}
|
||||
|
||||
appendTenant(resolveCurrentTenantDetails(c, tenantSvc, profile))
|
||||
|
||||
for _, joined := range profile.JoinedTenants {
|
||||
appendTenant(tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(joined.ID),
|
||||
Slug: strings.TrimSpace(joined.Slug),
|
||||
Name: strings.TrimSpace(joined.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(joined.Slug), strings.TrimSpace(joined.ID)),
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveCurrentTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) tenantAccessDeniedTenant {
|
||||
if profile == nil {
|
||||
return tenantAccessDeniedTenant{}
|
||||
}
|
||||
|
||||
if profile.Tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(profile.Tenant.ID),
|
||||
Slug: strings.TrimSpace(profile.Tenant.Slug),
|
||||
Name: strings.TrimSpace(profile.Tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(profile.Tenant.Slug), strings.TrimSpace(profile.Tenant.ID)),
|
||||
}
|
||||
}
|
||||
|
||||
if tenantSvc != nil {
|
||||
if profile.TenantID != nil && strings.TrimSpace(*profile.TenantID) != "" {
|
||||
if tenant, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(*profile.TenantID)); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(pointerValue(profile.TenantID)),
|
||||
Identifier: strings.TrimSpace(pointerValue(profile.TenantID)),
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAllowedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, identifier string) tenantAccessDeniedTenant {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return tenantAccessDeniedTenant{}
|
||||
}
|
||||
|
||||
if tenantSvc != nil {
|
||||
if tenant, err := tenantSvc.GetTenant(c.Context(), identifier); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
|
||||
}
|
||||
}
|
||||
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), identifier); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tenantAccessDeniedTenant{Identifier: identifier}
|
||||
}
|
||||
|
||||
func profileEmail(profile *domain.UserProfileResponse) string {
|
||||
if profile == nil {
|
||||
return ""
|
||||
}
|
||||
return profile.Email
|
||||
}
|
||||
|
||||
func pointerValue(value *string) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type clientStructuredScope struct {
|
||||
Name string `json:"name"`
|
||||
Mandatory bool `json:"mandatory"`
|
||||
Locked bool `json:"locked"`
|
||||
}
|
||||
|
||||
func mergeRequestedScopesWithClientRequirements(client domain.HydraClient, requested []string) []string {
|
||||
combined := make([]string, 0, len(requested)+2)
|
||||
combined = append(combined, requested...)
|
||||
combined = append(combined, requiredClientScopes(client)...)
|
||||
|
||||
return normalizeScopesInConsentOrder(combined)
|
||||
}
|
||||
|
||||
func normalizeScopesInConsentOrder(scopes []string) []string {
|
||||
combined := make([]string, 0, len(scopes))
|
||||
combined = append(combined, scopes...)
|
||||
|
||||
seen := make(map[string]struct{}, len(combined))
|
||||
out := make([]string, 0, len(combined))
|
||||
|
||||
appendIfPresent := func(scope string) {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[scope]; ok {
|
||||
return
|
||||
}
|
||||
for _, candidate := range combined {
|
||||
if strings.TrimSpace(candidate) != scope {
|
||||
continue
|
||||
}
|
||||
seen[scope] = struct{}{}
|
||||
out = append(out, scope)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
appendIfPresent("openid")
|
||||
appendIfPresent("tenant")
|
||||
|
||||
for _, scope := range combined {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[scope]; ok {
|
||||
continue
|
||||
}
|
||||
seen[scope] = struct{}{}
|
||||
out = append(out, scope)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func requiredClientScopes(client domain.HydraClient) []string {
|
||||
required := make([]string, 0, 4)
|
||||
if clientTenantAccessRestricted(client.Metadata) {
|
||||
required = append(required, "tenant")
|
||||
}
|
||||
|
||||
if client.Metadata == nil {
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
|
||||
rawStructuredScopes, ok := client.Metadata["structured_scopes"]
|
||||
if !ok || rawStructuredScopes == nil {
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
|
||||
rawBytes, err := json.Marshal(rawStructuredScopes)
|
||||
if err != nil {
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
|
||||
var scopes []clientStructuredScope
|
||||
if err := json.Unmarshal(rawBytes, &scopes); err != nil {
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
|
||||
for _, scope := range scopes {
|
||||
name := strings.TrimSpace(scope.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if scope.Mandatory || scope.Locked {
|
||||
required = append(required, name)
|
||||
}
|
||||
}
|
||||
|
||||
return normalizeScopesInConsentOrder(required)
|
||||
}
|
||||
458
baron-sso/backend/internal/handler/client_tenant_access_test.go
Normal file
458
baron-sso/backend/internal/handler/client_tenant_access_test.go
Normal file
@@ -0,0 +1,458 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) {
|
||||
var captured domain.HydraClient
|
||||
ownerTenantID := "tenant-owner"
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, json.Unmarshal(body, &captured))
|
||||
return httpJSONAny(r, http.StatusCreated, map[string]any{
|
||||
"client_id": captured.ClientID,
|
||||
"client_name": captured.ClientName,
|
||||
"redirect_uris": captured.RedirectURIs,
|
||||
"grant_types": captured.GrantTypes,
|
||||
"response_types": captured.ResponseTypes,
|
||||
"scope": captured.Scope,
|
||||
"token_endpoint_auth_method": captured.TokenEndpointAuthMethod,
|
||||
"metadata": captured.Metadata,
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
Keto: new(devMockKetoService),
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-1",
|
||||
Role: domain.RoleSuperAdmin,
|
||||
TenantID: &ownerTenantID,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/dev/clients", h.CreateClient)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"id": "client-tenant",
|
||||
"name": "Tenant Client",
|
||||
"type": "pkce",
|
||||
"redirectUris": []string{"https://rp.example.com/cb"},
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-b", "tenant-a", "tenant-b"},
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusCreated, resp.StatusCode)
|
||||
assert.True(t, clientTenantAccessRestricted(captured.Metadata))
|
||||
assert.Equal(t, []string{"tenant-a", "tenant-b", "tenant-owner"}, clientAllowedTenants(captured.Metadata))
|
||||
}
|
||||
|
||||
func TestCreateClient_RejectsTenantAccessWithoutAllowedTenants(t *testing.T) {
|
||||
hydraCalled := false
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
|
||||
hydraCalled = true
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
Keto: new(devMockKetoService),
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "user-1", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/dev/clients", h.CreateClient)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"id": "client-tenant",
|
||||
"name": "Tenant Client",
|
||||
"type": "pkce",
|
||||
"redirectUris": []string{"https://rp.example.com/cb"},
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
assert.False(t, hydraCalled)
|
||||
}
|
||||
|
||||
func TestMergeRequestedScopesWithClientRequirements_AddsTenantScope(t *testing.T) {
|
||||
client := domain.HydraClient{
|
||||
Metadata: map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"structured_scopes": []map[string]any{
|
||||
{"name": "openid", "mandatory": true},
|
||||
{"name": "tenant", "mandatory": true, "locked": true},
|
||||
{"name": "profile", "mandatory": false},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
merged := mergeRequestedScopesWithClientRequirements(client, []string{"openid", "profile"})
|
||||
assert.Equal(t, []string{"openid", "tenant", "profile"}, merged)
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant":
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-tenant",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"skip": false,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-tenant",
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-b"},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
case r.URL.Host == "kratos.test" && r.URL.Path == "/sessions/whoami":
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"identity": map[string]any{
|
||||
"id": "user-123",
|
||||
"traits": map[string]any{
|
||||
"email": "user@test.com",
|
||||
"tenant_id": "tenant-a",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
default:
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
}
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
|
||||
|
||||
t.Setenv("APP_ENV", "dev")
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant", nil)
|
||||
req.Header.Set("X-Mock-Role", "user")
|
||||
req.Header.Set("X-Tenant-ID", "tenant-a")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, account["email"])
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, currentTenant["identifier"])
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) {
|
||||
acceptCalled := false
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-profile-missing":
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"challenge": "challenge-profile-missing",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"skip": false,
|
||||
"subject": "user-123",
|
||||
"client": map[string]any{
|
||||
"client_id": "client-tenant",
|
||||
"metadata": map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-b"},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
case r.URL.Path == "/oauth2/auth/requests/consent/accept":
|
||||
acceptCalled = true
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
default:
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
}
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = client
|
||||
defer func() { http.DefaultClient = origDefault }()
|
||||
|
||||
t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test")
|
||||
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: func() service.KratosAdminService {
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]any{
|
||||
"email": "user@test.com",
|
||||
"tenant_id": "tenant-a",
|
||||
"companyCode": "tenant-a",
|
||||
},
|
||||
}, nil).Once()
|
||||
return mockKratos
|
||||
}(),
|
||||
TenantService: func() service.TenantService {
|
||||
tenantSvc := new(MockTenantService)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
||||
ID: "tenant-a",
|
||||
Slug: "tenant-a",
|
||||
Name: "Tenant A",
|
||||
}, nil)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-c").Return(&domain.Tenant{
|
||||
ID: "tenant-c",
|
||||
Slug: "tenant-c",
|
||||
Name: "Tenant C",
|
||||
}, nil)
|
||||
tenantSvc.On("ListJoinedTenants", mock.Anything, "user-123").Return([]domain.Tenant{
|
||||
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
|
||||
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
|
||||
}, nil).Once()
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError)
|
||||
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
|
||||
ID: "tenant-b-id",
|
||||
Slug: "tenant-b",
|
||||
Name: "Tenant B",
|
||||
}, nil)
|
||||
return tenantSvc
|
||||
}(),
|
||||
ConsentRepo: &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{
|
||||
ClientID: "client-tenant",
|
||||
Subject: "user-123",
|
||||
GrantedScopes: []string{"openid", "profile"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/auth/consent", h.GetConsentRequest)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-profile-missing", nil)
|
||||
req.Header.Set("Cookie", "ory_kratos_session=invalid-session")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
assert.False(t, acceptCalled)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "user@test.com", account["email"])
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant A", currentTenant["name"])
|
||||
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, affiliatedTenants, 2)
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/deny", func(c *fiber.Ctx) error {
|
||||
tenantID := "tenant-a"
|
||||
profile := &domain.UserProfileResponse{
|
||||
ID: "user-123",
|
||||
Role: domain.RoleUser,
|
||||
Email: "user@test.com",
|
||||
TenantID: &tenantID,
|
||||
CompanyCode: "tenant-a",
|
||||
JoinedTenants: []domain.Tenant{
|
||||
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
|
||||
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
|
||||
},
|
||||
}
|
||||
client := domain.HydraClient{
|
||||
ClientID: "client-tenant",
|
||||
Metadata: map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"tenant-b"},
|
||||
},
|
||||
}
|
||||
tenantSvc := new(MockTenantService)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
||||
ID: "tenant-a",
|
||||
Slug: "tenant-a",
|
||||
Name: "Tenant A",
|
||||
}, nil)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-c").Return(&domain.Tenant{
|
||||
ID: "tenant-c",
|
||||
Slug: "tenant-c",
|
||||
Name: "Tenant C",
|
||||
}, nil)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError)
|
||||
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
|
||||
ID: "tenant-b-id",
|
||||
Slug: "tenant-b",
|
||||
Name: "Tenant B",
|
||||
}, nil)
|
||||
enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/deny", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "user@test.com", account["email"])
|
||||
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant A", currentTenant["name"])
|
||||
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, affiliatedTenants, 2)
|
||||
|
||||
allowedTenants, ok := details["allowed_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, allowedTenants, 1)
|
||||
allowedTenant, ok := allowedTenants[0].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant B", allowedTenant["name"])
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_AllowsRestrictedClientForHanmacFamilyDescendant(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/allow-descendant", func(c *fiber.Ctx) error {
|
||||
hanmacFamilyID := "hanmac-family-id"
|
||||
samanID := "saman-id"
|
||||
profile := &domain.UserProfileResponse{
|
||||
ID: "user-123",
|
||||
Role: domain.RoleUser,
|
||||
Email: "user@samaneng.com",
|
||||
TenantID: &samanID,
|
||||
Tenant: &domain.Tenant{
|
||||
ID: samanID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
ParentID: &hanmacFamilyID,
|
||||
},
|
||||
JoinedTenants: []domain.Tenant{
|
||||
{
|
||||
ID: samanID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
ParentID: &hanmacFamilyID,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := domain.HydraClient{
|
||||
ClientID: "orgfront",
|
||||
Metadata: map[string]any{
|
||||
"tenant_access_restricted": true,
|
||||
"allowed_tenants": []string{"hanmac-family"},
|
||||
},
|
||||
}
|
||||
tenantSvc := new(MockTenantService)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "hanmac-family").Return(nil, assert.AnError).Maybe()
|
||||
tenantSvc.On("GetTenantBySlug", mock.Anything, "hanmac-family").Return(&domain.Tenant{
|
||||
ID: hanmacFamilyID,
|
||||
Slug: "hanmac-family",
|
||||
Name: "한맥가족",
|
||||
}, nil).Maybe()
|
||||
tenantSvc.On("GetTenant", mock.Anything, samanID).Return(&domain.Tenant{
|
||||
ID: samanID,
|
||||
Slug: "saman",
|
||||
Name: "삼안",
|
||||
ParentID: &hanmacFamilyID,
|
||||
}, nil).Maybe()
|
||||
tenantSvc.On("GetTenant", mock.Anything, hanmacFamilyID).Return(&domain.Tenant{
|
||||
ID: hanmacFamilyID,
|
||||
Slug: "hanmac-family",
|
||||
Name: "한맥가족",
|
||||
}, nil).Maybe()
|
||||
|
||||
blocked := enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
|
||||
assert.False(t, blocked)
|
||||
return c.SendStatus(http.StatusNoContent)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/allow-descendant", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
}
|
||||
303
baron-sso/backend/internal/handler/common_test.go
Normal file
303
baron-sso/backend/internal/handler/common_test.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Mock IDP Provider ---
|
||||
|
||||
type mockIdpProvider struct {
|
||||
userExists bool
|
||||
name string
|
||||
signInInfo *domain.AuthInfo
|
||||
issueSession *domain.AuthInfo
|
||||
verifyCodeInfo *domain.AuthInfo
|
||||
err error
|
||||
initiateLinkErr error
|
||||
updateCalled bool
|
||||
updateCallCount int
|
||||
updatedLoginID string
|
||||
updatedPassword string
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) Name() string {
|
||||
if m.name != "" {
|
||||
return m.name
|
||||
}
|
||||
return "mock-idp"
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) { return nil, m.err }
|
||||
func (m *mockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
|
||||
return "mock-user-id", m.err
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
|
||||
return m.signInInfo, m.err
|
||||
}
|
||||
func (m *mockIdpProvider) UserExists(loginID string) (bool, error) { return m.userExists, m.err }
|
||||
func (m *mockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) {
|
||||
if m.issueSession != nil {
|
||||
return m.issueSession, m.err
|
||||
}
|
||||
return &domain.AuthInfo{
|
||||
SessionToken: &domain.Token{JWT: "valid-jwt", SessionID: "valid-sid"},
|
||||
}, m.err
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
|
||||
if m.initiateLinkErr != nil {
|
||||
return nil, m.initiateLinkErr
|
||||
}
|
||||
return &domain.LinkLoginInit{FlowID: "mock-flow-id", Mode: "code"}, m.err
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
|
||||
return m.verifyCodeInfo, m.err
|
||||
}
|
||||
func (m *mockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) { return nil, m.err }
|
||||
func (m *mockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return m.err }
|
||||
func (m *mockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
func (m *mockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
|
||||
m.updateCalled = true
|
||||
m.updateCallCount++
|
||||
m.updatedLoginID = loginID
|
||||
m.updatedPassword = newPassword
|
||||
return m.err
|
||||
}
|
||||
|
||||
// --- Mock Audit Repository ---
|
||||
|
||||
type mockAuditRepo struct {
|
||||
logs []domain.AuditLog
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) Create(log *domain.AuditLog) error {
|
||||
m.logs = append(m.logs, *log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) FindPage(ctx context.Context, limit int, cursor *domain.AuditCursor, tenantID string) ([]domain.AuditLog, error) {
|
||||
return m.logs, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) FindByUserAndEvents(ctx context.Context, userID string, eventTypes []string, limit int) ([]domain.AuditLog, error) {
|
||||
var results []domain.AuditLog
|
||||
for _, log := range m.logs {
|
||||
if log.UserID == userID {
|
||||
if slices.Contains(eventTypes, log.EventType) {
|
||||
results = append(results, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) CountFailuresSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) CountEventsSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) CountActiveSessionsSince(ctx context.Context, since time.Time, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) Ping(ctx context.Context) error { return nil }
|
||||
|
||||
type mockRPUsageEventSink struct {
|
||||
events []domain.RPUsageEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRPUsageEventSink) EmitRPUsageEvent(ctx context.Context, event domain.RPUsageEvent) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.events = append(m.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockOathkeeperRepo struct {
|
||||
logs []domain.OathkeeperAccessLog
|
||||
}
|
||||
|
||||
func (m *mockOathkeeperRepo) FindPageBySubject(ctx context.Context, subject string, limit int, cursor *domain.AuditCursor) ([]domain.OathkeeperAccessLog, error) {
|
||||
if subject == "" {
|
||||
return m.logs, nil
|
||||
}
|
||||
results := make([]domain.OathkeeperAccessLog, 0, len(m.logs))
|
||||
for _, log := range m.logs {
|
||||
if log.Subject == subject {
|
||||
results = append(results, log)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *mockOathkeeperRepo) Ping(ctx context.Context) error { return nil }
|
||||
|
||||
// --- Mock Consent Repository ---
|
||||
|
||||
type mockConsentRepo struct {
|
||||
consents []domain.ClientConsent
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) Upsert(ctx context.Context, consent *domain.ClientConsent) error {
|
||||
m.consents = append(m.consents, *consent)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) ListBySubject(ctx context.Context, subject string) ([]domain.ClientConsent, error) {
|
||||
var results []domain.ClientConsent
|
||||
for _, c := range m.consents {
|
||||
if c.Subject == subject {
|
||||
results = append(results, c)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) ListSubjectsByClient(ctx context.Context, clientID string) ([]string, error) {
|
||||
seen := map[string]struct{}{}
|
||||
subjects := make([]string, 0, len(m.consents))
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID != clientID {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[consent.Subject]; ok {
|
||||
continue
|
||||
}
|
||||
seen[consent.Subject] = struct{}{}
|
||||
subjects = append(subjects, consent.Subject)
|
||||
}
|
||||
return subjects, nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) Find(ctx context.Context, clientID, subject string) (*domain.ClientConsent, error) {
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID == clientID && consent.Subject == subject {
|
||||
found := consent
|
||||
return &found, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) Delete(ctx context.Context, subject, clientID string) error {
|
||||
filtered := m.consents[:0]
|
||||
for _, consent := range m.consents {
|
||||
if consent.Subject == subject && (clientID == "" || consent.ClientID == clientID) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, consent)
|
||||
}
|
||||
m.consents = filtered
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) DeleteByClient(ctx context.Context, clientID string) error {
|
||||
filtered := m.consents[:0]
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID != clientID {
|
||||
filtered = append(filtered, consent)
|
||||
}
|
||||
}
|
||||
m.consents = filtered
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) List(ctx context.Context, clientID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
|
||||
results := make([]domain.ClientConsentWithTenantInfo, 0, len(m.consents))
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID == clientID {
|
||||
results = append(results, domain.ClientConsentWithTenantInfo{ClientConsent: consent})
|
||||
}
|
||||
}
|
||||
return results, int64(len(results)), nil
|
||||
}
|
||||
|
||||
func (m *mockConsentRepo) ListByTenant(ctx context.Context, clientID, tenantID string, limit, offset int) ([]domain.ClientConsentWithTenantInfo, int64, error) {
|
||||
results := make([]domain.ClientConsentWithTenantInfo, 0, len(m.consents))
|
||||
for _, consent := range m.consents {
|
||||
if consent.ClientID == clientID {
|
||||
results = append(results, domain.ClientConsentWithTenantInfo{
|
||||
ClientConsent: consent,
|
||||
TenantID: tenantID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return results, int64(len(results)), nil
|
||||
}
|
||||
|
||||
// --- Mock Secret Repository ---
|
||||
|
||||
type mockSecretRepo struct {
|
||||
secrets map[string]string
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) Upsert(ctx context.Context, clientID, secret string) error {
|
||||
if m.secrets == nil {
|
||||
m.secrets = make(map[string]string)
|
||||
}
|
||||
m.secrets[clientID] = secret
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) GetByID(ctx context.Context, clientID string) (string, error) {
|
||||
return m.secrets[clientID], nil
|
||||
}
|
||||
|
||||
func (m *mockSecretRepo) Delete(ctx context.Context, clientID string) error {
|
||||
delete(m.secrets, clientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- HTTP Mock Helpers ---
|
||||
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func setDefaultHTTPClientForTest(t interface{ Cleanup(func()) }, transport http.RoundTripper) {
|
||||
origDefault := http.DefaultClient
|
||||
http.DefaultClient = &http.Client{Transport: transport}
|
||||
t.Cleanup(func() {
|
||||
http.DefaultClient = origDefault
|
||||
})
|
||||
}
|
||||
|
||||
func httpResponse(r *http.Request, code int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||
Request: r,
|
||||
}
|
||||
}
|
||||
|
||||
func httpJSONAny(r *http.Request, code int, data any) *http.Response {
|
||||
body, _ := json.Marshal(data)
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewBuffer(body)),
|
||||
Request: r,
|
||||
}
|
||||
}
|
||||
4165
baron-sso/backend/internal/handler/dev_handler.go
Normal file
4165
baron-sso/backend/internal/handler/dev_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
242
baron-sso/backend/internal/handler/dev_handler_isolation_test.go
Normal file
242
baron-sso/backend/internal/handler/dev_handler_isolation_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestDevHandler_Isolation(t *testing.T) {
|
||||
createHandler := func(mockKeto *devMockKetoService) *DevHandler {
|
||||
return &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/clients" {
|
||||
return httpJSONAny(r, http.StatusOK, []map[string]any{
|
||||
{
|
||||
"client_id": "client-tenant-a",
|
||||
"client_name": "App Tenant A",
|
||||
"token_endpoint_auth_method": "none", // PKCE
|
||||
"metadata": map[string]any{"tenant_id": "tenant-a"},
|
||||
},
|
||||
{
|
||||
"client_id": "client-tenant-b",
|
||||
"client_name": "App Tenant B",
|
||||
"token_endpoint_auth_method": "none", // PKCE
|
||||
"metadata": map[string]any{"tenant_id": "tenant-b"},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if (r.Method == http.MethodGet || r.Method == http.MethodPut) && strings.HasPrefix(r.URL.Path, "/clients/") {
|
||||
id := strings.TrimPrefix(r.URL.Path, "/clients/")
|
||||
tenantID := "tenant-a"
|
||||
if id == "client-tenant-b" {
|
||||
tenantID = "tenant-b"
|
||||
}
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": id,
|
||||
"client_name": "App " + id,
|
||||
"token_endpoint_auth_method": "none",
|
||||
"metadata": map[string]any{"tenant_id": tenantID},
|
||||
}), nil
|
||||
}
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/clients" {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
return httpJSONAny(r, http.StatusCreated, body), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
}),
|
||||
},
|
||||
},
|
||||
Keto: mockKeto,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Local bypass should be removed", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
app.Get("/api/v1/dev/clients", h.ListClients)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
|
||||
req.Header.Set("Origin", "http://localhost:5174")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("ListClients should show all for SuperAdmin", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "super-user",
|
||||
Role: domain.RoleSuperAdmin,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/dev/clients", h.ListClients)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var res struct {
|
||||
Items []clientSummary `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
// Should see both clients
|
||||
assert.Equal(t, 2, len(res.Items))
|
||||
})
|
||||
|
||||
t.Run("ListClients should filter by permit for non-SuperAdmin", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
tenantA := "tenant-a"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-a",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &tenantA,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/dev/clients", h.ListClients)
|
||||
|
||||
// Explicit permission for private client check bypass
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "System", "global", "manage_all").Return(true, nil).Maybe()
|
||||
// Mock permit for the specific client
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(true, nil).Maybe()
|
||||
// Deny for other clients
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-b", "view").Return(false, nil).Maybe()
|
||||
|
||||
mockKeto.On("ListRelations", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]service.RelationTuple{}, nil).Maybe()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var res struct {
|
||||
Items []clientSummary `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
// Should only see client-tenant-a (tenant permit)
|
||||
assert.Equal(t, 1, len(res.Items))
|
||||
assert.Equal(t, "client-tenant-a", res.Items[0].ID)
|
||||
})
|
||||
|
||||
t.Run("Tenant member should see empty list from DevFront clients if no relation", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
tenantA := "tenant-a"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-member",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &tenantA,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/dev/clients", h.ListClients)
|
||||
|
||||
// Deny all by default
|
||||
mockKeto.On("CheckPermission", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(false, nil).Maybe()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var res struct {
|
||||
Items []clientSummary `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
// Empty list because we didn't mock any specific 'view' permissions for this user
|
||||
assert.Equal(t, 0, len(res.Items))
|
||||
})
|
||||
|
||||
t.Run("GetClient should enforce isolation for non-SuperAdmin", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
tenantA := "tenant-a"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-a",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &tenantA,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/api/v1/dev/clients/:id", h.GetClient)
|
||||
|
||||
// Case 1: Same tenant BUT no permit (Normal users need permit now)
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(false, nil).Once()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-a", nil)
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
// Case 2: Same tenant WITH permit
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-a", "view").Return(true, nil).Maybe()
|
||||
mockKeto.On("ListRelations", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]service.RelationTuple{}, nil).Maybe()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-a", nil)
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Case 3: Different tenant
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user-a", "RelyingParty", "client-tenant-b", "view").Return(false, nil).Maybe()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-tenant-b", nil)
|
||||
resp, _ = app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("CreateClient should record user_id and tenant_id", func(t *testing.T) {
|
||||
mockKeto := new(devMockKetoService)
|
||||
h := createHandler(mockKeto)
|
||||
app := fiber.New()
|
||||
tenantA := "tenant-a"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
ID: "user-a",
|
||||
Role: domain.RoleSuperAdmin, // Bypass for creation permission
|
||||
TenantID: &tenantA,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Post("/api/v1/dev/clients", h.CreateClient)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"client_name": "New App",
|
||||
"type": "pkce",
|
||||
"redirectUris": []string{"http://localhost/cb"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Tenant-ID", "tenant-a")
|
||||
|
||||
resp, _ := app.Test(req, -1)
|
||||
assert.Equal(t, http.StatusCreated, resp.StatusCode)
|
||||
|
||||
var res clientDetailResponse
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Equal(t, "tenant-a", res.Client.Metadata["tenant_id"])
|
||||
assert.Equal(t, "user-a", res.Client.Metadata["user_id"])
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type devMockRPUserMetadataRepo struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *devMockRPUserMetadataRepo) Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error) {
|
||||
args := m.Called(ctx, clientID, userID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.RPUserMetadata), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *devMockRPUserMetadataRepo) Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error {
|
||||
args := m.Called(ctx, metadata)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-1",
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "approvalLevel",
|
||||
"valueType": "text",
|
||||
"value": "A",
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
repo.On("Upsert", mock.Anything, mock.MatchedBy(func(row *domain.RPUserMetadata) bool {
|
||||
return row.ClientID == "client-1" &&
|
||||
row.UserID == "user-1" &&
|
||||
row.Metadata["approvalLevel"] == "A" &&
|
||||
row.Metadata["approvalLevel_permissions"].(map[string]any)["readPermission"] == "admin_only" &&
|
||||
row.Metadata["approvalLevel_permissions"].(map[string]any)["writePermission"] == "user_and_admin"
|
||||
})).Return(nil).Once()
|
||||
repo.On("Get", mock.Anything, "client-1", "user-1").Return(&domain.RPUserMetadata{
|
||||
ClientID: "client-1",
|
||||
UserID: "user-1",
|
||||
Metadata: domain.JSONMap{"approvalLevel": "A"},
|
||||
}, nil).Once()
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
app.Get("/api/v1/dev/clients/:id/users/:userId/metadata", h.GetRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{
|
||||
"approvalLevel": "A",
|
||||
"approvalLevel_permissions": map[string]any{
|
||||
"writePermission": "user_and_admin",
|
||||
},
|
||||
},
|
||||
})
|
||||
putReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
putReq.Header.Set("Content-Type", "application/json")
|
||||
putResp, _ := app.Test(putReq, -1)
|
||||
assert.Equal(t, http.StatusOK, putResp.StatusCode)
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-1/users/user-1/metadata", nil)
|
||||
getResp, _ := app.Test(getReq, -1)
|
||||
assert.Equal(t, http.StatusOK, getResp.StatusCode)
|
||||
|
||||
var got map[string]any
|
||||
assert.NoError(t, json.NewDecoder(getResp.Body).Decode(&got))
|
||||
assert.Equal(t, "client-1", got["clientId"])
|
||||
assert.Equal(t, "user-1", got["userId"])
|
||||
assert.Equal(t, "A", got["metadata"].(map[string]any)["approvalLevel"])
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataMirrorsToKratosTraits(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]any{
|
||||
"tenant_id": "tenant-1",
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "approvalLevel",
|
||||
"valueType": "text",
|
||||
"value": "A",
|
||||
"readPermission": "user_and_admin",
|
||||
"writePermission": "admin_only",
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
repo.On("Upsert", mock.Anything, mock.AnythingOfType("*domain.RPUserMetadata")).Return(nil).Once()
|
||||
kratos := new(MockKratosAdmin)
|
||||
kratos.On("GetIdentity", mock.Anything, "user-1").Return(&service.KratosIdentity{
|
||||
ID: "user-1",
|
||||
State: "active",
|
||||
Traits: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "User One",
|
||||
},
|
||||
}, nil).Once()
|
||||
var capturedTraits map[string]any
|
||||
kratos.On("UpdateIdentity", mock.Anything, "user-1", mock.Anything, "active").Run(func(args mock.Arguments) {
|
||||
capturedTraits = args.Get(2).(map[string]any)
|
||||
}).Return(&service.KratosIdentity{ID: "user-1", State: "active", Traits: map[string]any{}}, nil).Once()
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
KratosAdmin: kratos,
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{"approvalLevel": "B"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
rpClaims := capturedTraits["rp_custom_claims"].(map[string]any)
|
||||
clientClaims := rpClaims["client-1"].(domain.JSONMap)
|
||||
require.Equal(t, "B", clientClaims["approvalLevel"])
|
||||
require.Equal(t, map[string]any{
|
||||
"readPermission": "user_and_admin",
|
||||
"writePermission": "admin_only",
|
||||
}, clientClaims["approvalLevel_permissions"])
|
||||
repo.AssertExpectations(t)
|
||||
kratos.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataRejectsUndefinedClaimKey(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]any{
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "contract_date",
|
||||
"valueType": "date",
|
||||
"value": "2026-06-09",
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{"unknown_claim": "A"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataRejectsInvalidTypedClaimValue(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]any{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]any{
|
||||
"id_token_claims": []map[string]any{
|
||||
{
|
||||
"namespace": "rp_claims",
|
||||
"key": "contract_date",
|
||||
"valueType": "date",
|
||||
"value": "2026-06-09",
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{"contract_date": "2026/06/09"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req, -1)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
repo.AssertNotCalled(t, "Upsert", mock.Anything, mock.Anything)
|
||||
}
|
||||
3791
baron-sso/backend/internal/handler/dev_handler_test.go
Normal file
3791
baron-sso/backend/internal/handler/dev_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
17
baron-sso/backend/internal/handler/error_helper.go
Normal file
17
baron-sso/backend/internal/handler/error_helper.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/response"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// errorJSON은 기존 error 필드를 유지하면서 기계 판독용 code를 명시적으로 추가합니다.
|
||||
func errorJSON(c *fiber.Ctx, status int, message string) error {
|
||||
return response.Error(c, status, response.StatusCode(status), message)
|
||||
}
|
||||
|
||||
// errorJSONCode는 상태코드 기반 매핑만으로 부족한 경우 명시 코드를 강제할 때 사용합니다.
|
||||
func errorJSONCode(c *fiber.Ctx, status int, code, message string) error {
|
||||
return response.Error(c, status, code, message)
|
||||
}
|
||||
161
baron-sso/backend/internal/handler/federation_handler.go
Normal file
161
baron-sso/backend/internal/handler/federation_handler.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"errors"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// FederationHandler handles API requests for IdP federation.
|
||||
type FederationHandler struct {
|
||||
fedSvc *service.FederationService
|
||||
repo repository.FederationRepository // For IdP Config CRUD
|
||||
db *gorm.DB // For tenant existence checks, etc. in CRUD
|
||||
}
|
||||
|
||||
// NewFederationHandler creates a new FederationHandler.
|
||||
func NewFederationHandler(fedSvc *service.FederationService, repo repository.FederationRepository, db *gorm.DB) *FederationHandler {
|
||||
return &FederationHandler{
|
||||
fedSvc: fedSvc,
|
||||
repo: repo,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateOIDCLogin handles the start of the OIDC login flow.
|
||||
// It expects `provider_id` and `login_challenge` as query parameters.
|
||||
func (h *FederationHandler) InitiateOIDCLogin(c *fiber.Ctx) error {
|
||||
providerID := c.Query("provider_id")
|
||||
loginChallenge := c.Query("login_challenge")
|
||||
|
||||
if providerID == "" || loginChallenge == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "provider_id and login_challenge are required")
|
||||
}
|
||||
|
||||
redirectURL, err := h.fedSvc.InitiateOIDCLogin(c.Context(), providerID, loginChallenge)
|
||||
if err != nil {
|
||||
// Log the error properly in a real application
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to initiate OIDC login")
|
||||
}
|
||||
|
||||
return c.Redirect(redirectURL, fiber.StatusFound)
|
||||
}
|
||||
|
||||
// HandleOIDCCallback handles the OIDC callback from the IdP.
|
||||
func (h *FederationHandler) HandleOIDCCallback(c *fiber.Ctx) error {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "code and state are required")
|
||||
}
|
||||
|
||||
redirectURL, err := h.fedSvc.HandleOIDCCallback(c.Context(), code, state)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to handle OIDC callback")
|
||||
}
|
||||
|
||||
return c.Redirect(redirectURL, fiber.StatusFound)
|
||||
}
|
||||
|
||||
// --- New Client-based IdP Config Methods ---
|
||||
|
||||
// ListIdpConfigsForClient handles listing all IdP configurations for a client.
|
||||
func (h *FederationHandler) ListIdpConfigsForClient(c *fiber.Ctx) error {
|
||||
clientID := c.Params("clientId")
|
||||
if clientID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "clientId is required")
|
||||
}
|
||||
|
||||
var configs []domain.IdentityProviderConfig
|
||||
if err := h.db.Where("client_id = ?", clientID).Find(&configs).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(configs)
|
||||
}
|
||||
|
||||
// CreateIdpConfigForClient handles the creation of a new IdP configuration for a client.
|
||||
func (h *FederationHandler) CreateIdpConfigForClient(c *fiber.Ctx) error {
|
||||
clientID := c.Params("clientId")
|
||||
if clientID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "clientId is required in path")
|
||||
}
|
||||
|
||||
var req domain.IdentityProviderConfig
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
// Assign clientID from path parameter
|
||||
req.ClientID = clientID
|
||||
|
||||
// Basic validation
|
||||
if req.DisplayName == "" || req.ProviderType == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "display_name and provider_type are required")
|
||||
}
|
||||
|
||||
// TODO: Optionally, validate if the clientID exists in Hydra
|
||||
|
||||
// Create in DB
|
||||
if err := h.db.Create(&req).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(req)
|
||||
}
|
||||
|
||||
// --- Deprecated Tenant-based IdP Config Methods ---
|
||||
|
||||
// ListIdpConfigsForTenant handles listing all IdP configurations for a tenant.
|
||||
func (h *FederationHandler) ListIdpConfigsForTenant(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
if tenantID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
|
||||
}
|
||||
|
||||
// This is a temporary solution. We should create a proper method in the repository.
|
||||
var configs []domain.IdentityProviderConfig
|
||||
// Note: This now queries client_id, which is incorrect for tenants.
|
||||
// This method is deprecated.
|
||||
if err := h.db.Where("tenant_id = ?", tenantID).Find(&configs).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(configs)
|
||||
}
|
||||
|
||||
// CreateIdpConfig handles the creation of a new IdP configuration.
|
||||
func (h *FederationHandler) CreateIdpConfig(c *fiber.Ctx) error {
|
||||
var req domain.IdentityProviderConfig
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
// Basic validation - This is the old validation logic
|
||||
if req.ClientID == "" || req.DisplayName == "" || req.ProviderType == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "client_id, display_name, and provider_type are required")
|
||||
}
|
||||
|
||||
// This check is now incorrect and deprecated.
|
||||
var tenant domain.Tenant
|
||||
if err := h.db.First(&tenant, "id = ?", req.ClientID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "tenant not found")
|
||||
}
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
// Create in DB
|
||||
if err := h.db.Create(&req).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(req)
|
||||
}
|
||||
|
||||
// TODO: Re-implement Update, Delete handlers for IdP Configs for Clients
|
||||
243
baron-sso/backend/internal/handler/hanmac_email_policy.go
Normal file
243
baron-sso/backend/internal/handler/hanmac_email_policy.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const hanmacFamilyTenantSlug = "hanmac-family"
|
||||
|
||||
type hanmacEmailScope struct {
|
||||
TenantIDs map[string]bool
|
||||
Slugs map[string]bool
|
||||
IDList []string
|
||||
SlugList []string
|
||||
}
|
||||
|
||||
type hanmacEmailEvaluation struct {
|
||||
Email string
|
||||
OriginalEmail string
|
||||
SuggestedEmail string
|
||||
Status string
|
||||
Warnings []string
|
||||
Message string
|
||||
Blocking bool
|
||||
LocalPart string
|
||||
}
|
||||
|
||||
func (h *UserHandler) evaluateHanmacImportEmail(ctx context.Context, item bulkUserItem, scope *hanmacEmailScope, usedLocalParts map[string]bool) hanmacEmailEvaluation {
|
||||
originalEmail := strings.TrimSpace(item.Email)
|
||||
name := strings.TrimSpace(item.Name)
|
||||
evaluation := hanmacEmailEvaluation{
|
||||
Email: originalEmail,
|
||||
OriginalEmail: originalEmail,
|
||||
Status: "valid",
|
||||
}
|
||||
|
||||
localPart, domainPart, err := domain.SplitEmailDomain(originalEmail)
|
||||
if err != nil {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "invalid email format"
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
|
||||
base, needsReview, _ := domain.BuildKoreanNameEmailBase(name)
|
||||
if needsReview {
|
||||
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
|
||||
evaluation.Status = "needsReview"
|
||||
}
|
||||
|
||||
if localPart == "" {
|
||||
if base == "" {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "이름으로 이메일 ID를 제안할 수 없습니다."
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
nextLocalPart := nextAvailableHanmacLocalPart(base, usedLocalParts)
|
||||
evaluation.Email = nextLocalPart + "@" + domainPart
|
||||
evaluation.SuggestedEmail = evaluation.Email
|
||||
evaluation.LocalPart = nextLocalPart
|
||||
evaluation.Status = "suggested"
|
||||
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "suggested")
|
||||
return evaluation
|
||||
}
|
||||
|
||||
evaluation.LocalPart = localPart
|
||||
if usedLocalParts[localPart] {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "한맥가족 내에서 이미 사용 중인 이메일 ID입니다."
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
|
||||
if base != "" && !domain.MatchesSuggestedNameRule(localPart, base) {
|
||||
evaluation.Status = "ruleMismatch"
|
||||
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "ruleMismatch")
|
||||
}
|
||||
|
||||
if evaluation.Status == "needsReview" && len(evaluation.Warnings) == 0 {
|
||||
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
|
||||
}
|
||||
_ = scope
|
||||
return evaluation
|
||||
}
|
||||
|
||||
func (h *UserHandler) ensureHanmacCreateEmailAllowed(ctx context.Context, email string, tenantSlug string, tenantID string) error {
|
||||
scope, err := h.resolveHanmacEmailScope(ctx)
|
||||
if err != nil || scope == nil || !scope.ContainsTenant(tenantID, tenantSlug) {
|
||||
return nil
|
||||
}
|
||||
|
||||
localPart, err := domain.ExtractNormalizedEmailLocalPart(email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usedLocalParts, err := h.loadHanmacLocalParts(ctx, scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usedLocalParts[localPart] {
|
||||
return fmt.Errorf("한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *UserHandler) resolveHanmacEmailScope(ctx context.Context) (*hanmacEmailScope, error) {
|
||||
if h.TenantService == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tenants, _, err := h.TenantService.ListTenants(ctx, 10000, 0, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rootID string
|
||||
for _, tenant := range tenants {
|
||||
if strings.EqualFold(strings.TrimSpace(tenant.Slug), hanmacFamilyTenantSlug) {
|
||||
rootID = tenant.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
if rootID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tenantByID := make(map[string]domain.Tenant, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
tenantByID[tenant.ID] = tenant
|
||||
}
|
||||
|
||||
scope := &hanmacEmailScope{
|
||||
TenantIDs: make(map[string]bool),
|
||||
Slugs: make(map[string]bool),
|
||||
}
|
||||
for _, tenant := range tenants {
|
||||
if isTenantDescendantOf(tenant, rootID, tenantByID) {
|
||||
scope.TenantIDs[tenant.ID] = true
|
||||
scope.Slugs[strings.ToLower(strings.TrimSpace(tenant.Slug))] = true
|
||||
scope.IDList = append(scope.IDList, tenant.ID)
|
||||
scope.SlugList = append(scope.SlugList, tenant.Slug)
|
||||
}
|
||||
}
|
||||
return scope, nil
|
||||
}
|
||||
|
||||
func (h *UserHandler) loadHanmacLocalParts(ctx context.Context, scope *hanmacEmailScope) (map[string]bool, error) {
|
||||
used := make(map[string]bool)
|
||||
if h.UserRepo == nil || scope == nil {
|
||||
return used, nil
|
||||
}
|
||||
|
||||
if len(scope.IDList) > 0 {
|
||||
users, err := h.UserRepo.FindByTenantIDs(ctx, scope.IDList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addUserEmailLocalParts(used, users)
|
||||
}
|
||||
|
||||
if len(scope.SlugList) > 0 {
|
||||
users, err := h.UserRepo.FindByCompanyCodes(ctx, scope.SlugList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addUserEmailLocalParts(used, users)
|
||||
}
|
||||
|
||||
return used, nil
|
||||
}
|
||||
|
||||
func (s *hanmacEmailScope) ContainsTenant(tenantID string, slug string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if tenantID != "" && s.TenantIDs[tenantID] {
|
||||
return true
|
||||
}
|
||||
return s.Slugs[strings.ToLower(strings.TrimSpace(slug))]
|
||||
}
|
||||
|
||||
func isTenantDescendantOf(tenant domain.Tenant, rootID string, tenantByID map[string]domain.Tenant) bool {
|
||||
if tenant.ID == rootID {
|
||||
return true
|
||||
}
|
||||
visited := make(map[string]bool)
|
||||
parentID := ""
|
||||
if tenant.ParentID != nil {
|
||||
parentID = *tenant.ParentID
|
||||
}
|
||||
for parentID != "" {
|
||||
if parentID == rootID {
|
||||
return true
|
||||
}
|
||||
if visited[parentID] {
|
||||
return false
|
||||
}
|
||||
visited[parentID] = true
|
||||
parent, ok := tenantByID[parentID]
|
||||
if !ok || parent.ParentID == nil {
|
||||
return false
|
||||
}
|
||||
parentID = *parent.ParentID
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addUserEmailLocalParts(target map[string]bool, users []domain.User) {
|
||||
for _, user := range users {
|
||||
localPart, err := domain.ExtractNormalizedEmailLocalPart(user.Email)
|
||||
if err == nil && localPart != "" {
|
||||
target[localPart] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func nextAvailableHanmacLocalPart(base string, usedLocalParts map[string]bool) string {
|
||||
base = strings.ToLower(strings.TrimSpace(base))
|
||||
if base == "" {
|
||||
return ""
|
||||
}
|
||||
if !usedLocalParts[base] {
|
||||
return base
|
||||
}
|
||||
for index := 1; ; index++ {
|
||||
candidate := fmt.Sprintf("%s%d", base, index)
|
||||
if !usedLocalParts[candidate] {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendUniqueString(values []string, value string) []string {
|
||||
if slices.Contains(values, value) {
|
||||
return values
|
||||
}
|
||||
return append(values, value)
|
||||
}
|
||||
98
baron-sso/backend/internal/handler/password_policy_test.go
Normal file
98
baron-sso/backend/internal/handler/password_policy_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// 정책을 받아 필수 요구사항을 모두 포함하는 비밀번호를 생성한다.
|
||||
func generatePasswordFromPolicy(policy *domain.PasswordPolicy) string {
|
||||
minLen := policy.MinLength
|
||||
if minLen < 8 {
|
||||
minLen = 12 // 안전한 기본값
|
||||
}
|
||||
|
||||
pwd := make([]rune, 0, minLen)
|
||||
|
||||
if policy.Lowercase {
|
||||
pwd = append(pwd, 'a')
|
||||
}
|
||||
if policy.Uppercase {
|
||||
pwd = append(pwd, 'B')
|
||||
}
|
||||
if policy.Number {
|
||||
pwd = append(pwd, '3')
|
||||
}
|
||||
if policy.NonAlphanumeric {
|
||||
pwd = append(pwd, '!')
|
||||
}
|
||||
|
||||
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
|
||||
for len(pwd) < minLen {
|
||||
pwd = append(pwd, rune(charset[randomInt(len(charset))]))
|
||||
}
|
||||
|
||||
// 섞어서 예측 가능성을 낮춘다.
|
||||
for i := range pwd {
|
||||
j := randomInt(len(pwd))
|
||||
pwd[i], pwd[j] = pwd[j], pwd[i]
|
||||
}
|
||||
return string(pwd)
|
||||
}
|
||||
|
||||
func randomInt(n int) int {
|
||||
if n <= 0 {
|
||||
return 0
|
||||
}
|
||||
var b [8]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return 0
|
||||
}
|
||||
return int(binary.BigEndian.Uint64(b[:]) % uint64(n))
|
||||
}
|
||||
|
||||
func TestGeneratePasswordUsesNonAlphanumericRequirement(t *testing.T) {
|
||||
policy := &domain.PasswordPolicy{
|
||||
MinLength: 8,
|
||||
Lowercase: true,
|
||||
Uppercase: true,
|
||||
Number: true,
|
||||
NonAlphanumeric: true,
|
||||
}
|
||||
|
||||
pwd := generatePasswordFromPolicy(policy)
|
||||
|
||||
if len(pwd) < policy.MinLength {
|
||||
t.Fatalf("비밀번호 길이가 정책 최소 길이 미만: got %d, want >= %d", len(pwd), policy.MinLength)
|
||||
}
|
||||
|
||||
var hasLower, hasUpper, hasNumber, hasSymbol bool
|
||||
for _, r := range pwd {
|
||||
switch {
|
||||
case unicode.IsLower(r):
|
||||
hasLower = true
|
||||
case unicode.IsUpper(r):
|
||||
hasUpper = true
|
||||
case unicode.IsNumber(r):
|
||||
hasNumber = true
|
||||
case !unicode.IsLetter(r) && !unicode.IsNumber(r):
|
||||
hasSymbol = true
|
||||
}
|
||||
}
|
||||
|
||||
if policy.Lowercase && !hasLower {
|
||||
t.Fatalf("소문자 요구사항 미충족: %q", pwd)
|
||||
}
|
||||
if policy.Uppercase && !hasUpper {
|
||||
t.Fatalf("대문자 요구사항 미충족: %q", pwd)
|
||||
}
|
||||
if policy.Number && !hasNumber {
|
||||
t.Fatalf("숫자 요구사항 미충족: %q", pwd)
|
||||
}
|
||||
if policy.NonAlphanumeric && !hasSymbol {
|
||||
t.Fatalf("비영문자 요구사항 미충족: %q", pwd)
|
||||
}
|
||||
}
|
||||
114
baron-sso/backend/internal/handler/relying_party_handler.go
Normal file
114
baron-sso/backend/internal/handler/relying_party_handler.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"log/slog"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type RelyingPartyHandler struct {
|
||||
Service service.RelyingPartyService
|
||||
KratosAdmin service.KratosAdminService
|
||||
}
|
||||
|
||||
func NewRelyingPartyHandler(s service.RelyingPartyService, kratos service.KratosAdminService) *RelyingPartyHandler {
|
||||
return &RelyingPartyHandler{Service: s, KratosAdmin: kratos}
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) Create(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
if tenantID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
|
||||
}
|
||||
|
||||
var req domain.HydraClient
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
rp, err := h.Service.Create(c.Context(), tenantID, req)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(rp)
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) ListAll(c *fiber.Ctx) error {
|
||||
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
if !ok {
|
||||
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: user profile not found in context")
|
||||
}
|
||||
|
||||
var rps []domain.RelyingParty
|
||||
var err error
|
||||
role := domain.NormalizeRole(profile.Role)
|
||||
|
||||
if role == domain.RoleSuperAdmin {
|
||||
rps, err = h.Service.ListAll(c.Context())
|
||||
} else if role == "tenant_admin" && profile.TenantID != nil {
|
||||
rps, err = h.Service.List(c.Context(), *profile.TenantID)
|
||||
} else {
|
||||
slog.Warn("Forbidden access to all applications", "userID", profile.ID, "role", role)
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient role to list all applications")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(rps)
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) List(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
if tenantID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "tenantId is required")
|
||||
}
|
||||
|
||||
rps, err := h.Service.List(c.Context(), tenantID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(rps)
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) Get(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
rp, hydraClient, err := h.Service.Get(c.Context(), id)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusNotFound, "relying party not found")
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"relyingParty": rp,
|
||||
"oauth2Config": hydraClient,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) Update(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
var req domain.HydraClient
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
|
||||
rp, err := h.Service.Update(c.Context(), id, req)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(rp)
|
||||
}
|
||||
|
||||
func (h *RelyingPartyHandler) Delete(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
if err := h.Service.Delete(c.Context(), id); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
255
baron-sso/backend/internal/handler/rp_manifest_handler.go
Normal file
255
baron-sso/backend/internal/handler/rp_manifest_handler.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"html"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type RPManifestHandler struct{}
|
||||
|
||||
const rpObjectLookupMermaid = `flowchart TD
|
||||
A[RP request] --> B{obj_id supplied?}
|
||||
B -->|yes| C[Normalize object type and obj_id]
|
||||
B -->|no| D{Route has client_id?}
|
||||
D -->|yes| E[obj_id = RelyingParty:<client_id>]
|
||||
D -->|no| F{Route has tenant_id?}
|
||||
F -->|yes| G[obj_id = Tenant:<tenant_id>]
|
||||
F -->|no| H[Reject: explicit obj_id required]
|
||||
C --> I[Check Keto relation]
|
||||
E --> I
|
||||
G --> I
|
||||
I --> J{allowed?}
|
||||
J -->|yes| K[Inject trusted Baron headers]
|
||||
J -->|no| L[Reject request]
|
||||
K --> M[Write audit with obj_id, relation, client_id, X-Request-Id]`
|
||||
|
||||
const rpExternalKeyMermaid = `flowchart TD
|
||||
A[User authenticates through Baron SSO] --> B[Baron resolves internal identity]
|
||||
B --> C[Baron derives or loads Baron-issued alias]
|
||||
C --> D[Baron injects X-Baron-External-Key]
|
||||
D --> E[Baron injects X-Baron-Subject]
|
||||
E --> I[RP receives trusted headers from Baron gateway]
|
||||
I --> F[RP upserts local user with provider + X-Baron-External-Key]
|
||||
F --> G[RP stores the full external key as opaque value]
|
||||
G --> H[RP never parses or stores raw kratos_identity_id]`
|
||||
|
||||
func NewRPManifestHandler() *RPManifestHandler {
|
||||
return &RPManifestHandler{}
|
||||
}
|
||||
|
||||
func (h *RPManifestHandler) GetJSON(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
|
||||
return c.JSON(buildRPManifest(c))
|
||||
}
|
||||
|
||||
func (h *RPManifestHandler) GetSchema(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
|
||||
return c.JSON(rpManifestSchema())
|
||||
}
|
||||
|
||||
func (h *RPManifestHandler) GetHTML(c *fiber.Ctx) error {
|
||||
manifest := buildRPManifest(c)
|
||||
issuer, _ := manifest["issuer"].(string)
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age=300")
|
||||
c.Type("html", "utf-8")
|
||||
return c.SendString(`<!doctype html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Baron RP IAM Manifest</title>
|
||||
<style>
|
||||
body { font-family: system-ui, sans-serif; margin: 2rem; line-height: 1.6; max-width: 920px; }
|
||||
code, pre { background: #f5f5f5; border-radius: 4px; padding: .1rem .3rem; }
|
||||
pre { padding: 1rem; overflow: auto; }
|
||||
table { border-collapse: collapse; width: 100%; }
|
||||
th, td { border: 1px solid #ddd; padding: .5rem; text-align: left; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Baron RP IAM Manifest</h1>
|
||||
<p>외부 RP가 Baron SSO/Ory Stack/Keto 기반 공용 IAM을 연동하기 위한 공개 규격입니다.</p>
|
||||
<ul>
|
||||
<li>Machine-readable manifest: <a href="/.well-known/baron-rp-manifest.json">/.well-known/baron-rp-manifest.json</a></li>
|
||||
<li>JSON schema: <a href="/.well-known/baron-rp-manifest.schema.json">/.well-known/baron-rp-manifest.schema.json</a></li>
|
||||
</ul>
|
||||
<h2>Issuer</h2>
|
||||
<pre>` + html.EscapeString(issuer) + `</pre>
|
||||
<h2>Identity Contract</h2>
|
||||
<table>
|
||||
<tr><th>용도</th><th>Header</th><th>정책</th></tr>
|
||||
<tr><td>Keto subject</td><td><code>X-Baron-Subject</code></td><td><code>User:<baron_identity_id></code> 전체 문자열을 opaque subject로 취급합니다.</td></tr>
|
||||
<tr><td>RP upsert key</td><td><code>X-Baron-External-Key</code></td><td>Baron-issued alias입니다. RP가 만들거나 제출하지 않고, Baron이 주입한 전체 문자열을 local user external key로 저장합니다.</td></tr>
|
||||
<tr><td>RP client</td><td><code>X-Baron-Client-ID</code></td><td>현재 접근 중인 RP client id입니다.</td></tr>
|
||||
</table>
|
||||
<h2>External Key Flow</h2>
|
||||
<p><code>X-Baron-External-Key</code>는 RP 입력값이 아니라 Baron이 인증된 subject에서 발급/조회해 주입하는 opaque alias입니다. RP upserts local user from the Baron-issued alias.</p>
|
||||
<pre>` + "```mermaid\n" + html.EscapeString(rpExternalKeyMermaid) + "\n```" + `</pre>
|
||||
<h2>Object Lookup</h2>
|
||||
<pre>check(User:abc, viewers, RelyingParty:<client_id>)
|
||||
check(User:abc, members, Tenant:<tenant_id>)
|
||||
check(User:abc, viewers, Resource:<resource_type>:<resource_id>)</pre>
|
||||
<h2>audit_contract</h2>
|
||||
<p>권한과 설정을 변경하는 command는 sync audit write에 실패하면 요청도 실패해야 합니다. Read audit은 allowlist된 조회에 한해 best effort로 취급합니다.</p>
|
||||
<pre>{
|
||||
"mutating_command_mode": "fail_closed_sync",
|
||||
"missing_audit_sink_behavior": "reject_mutation",
|
||||
"correlation_header": "X-Request-Id"
|
||||
}</pre>
|
||||
<h2>Object Lookup Flow</h2>
|
||||
<pre>` + "```mermaid\n" + html.EscapeString(rpObjectLookupMermaid) + "\n```" + `</pre>
|
||||
</body>
|
||||
</html>`)
|
||||
}
|
||||
|
||||
func buildRPManifest(c *fiber.Ctx) map[string]any {
|
||||
issuer := resolvePublicRequestBaseURL(c, os.Getenv("BACKEND_PUBLIC_URL"))
|
||||
if issuer == "" {
|
||||
issuer = strings.TrimRight(os.Getenv("USERFRONT_URL"), "/")
|
||||
}
|
||||
if issuer == "" {
|
||||
issuer = "https://sso.hmac.kr"
|
||||
}
|
||||
issuer = strings.TrimRight(issuer, "/")
|
||||
|
||||
return map[string]any{
|
||||
"version": "2026-05-11",
|
||||
"issuer": issuer,
|
||||
"oidc": map[string]any{
|
||||
"discovery_url": issuer + "/.well-known/openid-configuration",
|
||||
"jwks_url": issuer + "/.well-known/jwks.json",
|
||||
"supported_flows": []string{"authorization_code_pkce"},
|
||||
"required_scopes": []string{"openid", "profile", "email"},
|
||||
},
|
||||
"iam": map[string]any{
|
||||
"authorization_engine": "ory-keto",
|
||||
"subject_format": "User:<baron_identity_id>",
|
||||
"target_object_patterns": []string{
|
||||
"RelyingParty:<client_id>",
|
||||
"Tenant:<tenant_id>",
|
||||
"Resource:<resource_type>:<resource_id>",
|
||||
},
|
||||
"supported_relations": []string{
|
||||
"admins",
|
||||
"users",
|
||||
"viewers",
|
||||
"operators",
|
||||
"members",
|
||||
"owners",
|
||||
"editors",
|
||||
},
|
||||
},
|
||||
"identity_contract": map[string]any{
|
||||
"subject_header": "X-Baron-Subject",
|
||||
"external_key_header": "X-Baron-External-Key",
|
||||
"external_key_is_opaque": true,
|
||||
"external_key_issuer": "baron",
|
||||
"external_key_delivery": "baron_injected_header",
|
||||
"external_key_lifecycle": "issued_or_loaded_after_successful_authentication_before_rp_request",
|
||||
"rp_supplied_external_key_allowed": false,
|
||||
"rp_user_upsert_source": "rp_must_upsert_from_header_value",
|
||||
"raw_kratos_identity_id_exposed": false,
|
||||
"rp_user_upsert_key": "provider + external_key",
|
||||
"email_is_stable_primary_key": false,
|
||||
"initial_external_key_expression": "X-Baron-External-Key",
|
||||
"fallback_to_subject_allowed": false,
|
||||
},
|
||||
"trusted_headers": map[string]any{
|
||||
"subject": "X-Baron-Subject",
|
||||
"external_key": "X-Baron-External-Key",
|
||||
"email": "X-Baron-Email",
|
||||
"tenant": "X-Baron-Tenant",
|
||||
"relations": "X-Baron-Relations",
|
||||
"client_id": "X-Baron-Client-ID",
|
||||
},
|
||||
"object_lookup": map[string]any{
|
||||
"rp_level": map[string]any{
|
||||
"object": "RelyingParty:<client_id>",
|
||||
"relations": []string{"viewers", "users", "operators", "admins"},
|
||||
"example": "check(User:abc, viewers, RelyingParty:mh-dashboard)",
|
||||
},
|
||||
"tenant_level": map[string]any{
|
||||
"object": "Tenant:<tenant_id>",
|
||||
"relations": []string{"members", "admins", "owners"},
|
||||
"example": "check(User:abc, members, Tenant:9caf62e1-297d-4e8f-870b-61780998bbe)",
|
||||
},
|
||||
"resource_level": map[string]any{
|
||||
"object": "Resource:<resource_type>:<resource_id>",
|
||||
"relations": []string{"viewers", "editors", "owners"},
|
||||
"example": "check(User:abc, viewers, Resource:dashboard:mh-monthly-2026-05)",
|
||||
},
|
||||
"recommended_order": []string{
|
||||
"authenticated",
|
||||
"rp_level",
|
||||
"tenant_or_resource_level",
|
||||
"trusted_header_injection",
|
||||
},
|
||||
},
|
||||
"object_lookup_flow": map[string]any{
|
||||
"format": "mermaid",
|
||||
"mermaid": rpObjectLookupMermaid,
|
||||
},
|
||||
"external_key_flow": map[string]any{
|
||||
"format": "mermaid",
|
||||
"mermaid": rpExternalKeyMermaid,
|
||||
},
|
||||
"audit_contract": map[string]any{
|
||||
"mutating_command_mode": "fail_closed_sync",
|
||||
"missing_audit_sink_behavior": "reject_mutation",
|
||||
"read_audit_mode": "best_effort_allowlisted",
|
||||
"correlation_header": "X-Request-Id",
|
||||
"rp_business_audit_required": true,
|
||||
"baron_gateway_audit_required": true,
|
||||
"required_detail_fields": []string{
|
||||
"obj_id",
|
||||
"relation",
|
||||
"client_id",
|
||||
"subject",
|
||||
"decision",
|
||||
},
|
||||
"guarantee_scope": "Baron-mediated IAM mutations fail closed on audit write failure; RP-owned business events must be emitted by the RP with the same correlation header.",
|
||||
},
|
||||
"security_requirements": map[string]any{
|
||||
"strip_external_identity_headers": true,
|
||||
"backend_direct_exposure_allowed": false,
|
||||
"static_snapshot_requires_auth": true,
|
||||
"email_as_primary_key_allowed": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func rpManifestSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"title": "Baron RP IAM Manifest",
|
||||
"type": "object",
|
||||
"required": []string{
|
||||
"version",
|
||||
"issuer",
|
||||
"oidc",
|
||||
"iam",
|
||||
"trusted_headers",
|
||||
"identity_contract",
|
||||
"object_lookup",
|
||||
"object_lookup_flow",
|
||||
"external_key_flow",
|
||||
"audit_contract",
|
||||
"security_requirements",
|
||||
},
|
||||
"properties": map[string]any{
|
||||
"version": map[string]any{"type": "string"},
|
||||
"issuer": map[string]any{"type": "string", "format": "uri"},
|
||||
"oidc": map[string]any{"type": "object"},
|
||||
"iam": map[string]any{"type": "object"},
|
||||
"trusted_headers": map[string]any{"type": "object"},
|
||||
"identity_contract": map[string]any{"type": "object"},
|
||||
"object_lookup": map[string]any{"type": "object"},
|
||||
"object_lookup_flow": map[string]any{"type": "object"},
|
||||
"external_key_flow": map[string]any{"type": "object"},
|
||||
"audit_contract": map[string]any{"type": "object"},
|
||||
"security_requirements": map[string]any{"type": "object"},
|
||||
},
|
||||
}
|
||||
}
|
||||
125
baron-sso/backend/internal/handler/rp_manifest_handler_test.go
Normal file
125
baron-sso/backend/internal/handler/rp_manifest_handler_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRPManifestJSONIncludesIAMAndExternalKeyContract(t *testing.T) {
|
||||
t.Setenv("BACKEND_PUBLIC_URL", "")
|
||||
|
||||
app := fiber.New()
|
||||
h := NewRPManifestHandler()
|
||||
app.Get("/.well-known/baron-rp-manifest.json", h.GetJSON)
|
||||
|
||||
req := httptest.NewRequest("GET", "/.well-known/baron-rp-manifest.json", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "sso.hmac.kr")
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Contains(t, resp.Header.Get("Content-Type"), "application/json")
|
||||
|
||||
var body map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
require.Equal(t, "https://sso.hmac.kr", body["issuer"])
|
||||
|
||||
oidc := body["oidc"].(map[string]any)
|
||||
require.Equal(t, "https://sso.hmac.kr/.well-known/openid-configuration", oidc["discovery_url"])
|
||||
require.Equal(t, "https://sso.hmac.kr/.well-known/jwks.json", oidc["jwks_url"])
|
||||
|
||||
iam := body["iam"].(map[string]any)
|
||||
require.Equal(t, "ory-keto", iam["authorization_engine"])
|
||||
require.Equal(t, "User:<baron_identity_id>", iam["subject_format"])
|
||||
require.Contains(t, iam["target_object_patterns"].([]any), "RelyingParty:<client_id>")
|
||||
require.Contains(t, iam["target_object_patterns"].([]any), "Tenant:<tenant_id>")
|
||||
require.Contains(t, iam["target_object_patterns"].([]any), "Resource:<resource_type>:<resource_id>")
|
||||
|
||||
identity := body["identity_contract"].(map[string]any)
|
||||
require.Equal(t, "X-Baron-External-Key", identity["external_key_header"])
|
||||
require.Equal(t, true, identity["external_key_is_opaque"])
|
||||
require.Equal(t, false, identity["raw_kratos_identity_id_exposed"])
|
||||
require.Equal(t, "baron", identity["external_key_issuer"])
|
||||
require.Equal(t, "baron_injected_header", identity["external_key_delivery"])
|
||||
require.Equal(t, false, identity["rp_supplied_external_key_allowed"])
|
||||
require.Equal(t, "rp_must_upsert_from_header_value", identity["rp_user_upsert_source"])
|
||||
|
||||
headers := body["trusted_headers"].(map[string]any)
|
||||
require.Equal(t, "X-Baron-Subject", headers["subject"])
|
||||
require.Equal(t, "X-Baron-External-Key", headers["external_key"])
|
||||
require.Equal(t, "X-Baron-Client-ID", headers["client_id"])
|
||||
|
||||
security := body["security_requirements"].(map[string]any)
|
||||
require.Equal(t, true, security["strip_external_identity_headers"])
|
||||
require.Equal(t, false, security["backend_direct_exposure_allowed"])
|
||||
|
||||
audit := body["audit_contract"].(map[string]any)
|
||||
require.Equal(t, "fail_closed_sync", audit["mutating_command_mode"])
|
||||
require.Equal(t, "reject_mutation", audit["missing_audit_sink_behavior"])
|
||||
require.Equal(t, "X-Request-Id", audit["correlation_header"])
|
||||
require.Contains(t, audit["required_detail_fields"].([]any), "obj_id")
|
||||
require.Contains(t, audit["required_detail_fields"].([]any), "client_id")
|
||||
|
||||
flow := body["object_lookup_flow"].(map[string]any)
|
||||
require.Contains(t, flow["mermaid"].(string), "flowchart TD")
|
||||
require.Contains(t, flow["mermaid"].(string), "obj_id")
|
||||
|
||||
aliasFlow := body["external_key_flow"].(map[string]any)
|
||||
require.Contains(t, aliasFlow["mermaid"].(string), "Baron resolves internal identity")
|
||||
require.Contains(t, aliasFlow["mermaid"].(string), "Baron injects X-Baron-External-Key")
|
||||
require.Contains(t, aliasFlow["mermaid"].(string), "RP upserts local user")
|
||||
require.NotContains(t, aliasFlow["mermaid"].(string), "RP creates external key")
|
||||
}
|
||||
|
||||
func TestRPManifestSchemaRequiresLookupAndIdentityContracts(t *testing.T) {
|
||||
app := fiber.New()
|
||||
h := NewRPManifestHandler()
|
||||
app.Get("/.well-known/baron-rp-manifest.schema.json", h.GetSchema)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/.well-known/baron-rp-manifest.schema.json", nil))
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
var body map[string]any
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
|
||||
required := body["required"].([]any)
|
||||
require.Contains(t, required, "iam")
|
||||
require.Contains(t, required, "trusted_headers")
|
||||
require.Contains(t, required, "identity_contract")
|
||||
require.Contains(t, required, "object_lookup")
|
||||
require.Contains(t, required, "audit_contract")
|
||||
require.Contains(t, required, "object_lookup_flow")
|
||||
require.Contains(t, required, "external_key_flow")
|
||||
}
|
||||
|
||||
func TestRPManifestHTMLLinksMachineReadableManifest(t *testing.T) {
|
||||
app := fiber.New()
|
||||
h := NewRPManifestHandler()
|
||||
app.Get("/.well-known/baron-rp-manifest", h.GetHTML)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/.well-known/baron-rp-manifest", nil))
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Contains(t, resp.Header.Get("Content-Type"), "text/html")
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
text := string(raw)
|
||||
require.Contains(t, text, "/.well-known/baron-rp-manifest.json")
|
||||
require.Contains(t, text, "X-Baron-External-Key")
|
||||
require.Contains(t, text, "RelyingParty:<client_id>")
|
||||
require.Contains(t, text, "```mermaid")
|
||||
require.Contains(t, text, "audit_contract")
|
||||
require.Contains(t, text, "Baron-issued alias")
|
||||
require.Contains(t, text, "RP upserts local user")
|
||||
}
|
||||
154
baron-sso/backend/internal/handler/tenant_access_cleanup.go
Normal file
154
baron-sso/backend/internal/handler/tenant_access_cleanup.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const tenantAccessCleanupClientPageSize = 500
|
||||
|
||||
func cleanupDeletedTenantReferences(ctx context.Context, hydra *service.HydraAdminService, consentRepo repository.ClientConsentRepository, ketoOutbox repository.KetoOutboxRepository, deletedTenantIDs []string) error {
|
||||
if hydra == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
deletedTenantSet := make(map[string]struct{}, len(deletedTenantIDs))
|
||||
for _, tenantID := range deletedTenantIDs {
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
continue
|
||||
}
|
||||
deletedTenantSet[tenantID] = struct{}{}
|
||||
}
|
||||
if len(deletedTenantSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for offset := 0; ; offset += tenantAccessCleanupClientPageSize {
|
||||
clients, err := hydra.ListClients(ctx, tenantAccessCleanupClientPageSize, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list hydra clients for tenant cleanup: %w", err)
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
beforeMetadata := maps.Clone(client.Metadata)
|
||||
updatedMetadata, changed, removedOwnerTenantID := pruneDeletedTenantReferences(beforeMetadata, deletedTenantSet)
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
|
||||
updatedClient := client
|
||||
updatedClient.Metadata = updatedMetadata
|
||||
if _, err := hydra.UpdateClient(ctx, client.ClientID, updatedClient); err != nil {
|
||||
return fmt.Errorf("failed to update hydra client %s during tenant cleanup: %w", client.ClientID, err)
|
||||
}
|
||||
if removedOwnerTenantID != "" {
|
||||
if err := enqueueDeletedTenantRelyingPartyParentCleanup(ctx, ketoOutbox, client.ClientID, removedOwnerTenantID); err != nil {
|
||||
return fmt.Errorf("failed to cleanup RP parent relation for client %s during tenant cleanup: %w", client.ClientID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if tenantAccessPolicyChanged(beforeMetadata, updatedMetadata) {
|
||||
if err := revokeClientConsentsForPolicyChange(ctx, hydra, consentRepo, client.ClientID); err != nil {
|
||||
return fmt.Errorf("failed to revoke consent sessions for client %s during tenant cleanup: %w", client.ClientID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(clients) < tenantAccessCleanupClientPageSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func pruneDeletedTenantReferences(metadata map[string]any, deletedTenantSet map[string]struct{}) (map[string]any, bool, string) {
|
||||
if len(deletedTenantSet) == 0 {
|
||||
return metadata, false, ""
|
||||
}
|
||||
|
||||
ownerTenantID := normalizeMetadataString(metadata["tenant_id"])
|
||||
_, ownerDeleted := deletedTenantSet[ownerTenantID]
|
||||
|
||||
allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])
|
||||
filtered := make([]string, 0, len(allowedTenants))
|
||||
for _, tenantID := range allowedTenants {
|
||||
if _, ok := deletedTenantSet[tenantID]; ok {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, tenantID)
|
||||
}
|
||||
allowedChanged := len(filtered) != len(allowedTenants)
|
||||
|
||||
if !ownerDeleted && !allowedChanged {
|
||||
return metadata, false, ""
|
||||
}
|
||||
|
||||
updated := maps.Clone(metadata)
|
||||
if ownerDeleted {
|
||||
delete(updated, "tenant_id")
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
delete(updated, clientAllowedTenantsKey)
|
||||
updated[clientTenantAccessRestrictedKey] = false
|
||||
return updated, true, ownerTenantID
|
||||
}
|
||||
|
||||
updated[clientAllowedTenantsKey] = uniqueSortedStrings(filtered)
|
||||
updated[clientTenantAccessRestrictedKey] = true
|
||||
return updated, true, ownerTenantID
|
||||
}
|
||||
|
||||
func enqueueDeletedTenantRelyingPartyParentCleanup(ctx context.Context, ketoOutbox repository.KetoOutboxRepository, clientID, tenantID string) error {
|
||||
if ketoOutbox == nil {
|
||||
return nil
|
||||
}
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if clientID == "" || tenantID == "" {
|
||||
return nil
|
||||
}
|
||||
return ketoOutbox.Create(ctx, &domain.KetoOutbox{
|
||||
Namespace: "RelyingParty",
|
||||
Object: clientID,
|
||||
Relation: "parents",
|
||||
Subject: "Tenant:" + tenantID,
|
||||
Action: domain.KetoOutboxActionDelete,
|
||||
})
|
||||
}
|
||||
|
||||
func revokeClientConsentsForPolicyChange(ctx context.Context, hydra *service.HydraAdminService, consentRepo repository.ClientConsentRepository, clientID string) error {
|
||||
if consentRepo == nil || hydra == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
subjects, err := consentRepo.ListSubjectsByClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, subject := range subjects {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
continue
|
||||
}
|
||||
if err := hydra.RevokeConsentSessions(ctx, subject, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return consentRepo.DeleteByClient(ctx, clientID)
|
||||
}
|
||||
|
||||
func logTenantCleanupFailure(err error, deletedTenantIDs []string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
slog.Error("Failed to cleanup RP tenant restrictions after tenant deletion", "tenant_ids", deletedTenantIDs, "error", err)
|
||||
}
|
||||
178
baron-sso/backend/internal/handler/tenant_access_cleanup_test.go
Normal file
178
baron-sso/backend/internal/handler/tenant_access_cleanup_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestPruneDeletedTenantReferences_PreservesOtherAllowedTenants(t *testing.T) {
|
||||
metadata := map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []string{"keep-tenant", "deleted-tenant"},
|
||||
"tenant_id": "deleted-tenant",
|
||||
}
|
||||
|
||||
updated, changed, removedOwnerTenantID := pruneDeletedTenantReferences(metadata, map[string]struct{}{
|
||||
"deleted-tenant": {},
|
||||
})
|
||||
|
||||
require.True(t, changed)
|
||||
assert.Equal(t, "deleted-tenant", removedOwnerTenantID)
|
||||
assert.Equal(t, true, updated[clientTenantAccessRestrictedKey])
|
||||
assert.Equal(t, []string{"keep-tenant"}, updated[clientAllowedTenantsKey])
|
||||
_, exists := updated["tenant_id"]
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestPruneDeletedTenantReferences_DisablesRestrictionWhenLastTenantRemoved(t *testing.T) {
|
||||
metadata := map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []string{"deleted-tenant"},
|
||||
"tenant_id": "deleted-tenant",
|
||||
}
|
||||
|
||||
updated, changed, removedOwnerTenantID := pruneDeletedTenantReferences(metadata, map[string]struct{}{
|
||||
"deleted-tenant": {},
|
||||
})
|
||||
|
||||
require.True(t, changed)
|
||||
assert.Equal(t, "deleted-tenant", removedOwnerTenantID)
|
||||
assert.Equal(t, false, updated[clientTenantAccessRestrictedKey])
|
||||
_, exists := updated[clientAllowedTenantsKey]
|
||||
assert.False(t, exists)
|
||||
_, exists = updated["tenant_id"]
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestCleanupDeletedTenantReferences_PrunesClientsAndRevokesConsents(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
page0Called bool
|
||||
updated = map[string]map[string]any{}
|
||||
revokes []string
|
||||
)
|
||||
|
||||
transport := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
switch {
|
||||
case req.Method == http.MethodGet && req.URL.Path == "/clients":
|
||||
switch req.URL.Query().Get("offset") {
|
||||
case "":
|
||||
page0Called = true
|
||||
return httpJSONAny(req, http.StatusOK, []domain.HydraClient{
|
||||
{
|
||||
ClientID: "client-keep",
|
||||
Metadata: map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []string{"keep-tenant", "deleted-tenant"},
|
||||
"tenant_id": "deleted-tenant",
|
||||
},
|
||||
},
|
||||
{
|
||||
ClientID: "client-drop",
|
||||
Metadata: map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []string{"deleted-tenant"},
|
||||
"tenant_id": "deleted-tenant",
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
default:
|
||||
return httpResponse(req, http.StatusBadRequest, "unexpected offset"), nil
|
||||
}
|
||||
|
||||
case req.Method == http.MethodPut && strings.HasPrefix(req.URL.Path, "/clients/"):
|
||||
var client domain.HydraClient
|
||||
require.NoError(t, json.NewDecoder(req.Body).Decode(&client))
|
||||
updated[client.ClientID] = client.Metadata
|
||||
return httpJSONAny(req, http.StatusOK, client), nil
|
||||
|
||||
case req.Method == http.MethodDelete && req.URL.Path == "/oauth2/auth/sessions/consent":
|
||||
revokes = append(revokes, req.URL.Query().Get("subject")+"|"+req.URL.Query().Get("client"))
|
||||
return httpResponse(req, http.StatusNoContent, ""), nil
|
||||
|
||||
default:
|
||||
return httpResponse(req, http.StatusNotFound, "unexpected request"), nil
|
||||
}
|
||||
})
|
||||
|
||||
hydra := &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
}
|
||||
consentRepo := &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{ClientID: "client-keep", Subject: "user-a"},
|
||||
{ClientID: "client-drop", Subject: "user-b"},
|
||||
},
|
||||
}
|
||||
outbox := &tenantCleanupMockKetoOutboxRepository{}
|
||||
|
||||
err := cleanupDeletedTenantReferences(context.Background(), hydra, consentRepo, outbox, []string{"deleted-tenant"})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, page0Called)
|
||||
assert.Equal(t, map[string]any{
|
||||
clientTenantAccessRestrictedKey: true,
|
||||
clientAllowedTenantsKey: []any{"keep-tenant"},
|
||||
}, updated["client-keep"])
|
||||
assert.Equal(t, map[string]any{
|
||||
clientTenantAccessRestrictedKey: false,
|
||||
}, updated["client-drop"])
|
||||
assert.ElementsMatch(t, []string{"user-a|client-keep", "user-b|client-drop"}, revokes)
|
||||
assert.Empty(t, consentRepo.consents)
|
||||
require.Len(t, outbox.entries, 2)
|
||||
assert.ElementsMatch(t, []string{"client-keep", "client-drop"}, []string{outbox.entries[0].Object, outbox.entries[1].Object})
|
||||
for _, entry := range outbox.entries {
|
||||
assert.Equal(t, "RelyingParty", entry.Namespace)
|
||||
assert.Equal(t, "parents", entry.Relation)
|
||||
assert.Equal(t, "Tenant:deleted-tenant", entry.Subject)
|
||||
assert.Equal(t, domain.KetoOutboxActionDelete, entry.Action)
|
||||
}
|
||||
}
|
||||
|
||||
type tenantCleanupMockKetoOutboxRepository struct {
|
||||
entries []domain.KetoOutbox
|
||||
}
|
||||
|
||||
var _ repository.KetoOutboxRepository = (*tenantCleanupMockKetoOutboxRepository)(nil)
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) Create(ctx context.Context, entry *domain.KetoOutbox) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
m.entries = append(m.entries, *entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) CreateWithTx(tx *gorm.DB, entry *domain.KetoOutbox) error {
|
||||
return m.Create(context.Background(), entry)
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) FindPending(ctx context.Context, limit int) ([]domain.KetoOutbox, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) ListCurrentBySubject(ctx context.Context, namespace, subject string) ([]domain.KetoOutbox, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) UpdateStatus(ctx context.Context, id string, status string, retryCount int, lastError string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *tenantCleanupMockKetoOutboxRepository) MarkProcessed(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
155
baron-sso/backend/internal/handler/tenant_assignment_policy.go
Normal file
155
baron-sso/backend/internal/handler/tenant_assignment_policy.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func representativeTenantIDFromTraits(traits map[string]any) string {
|
||||
if value := tenantClaimString(traits, "tenant_id"); value != "" {
|
||||
return value
|
||||
}
|
||||
if value := tenantClaimString(traits, "primaryTenantId"); value != "" {
|
||||
return value
|
||||
}
|
||||
if metadata, ok := traits["metadata"].(map[string]any); ok {
|
||||
if value := tenantClaimString(metadata, "primaryTenantId"); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
appointments := tenantAssignmentAppointmentsFromTraits(traits)
|
||||
for _, appointment := range appointments {
|
||||
if tenantAssignmentBool(appointment, "isPrimary", "primary", "representative", "isRepresentative") {
|
||||
if value := tenantAssignmentTenantID(appointment); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, appointment := range appointments {
|
||||
if value := tenantAssignmentTenantID(appointment); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
for _, tenantID := range tenantNamespaceIDsFromTraits(traits) {
|
||||
return tenantID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func joinedTenantIDsFromTraits(traits map[string]any, representativeTenantID string) []string {
|
||||
values := make([]string, 0)
|
||||
if representativeTenantID != "" {
|
||||
values = append(values, representativeTenantID)
|
||||
}
|
||||
if value := tenantClaimString(traits, "tenant_id"); value != "" {
|
||||
values = append(values, value)
|
||||
}
|
||||
for _, appointment := range tenantAssignmentAppointmentsFromTraits(traits) {
|
||||
if value := tenantAssignmentTenantID(appointment); value != "" {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
values = append(values, tenantNamespaceIDsFromTraits(traits)...)
|
||||
return uniqueSortedStrings(values)
|
||||
}
|
||||
|
||||
func tenantAssignmentAppointmentsFromTraits(traits map[string]any) []map[string]any {
|
||||
raw := rawAdditionalAppointments(traits)
|
||||
switch values := raw.(type) {
|
||||
case []any:
|
||||
appointments := make([]map[string]any, 0, len(values))
|
||||
for _, item := range values {
|
||||
if appointment, ok := item.(map[string]any); ok {
|
||||
appointments = append(appointments, appointment)
|
||||
}
|
||||
}
|
||||
return appointments
|
||||
case []map[string]any:
|
||||
return values
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func tenantAssignmentTenantID(appointment map[string]any) string {
|
||||
for _, key := range []string{"tenantId", "tenant_id"} {
|
||||
if value := tenantClaimString(appointment, key); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func tenantAssignmentBool(values map[string]any, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
raw, ok := values[key]
|
||||
if !ok || raw == nil {
|
||||
continue
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case bool:
|
||||
if value {
|
||||
return true
|
||||
}
|
||||
case string:
|
||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||
if normalized == "true" || normalized == "1" || normalized == "yes" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tenantNamespaceIDsFromTraits(traits map[string]any) []string {
|
||||
if traits == nil {
|
||||
return nil
|
||||
}
|
||||
ids := make([]string, 0)
|
||||
for key, value := range traits {
|
||||
if key == "" || key == "metadata" {
|
||||
continue
|
||||
}
|
||||
switch value.(type) {
|
||||
case map[string]any:
|
||||
ids = append(ids, key)
|
||||
}
|
||||
}
|
||||
return uniqueSortedStrings(ids)
|
||||
}
|
||||
|
||||
func createPersonalTenantForUser(ctx context.Context, tenantService service.TenantService, email string) (*domain.Tenant, error) {
|
||||
if tenantService == nil {
|
||||
return nil, errors.New("tenant service unavailable")
|
||||
}
|
||||
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalizedEmail == "" {
|
||||
normalizedEmail = "user"
|
||||
}
|
||||
slug := "personal-" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
tenant, err := tenantService.RegisterTenant(
|
||||
ctx,
|
||||
fmt.Sprintf("Personal - %s", normalizedEmail),
|
||||
slug,
|
||||
domain.TenantTypePersonal,
|
||||
"Automatically provisioned personal tenant",
|
||||
nil,
|
||||
nil,
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tenant == nil {
|
||||
return nil, errors.New("personal tenant not created")
|
||||
}
|
||||
return tenant, nil
|
||||
}
|
||||
3154
baron-sso/backend/internal/handler/tenant_handler.go
Normal file
3154
baron-sso/backend/internal/handler/tenant_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,159 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/testsupport"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
postgres_module "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
gorm_postgres "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func newTenantHandlerSeedDeleteDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
if !testsupport.DockerAvailable() {
|
||||
t.Skip("Docker provider is unavailable in this environment")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
postgresContainer, err := postgres_module.Run(ctx,
|
||||
"postgres:16-alpine",
|
||||
postgres_module.WithDatabase("testdb"),
|
||||
postgres_module.WithUsername("user"),
|
||||
postgres_module.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(30*time.Second)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start postgres container: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := postgresContainer.Terminate(ctx); err != nil {
|
||||
log.Printf("failed to terminate postgres container: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get postgres connection string: %v", err)
|
||||
}
|
||||
db, err := gorm.Open(gorm_postgres.Open(connStr), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open postgres connection: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&domain.Tenant{}); err != nil {
|
||||
t.Fatalf("failed to migrate tenants: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func setSeedTenantCSVForDeleteGuard(t *testing.T, slug string) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "seed-tenant.csv")
|
||||
csv := "name,type,parent_tenant_slug,slug,memo,email_domain\n" +
|
||||
"Protected,COMPANY_GROUP,," + slug + ",Protected seed,\n"
|
||||
if err := os.WriteFile(path, []byte(csv), 0o600); err != nil {
|
||||
t.Fatalf("failed to write seed csv: %v", err)
|
||||
}
|
||||
t.Setenv("SEED_TENANT_CSV_PATH", path)
|
||||
}
|
||||
|
||||
func TestTenantHandlerDeleteTenantRejectsSeedTenant(t *testing.T) {
|
||||
setSeedTenantCSVForDeleteGuard(t, "protected-root")
|
||||
db := newTenantHandlerSeedDeleteDB(t)
|
||||
tenant := domain.Tenant{
|
||||
ID: "00000000-0000-0000-0000-000000000001",
|
||||
Name: "Protected",
|
||||
Slug: "protected-root",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
if err := db.Create(&tenant).Error; err != nil {
|
||||
t.Fatalf("failed to create tenant: %v", err)
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Delete("/tenants/:id", (&TenantHandler{DB: db}).DeleteTenant)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/tenants/"+tenant.ID, nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusConflict {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
|
||||
}
|
||||
var count int64
|
||||
if err := db.Model(&domain.Tenant{}).Where("id = ?", tenant.ID).Count(&count).Error; err != nil {
|
||||
t.Fatalf("count tenant: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("seed tenant count = %d, want 1", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantHandlerDeleteTenantsBulkRejectsSeedTenant(t *testing.T) {
|
||||
setSeedTenantCSVForDeleteGuard(t, "protected-root")
|
||||
db := newTenantHandlerSeedDeleteDB(t)
|
||||
seed := domain.Tenant{
|
||||
ID: "00000000-0000-0000-0000-000000000011",
|
||||
Name: "Protected",
|
||||
Slug: "protected-root",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
normal := domain.Tenant{
|
||||
ID: "00000000-0000-0000-0000-000000000012",
|
||||
Name: "Normal",
|
||||
Slug: "normal",
|
||||
Type: domain.TenantTypeCompany,
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
if err := db.Create(&seed).Error; err != nil {
|
||||
t.Fatalf("failed to create seed tenant: %v", err)
|
||||
}
|
||||
if err := db.Create(&normal).Error; err != nil {
|
||||
t.Fatalf("failed to create normal tenant: %v", err)
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Delete("/tenants/bulk", (&TenantHandler{DB: db}).DeleteTenantsBulk)
|
||||
body, _ := json.Marshal(map[string][]string{"ids": {seed.ID, normal.ID}})
|
||||
req := httptest.NewRequest(http.MethodDelete, "/tenants/bulk", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusConflict {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusConflict)
|
||||
}
|
||||
var count int64
|
||||
if err := db.Model(&domain.Tenant{}).Where("id IN ?", []string{seed.ID, normal.ID}).Count(&count).Error; err != nil {
|
||||
t.Fatalf("count tenants: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatalf("remaining tenant count = %d, want 2", count)
|
||||
}
|
||||
}
|
||||
1732
baron-sso/backend/internal/handler/tenant_handler_test.go
Normal file
1732
baron-sso/backend/internal/handler/tenant_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
133
baron-sso/backend/internal/handler/user_group_handler.go
Normal file
133
baron-sso/backend/internal/handler/user_group_handler.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type UserGroupHandler struct {
|
||||
Service service.UserGroupService
|
||||
}
|
||||
|
||||
func NewUserGroupHandler(s service.UserGroupService) *UserGroupHandler {
|
||||
return &UserGroupHandler{Service: s}
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) List(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
groups, err := h.Service.List(c.Context(), tenantID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(groups)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) Create(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
var req domain.GroupCreateRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid body"})
|
||||
}
|
||||
|
||||
group, err := h.Service.Create(c.Context(), tenantID, req.ParentID, req.Name, req.Description, req.UnitType)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.Status(fiber.StatusCreated).JSON(group)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) Get(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
group, err := h.Service.Get(c.Context(), id)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to get group: "+err.Error())
|
||||
}
|
||||
return c.JSON(group)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) Update(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
groupID := c.Params("id")
|
||||
var req domain.GroupCreateRequest // Using create request for update fields
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid body"})
|
||||
}
|
||||
|
||||
group, err := h.Service.Update(c.Context(), tenantID, groupID, req.Name, req.Description, req.UnitType, req.ParentID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(group)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) Delete(c *fiber.Ctx) error {
|
||||
tenantID := c.Params("tenantId")
|
||||
groupID := c.Params("id")
|
||||
if err := h.Service.Delete(c.Context(), tenantID, groupID); err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) AddMember(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
var req struct {
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "userId is required")
|
||||
}
|
||||
|
||||
if err := h.Service.AddMember(c.Context(), groupID, req.UserID); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) RemoveMember(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
userID := c.Params("userId")
|
||||
|
||||
if err := h.Service.RemoveMember(c.Context(), groupID, userID); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) AssignRole(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
var req struct {
|
||||
TenantID string `json:"tenantId"`
|
||||
Relation string `json:"relation"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid body")
|
||||
}
|
||||
|
||||
if err := h.Service.AssignRoleToTenant(c.Context(), groupID, req.TenantID, req.Relation); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) ListRoles(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
roles, err := h.Service.ListRoles(c.Context(), groupID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(roles)
|
||||
}
|
||||
|
||||
func (h *UserGroupHandler) RemoveRole(c *fiber.Ctx) error {
|
||||
groupID := c.Params("id")
|
||||
tenantID := c.Params("tenantId")
|
||||
relation := c.Params("relation")
|
||||
|
||||
if err := h.Service.RemoveRoleFromTenant(c.Context(), groupID, tenantID, relation); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
155
baron-sso/backend/internal/handler/user_group_handler_test.go
Normal file
155
baron-sso/backend/internal/handler/user_group_handler_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// --- Mocks ---
|
||||
|
||||
type MockUserGroupService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) Create(ctx context.Context, tenantID string, parentID *string, name, description, unitType string) (*domain.UserGroup, error) {
|
||||
args := m.Called(ctx, tenantID, parentID, name, description, unitType)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) Update(ctx context.Context, tenantID, groupID string, name, description, unitType string, parentID *string) (*domain.UserGroup, error) {
|
||||
args := m.Called(ctx, tenantID, groupID, name, description, unitType, parentID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) Delete(ctx context.Context, tenantID, groupID string) error {
|
||||
return m.Called(ctx, tenantID, groupID).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) Get(ctx context.Context, id string) (*domain.UserGroup, error) {
|
||||
args := m.Called(ctx, id)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) List(ctx context.Context, tenantID string) ([]domain.UserGroup, error) {
|
||||
args := m.Called(ctx, tenantID)
|
||||
return args.Get(0).([]domain.UserGroup), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) SetWorksmobileSyncer(syncer service.WorksmobileSyncer) {}
|
||||
|
||||
func (m *MockUserGroupService) AddMember(ctx context.Context, groupID, userID string) error {
|
||||
return m.Called(ctx, groupID, userID).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) RemoveMember(ctx context.Context, groupID, userID string) error {
|
||||
return m.Called(ctx, groupID, userID).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) ListRoles(ctx context.Context, groupID string) ([]domain.GroupRole, error) {
|
||||
args := m.Called(ctx, groupID)
|
||||
return args.Get(0).([]domain.GroupRole), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) AssignRoleToTenant(ctx context.Context, groupID, tenantID, relation string) error {
|
||||
return m.Called(ctx, groupID, tenantID, relation).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockUserGroupService) RemoveRoleFromTenant(ctx context.Context, groupID, tenantID, relation string) error {
|
||||
return m.Called(ctx, groupID, tenantID, relation).Error(0)
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestUserGroupHandler_List(t *testing.T) {
|
||||
mockSvc := new(MockUserGroupService)
|
||||
h := NewUserGroupHandler(mockSvc)
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/user-groups", h.List)
|
||||
|
||||
tenantID := "t1"
|
||||
groups := []domain.UserGroup{{ID: "g1", Name: "Group 1"}}
|
||||
mockSvc.On("List", mock.Anything, tenantID).Return(groups, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/tenants/t1/user-groups", nil)
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var result []domain.UserGroup
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "Group 1", result[0].Name)
|
||||
}
|
||||
|
||||
func TestUserGroupHandler_Create(t *testing.T) {
|
||||
mockSvc := new(MockUserGroupService)
|
||||
h := NewUserGroupHandler(mockSvc)
|
||||
app := fiber.New()
|
||||
app.Post("/tenants/:tenantId/user-groups", h.Create)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "New Group"})
|
||||
mockSvc.On("Create", mock.Anything, "t1", mock.Anything, "New Group", mock.Anything, mock.Anything).Return(&domain.UserGroup{ID: "g1", Name: "New Group"}, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/tenants/t1/user-groups", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestUserGroupHandler_AddMember(t *testing.T) {
|
||||
mockSvc := new(MockUserGroupService)
|
||||
h := NewUserGroupHandler(mockSvc)
|
||||
app := fiber.New()
|
||||
app.Post("/user-groups/:id/members", h.AddMember)
|
||||
|
||||
groupID := "g1"
|
||||
userID := "u1"
|
||||
body, _ := json.Marshal(map[string]string{"userId": userID})
|
||||
|
||||
mockSvc.On("AddMember", mock.Anything, groupID, userID).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/user-groups/g1/members", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestUserGroupHandler_AssignRole(t *testing.T) {
|
||||
mockSvc := new(MockUserGroupService)
|
||||
h := NewUserGroupHandler(mockSvc)
|
||||
app := fiber.New()
|
||||
app.Post("/user-groups/:id/roles", h.AssignRole)
|
||||
|
||||
groupID := "g1"
|
||||
targetTenantID := "t2"
|
||||
relation := "manage"
|
||||
body, _ := json.Marshal(map[string]string{"tenantId": targetTenantID, "relation": relation})
|
||||
|
||||
mockSvc.On("AssignRoleToTenant", mock.Anything, groupID, targetTenantID, relation).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/user-groups/g1/roles", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
3702
baron-sso/backend/internal/handler/user_handler.go
Normal file
3702
baron-sso/backend/internal/handler/user_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
2707
baron-sso/backend/internal/handler/user_handler_test.go
Normal file
2707
baron-sso/backend/internal/handler/user_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
68
baron-sso/backend/internal/handler/user_handler_uuid_test.go
Normal file
68
baron-sso/backend/internal/handler/user_handler_uuid_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestUserHandler_BulkCreateUsers_UUIDImportPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
}{
|
||||
{name: "id 필드 차단", field: "id"},
|
||||
{name: "uuid 필드 차단", field: "uuid"},
|
||||
{name: "userId 필드 차단", field: "userId"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockKratos := new(MockKratosAdmin)
|
||||
mockOry := new(MockOryProvider)
|
||||
|
||||
h := &UserHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
OryProvider: mockOry,
|
||||
}
|
||||
app.Post("/users/bulk", h.BulkCreateUsers)
|
||||
|
||||
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
|
||||
|
||||
payload := map[string]any{
|
||||
"users": []map[string]any{
|
||||
{
|
||||
"email": "uuid-import@test.com",
|
||||
"name": "UUID Import User",
|
||||
tt.field: "550e8400-e29b-41d4-a716-446655440000",
|
||||
"metadata": map[string]any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest("POST", "/users/bulk", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var result struct {
|
||||
Results []bulkUserResult `json:"results"`
|
||||
}
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
|
||||
assert.Len(t, result.Results, 1)
|
||||
assert.False(t, result.Results[0].Success)
|
||||
assert.Contains(t, result.Results[0].Message, "사용자 UUID 가져오기는 지원하지 않습니다")
|
||||
|
||||
mockOry.AssertExpectations(t)
|
||||
mockKratos.AssertNotCalled(t, "FindIdentityIDByIdentifier", mock.Anything, mock.Anything)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/repository"
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func markUserProjectionFailed(ctx context.Context, repo repository.UserProjectionRepository, syncErr error) {
|
||||
if repo == nil || syncErr == nil {
|
||||
return
|
||||
}
|
||||
if err := repo.MarkFailed(ctx, syncErr); err != nil {
|
||||
slog.Error("Failed to mark user projection as failed", "syncError", syncErr, "error", err)
|
||||
}
|
||||
}
|
||||
217
baron-sso/backend/internal/handler/worksmobile_handler.go
Normal file
217
baron-sso/backend/internal/handler/worksmobile_handler.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type WorksmobileHandler struct {
|
||||
Service service.WorksmobileAdminService
|
||||
}
|
||||
|
||||
func NewWorksmobileHandler(svc service.WorksmobileAdminService) *WorksmobileHandler {
|
||||
return &WorksmobileHandler{Service: svc}
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) GetOverview(c *fiber.Ctx) error {
|
||||
overview, err := h.Service.GetTenantOverview(c.Context(), strings.TrimSpace(c.Params("tenantId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "get_overview")
|
||||
}
|
||||
if !worksmobileOverviewAllowed(overview) {
|
||||
return errorJSON(c, fiber.StatusNotFound, "worksmobile is only available for hanmac-family root tenant")
|
||||
}
|
||||
return c.JSON(overview)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) GetComparison(c *fiber.Ctx) error {
|
||||
includeMatched := strings.EqualFold(strings.TrimSpace(c.Query("includeMatched")), "true")
|
||||
comparison, err := h.Service.GetComparison(c.Context(), strings.TrimSpace(c.Params("tenantId")), includeMatched)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "get_comparison")
|
||||
}
|
||||
return c.JSON(comparison)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) OAuthCallback(c *fiber.Ctx) error {
|
||||
return c.Type("html").SendString("<!doctype html><html><body>Worksmobile OAuth callback reachable</body></html>")
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) BackfillDryRun(c *fiber.Ctx) error {
|
||||
result, err := h.Service.EnqueueBackfillDryRun(c.Context(), strings.TrimSpace(c.Params("tenantId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "backfill_dry_run")
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) SyncOrgUnit(c *fiber.Ctx) error {
|
||||
orgUnitID := strings.TrimSpace(c.Params("orgUnitId"))
|
||||
job, err := h.Service.EnqueueOrgUnitSync(c.Context(), strings.TrimSpace(c.Params("tenantId")), orgUnitID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "sync_orgunit", "org_unit_id", orgUnitID)
|
||||
}
|
||||
return c.Status(fiber.StatusAccepted).JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) DeleteOrgUnit(c *fiber.Ctx) error {
|
||||
orgUnitID := strings.TrimSpace(c.Params("orgUnitId"))
|
||||
job, err := h.Service.EnqueueOrgUnitDelete(c.Context(), strings.TrimSpace(c.Params("tenantId")), orgUnitID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "delete_orgunit", "org_unit_id", orgUnitID)
|
||||
}
|
||||
return c.Status(fiber.StatusAccepted).JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) SyncUser(c *fiber.Ctx) error {
|
||||
userID := strings.TrimSpace(c.Params("userId"))
|
||||
credentialRequest, err := parseWorksmobileCredentialRequest(c)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
job, err := h.Service.EnqueueUserSync(
|
||||
c.Context(),
|
||||
strings.TrimSpace(c.Params("tenantId")),
|
||||
userID,
|
||||
credentialRequest.CredentialBatchID,
|
||||
credentialRequest.InitialPassword,
|
||||
)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "sync_user", "user_id", userID)
|
||||
}
|
||||
return c.Status(fiber.StatusAccepted).JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) ResetUserPassword(c *fiber.Ctx) error {
|
||||
userID := strings.TrimSpace(c.Params("userId"))
|
||||
credentialBatchID, err := parseWorksmobileCredentialBatchID(c)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
job, err := h.Service.EnqueueUserPasswordReset(c.Context(), strings.TrimSpace(c.Params("tenantId")), userID, credentialBatchID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "reset_user_password", "user_id", userID)
|
||||
}
|
||||
return c.Status(fiber.StatusAccepted).JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) RetryJob(c *fiber.Ctx) error {
|
||||
jobID := strings.TrimSpace(c.Params("jobId"))
|
||||
job, err := h.Service.RetryJob(c.Context(), strings.TrimSpace(c.Params("tenantId")), jobID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "retry_job", "job_id", jobID)
|
||||
}
|
||||
return c.JSON(job)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) DeletePendingJobs(c *fiber.Ctx) error {
|
||||
result, err := h.Service.DeletePendingJobs(c.Context(), strings.TrimSpace(c.Params("tenantId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "delete_pending_jobs")
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) DownloadInitialPasswordsCSV(c *fiber.Ctx) error {
|
||||
credentials, err := h.Service.ListInitialPasswordCredentials(c.Context(), strings.TrimSpace(c.Params("tenantId")), strings.TrimSpace(c.Query("batchId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "download_initial_passwords")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := csv.NewWriter(&buf)
|
||||
if err := writer.Write([]string{"email", "name", "primaryLeafOrgName", "initialPassword", "status", "lastError"}); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
for _, credential := range credentials {
|
||||
if err := writer.Write([]string{credential.Email, credential.Name, credential.PrimaryLeafOrgName, credential.InitialPassword, credential.Status, credential.LastError}); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
}
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
c.Set(fiber.HeaderContentType, "text/csv; charset=utf-8")
|
||||
c.Set(fiber.HeaderContentDisposition, `attachment; filename="worksmobile_initial_passwords.csv"`)
|
||||
return c.Send(buf.Bytes())
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) ListCredentialBatches(c *fiber.Ctx) error {
|
||||
batches, err := h.Service.ListCredentialBatches(c.Context(), strings.TrimSpace(c.Params("tenantId")))
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "list_credential_batches")
|
||||
}
|
||||
return c.JSON(batches)
|
||||
}
|
||||
|
||||
func (h *WorksmobileHandler) DeleteCredentialBatchPasswords(c *fiber.Ctx) error {
|
||||
batchID := strings.TrimSpace(c.Params("batchId"))
|
||||
batch, err := h.Service.DeleteCredentialBatchPasswords(c.Context(), strings.TrimSpace(c.Params("tenantId")), batchID)
|
||||
if err != nil {
|
||||
return worksmobileGuardError(c, err, "delete_credential_batch_passwords", "batch_id", batchID)
|
||||
}
|
||||
return c.JSON(batch)
|
||||
}
|
||||
|
||||
type worksmobileCredentialBatchRequest struct {
|
||||
CredentialBatchID string `json:"credentialBatchId"`
|
||||
InitialPassword string `json:"initialPassword"`
|
||||
}
|
||||
|
||||
func parseWorksmobileCredentialBatchID(c *fiber.Ctx) (string, error) {
|
||||
req, err := parseWorksmobileCredentialRequest(c)
|
||||
return req.CredentialBatchID, err
|
||||
}
|
||||
|
||||
func parseWorksmobileCredentialRequest(c *fiber.Ctx) (worksmobileCredentialBatchRequest, error) {
|
||||
batchID := strings.TrimSpace(c.Query("credentialBatchId"))
|
||||
req := worksmobileCredentialBatchRequest{CredentialBatchID: batchID}
|
||||
if len(bytes.TrimSpace(c.Body())) == 0 {
|
||||
return req, nil
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return worksmobileCredentialBatchRequest{}, err
|
||||
}
|
||||
req.InitialPassword = strings.TrimSpace(req.InitialPassword)
|
||||
if bodyBatchID := strings.TrimSpace(req.CredentialBatchID); bodyBatchID != "" {
|
||||
req.CredentialBatchID = bodyBatchID
|
||||
return req, nil
|
||||
}
|
||||
req.CredentialBatchID = batchID
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func worksmobileOverviewAllowed(overview service.WorksmobileTenantOverview) bool {
|
||||
return overview.Tenant.Slug == service.HanmacFamilyTenantSlug && overview.Tenant.ParentID == nil
|
||||
}
|
||||
|
||||
func worksmobileGuardError(c *fiber.Ctx, err error, operation string, attrs ...any) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
logAttrs := []any{
|
||||
"operation", operation,
|
||||
"tenant_id", strings.TrimSpace(c.Params("tenantId")),
|
||||
"path", c.Path(),
|
||||
"error", err,
|
||||
}
|
||||
logAttrs = append(logAttrs, attrs...)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Warn("worksmobile admin operation failed", logAttrs...)
|
||||
return errorJSON(c, fiber.StatusRequestTimeout, err.Error())
|
||||
}
|
||||
slog.Error("worksmobile admin operation failed", logAttrs...)
|
||||
if strings.Contains(err.Error(), "hanmac-family root") {
|
||||
return errorJSON(c, fiber.StatusNotFound, err.Error())
|
||||
}
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
267
baron-sso/backend/internal/handler/worksmobile_handler_test.go
Normal file
267
baron-sso/backend/internal/handler/worksmobile_handler_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWorksmobileHandlerRejectsNonHanmacTenant(t *testing.T) {
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
overview: service.WorksmobileTenantOverview{
|
||||
Tenant: domain.Tenant{ID: "tenant-1", Slug: "other"},
|
||||
},
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile", h.GetOverview)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/tenant-1/worksmobile", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerReturnsOverviewForHanmacTenant(t *testing.T) {
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
overview: service.WorksmobileTenantOverview{
|
||||
Tenant: domain.Tenant{ID: "hanmac-id", Slug: "hanmac-family"},
|
||||
Config: service.WorksmobileConfigSummary{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile", h.GetOverview)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerDownloadsInitialPasswordCSV(t *testing.T) {
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
credentials: []service.WorksmobileInitialPasswordCredential{
|
||||
{
|
||||
Email: "user@hanmaceng.co.kr",
|
||||
Name: "홍길동",
|
||||
PrimaryLeafOrgName: "인재성장",
|
||||
InitialPassword: "Aa1!Aa1!Aa1!Aa1!",
|
||||
Status: "processed",
|
||||
},
|
||||
},
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile/initial-passwords.csv", h.DownloadInitialPasswordsCSV)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/initial-passwords.csv", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Contains(t, resp.Header.Get("Content-Disposition"), "worksmobile_initial_passwords.csv")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), "email,name,primaryLeafOrgName,initialPassword,status,lastError")
|
||||
require.Contains(t, string(body), "user@hanmaceng.co.kr,홍길동,인재성장,Aa1!Aa1!Aa1!Aa1!,processed,")
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerPassesInitialPasswordBatchID(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{
|
||||
credentials: []service.WorksmobileInitialPasswordCredential{
|
||||
{Email: "batch-user@hanmaceng.co.kr", InitialPassword: "BatchPass1!", Status: "pending"},
|
||||
},
|
||||
}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile/initial-passwords.csv", h.DownloadInitialPasswordsCSV)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/initial-passwords.csv?batchId=batch-1", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "batch-1", fakeService.downloadCredentialBatchID)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerPassesSyncUserCredentialBatchID(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
|
||||
|
||||
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", strings.NewReader(`{"credentialBatchId":"batch-1","initialPassword":"InputPass1!"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
|
||||
require.Equal(t, "batch-1", fakeService.syncUserCredentialBatchID)
|
||||
require.Equal(t, "InputPass1!", fakeService.syncUserInitialPassword)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerPassesPasswordResetCredentialBatchID(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Post("/tenants/:tenantId/worksmobile/users/:userId/password/reset", h.ResetUserPassword)
|
||||
|
||||
req := httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/password/reset", strings.NewReader(`{"credentialBatchId":"batch-1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := app.Test(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusAccepted, resp.StatusCode)
|
||||
require.Equal(t, "batch-1", fakeService.resetPasswordCredentialBatchID)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerReturnsCredentialBatchHistory(t *testing.T) {
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
credentialBatches: []service.WorksmobileCredentialBatch{
|
||||
{BatchID: "batch-1", UserCount: 2, HasPasswords: true},
|
||||
},
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Get("/tenants/:tenantId/worksmobile/credential-batches", h.ListCredentialBatches)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/tenants/hanmac-id/worksmobile/credential-batches", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), `"batchId":"batch-1"`)
|
||||
require.Contains(t, string(body), `"userCount":2`)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerDeletesCredentialBatchPasswords(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Delete("/tenants/:tenantId/worksmobile/credential-batches/:batchId/passwords", h.DeleteCredentialBatchPasswords)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("DELETE", "/tenants/hanmac-id/worksmobile/credential-batches/batch-1/passwords", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "batch-1", fakeService.deletedCredentialBatchID)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerDeletesPendingJobs(t *testing.T) {
|
||||
fakeService := &fakeWorksmobileAdminService{
|
||||
pendingJobsDeleteResult: service.WorksmobilePendingJobDeleteResult{DeletedCount: 3},
|
||||
}
|
||||
h := NewWorksmobileHandler(fakeService)
|
||||
app := fiber.New()
|
||||
app.Delete("/tenants/:tenantId/worksmobile/jobs/pending", h.DeletePendingJobs)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("DELETE", "/tenants/hanmac-id/worksmobile/jobs/pending", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "hanmac-id", fakeService.deletedPendingJobsTenantID)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), `"deletedCount":3`)
|
||||
}
|
||||
|
||||
func TestWorksmobileHandlerLogsActionFailures(t *testing.T) {
|
||||
var logs bytes.Buffer
|
||||
previous := slog.Default()
|
||||
slog.SetDefault(slog.New(slog.NewJSONHandler(&logs, nil)))
|
||||
t.Cleanup(func() {
|
||||
slog.SetDefault(previous)
|
||||
})
|
||||
|
||||
h := NewWorksmobileHandler(&fakeWorksmobileAdminService{
|
||||
syncUserErr: errors.New("works user sync failed"),
|
||||
})
|
||||
app := fiber.New()
|
||||
app.Post("/tenants/:tenantId/worksmobile/users/:userId/sync", h.SyncUser)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("POST", "/tenants/hanmac-id/worksmobile/users/user-1/sync", nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
require.Contains(t, logs.String(), "worksmobile admin operation failed")
|
||||
require.Contains(t, logs.String(), "sync_user")
|
||||
require.Contains(t, logs.String(), "works user sync failed")
|
||||
}
|
||||
|
||||
type fakeWorksmobileAdminService struct {
|
||||
overview service.WorksmobileTenantOverview
|
||||
credentials []service.WorksmobileInitialPasswordCredential
|
||||
syncUserErr error
|
||||
syncUserCredentialBatchID string
|
||||
syncUserInitialPassword string
|
||||
resetPasswordCredentialBatchID string
|
||||
downloadCredentialBatchID string
|
||||
deletedCredentialBatchID string
|
||||
deletedPendingJobsTenantID string
|
||||
pendingJobsDeleteResult service.WorksmobilePendingJobDeleteResult
|
||||
credentialBatches []service.WorksmobileCredentialBatch
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) GetTenantOverview(ctx context.Context, tenantID string) (service.WorksmobileTenantOverview, error) {
|
||||
return f.overview, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) GetComparison(ctx context.Context, tenantID string, includeMatched bool) (service.WorksmobileComparison, error) {
|
||||
return service.WorksmobileComparison{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueBackfillDryRun(ctx context.Context, tenantID string) (service.WorksmobileBackfillDryRun, error) {
|
||||
return service.WorksmobileBackfillDryRun{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueOrgUnitSync(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error) {
|
||||
return &domain.WorksmobileOutbox{ID: "job-orgunit", ResourceID: orgUnitID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueOrgUnitDelete(ctx context.Context, tenantID, orgUnitID string) (*domain.WorksmobileOutbox, error) {
|
||||
return &domain.WorksmobileOutbox{ID: "job-orgunit-delete", ResourceID: orgUnitID, Action: domain.WorksmobileActionDelete}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueUserSync(ctx context.Context, tenantID, userID, credentialBatchID, initialPassword string) (*domain.WorksmobileOutbox, error) {
|
||||
f.syncUserCredentialBatchID = credentialBatchID
|
||||
f.syncUserInitialPassword = initialPassword
|
||||
if f.syncUserErr != nil {
|
||||
return nil, f.syncUserErr
|
||||
}
|
||||
return &domain.WorksmobileOutbox{ID: "job-user", ResourceID: userID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) EnqueueUserPasswordReset(ctx context.Context, tenantID, userID, credentialBatchID string) (*domain.WorksmobileOutbox, error) {
|
||||
f.resetPasswordCredentialBatchID = credentialBatchID
|
||||
return &domain.WorksmobileOutbox{ID: "job-user-password-reset", ResourceID: userID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) RetryJob(ctx context.Context, tenantID, jobID string) (*domain.WorksmobileOutbox, error) {
|
||||
return &domain.WorksmobileOutbox{ID: jobID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) ListInitialPasswordCredentials(ctx context.Context, tenantID, credentialBatchID string) ([]service.WorksmobileInitialPasswordCredential, error) {
|
||||
f.downloadCredentialBatchID = credentialBatchID
|
||||
return f.credentials, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) ListCredentialBatches(ctx context.Context, tenantID string) ([]service.WorksmobileCredentialBatch, error) {
|
||||
return f.credentialBatches, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) DeleteCredentialBatchPasswords(ctx context.Context, tenantID, credentialBatchID string) (service.WorksmobileCredentialBatch, error) {
|
||||
f.deletedCredentialBatchID = credentialBatchID
|
||||
return service.WorksmobileCredentialBatch{BatchID: credentialBatchID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeWorksmobileAdminService) DeletePendingJobs(ctx context.Context, tenantID string) (service.WorksmobilePendingJobDeleteResult, error) {
|
||||
f.deletedPendingJobsTenantID = tenantID
|
||||
return f.pendingJobsDeleteResult, nil
|
||||
}
|
||||
Reference in New Issue
Block a user