1
0
forked from baron/baron-sso

refactor: backend tenant_group 제거 및 리팩터 반영

This commit is contained in:
Lectom C Han
2026-02-12 22:14:34 +09:00
parent b0792113ae
commit a8a219d7ef
26 changed files with 494 additions and 1001 deletions

View File

@@ -245,9 +245,7 @@ func main() {
// 2. Initialize Handlers // 2. Initialize Handlers
tenantRepo := repository.NewTenantRepository(db) tenantRepo := repository.NewTenantRepository(db)
tenantGroupRepo := repository.NewTenantGroupRepository(db)
tenantService := service.NewTenantService(tenantRepo) tenantService := service.NewTenantService(tenantRepo)
tenantGroupService := service.NewTenantGroupService(tenantGroupRepo, ketoService)
tenantService.SetKetoService(ketoService) // Keto 주입 tenantService.SetKetoService(ketoService) // Keto 주입
userRepo := repository.NewUserRepository(db) userRepo := repository.NewUserRepository(db)
// relyingPartyRepo removed as SSOT is now Hydra+Keto // relyingPartyRepo removed as SSOT is now Hydra+Keto
@@ -256,16 +254,14 @@ func main() {
secretRepo := repository.NewClientSecretRepository(db) secretRepo := repository.NewClientSecretRepository(db)
consentRepo := repository.NewClientConsentRepository(db) consentRepo := repository.NewClientConsentRepository(db)
kratosAdminService := service.NewKratosAdminService()
oryAdminProvider := service.NewOryProvider()
auditHandler := handler.NewAuditHandler(auditRepo) auditHandler := handler.NewAuditHandler(auditRepo)
authHandler := handler.NewAuthHandler(redisService, idpProvider, auditRepo, oathkeeperRepo, tenantService, ketoService, userRepo, consentRepo) authHandler := handler.NewAuthHandler(redisService, idpProvider, auditRepo, oathkeeperRepo, tenantService, ketoService, userRepo, consentRepo)
adminHandler := handler.NewAdminHandler(ketoService) adminHandler := handler.NewAdminHandler()
devHandler := handler.NewDevHandler(redisService, secretRepo, consentRepo, relyingPartyService) devHandler := handler.NewDevHandler(redisService, secretRepo, consentRepo)
tenantHandler := handler.NewTenantHandler(db, tenantService, ketoService, kratosAdminService) tenantHandler := handler.NewTenantHandler(db, tenantService)
tenantGroupHandler := handler.NewTenantGroupHandler(tenantGroupService, kratosAdminService) relyingPartyHandler := handler.NewRelyingPartyHandler(relyingPartyService)
relyingPartyHandler := handler.NewRelyingPartyHandler(relyingPartyService, kratosAdminService) kratosAdminService := service.NewKratosAdminService()
oryAdminProvider := service.NewOryProvider()
userHandler := handler.NewUserHandler(kratosAdminService, oryAdminProvider, tenantService, ketoService, userRepo) userHandler := handler.NewUserHandler(kratosAdminService, oryAdminProvider, tenantService, ketoService, userRepo)
apiKeyHandler := handler.NewApiKeyHandler(db) apiKeyHandler := handler.NewApiKeyHandler(db)
@@ -489,7 +485,6 @@ func main() {
// Auth Proxy Routes // Auth Proxy Routes
auth := api.Group("/auth") auth := api.Group("/auth")
auth.All("/oidc/*", authHandler.ProxyOidc)
auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink) auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink)
auth.Post("/enchanted-link/poll", authHandler.PollEnchantedLink) auth.Post("/enchanted-link/poll", authHandler.PollEnchantedLink)
auth.Post("/magic-link/verify", authHandler.VerifyMagicLink) auth.Post("/magic-link/verify", authHandler.VerifyMagicLink)
@@ -554,14 +549,13 @@ func main() {
KetoService: ketoService, KetoService: ketoService,
}) })
requireAdmin := middleware.RequireRole(middleware.RBACConfig{ requireAdmin := middleware.RequireRole(middleware.RBACConfig{
AllowedRoles: []string{domain.RoleSuperAdmin, domain.RoleTenantAdmin, domain.RoleRPAdmin}, AllowedRoles: []string{domain.RoleSuperAdmin, domain.RoleTenantAdmin},
AuthHandler: authHandler, AuthHandler: authHandler,
KetoService: ketoService, KetoService: ketoService,
}) })
admin.Get("/check", adminHandler.CheckAuth) // 기본 Admin 체크는 requireAdmin 없이 ApiKeyAuth로만 보호될 수 있음 (또는 추가 가능) admin.Get("/check", adminHandler.CheckAuth) // 기본 Admin 체크는 requireAdmin 없이 ApiKeyAuth로만 보호될 수 있음 (또는 추가 가능)
admin.Get("/stats", requireSuperAdmin, adminHandler.GetSystemStats) admin.Get("/stats", requireSuperAdmin, adminHandler.GetSystemStats)
admin.Get("/debug/check-permission", requireSuperAdmin, adminHandler.CheckPermission)
// Tenant Management (Super Admin Only) // Tenant Management (Super Admin Only)
admin.Get("/tenants", requireSuperAdmin, tenantHandler.ListTenants) admin.Get("/tenants", requireSuperAdmin, tenantHandler.ListTenants)
@@ -570,27 +564,9 @@ func main() {
admin.Get("/tenants/:id", requireAdmin, middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "Tenant", "view"), tenantHandler.GetTenant) admin.Get("/tenants/:id", requireAdmin, middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "Tenant", "view"), tenantHandler.GetTenant)
admin.Put("/tenants/:id", requireSuperAdmin, tenantHandler.UpdateTenant) admin.Put("/tenants/:id", requireSuperAdmin, tenantHandler.UpdateTenant)
admin.Delete("/tenants/:id", requireSuperAdmin, tenantHandler.DeleteTenant) admin.Delete("/tenants/:id", requireSuperAdmin, tenantHandler.DeleteTenant)
admin.Get("/tenants/:id/admins", requireAdmin, middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "Tenant", "manage"), tenantHandler.ListAdmins)
admin.Post("/tenants/:id/admins/:userId", requireAdmin, middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "Tenant", "manage"), tenantHandler.AddAdmin)
admin.Delete("/tenants/:id/admins/:userId", requireAdmin, middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "Tenant", "manage"), tenantHandler.RemoveAdmin)
// Tenant Group Management (Super Admin Only)
admin.Get("/tenant-groups", requireSuperAdmin, tenantGroupHandler.ListGroups)
admin.Post("/tenant-groups", requireSuperAdmin, tenantGroupHandler.CreateGroup)
admin.Get("/tenant-groups/:id", requireSuperAdmin, tenantGroupHandler.GetGroup)
admin.Put("/tenant-groups/:id", requireSuperAdmin, tenantGroupHandler.UpdateGroup)
admin.Delete("/tenant-groups/:id", requireSuperAdmin, tenantGroupHandler.DeleteGroup)
admin.Post("/tenant-groups/:id/tenants/:tenantId", requireSuperAdmin, tenantGroupHandler.AddTenantToGroup)
admin.Delete("/tenant-groups/:id/tenants/:tenantId", requireSuperAdmin, tenantGroupHandler.RemoveTenantFromGroup)
admin.Get("/tenant-groups/:id/admins", requireSuperAdmin, tenantGroupHandler.ListAdmins)
admin.Post("/tenant-groups/:id/admins/:userId", requireSuperAdmin, tenantGroupHandler.AddAdmin)
admin.Delete("/tenant-groups/:id/admins/:userId", requireSuperAdmin, tenantGroupHandler.RemoveAdmin)
// Relying Party Management (Global List) // Relying Party Management (Global List)
admin.Get("/relying-parties", requireAdmin, relyingPartyHandler.ListAll) admin.Get("/relying-parties", requireAdmin, relyingPartyHandler.ListAll)
admin.Get("/relying-parties/:id/owners", requireAdmin, middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "RelyingParty", "manage"), relyingPartyHandler.ListOwners)
admin.Post("/relying-parties/:id/owners/:subject", requireAdmin, middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "RelyingParty", "manage"), relyingPartyHandler.AddOwner)
admin.Delete("/relying-parties/:id/owners/:subject", requireAdmin, middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "RelyingParty", "manage"), relyingPartyHandler.RemoveOwner)
// Relying Party Management (Tenant Context) // Relying Party Management (Tenant Context)
admin.Post("/tenants/:tenantId/relying-parties", admin.Post("/tenants/:tenantId/relying-parties",
@@ -631,24 +607,14 @@ func main() {
admin.Delete("/api-keys/:id", requireSuperAdmin, apiKeyHandler.DeleteApiKey) admin.Delete("/api-keys/:id", requireSuperAdmin, apiKeyHandler.DeleteApiKey)
// 개발자 포털 라우트 (RP/Consent 관리 및 IdP 설정) // 개발자 포털 라우트 (RP/Consent 관리 및 IdP 설정)
dev := api.Group("/dev", requireAdmin) dev := api.Group("/dev")
dev.Get("/clients", devHandler.ListClients) dev.Get("/clients", devHandler.ListClients)
dev.Post("/clients", devHandler.CreateClient) dev.Post("/clients", devHandler.CreateClient)
dev.Get("/clients/:id", dev.Get("/clients/:id", devHandler.GetClient)
middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "RelyingParty", "view"), dev.Put("/clients/:id", devHandler.UpdateClient)
devHandler.GetClient) dev.Post("/clients/:id/secret/rotate", devHandler.RotateClientSecret)
dev.Put("/clients/:id", dev.Patch("/clients/:id/status", devHandler.UpdateClientStatus)
middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "RelyingParty", "manage"), dev.Delete("/clients/:id", devHandler.DeleteClient)
devHandler.UpdateClient)
dev.Post("/clients/:id/secret/rotate",
middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "RelyingParty", "manage"),
devHandler.RotateClientSecret)
dev.Patch("/clients/:id/status",
middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "RelyingParty", "manage"),
devHandler.UpdateClientStatus)
dev.Delete("/clients/:id",
middleware.RequireKetoPermission(middleware.RBACConfig{AuthHandler: authHandler, KetoService: ketoService}, "RelyingParty", "manage"),
devHandler.DeleteClient)
dev.Get("/consents", devHandler.ListConsents) dev.Get("/consents", devHandler.ListConsents)
dev.Delete("/consents", devHandler.RevokeConsents) dev.Delete("/consents", devHandler.RevokeConsents)

View File

@@ -31,7 +31,6 @@ func migrateSchemas(db *gorm.DB) error {
slog.Info("[Bootstrap] Migrating database schemas...") slog.Info("[Bootstrap] Migrating database schemas...")
// Add all domain models here // Add all domain models here
return db.AutoMigrate( return db.AutoMigrate(
&domain.TenantGroup{},
&domain.Tenant{}, &domain.Tenant{},
&domain.TenantDomain{}, &domain.TenantDomain{},
&domain.User{}, &domain.User{},

View File

@@ -25,18 +25,6 @@ func SyncKetoRelations(db *gorm.DB, keto service.KetoService) error {
if t.ParentID != nil { if t.ParentID != nil {
_ = keto.CreateRelation(ctx, "Tenant", t.ID, "parent", *t.ParentID) _ = keto.CreateRelation(ctx, "Tenant", t.ID, "parent", *t.ParentID)
} }
if t.TenantGroupID != nil {
_ = keto.CreateRelation(ctx, "Tenant", t.ID, "parent_group", *t.TenantGroupID)
}
}
// 1.1 Sync Tenant Groups (Group Admins)
var groups []domain.TenantGroup
if err := db.Find(&groups).Error; err == nil {
slog.Info("Syncing tenant groups to Keto", "count", len(groups))
for range groups {
// 그룹 관리자 개념 확정 후 관계 생성 로직 추가 예정
}
} }
// 2. Sync All Users // 2. Sync All Users

View File

@@ -34,7 +34,6 @@ func SeedAdminIdentity(idp domain.IdentityProvider) error {
"affiliationType": "internal", "affiliationType": "internal",
"companyCode": "", "companyCode": "",
"grade": "admin", "grade": "admin",
"role": domain.RoleSuperAdmin,
}, },
} }

View File

@@ -68,19 +68,18 @@ type SignupRequest struct {
// User Profile Models // User Profile Models
type UserProfileResponse struct { type UserProfileResponse struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
Name string `json:"name"` Name string `json:"name"`
Phone string `json:"phone"` Phone string `json:"phone"`
Role string `json:"role"` // 추가 Role string `json:"role"` // 추가
Department string `json:"department"` Department string `json:"department"`
AffiliationType string `json:"affiliationType"` AffiliationType string `json:"affiliationType"`
CompanyCode string `json:"companyCode,omitempty"` CompanyCode string `json:"companyCode,omitempty"`
TenantID *string `json:"tenantId,omitempty"` // 추가 TenantID *string `json:"tenantId,omitempty"` // 추가
RelyingPartyID *string `json:"relyingPartyId,omitempty"` // 추가 RelyingPartyID *string `json:"relyingPartyId,omitempty"` // 추가
Metadata map[string]any `json:"metadata,omitempty"` Metadata map[string]any `json:"metadata,omitempty"`
Tenant *Tenant `json:"tenant,omitempty"` Tenant *Tenant `json:"tenant,omitempty"`
ManageableTenants []Tenant `json:"manageableTenants,omitempty"` // 추가: 관리 가능한 테넌트 목록
} }
type UpdateUserRequest struct { type UpdateUserRequest struct {

View File

@@ -17,47 +17,23 @@ const (
// Tenant represents a tenant model stored in PostgreSQL. // Tenant represents a tenant model stored in PostgreSQL.
type Tenant struct { type Tenant struct {
ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"` ID string `gorm:"primaryKey;type:uuid;default:gen_random_uuid()" json:"id"`
ParentID *string `gorm:"type:uuid;index" json:"parentId,omitempty"` // 부모 테넌트 ID ParentID *string `gorm:"type:uuid;index" json:"parentId,omitempty"` // 부모 테넌트 ID
TenantGroupID *string `gorm:"type:uuid;index" json:"tenantGroupId,omitempty"` Name string `gorm:"not null" json:"name"`
TenantGroup *TenantGroup `gorm:"foreignKey:TenantGroupID" json:"tenantGroup,omitempty"` Slug string `gorm:"uniqueIndex;not null" json:"slug"`
Name string `gorm:"not null" json:"name"` Description string `json:"description"`
Slug string `gorm:"uniqueIndex;not null" json:"slug"` Status string `gorm:"default:'pending'" json:"status"`
Description string `json:"description"` Domains []TenantDomain `gorm:"foreignKey:TenantID" json:"domains,omitempty"`
Status string `gorm:"default:'pending'" json:"status"` Config JSONMap `gorm:"type:jsonb" json:"config,omitempty"`
Domains []TenantDomain `gorm:"foreignKey:TenantID" json:"domains,omitempty"` CreatedAt time.Time `json:"createdAt"`
Config JSONMap `gorm:"type:jsonb" json:"config,omitempty"` UpdatedAt time.Time `json:"updatedAt"`
CreatedAt time.Time `json:"createdAt"` DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
} }
func (t *Tenant) IsActive() bool { func (t *Tenant) IsActive() bool {
return t.Status == TenantStatusActive return t.Status == TenantStatusActive
} }
// GetMergedConfig merges the group-level config with tenant-level config.
// Tenant config takes precedence.
func (t *Tenant) GetMergedConfig() JSONMap {
merged := make(JSONMap)
// 1. Apply Group Config (Base)
if t.TenantGroup != nil && t.TenantGroup.Config != nil {
for k, v := range t.TenantGroup.Config {
merged[k] = v
}
}
// 2. Apply Tenant Config (Overrides)
if t.Config != nil {
for k, v := range t.Config {
merged[k] = v
}
}
return merged
}
// BeforeCreate hook to generate UUID if not present. // BeforeCreate hook to generate UUID if not present.
func (t *Tenant) BeforeCreate(tx *gorm.DB) (err error) { func (t *Tenant) BeforeCreate(tx *gorm.DB) (err error) {
if t.ID == "" { if t.ID == "" {

View File

@@ -1,51 +1,22 @@
package handler package handler
import ( import (
"baron-sso-backend/internal/service"
"runtime" "runtime"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
type AdminHandler struct { type AdminHandler struct{}
Keto service.KetoService
}
func NewAdminHandler(keto service.KetoService) *AdminHandler { func NewAdminHandler() *AdminHandler {
return &AdminHandler{Keto: keto} return &AdminHandler{}
} }
func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error { func (h *AdminHandler) CheckAuth(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"}) return c.Status(fiber.StatusOK).JSON(fiber.Map{"status": "ok"})
} }
func (h *AdminHandler) CheckPermission(c *fiber.Ctx) error {
namespace := c.Query("namespace")
object := c.Query("object")
relation := c.Query("relation")
subject := c.Query("subject")
if namespace == "" || object == "" || relation == "" || subject == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "namespace, object, relation, and subject are required"})
}
allowed, err := h.Keto.CheckPermission(c.Context(), subject, namespace, object, relation)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{
"allowed": allowed,
"query": fiber.Map{
"namespace": namespace,
"object": object,
"relation": relation,
"subject": subject,
},
})
}
// GetSystemStats returns runtime statistics for monitoring // GetSystemStats returns runtime statistics for monitoring
func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error { func (h *AdminHandler) GetSystemStats(c *fiber.Ctx) error {
var m runtime.MemStats var m runtime.MemStats

View File

@@ -125,11 +125,10 @@ func GenerateSecureAlnumToken(length int) string {
func GenerateUserCode() string { func GenerateUserCode() string {
const letters = "ABCDEFGHJKLMNPQRSTUVWXYZ" const letters = "ABCDEFGHJKLMNPQRSTUVWXYZ"
// [Fixed] 요청하신 포맷 (영문 2자리 + 숫자 6자리, 하이픈 없음)으로 변경 return fmt.Sprintf("%c%c-%03d",
return fmt.Sprintf("%c%c%06d",
letters[rand.Intn(len(letters))], letters[rand.Intn(len(letters))],
letters[rand.Intn(len(letters))], letters[rand.Intn(len(letters))],
rand.Intn(1000000), rand.Intn(1000),
) )
} }
@@ -455,7 +454,8 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
slog.Info("[Signup] New user registered", "email", req.Email, "type", req.AffiliationType, "provider", h.IdpProvider.Name(), "subject", providerID) slog.Info("[Signup] New user registered", "email", req.Email, "type", req.AffiliationType, "provider", h.IdpProvider.Name(), "subject", providerID)
// [New] Local DB Sync // [SoT Policy] Kratos가 SoT이므로 로컬 DB 저장은 비동기 Read-Model 동기화로 처리합니다.
// 로컬 DB 저장이 실패하더라도 회원가입 프로세스는 성공으로 간주합니다.
localUser := &domain.User{ localUser := &domain.User{
ID: providerID, // Match IDP Subject ID: providerID, // Match IDP Subject
Email: req.Email, Email: req.Email,
@@ -471,9 +471,17 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
} }
if h.UserRepo != nil { if h.UserRepo != nil {
if err := h.UserRepo.Create(c.Context(), localUser); err != nil { go func(u *domain.User) {
slog.Error("[Signup] Failed to sync user to local DB", "email", req.Email, "error", err) // 요청 Context가 취소될 수 있으므로 Background Context 사용
} ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.UserRepo.Create(ctx, u); err != nil {
slog.Error("[Signup] Failed to sync user to Read-Model (Local DB)", "email", u.Email, "error", err)
} else {
slog.Debug("[Signup] Synced user to Read-Model", "email", u.Email)
}
}(localUser)
} }
// [Keto] Sync user-tenant relationship // [Keto] Sync user-tenant relationship
@@ -959,20 +967,13 @@ func (h *AuthHandler) InitEnchantedLink(c *fiber.Ctx) error {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "Identity provider unavailable"}) return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{"error": "Identity provider unavailable"})
} }
// [Changed] 토큰 길이를 사용자의 요청에 맞춰 6글자(3바이트)로, pendingRef를 8글자(4바이트)로 조정
userCode := GenerateUserCode() userCode := GenerateUserCode()
token := GenerateSecureToken(3) token := GenerateSecureToken(3)
pendingRef := GenerateSecureToken(3) pendingRef := GenerateSecureToken(3)
slog.Info("[Enchanted] Initiating enchanted link", "loginID", loginID, "token", token, "pendingRef", pendingRef) slog.Info("[Enchanted] Initiating enchanted link", "loginID", loginID, "token", token, "pendingRef", pendingRef)
// [Added] 사용자가 입력할 간편 코드를 Redis에 저장합니다. (이게 없으면 인증이 안 됩니다)
shortCodePayload, _ := json.Marshal(shortLoginCodePayload{
LoginID: lookupLoginID,
Code: token,
PendingRef: pendingRef,
})
h.RedisService.Set(prefixLoginCodeShort+userCode, string(shortCodePayload), defaultExpiration)
// Store in Redis // Store in Redis
sessionData, _ := json.Marshal(map[string]string{ sessionData, _ := json.Marshal(map[string]string{
"status": statusPending, "status": statusPending,
@@ -1026,13 +1027,12 @@ func (h *AuthHandler) InitEnchantedLink(c *fiber.Ctx) error {
} }
} else { } else {
// Send SMS // Send SMS
phone := sanitizePhoneForSms(loginID) content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s | 코드: %s", link, userCode)
content := fmt.Sprintf("[Baron 로그인] 로그인 링크: %s | 간편 코드: %s", link, userCode)
if drySend { if drySend {
slog.Info("[Enchanted][DrySend] SMS send skipped", "loginID", phone, "content", content) slog.Info("[Enchanted][DrySend] SMS send skipped", "loginID", loginID, "content", content)
} else { } else {
slog.Info("[Enchanted] Sending SMS via Naver Cloud", "to", phone) slog.Info("[Enchanted] Sending SMS via Naver Cloud", "loginID", loginID)
if err := h.SmsService.SendSms(phone, content); err != nil { if err := h.SmsService.SendSms(loginID, content); err != nil {
slog.Error("[Enchanted] SMS Failed", "error", err) slog.Error("[Enchanted] SMS Failed", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to send SMS"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to send SMS"})
} }
@@ -1526,7 +1526,7 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
loginID := strings.TrimSpace(req.LoginID) loginID := strings.TrimSpace(req.LoginID)
ale.LoginIDs["loginId"] = req.LoginID // 원문 ale.LoginIDs["loginId"] = req.LoginID // 원문
ale.LoginIDs["loginId_normalized"] = loginID ale.LoginIDs["loginId_normalized"] = loginID
// ale.NewPassword = req.Password // For test only, logging password (sensitive) ale.NewPassword = req.Password // For test only, logging password (sensitive)
ale.Log(slog.LevelInfo, "Attempting to login") ale.Log(slog.LevelInfo, "Attempting to login")
@@ -1568,25 +1568,22 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
// --- OIDC 로그인 흐름 처리 --- // --- OIDC 로그인 흐름 처리 ---
if req.LoginChallenge != "" { if req.LoginChallenge != "" {
slog.Info("OIDC login flow detected", "challenge", req.LoginChallenge, "subject", subject) slog.Info("OIDC login flow detected", "challenge", req.LoginChallenge)
// Check if the client is active // Check if the client is active
loginReq, err := h.Hydra.GetLoginRequest(c.Context(), req.LoginChallenge) loginReq, err := h.Hydra.GetLoginRequest(c.Context(), req.LoginChallenge)
if err == nil && loginReq != nil { if err == nil && loginReq != nil && loginReq.Client.Metadata != nil {
slog.Info("OIDC Client Info", "client_id", loginReq.Client.ClientID, "name", loginReq.Client.ClientName) if status, ok := loginReq.Client.Metadata["status"].(string); ok {
if loginReq.Client.Metadata != nil { if strings.ToLower(status) == "inactive" {
if status, ok := loginReq.Client.Metadata["status"].(string); ok { slog.Warn("Login rejected for inactive client in PasswordLogin", "client_id", loginReq.Client.ClientID)
if strings.ToLower(status) == "inactive" { return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
slog.Warn("Login rejected for inactive client in PasswordLogin", "client_id", loginReq.Client.ClientID)
return fiber.NewError(fiber.StatusForbidden, "The client application is disabled.")
}
} }
} }
} }
acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject) acceptResp, err := h.Hydra.AcceptLoginRequest(c.Context(), req.LoginChallenge, subject)
if err != nil { if err != nil {
slog.Error("failed to accept hydra login request", "error", err, "challenge", req.LoginChallenge) slog.Error("failed to accept hydra login request", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request") return fiber.NewError(fiber.StatusInternalServerError, "Failed to accept OIDC login request")
} }
slog.Info("Hydra login request accepted", "redirectTo", acceptResp.RedirectTo) slog.Info("Hydra login request accepted", "redirectTo", acceptResp.RedirectTo)
@@ -1597,13 +1594,12 @@ func (h *AuthHandler) PasswordLogin(c *fiber.Ctx) error {
// --- OIDC 로그인 흐름 처리 끝 --- // --- OIDC 로그인 흐름 처리 끝 ---
resp := fiber.Map{ resp := fiber.Map{
"sessionToken": authInfo.SessionToken.JWT, "sessionJwt": authInfo.SessionToken.JWT,
"sessionJwt": authInfo.SessionToken.JWT, // Frontend compatibility "status": "ok",
"status": "ok", "provider": h.IdpProvider.Name(),
"provider": h.IdpProvider.Name(),
} }
if authInfo.RefreshToken != nil { if authInfo.RefreshToken != nil {
resp["refreshToken"] = authInfo.RefreshToken.JWT resp["refreshJwt"] = authInfo.RefreshToken.JWT
} }
if authInfo.Subject != "" { if authInfo.Subject != "" {
resp["subject"] = authInfo.Subject resp["subject"] = authInfo.Subject
@@ -2079,16 +2075,6 @@ type kratosCourierRequest struct {
Body string `json:"body"` Body string `json:"body"`
} }
// sanitizePhoneForSms - 네이버 SMS 등 국내 발송기를 위해 +82 형식을 010 형식으로 변환합니다.
func sanitizePhoneForSms(phone string) string {
p := strings.ReplaceAll(phone, "-", "")
p = strings.ReplaceAll(p, " ", "")
if strings.HasPrefix(p, "+82") {
return "0" + p[3:]
}
return p
}
// HandleKratosCourierRelay - Kratos courier HTTP 요청을 받아 메일/SMS 발송으로 변환합니다. // HandleKratosCourierRelay - Kratos courier HTTP 요청을 받아 메일/SMS 발송으로 변환합니다.
func (h *AuthHandler) HandleKratosCourierRelay(c *fiber.Ctx) error { func (h *AuthHandler) HandleKratosCourierRelay(c *fiber.Ctx) error {
var req kratosCourierRequest var req kratosCourierRequest
@@ -2467,6 +2453,16 @@ func extractFirstString(data map[string]interface{}, keys ...string) string {
return "" return ""
} }
func sanitizePhoneForSms(phone string) string {
sanitized := strings.TrimSpace(phone)
if strings.HasPrefix(sanitized, "+82") {
sanitized = "0" + sanitized[3:]
}
sanitized = strings.ReplaceAll(sanitized, "-", "")
sanitized = strings.ReplaceAll(sanitized, " ", "")
return sanitized
}
// --- User Profile Handlers --- // --- User Profile Handlers ---
func (h *AuthHandler) formatPhoneForDisplay(phone string) string { func (h *AuthHandler) formatPhoneForDisplay(phone string) string {
@@ -2484,56 +2480,7 @@ func (h *AuthHandler) formatPhoneForStorage(phone string) string {
return phone return phone
} }
// ProxyOidc - 프론트엔드의 OIDC 요청을 내부 Hydra 서비스로 프록시합니다. // GetMe - Returns current user's profile with enriched data from local DB
func (h *AuthHandler) ProxyOidc(c *fiber.Ctx) error {
path := c.Params("*")
// [Strict] Always use internal Docker network address for proxying to avoid external loops
targetURL := "http://hydra:4444"
// 프록시 URL 구성
u, err := url.Parse(targetURL)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "invalid hydra public url")
}
u.Path = strings.TrimRight(u.Path, "/") + "/" + path
u.RawQuery = string(c.Request().URI().QueryString())
slog.Debug("Proxying OIDC request", "from", c.Path(), "to", u.String())
// 요청 준비
req, err := http.NewRequestWithContext(c.Context(), c.Method(), u.String(), bytes.NewReader(c.Body()))
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "failed to create proxy request")
}
// 헤더 복사
c.Request().Header.VisitAll(func(key, value []byte) {
k := string(key)
if k != "Host" && k != "Connection" {
req.Header.Add(k, string(value))
}
})
// 요청 실행 (Hydra 내부 HttpClient 사용)
resp, err := h.Hydra.HttpClient().Do(req)
if err != nil {
return fiber.NewError(fiber.StatusServiceUnavailable, "hydra public api unavailable")
}
defer resp.Body.Close()
// 응답 헤더 복사
for k, values := range resp.Header {
for _, v := range values {
c.Set(k, v)
}
}
// 상태 코드 및 바디 설정
c.Status(resp.StatusCode)
_, err = io.Copy(c.Response().BodyWriter(), resp.Body)
return err
}
func (h *AuthHandler) GetMe(c *fiber.Ctx) error { func (h *AuthHandler) GetMe(c *fiber.Ctx) error {
profile, err := h.resolveCurrentProfile(c) profile, err := h.resolveCurrentProfile(c)
if err != nil { if err != nil {
@@ -4006,13 +3953,6 @@ func (h *AuthHandler) resolveCurrentProfile(c *fiber.Ctx) (*domain.UserProfileRe
} }
} }
// Fetch Manageable Tenants for Admins
if profile.Role == domain.RoleSuperAdmin || profile.Role == domain.RoleTenantAdmin || profile.Role == domain.RoleRPAdmin {
if tenants, err := h.TenantService.ListManageableTenants(c.Context(), profile.ID); err == nil {
profile.ManageableTenants = tenants
}
}
// 4. Save to Redis Cache (Short TTL) // 4. Save to Redis Cache (Short TTL)
if h.RedisService != nil && cacheKey != "" { if h.RedisService != nil && cacheKey != "" {
if data, err := json.Marshal(profile); err == nil { if data, err := json.Marshal(profile); err == nil {
@@ -4842,7 +4782,10 @@ func extractLoginIDFromClaims(claims map[string]any) string {
} }
func (h *AuthHandler) getKratosIdentity(sessionToken string) (string, map[string]interface{}, error) { func (h *AuthHandler) getKratosIdentity(sessionToken string) (string, map[string]interface{}, error) {
kratosURL := strings.TrimRight(utils.GetEnv("KRATOS_PUBLIC_URL", "http://kratos:4433"), "/") kratosURL := strings.TrimRight(os.Getenv("KRATOS_PUBLIC_URL"), "/")
if kratosURL == "" {
kratosURL = "http://kratos:4433"
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, kratosURL+"/sessions/whoami", nil) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, kratosURL+"/sessions/whoami", nil)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@@ -4850,44 +4793,33 @@ func (h *AuthHandler) getKratosIdentity(sessionToken string) (string, map[string
req.Header.Set("X-Session-Token", sessionToken) req.Header.Set("X-Session-Token", sessionToken)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err == nil { if err != nil {
defer resp.Body.Close() return "", nil, err
if resp.StatusCode == http.StatusOK { }
var result struct { defer resp.Body.Close()
Identity struct { if resp.StatusCode >= 300 {
ID string `json:"id"` body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
Traits map[string]interface{} `json:"traits"` return "", nil, fmt.Errorf("kratos whoami failed status=%d body=%s", resp.StatusCode, string(body))
} `json:"identity"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err == nil {
return result.Identity.ID, result.Identity.Traits, nil
}
}
} }
// 2. Kratos 실패 시 Hydra Introspection 시도 (OIDC Access Token 대응) var result struct {
if h.Hydra != nil { Identity struct {
slog.Debug("[Auth] Kratos whoami failed, trying Hydra introspection", "token_prefix", sessionToken[:min(len(sessionToken), 10)]) ID string `json:"id"`
introspection, err := h.Hydra.IntrospectToken(context.Background(), sessionToken) Traits map[string]interface{} `json:"traits"`
if err == nil && introspection["active"] == true { } `json:"identity"`
subject, _ := introspection["sub"].(string) }
if subject != "" { if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
// Hydra는 Traits를 직접 주지 않으므로, Kratos Admin API로 상세 정보를 가져옴 return "", nil, err
identity, err := h.KratosAdmin.GetIdentity(context.Background(), subject)
if err == nil && identity != nil {
return identity.ID, identity.Traits, nil
}
// Identity 정보가 없더라도 최소한 Subject는 반환
return subject, map[string]interface{}{}, nil
}
}
} }
return "", nil, fmt.Errorf("invalid session or token") return result.Identity.ID, result.Identity.Traits, nil
} }
func (h *AuthHandler) getKratosSessionID(sessionToken string) (string, error) { func (h *AuthHandler) getKratosSessionID(sessionToken string) (string, error) {
kratosURL := strings.TrimRight(utils.GetEnv("KRATOS_PUBLIC_URL", "http://kratos:4433"), "/") kratosURL := strings.TrimRight(os.Getenv("KRATOS_PUBLIC_URL"), "/")
if kratosURL == "" {
kratosURL = "http://kratos:4433"
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, kratosURL+"/sessions/whoami", nil) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, kratosURL+"/sessions/whoami", nil)
if err != nil { if err != nil {
return "", err return "", err
@@ -4910,7 +4842,6 @@ func (h *AuthHandler) getKratosSessionID(sessionToken string) (string, error) {
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err return "", err
} }
return result.ID, nil return result.ID, nil
} }
@@ -4919,7 +4850,10 @@ func (h *AuthHandler) issueKratosSession(ctx context.Context, identityID string)
return "", fmt.Errorf("kratos identity id is empty") return "", fmt.Errorf("kratos identity id is empty")
} }
kratosAdminURL := strings.TrimRight(utils.GetEnv("KRATOS_ADMIN_URL", "http://kratos:4434"), "/") kratosAdminURL := strings.TrimRight(os.Getenv("KRATOS_ADMIN_URL"), "/")
if kratosAdminURL == "" {
kratosAdminURL = "http://kratos:4434"
}
payload := map[string]interface{}{ payload := map[string]interface{}{
"identity_id": identityID, "identity_id": identityID,

View File

@@ -0,0 +1,251 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// --- Async Test Mocks ---
type AsyncMockIdpProvider struct {
mock.Mock
}
func (m *AsyncMockIdpProvider) Name() string { return "mock-idp" }
func (m *AsyncMockIdpProvider) GetMetadata() (*domain.IDPMetadata, error) {
return &domain.IDPMetadata{}, nil
}
func (m *AsyncMockIdpProvider) UserExists(loginID string) (bool, error) {
args := m.Called(loginID)
return args.Bool(0), args.Error(1)
}
func (m *AsyncMockIdpProvider) CreateUser(user *domain.BrokerUser, password string) (string, error) {
args := m.Called(user, password)
return args.String(0), args.Error(1)
}
func (m *AsyncMockIdpProvider) SignIn(loginID, password string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) IssueSession(loginID string) (*domain.AuthInfo, error) { return nil, nil }
func (m *AsyncMockIdpProvider) InitiateLinkLogin(loginID, returnTo string) (*domain.LinkLoginInit, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) VerifyLoginCode(loginID, flowID, code string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) GetPasswordPolicy() (*domain.PasswordPolicy, error) {
return &domain.PasswordPolicy{MinLength: 12}, nil
}
func (m *AsyncMockIdpProvider) InitiatePasswordReset(loginID, redirectUrl string) error { return nil }
func (m *AsyncMockIdpProvider) VerifyPasswordResetToken(token string) (*domain.AuthInfo, error) {
return nil, nil
}
func (m *AsyncMockIdpProvider) UpdateUserPassword(loginID, newPassword string, r *http.Request) error {
return nil
}
type AsyncMockUserRepo struct {
mock.Mock
createCalled chan bool
}
func (m *AsyncMockUserRepo) Create(ctx context.Context, user *domain.User) error {
// Simulate DB latency
time.Sleep(50 * time.Millisecond)
args := m.Called(ctx, user)
if m.createCalled != nil {
m.createCalled <- true
}
return args.Error(0)
}
func (m *AsyncMockUserRepo) Update(ctx context.Context, user *domain.User) error { return nil }
func (m *AsyncMockUserRepo) FindByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByID(ctx context.Context, id string) (*domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) FindByIDs(ctx context.Context, ids []string) ([]domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) ListByTenant(ctx context.Context, tenantID string) ([]domain.User, error) {
return nil, nil
}
func (m *AsyncMockUserRepo) List(ctx context.Context, offset, limit int, search string) ([]domain.User, int64, error) {
return nil, 0, nil
}
type AsyncMockRedisRepo struct {
mock.Mock
}
func (m *AsyncMockRedisRepo) Set(key string, value string, expiration time.Duration) error {
args := m.Called(key, value, expiration)
return args.Error(0)
}
func (m *AsyncMockRedisRepo) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *AsyncMockRedisRepo) Delete(key string) error {
args := m.Called(key)
return args.Error(0)
}
func (m *AsyncMockRedisRepo) StoreVerificationCode(phone, code string) error { return nil }
func (m *AsyncMockRedisRepo) GetVerificationCode(phone string) (string, error) { return "", nil }
func (m *AsyncMockRedisRepo) DeleteVerificationCode(phone string) error { return nil }
type AsyncMockTenantService struct {
mock.Mock
}
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, description string, domains []string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
args := m.Called(ctx, emailDomain)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *AsyncMockTenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *AsyncMockTenantService) ApproveTenant(ctx context.Context, id string) error { return nil }
func (m *AsyncMockTenantService) SetKetoService(keto service.KetoService) {}
func (m *AsyncMockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
func (m *AsyncMockTenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
return nil
}
func (m *AsyncMockTenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
return nil, nil
}
type AsyncMockKetoService struct {
mock.Mock
}
func (m *AsyncMockKetoService) CreateRelation(ctx context.Context, namespace, object, relation, subject string) error {
args := m.Called(ctx, namespace, object, relation, subject)
return args.Error(0)
}
func (m *AsyncMockKetoService) DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error {
return nil
}
func (m *AsyncMockKetoService) CheckPermission(ctx context.Context, namespace, object, relation, subject string) (bool, error) {
return false, nil
}
func (m *AsyncMockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
return nil, nil
}
func (m *AsyncMockKetoService) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]service.RelationTuple, error) {
return nil, nil
}
// --- Tests ---
func TestSignup_AsyncDB_Isolation(t *testing.T) {
mockIdp := new(AsyncMockIdpProvider)
mockUserRepo := new(AsyncMockUserRepo)
mockRedis := new(AsyncMockRedisRepo)
mockTenant := new(AsyncMockTenantService)
mockKeto := new(AsyncMockKetoService)
h := &AuthHandler{
IdpProvider: mockIdp,
UserRepo: mockUserRepo,
RedisService: mockRedis,
TenantService: mockTenant,
KetoService: mockKeto,
}
app := fiber.New()
app.Post("/signup", h.Signup)
t.Run("SoT_DB_Failure_Ignored_And_Async", func(t *testing.T) {
email := "test@example.com"
phone := "010-1234-5678"
emailKey := "signup:email:" + email
phoneKey := "signup:phone:" + "01012345678"
// Redis Mocks
mockRedis.On("Get", emailKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
mockRedis.On("Get", phoneKey).Return(`{"verified": true, "expires_at": 9999999999}`, nil)
mockRedis.On("Delete", emailKey).Return(nil)
mockRedis.On("Delete", phoneKey).Return(nil)
// Tenant Mocks
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(nil, errors.New("not found"))
// Kratos Mocks (Success)
mockIdp.On("CreateUser", mock.Anything, "Password123!").Return("new-user-uuid", nil)
// UserRepo Mocks (Async & Failure)
mockUserRepo.createCalled = make(chan bool, 1)
mockUserRepo.On("Create", mock.Anything, mock.MatchedBy(func(u *domain.User) bool {
return u.Email == email
})).Return(errors.New("db connection error"))
// Keto Mocks (Optional, since it's also async)
// We won't block on this either
body, _ := json.Marshal(domain.SignupRequest{
Email: email,
Password: "Password123!",
Name: "Test User",
Phone: phone,
TermsAccepted: true,
})
req := httptest.NewRequest("POST", "/signup", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
start := time.Now()
resp, err := app.Test(req, 5000)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
assert.Equal(t, 200, resp.StatusCode)
// Ensure API responded faster than DB latency (50ms)
assert.Less(t, int64(elapsed), int64(60*time.Millisecond), "API should return before DB timeout")
// Wait for async execution
select {
case <-mockUserRepo.createCalled:
// Pass
case <-time.After(2 * time.Second):
t.Fatal("UserRepo.Create was not called asynchronously")
}
mockRedis.AssertExpectations(t)
mockIdp.AssertExpectations(t)
mockUserRepo.AssertExpectations(t)
})
}

View File

@@ -288,8 +288,8 @@ func TestPasswordLogin_NoOIDC_Success(t *testing.T) {
} }
var got map[string]string var got map[string]string
json.NewDecoder(resp.Body).Decode(&got) json.NewDecoder(resp.Body).Decode(&got)
if got["sessionToken"] != "valid-jwt" { if got["sessionJwt"] != "valid-jwt" {
t.Errorf("expected jwt valid-jwt, got %s", got["sessionToken"]) t.Errorf("expected jwt valid-jwt, got %s", got["sessionJwt"])
} }
// No redirectTo // No redirectTo
if _, ok := got["redirectTo"]; ok { if _, ok := got["redirectTo"]; ok {

View File

@@ -22,17 +22,15 @@ type DevHandler struct {
SecretRepo domain.ClientSecretRepository SecretRepo domain.ClientSecretRepository
KratosAdmin *service.KratosAdminService KratosAdmin *service.KratosAdminService
ConsentRepo repository.ClientConsentRepository ConsentRepo repository.ClientConsentRepository
RPService service.RelyingPartyService
} }
func NewDevHandler(redis domain.RedisRepository, secretRepo domain.ClientSecretRepository, consentRepo repository.ClientConsentRepository, rpService service.RelyingPartyService) *DevHandler { func NewDevHandler(redis domain.RedisRepository, secretRepo domain.ClientSecretRepository, consentRepo repository.ClientConsentRepository) *DevHandler {
return &DevHandler{ return &DevHandler{
Hydra: service.NewHydraAdminService(), Hydra: service.NewHydraAdminService(),
Redis: redis, Redis: redis,
SecretRepo: secretRepo, SecretRepo: secretRepo,
KratosAdmin: service.NewKratosAdminService(), KratosAdmin: service.NewKratosAdminService(),
ConsentRepo: consentRepo, ConsentRepo: consentRepo,
RPService: rpService,
} }
} }
@@ -97,58 +95,38 @@ type clientUpsertRequest struct {
} }
func (h *DevHandler) ListClients(c *fiber.Ctx) error { func (h *DevHandler) ListClients(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse) limit := c.QueryInt("limit", 50)
if !ok { offset := c.QueryInt("offset", 0)
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized: user profile not found"}) if limit <= 0 {
limit = 50
}
if offset < 0 {
offset = 0
} }
// Super Admin sees all (best effort via Hydra list for now, or we can use RPService if it's improved) clients, err := h.Hydra.ListClients(c.Context(), limit, offset)
if profile.Role == domain.RoleSuperAdmin {
limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0)
clients, err := h.Hydra.ListClients(c.Context(), limit, offset)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
items := make([]clientSummary, 0, len(clients))
for _, client := range clients {
items = append(items, h.mapClientSummary(client))
}
return c.JSON(clientListResponse{Items: items, Limit: limit, Offset: offset})
}
// For others, only show manageable tenants' clients
var tenantIDs []string
for _, t := range profile.ManageableTenants {
tenantIDs = append(tenantIDs, t.ID)
}
if len(tenantIDs) == 0 && profile.TenantID != nil {
tenantIDs = append(tenantIDs, *profile.TenantID)
}
if len(tenantIDs) == 0 {
return c.JSON(clientListResponse{Items: []clientSummary{}, Limit: 50, Offset: 0})
}
rps, err := h.RPService.ListByTenantIDs(c.Context(), tenantIDs)
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) if errors.Is(err, service.ErrHydraNotFound) {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "clients not found"})
}
errMsg := err.Error()
if strings.Contains(errMsg, "connection refused") || strings.Contains(errMsg, "dial tcp") {
return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{
"error": "Hydra service is unavailable. Please check if Ory Hydra is running.",
})
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": errMsg})
} }
items := make([]clientSummary, 0, len(rps)) items := make([]clientSummary, 0, len(clients))
for _, rp := range rps { for _, client := range clients {
// We need HydraClient details for the summary items = append(items, h.mapClientSummary(client))
client, err := h.Hydra.GetClient(c.Context(), rp.ClientID)
if err == nil {
items = append(items, h.mapClientSummary(*client))
}
} }
return c.JSON(clientListResponse{ return c.JSON(clientListResponse{
Items: items, Items: items,
Limit: len(items), Limit: limit,
Offset: 0, Offset: offset,
}) })
} }
@@ -166,11 +144,6 @@ func (h *DevHandler) GetClient(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
} }
// Set for audit logging
if tid, ok := client.Metadata["tenant_id"].(string); ok {
c.Locals("tenant_id", tid)
}
summary := h.mapClientSummary(*client) summary := h.mapClientSummary(*client)
return c.JSON(clientDetailResponse{ return c.JSON(clientDetailResponse{
Client: summary, Client: summary,
@@ -224,49 +197,11 @@ func (h *DevHandler) UpdateClientStatus(c *fiber.Ctx) error {
} }
func (h *DevHandler) CreateClient(c *fiber.Ctx) error { func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized"})
}
var req clientUpsertRequest var req clientUpsertRequest
if err := c.BodyParser(&req); err != nil { if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
} }
// Determine Tenant ID
targetTenantID := c.Get("X-Tenant-ID")
if targetTenantID == "" && profile.TenantID != nil {
targetTenantID = *profile.TenantID
}
if targetTenantID == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "X-Tenant-ID header is required"})
}
// Set for audit logging
c.Locals("tenant_id", targetTenantID)
// Validate Permission
isAllowed := false
if profile.Role == domain.RoleSuperAdmin {
isAllowed = true
} else {
for _, t := range profile.ManageableTenants {
if t.ID == targetTenantID {
isAllowed = true
break
}
}
if !isAllowed && profile.TenantID != nil && *profile.TenantID == targetTenantID {
isAllowed = true
}
}
if !isAllowed {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "you do not have permission to create clients for this tenant"})
}
clientID := strings.TrimSpace(valueOr(req.ID, "")) clientID := strings.TrimSpace(valueOr(req.ID, ""))
if clientID == "" { if clientID == "" {
clientID = uuid.NewString() clientID = uuid.NewString()
@@ -322,18 +257,11 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
Metadata: metadata, Metadata: metadata,
} }
// Use RPService to ensure Keto relations are created created, err := h.Hydra.CreateClient(c.Context(), clientReq)
rp, err := h.RPService.Create(c.Context(), targetTenantID, clientReq)
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
} }
// Fetch back the Hydra client to get the secret (RPService.Create returns domain.RelyingParty which has limited fields)
created, err := h.Hydra.GetClient(c.Context(), rp.ClientID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "client created but failed to retrieve details"})
}
// Store secret in metadata for later retrieval // Store secret in metadata for later retrieval
if created.ClientSecret != "" { if created.ClientSecret != "" {
// 1. Store in PostgreSQL (Source of Truth) // 1. Store in PostgreSQL (Source of Truth)
@@ -379,11 +307,6 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
} }
// Set for audit logging
if tid, ok := current.Metadata["tenant_id"].(string); ok {
c.Locals("tenant_id", tid)
}
clientType := "" clientType := ""
if req.Type != nil { if req.Type != nil {
clientType = strings.ToLower(strings.TrimSpace(*req.Type)) clientType = strings.ToLower(strings.TrimSpace(*req.Type))
@@ -459,14 +382,6 @@ func (h *DevHandler) DeleteClient(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "client id is required"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "client id is required"})
} }
// Fetch first for audit log tenant_id
client, err := h.Hydra.GetClient(c.Context(), clientID)
if err == nil {
if tid, ok := client.Metadata["tenant_id"].(string); ok {
c.Locals("tenant_id", tid)
}
}
if err := h.Hydra.DeleteClient(c.Context(), clientID); err != nil { if err := h.Hydra.DeleteClient(c.Context(), clientID); err != nil {
if errors.Is(err, service.ErrHydraNotFound) { if errors.Is(err, service.ErrHydraNotFound) {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "client not found"}) return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "client not found"})
@@ -488,24 +403,11 @@ func (h *DevHandler) DeleteClient(c *fiber.Ctx) error {
} }
func (h *DevHandler) ListConsents(c *fiber.Ctx) error { func (h *DevHandler) ListConsents(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized"})
}
clientID := strings.TrimSpace(c.Query("client_id")) clientID := strings.TrimSpace(c.Query("client_id"))
if clientID == "" { if clientID == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "client_id is required"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "client_id is required"})
} }
// Permission Check
if profile.Role != domain.RoleSuperAdmin {
allowed, err := h.RPService.CheckPermission(c.Context(), profile.ID, clientID, "view")
if err != nil || !allowed {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: you do not have permission to view consents for this client"})
}
}
subject := strings.TrimSpace(c.Query("subject")) subject := strings.TrimSpace(c.Query("subject"))
limit := c.QueryInt("limit", 50) limit := c.QueryInt("limit", 50)
offset := c.QueryInt("offset", 0) offset := c.QueryInt("offset", 0)
@@ -582,28 +484,12 @@ func (h *DevHandler) ListConsents(c *fiber.Ctx) error {
} }
func (h *DevHandler) RevokeConsents(c *fiber.Ctx) error { func (h *DevHandler) RevokeConsents(c *fiber.Ctx) error {
profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized"})
}
subject := strings.TrimSpace(c.Query("subject")) subject := strings.TrimSpace(c.Query("subject"))
if subject == "" { if subject == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "subject is required"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "subject is required"})
} }
clientID := strings.TrimSpace(c.Query("client_id")) clientID := strings.TrimSpace(c.Query("client_id"))
// Permission Check (if clientID is provided)
if clientID != "" && profile.Role != domain.RoleSuperAdmin {
allowed, err := h.RPService.CheckPermission(c.Context(), profile.ID, clientID, "manage")
if err != nil || !allowed {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: you do not have permission to revoke consents for this client"})
}
} else if clientID == "" && profile.Role != domain.RoleSuperAdmin {
// If clientID is not provided, we might need a more global check or just disallow it for non-superadmins
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "client_id is required for non-superadmins"})
}
// If subject is not a UUID, try to resolve it as an identifier (email/username) // If subject is not a UUID, try to resolve it as an identifier (email/username)
if _, err := uuid.Parse(subject); err != nil { if _, err := uuid.Parse(subject); err != nil {
resolved, err := h.KratosAdmin.FindIdentityIDByIdentifier(c.Context(), subject) resolved, err := h.KratosAdmin.FindIdentityIDByIdentifier(c.Context(), subject)
@@ -646,11 +532,6 @@ func (h *DevHandler) RotateClientSecret(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
} }
// Set for audit logging
if tid, ok := current.Metadata["tenant_id"].(string); ok {
c.Locals("tenant_id", tid)
}
// 3. Update Hydra // 3. Update Hydra
current.ClientSecret = newSecret current.ClientSecret = newSecret
updated, err := h.Hydra.UpdateClient(c.Context(), clientID, *current) updated, err := h.Hydra.UpdateClient(c.Context(), clientID, *current)

View File

@@ -1,10 +1,8 @@
package handler package handler
import ( import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service" "baron-sso-backend/internal/service"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -12,75 +10,8 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
) )
type MockRPService struct {
mock.Mock
}
func (m *MockRPService) Create(ctx context.Context, tenantID string, client domain.HydraClient) (*domain.RelyingParty, error) {
args := m.Called(ctx, tenantID, client)
return args.Get(0).(*domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) Get(ctx context.Context, clientID string) (*domain.RelyingParty, *domain.HydraClient, error) {
args := m.Called(ctx, clientID)
return args.Get(0).(*domain.RelyingParty), args.Get(1).(*domain.HydraClient), args.Error(2)
}
func (m *MockRPService) List(ctx context.Context, tenantID string) ([]domain.RelyingParty, error) {
args := m.Called(ctx, tenantID)
return args.Get(0).([]domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) ListAll(ctx context.Context) ([]domain.RelyingParty, error) {
args := m.Called(ctx)
return args.Get(0).([]domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) ListByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.RelyingParty, error) {
args := m.Called(ctx, tenantIDs)
return args.Get(0).([]domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) Update(ctx context.Context, clientID string, client domain.HydraClient) (*domain.RelyingParty, error) {
args := m.Called(ctx, clientID, client)
return args.Get(0).(*domain.RelyingParty), args.Error(1)
}
func (m *MockRPService) Delete(ctx context.Context, clientID string) error {
args := m.Called(ctx, clientID)
return args.Error(0)
}
func (m *MockRPService) CheckPermission(ctx context.Context, userID, clientID, relation string) (bool, error) {
args := m.Called(ctx, userID, clientID, relation)
return args.Bool(0), args.Error(1)
}
func (m *MockRPService) AddOwner(ctx context.Context, clientID, subject string) error {
args := m.Called(ctx, clientID, subject)
return args.Error(0)
}
func (m *MockRPService) RemoveOwner(ctx context.Context, clientID, subject string) error {
args := m.Called(ctx, clientID, subject)
return args.Error(0)
}
func (m *MockRPService) ListOwners(ctx context.Context, clientID string) ([]string, error) {
args := m.Called(ctx, clientID)
return args.Get(0).([]string), args.Error(1)
}
func withMockProfile(profile *domain.UserProfileResponse) fiber.Handler {
return func(c *fiber.Ctx) error {
c.Locals("user_profile", profile)
return c.Next()
}
}
func TestListClients_Success(t *testing.T) { func TestListClients_Success(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients" { if r.URL.Path == "/clients" {
@@ -99,11 +30,7 @@ func TestListClients_Success(t *testing.T) {
}, },
} }
app := fiber.New() app := fiber.New()
adminProfile := &domain.UserProfileResponse{ app.Get("/api/v1/dev/clients", h.ListClients)
ID: "admin-1",
Role: domain.RoleSuperAdmin,
}
app.Get("/api/v1/dev/clients", withMockProfile(adminProfile), h.ListClients)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil) req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients", nil)
resp, _ := app.Test(req, -1) resp, _ := app.Test(req, -1)
@@ -139,11 +66,7 @@ func TestGetClient_Success(t *testing.T) {
}, },
} }
app := fiber.New() app := fiber.New()
adminProfile := &domain.UserProfileResponse{ app.Get("/api/v1/dev/clients/:id", h.GetClient)
ID: "admin-1",
Role: domain.RoleSuperAdmin,
}
app.Get("/api/v1/dev/clients/:id", withMockProfile(adminProfile), h.GetClient)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-123", nil) req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-123", nil)
resp, _ := app.Test(req, -1) resp, _ := app.Test(req, -1)
@@ -169,11 +92,7 @@ func TestGetClient_NotFound(t *testing.T) {
}, },
} }
app := fiber.New() app := fiber.New()
adminProfile := &domain.UserProfileResponse{ app.Get("/api/v1/dev/clients/:id", h.GetClient)
ID: "admin-1",
Role: domain.RoleSuperAdmin,
}
app.Get("/api/v1/dev/clients/:id", withMockProfile(adminProfile), h.GetClient)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/non-existent", nil) req := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/non-existent", nil)
resp, _ := app.Test(req, -1) resp, _ := app.Test(req, -1)
@@ -190,49 +109,30 @@ func TestCreateClient_Success(t *testing.T) {
"client_secret": "secret-123", "client_secret": "secret-123",
}), nil }), nil
} }
if r.Method == http.MethodGet && r.URL.Path == "/clients/new-client-123" { return httpJSONAny(r, http.StatusInternalServerError, map[string]string{"error": "hydra error"}), nil
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"client_id": "new-client-123",
"client_name": "New App",
"client_secret": "secret-123",
"metadata": map[string]interface{}{"status": "active"},
}), nil
}
return httpJSONAny(r, http.StatusInternalServerError, map[string]string{"error": "hydra error path: " + r.URL.Path}), nil
}) })
secretRepo := &mockSecretRepo{secrets: make(map[string]string)} secretRepo := &mockSecretRepo{secrets: make(map[string]string)}
redisRepo := &mockRedisRepo{data: make(map[string]string)} redisRepo := &mockRedisRepo{data: make(map[string]string)}
mockRP := new(MockRPService)
h := &DevHandler{ h := &DevHandler{
Hydra: &service.HydraAdminService{ Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test", AdminURL: "http://hydra.test",
PublicURL: "http://hydra-public.test",
HTTPClient: &http.Client{Transport: transport}, HTTPClient: &http.Client{Transport: transport},
}, },
SecretRepo: secretRepo, SecretRepo: secretRepo,
Redis: redisRepo, Redis: redisRepo,
RPService: mockRP,
} }
app := fiber.New() app := fiber.New()
adminProfile := &domain.UserProfileResponse{ app.Post("/api/v1/dev/clients", h.CreateClient)
ID: "admin-1",
Role: domain.RoleSuperAdmin,
}
app.Post("/api/v1/dev/clients", withMockProfile(adminProfile), h.CreateClient)
body, _ := json.Marshal(map[string]interface{}{ body, _ := json.Marshal(map[string]interface{}{
"client_name": "New App", "client_name": "New App",
"type": "confidential", "type": "confidential",
"redirectUris": []string{"http://localhost/cb"}, "redirectUris": []string{"http://localhost/cb"},
}) })
mockRP.On("Create", mock.Anything, "t1", mock.Anything).Return(&domain.RelyingParty{ClientID: "new-client-123"}, nil)
req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Tenant-ID", "t1")
resp, _ := app.Test(req, -1) resp, _ := app.Test(req, -1)
assert.Equal(t, http.StatusCreated, resp.StatusCode) assert.Equal(t, http.StatusCreated, resp.StatusCode)

View File

@@ -10,11 +10,10 @@ import (
type RelyingPartyHandler struct { type RelyingPartyHandler struct {
Service service.RelyingPartyService Service service.RelyingPartyService
UserSvc *service.KratosAdminService
} }
func NewRelyingPartyHandler(s service.RelyingPartyService, userSvc *service.KratosAdminService) *RelyingPartyHandler { func NewRelyingPartyHandler(s service.RelyingPartyService) *RelyingPartyHandler {
return &RelyingPartyHandler{Service: s, UserSvc: userSvc} return &RelyingPartyHandler{Service: s}
} }
func (h *RelyingPartyHandler) Create(c *fiber.Ctx) error { func (h *RelyingPartyHandler) Create(c *fiber.Ctx) error {
@@ -111,58 +110,3 @@ func (h *RelyingPartyHandler) Delete(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent) return c.SendStatus(fiber.StatusNoContent)
} }
func (h *RelyingPartyHandler) ListOwners(c *fiber.Ctx) error {
clientID := c.Params("id")
subjects, err := h.Service.ListOwners(c.Context(), clientID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
type ownerInfo struct {
Subject string `json:"subject"`
Name string `json:"name,omitempty"`
Email string `json:"email,omitempty"`
Type string `json:"type"` // "user" or "group"
}
owners := make([]ownerInfo, 0, len(subjects))
for _, s := range subjects {
info := ownerInfo{Subject: s, Type: "unknown"}
if len(s) > 5 && s[:5] == "User:" {
info.Type = "user"
userID := s[5:]
identity, err := h.UserSvc.GetIdentity(c.Context(), userID)
if err == nil && identity != nil {
info.Name, _ = identity.Traits["name"].(string)
info.Email, _ = identity.Traits["email"].(string)
}
} else if len(s) > 10 && s[:10] == "UserGroup:" {
info.Type = "group"
// Group name enrichment could be added if we have a GroupService here
}
owners = append(owners, info)
}
return c.JSON(owners)
}
func (h *RelyingPartyHandler) AddOwner(c *fiber.Ctx) error {
clientID := c.Params("id")
subject := c.Params("subject") // e.g. "User:uuid"
if err := h.Service.AddOwner(c.Context(), clientID, subject); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{"message": "owner added"})
}
func (h *RelyingPartyHandler) RemoveOwner(c *fiber.Ctx) error {
clientID := c.Params("id")
subject := c.Params("subject")
if err := h.Service.RemoveOwner(c.Context(), clientID, subject); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{"message": "owner removed"})
}

View File

@@ -14,25 +14,22 @@ import (
type TenantHandler struct { type TenantHandler struct {
DB *gorm.DB DB *gorm.DB
Service service.TenantService Service service.TenantService
Keto service.KetoService
UserSvc *service.KratosAdminService
} }
func NewTenantHandler(db *gorm.DB, svc service.TenantService, keto service.KetoService, userSvc *service.KratosAdminService) *TenantHandler { func NewTenantHandler(db *gorm.DB, svc service.TenantService) *TenantHandler {
return &TenantHandler{DB: db, Service: svc, Keto: keto, UserSvc: userSvc} return &TenantHandler{DB: db, Service: svc}
} }
type tenantSummary struct { type tenantSummary struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Slug string `json:"slug"` Slug string `json:"slug"`
Description string `json:"description"` Description string `json:"description"`
Status string `json:"status"` Status string `json:"status"`
TenantGroupID *string `json:"tenantGroupId,omitempty"` Domains []string `json:"domains,omitempty"`
Domains []string `json:"domains,omitempty"` Config domain.JSONMap `json:"config,omitempty"`
Config domain.JSONMap `json:"config,omitempty"` CreatedAt string `json:"createdAt"`
CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"`
UpdatedAt string `json:"updatedAt"`
} }
type tenantListResponse struct { type tenantListResponse struct {
@@ -103,7 +100,7 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
} }
var tenants []domain.Tenant var tenants []domain.Tenant
if err := h.DB.Order("created_at desc").Limit(limit).Offset(offset).Preload("Domains").Preload("TenantGroup").Find(&tenants).Error; err != nil { if err := h.DB.Order("created_at desc").Limit(limit).Offset(offset).Preload("Domains").Find(&tenants).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
} }
@@ -126,7 +123,7 @@ func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
} }
var tenant domain.Tenant var tenant domain.Tenant
if err := h.DB.Preload("Domains").Preload("TenantGroup").First(&tenant, "id = ?", tenantID).Error; err != nil { if err := h.DB.Preload("Domains").First(&tenant, "id = ?", tenantID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "tenant not found"}) return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "tenant not found"})
} }
@@ -207,13 +204,12 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
} }
var req struct { var req struct {
Name *string `json:"name"` Name *string `json:"name"`
Slug *string `json:"slug"` Slug *string `json:"slug"`
Description *string `json:"description"` Description *string `json:"description"`
Status *string `json:"status"` Status *string `json:"status"`
TenantGroupID *string `json:"tenantGroupId"` Domains []string `json:"domains"`
Domains []string `json:"domains"` Config map[string]any `json:"config"`
Config map[string]any `json:"config"`
} }
if err := c.BodyParser(&req); err != nil { if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
@@ -255,29 +251,6 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
tenant.Config = req.Config tenant.Config = req.Config
} }
// Handle Group Change
if req.TenantGroupID != nil {
oldGroupID := tenant.TenantGroupID
newGroupID := req.TenantGroupID
if *newGroupID == "" {
newGroupID = nil
}
// Update Keto if group changed
if h.Keto != nil {
// Remove old group relation if existed
if oldGroupID != nil && (newGroupID == nil || *oldGroupID != *newGroupID) {
_ = h.Keto.DeleteRelation(c.Context(), "Tenant", tenant.ID, "parent_group", *oldGroupID)
}
// Add new group relation
if newGroupID != nil && (oldGroupID == nil || *oldGroupID != *newGroupID) {
_ = h.Keto.CreateRelation(c.Context(), "Tenant", tenant.ID, "parent_group", *newGroupID)
}
}
tenant.TenantGroupID = newGroupID
}
if err := h.DB.Save(&tenant).Error; err != nil { if err := h.DB.Save(&tenant).Error; err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
} }
@@ -328,58 +301,6 @@ func (h *TenantHandler) DeleteTenant(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent) return c.SendStatus(fiber.StatusNoContent)
} }
func (h *TenantHandler) ListAdmins(c *fiber.Ctx) error {
tenantID := c.Params("id")
userIDs, err := h.Service.ListTenantAdmins(c.Context(), tenantID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
type adminInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
}
admins := make([]adminInfo, 0, len(userIDs))
for _, uid := range userIDs {
identity, err := h.UserSvc.GetIdentity(c.Context(), uid)
if err == nil && identity != nil {
name, _ := identity.Traits["name"].(string)
email, _ := identity.Traits["email"].(string)
admins = append(admins, adminInfo{
ID: uid,
Name: name,
Email: email,
})
} else {
admins = append(admins, adminInfo{ID: uid})
}
}
return c.JSON(admins)
}
func (h *TenantHandler) AddAdmin(c *fiber.Ctx) error {
tenantID := c.Params("id")
userID := c.Params("userId")
if err := h.Service.AddTenantAdmin(c.Context(), tenantID, userID); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{"message": "admin added to tenant"})
}
func (h *TenantHandler) RemoveAdmin(c *fiber.Ctx) error {
tenantID := c.Params("id")
userID := c.Params("userId")
if err := h.Service.RemoveTenantAdmin(c.Context(), tenantID, userID); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(fiber.Map{"message": "admin removed from tenant"})
}
func mapTenantSummary(t domain.Tenant) tenantSummary { func mapTenantSummary(t domain.Tenant) tenantSummary {
domains := make([]string, 0, len(t.Domains)) domains := make([]string, 0, len(t.Domains))
for _, d := range t.Domains { for _, d := range t.Domains {
@@ -387,16 +308,15 @@ func mapTenantSummary(t domain.Tenant) tenantSummary {
} }
return tenantSummary{ return tenantSummary{
ID: t.ID, ID: t.ID,
Name: t.Name, Name: t.Name,
Slug: t.Slug, Slug: t.Slug,
Description: t.Description, Description: t.Description,
Status: t.Status, Status: t.Status,
TenantGroupID: t.TenantGroupID, Domains: domains,
Domains: domains, Config: t.Config,
Config: t.GetMergedConfig(), CreatedAt: t.CreatedAt.Format(time.RFC3339),
CreatedAt: t.CreatedAt.Format(time.RFC3339), UpdatedAt: t.UpdatedAt.Format(time.RFC3339),
UpdatedAt: t.UpdatedAt.Format(time.RFC3339),
} }
} }

View File

@@ -70,26 +70,6 @@ func (m *MockTenantService) SetKetoService(keto service.KetoService) {
m.Called(keto) m.Called(keto)
} }
func (m *MockTenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
args := m.Called(ctx, userID)
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
args := m.Called(ctx, tenantID, userID)
return args.Error(0)
}
func (m *MockTenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
args := m.Called(ctx, tenantID, userID)
return args.Error(0)
}
func (m *MockTenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
args := m.Called(ctx, tenantID)
return args.Get(0).([]string), args.Error(1)
}
func TestTenantHandler_CreateTenant(t *testing.T) { func TestTenantHandler_CreateTenant(t *testing.T) {
app := fiber.New() app := fiber.New()
mockSvc := new(MockTenantService) mockSvc := new(MockTenantService)

View File

@@ -304,10 +304,15 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
localUser.TenantID = &tenantID localUser.TenantID = &tenantID
} }
// [SoT Policy] Kratos가 SoT이므로 로컬 DB 저장은 비동기 Read-Model 동기화로 처리합니다.
if h.UserRepo != nil { if h.UserRepo != nil {
if err := h.UserRepo.Create(c.Context(), localUser); err != nil { go func(u *domain.User) {
slog.Error("[UserHandler] Failed to sync user to local DB", "email", email, "error", err) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
} defer cancel()
if err := h.UserRepo.Create(ctx, u); err != nil {
slog.Error("[UserHandler] Failed to sync user to local DB", "email", u.Email, "error", err)
}
}(localUser)
} }
// [Keto] Sync relations // [Keto] Sync relations
@@ -483,27 +488,32 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
localUser.Metadata = req.Metadata localUser.Metadata = req.Metadata
} }
if err := h.UserRepo.Update(c.Context(), localUser); err == nil { // [SoT Policy] Kratos가 SoT이므로 로컬 DB 저장은 비동기 Read-Model 동기화로 처리합니다.
// [Keto Sync on Role Change] go func(u *domain.User, rRole *string, oRole string, oTenantID string) {
if h.KetoService != nil && req.Role != nil && *req.Role != oldRole { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func(uID, oldR, newR, tID string) { defer cancel()
ctx := context.Background()
if oldR == domain.RoleSuperAdmin { if err := h.UserRepo.Update(ctx, u); err == nil {
// [Keto Sync on Role Change]
if h.KetoService != nil && rRole != nil && *rRole != oRole {
uID := u.ID
newR := *rRole
if oRole == domain.RoleSuperAdmin {
_ = h.KetoService.DeleteRelation(ctx, "System", "global", "super_admins", uID) _ = h.KetoService.DeleteRelation(ctx, "System", "global", "super_admins", uID)
} else if oldR == domain.RoleTenantAdmin && tID != "" { } else if oRole == domain.RoleTenantAdmin && oTenantID != "" {
_ = h.KetoService.DeleteRelation(ctx, "Tenant", tID, "admins", uID) _ = h.KetoService.DeleteRelation(ctx, "Tenant", oTenantID, "admins", uID)
} }
if newR == domain.RoleSuperAdmin { if newR == domain.RoleSuperAdmin {
_ = h.KetoService.CreateRelation(ctx, "System", "global", "super_admins", uID) _ = h.KetoService.CreateRelation(ctx, "System", "global", "super_admins", uID)
} else if newR == domain.RoleTenantAdmin && tID != "" { } else if newR == domain.RoleTenantAdmin && u.TenantID != nil {
_ = h.KetoService.CreateRelation(ctx, "Tenant", tID, "admins", uID) _ = h.KetoService.CreateRelation(ctx, "Tenant", *u.TenantID, "admins", uID)
} }
}(userID, oldRole, *req.Role, oldTenantID) }
} else {
slog.Error("[UserHandler] Failed to sync user update to local DB", "userID", u.ID, "error", err)
} }
} else { }(localUser, req.Role, oldRole, oldTenantID)
slog.Error("[UserHandler] Failed to sync user update to local DB", "userID", userID, "error", err)
}
} }
} }

View File

@@ -3,7 +3,6 @@ package middleware
import ( import (
"baron-sso-backend/internal/domain" "baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service" "baron-sso-backend/internal/service"
"fmt"
"log/slog" "log/slog"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@@ -26,7 +25,7 @@ func RequireKetoPermission(config RBACConfig, namespace, relation string) fiber.
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
profile, err := config.AuthHandler.GetEnrichedProfile(c) profile, err := config.AuthHandler.GetEnrichedProfile(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "인증에 실패했습니다. (rbac_keto)"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized (trace:rbac_keto)"})
} }
// Store profile in locals for further use in handlers // Store profile in locals for further use in handlers
@@ -44,21 +43,14 @@ func RequireKetoPermission(config RBACConfig, namespace, relation string) fiber.
} }
if objectID == "" { if objectID == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "권한 검증을 위한 대상 ID가 누락되었습니다."}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "missing object id for permission check"})
}
// Set tenant_id for audit logging if namespace is Tenant
if namespace == "Tenant" {
c.Locals("tenant_id", objectID)
} }
// Check with Keto // Check with Keto
allowed, err := config.KetoService.CheckPermission(c.Context(), profile.ID, namespace, objectID, relation) allowed, err := config.KetoService.CheckPermission(c.Context(), profile.ID, namespace, objectID, relation)
if err != nil || !allowed { if err != nil || !allowed {
slog.Warn("Keto permission denied", "userID", profile.ID, "namespace", namespace, "objectID", objectID, "relation", relation) slog.Warn("Keto permission denied", "userID", profile.ID, "namespace", namespace, "objectID", objectID, "relation", relation)
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden: keto permission denied"})
"error": fmt.Sprintf("접근 권한이 없습니다. 현재 '%s' 권한으로는 요청하신 리소스에 대한 상세 권한(Keto)이 부족합니다. 관리자에게 문의하세요.", profile.Role),
})
} }
return c.Next() return c.Next()
@@ -76,7 +68,7 @@ func RequireRole(config RBACConfig) fiber.Handler {
profile, err := config.AuthHandler.GetEnrichedProfile(c) profile, err := config.AuthHandler.GetEnrichedProfile(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "인증 정보 조회에 실패했습니다: " + err.Error(), "error": "unauthorized (trace:rbac_role): " + err.Error(),
}) })
} }
@@ -105,7 +97,7 @@ func RequireRole(config RBACConfig) fiber.Handler {
"path", c.Path(), "path", c.Path(),
) )
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": fmt.Sprintf("접근 권한이 없습니다. 현재 '%s' 권한으로는 이 기능을 사용할 수 없습니다. 관리자에게 문의하여 'rp_admin' 이상의 권한을 확보하세요.", profile.Role), "error": "forbidden: insufficient permissions",
}) })
} }
@@ -126,7 +118,7 @@ func RequireTenantMatch(config RBACConfig) fiber.Handler {
profile, err := config.AuthHandler.GetEnrichedProfile(c) profile, err := config.AuthHandler.GetEnrichedProfile(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "인증에 실패했습니다. (rbac_match)"}) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized (trace:rbac_match)"})
} }
// Store profile in locals for further use in handlers // Store profile in locals for further use in handlers
@@ -146,12 +138,12 @@ func RequireTenantMatch(config RBACConfig) fiber.Handler {
if profile.TenantID == nil || *profile.TenantID != targetTenantID { if profile.TenantID == nil || *profile.TenantID != targetTenantID {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": fmt.Sprintf("해당 테넌트에 대한 접근 권한이 없습니다. 사용자님의 '%s' 권한은 소속된 테넌트의 리소스만 관리할 수 있습니다.", profile.Role), "error": "forbidden: you do not have access to this tenant",
}) })
} }
return c.Next() return c.Next()
} }
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "요청하신 리소스에 접근할 수 없습니다."}) return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "forbidden"})
} }
} }

View File

@@ -54,14 +54,6 @@ func (m *MockKetoService) ListRelations(ctx context.Context, namespace, object,
return args.Get(0).([]service.RelationTuple), args.Error(1) return args.Get(0).([]service.RelationTuple), args.Error(1)
} }
func (m *MockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
args := m.Called(ctx, namespace, relation, subject)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]string), args.Error(1)
}
// Fixed MockKetoService to match service.KetoService exactly if possible. // Fixed MockKetoService to match service.KetoService exactly if possible.
// Wait, middleware/rbac.go imports baron-sso-backend/internal/service. // Wait, middleware/rbac.go imports baron-sso-backend/internal/service.
// So I should use service.RelationTuple. // So I should use service.RelationTuple.

View File

@@ -14,7 +14,6 @@ type TenantRepository interface {
FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error)
FindByName(ctx context.Context, name string) (*domain.Tenant, error) FindByName(ctx context.Context, name string) (*domain.Tenant, error)
FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) FindByDomain(ctx context.Context, domainName string) (*domain.Tenant, error)
FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error)
AddDomain(ctx context.Context, tenantID string, domainName string) error AddDomain(ctx context.Context, tenantID string, domainName string) error
} }
@@ -42,17 +41,6 @@ func (r *tenantRepository) FindByID(ctx context.Context, id string) (*domain.Ten
return &tenant, nil return &tenant, nil
} }
func (r *tenantRepository) FindByIDs(ctx context.Context, ids []string) ([]domain.Tenant, error) {
var tenants []domain.Tenant
if len(ids) == 0 {
return tenants, nil
}
if err := r.db.WithContext(ctx).Preload("Domains").Where("id IN ?", ids).Find(&tenants).Error; err != nil {
return nil, err
}
return tenants, nil
}
func (r *tenantRepository) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) { func (r *tenantRepository) FindBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
var tenant domain.Tenant var tenant domain.Tenant
if err := r.db.WithContext(ctx).Preload("Domains").Where("slug = ?", slug).First(&tenant).Error; err != nil { if err := r.db.WithContext(ctx).Preload("Domains").Where("slug = ?", slug).First(&tenant).Error; err != nil {

View File

@@ -2,7 +2,6 @@ package service
import ( import (
"baron-sso-backend/internal/domain" "baron-sso-backend/internal/domain"
"baron-sso-backend/internal/utils"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
@@ -28,8 +27,8 @@ type HydraAdminService struct {
func NewHydraAdminService() *HydraAdminService { func NewHydraAdminService() *HydraAdminService {
return &HydraAdminService{ return &HydraAdminService{
AdminURL: utils.GetEnv("HYDRA_ADMIN_URL", "http://hydra:4445"), AdminURL: getenv("HYDRA_ADMIN_URL", "http://hydra:4445"),
PublicURL: utils.GetEnv("HYDRA_PUBLIC_URL", "http://hydra:4444"), PublicURL: getenv("HYDRA_PUBLIC_URL", "http://hydra:4444"),
} }
} }
@@ -47,7 +46,7 @@ func (s *HydraAdminService) ListClients(ctx context.Context, limit, offset int)
return nil, err return nil, err
} }
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -75,7 +74,7 @@ func (s *HydraAdminService) GetClient(ctx context.Context, clientID string) (*do
return nil, err return nil, err
} }
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -114,7 +113,7 @@ func (s *HydraAdminService) PatchClientStatus(ctx context.Context, clientID, sta
} }
req.Header.Set("Content-Type", "application/json-patch+json") req.Header.Set("Content-Type", "application/json-patch+json")
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -145,7 +144,7 @@ func (s *HydraAdminService) CreateClient(ctx context.Context, client domain.Hydr
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -174,7 +173,7 @@ func (s *HydraAdminService) UpdateClient(ctx context.Context, clientID string, c
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -202,7 +201,7 @@ func (s *HydraAdminService) DeleteClient(ctx context.Context, clientID string) e
return err return err
} }
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return err return err
} }
@@ -235,7 +234,7 @@ func (s *HydraAdminService) ListConsentSessions(ctx context.Context, subject, cl
return nil, err return nil, err
} }
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -276,7 +275,7 @@ func (s *HydraAdminService) RevokeConsentSessions(ctx context.Context, subject,
return err return err
} }
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return err return err
} }
@@ -289,7 +288,7 @@ func (s *HydraAdminService) RevokeConsentSessions(ctx context.Context, subject,
return nil return nil
} }
func (s *HydraAdminService) HttpClient() *http.Client { func (s *HydraAdminService) httpClient() *http.Client {
if s.HTTPClient != nil { if s.HTTPClient != nil {
return s.HTTPClient return s.HTTPClient
} }
@@ -367,7 +366,7 @@ func (s *HydraAdminService) GetConsentRequest(ctx context.Context, challenge str
return nil, fmt.Errorf("hydra admin: create request for get consent failed: %w", err) return nil, fmt.Errorf("hydra admin: create request for get consent failed: %w", err)
} }
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("hydra admin: get consent request failed: %w", err) return nil, fmt.Errorf("hydra admin: get consent request failed: %w", err)
} }
@@ -407,7 +406,7 @@ func (s *HydraAdminService) RejectConsentRequest(ctx context.Context, challenge
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("hydra admin: reject consent request failed: %w", err) return nil, fmt.Errorf("hydra admin: reject consent request failed: %w", err)
} }
@@ -449,7 +448,7 @@ func (s *HydraAdminService) RejectLoginRequest(ctx context.Context, challenge, e
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("hydra admin: reject login request failed: %w", err) return nil, fmt.Errorf("hydra admin: reject login request failed: %w", err)
} }
@@ -484,7 +483,7 @@ func (s *HydraAdminService) GetLoginRequest(ctx context.Context, challenge strin
return nil, fmt.Errorf("hydra admin: create request for get login failed: %w", err) return nil, fmt.Errorf("hydra admin: create request for get login failed: %w", err)
} }
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("hydra admin: get login request failed: %w", err) return nil, fmt.Errorf("hydra admin: get login request failed: %w", err)
} }
@@ -532,7 +531,7 @@ func (s *HydraAdminService) AcceptConsentRequest(ctx context.Context, challenge
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("hydra admin: accept consent request failed: %w", err) return nil, fmt.Errorf("hydra admin: accept consent request failed: %w", err)
} }
@@ -576,7 +575,7 @@ func (s *HydraAdminService) AcceptLoginRequest(ctx context.Context, challenge st
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := s.HttpClient().Do(req) resp, err := s.httpClient().Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("hydra admin: accept login request failed: %w", err) return nil, fmt.Errorf("hydra admin: accept login request failed: %w", err)
} }
@@ -597,34 +596,3 @@ func (s *HydraAdminService) AcceptLoginRequest(ctx context.Context, challenge st
return &AcceptLoginRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil return &AcceptLoginRequestResponse{RedirectTo: hydraResp.RedirectTo}, nil
} }
func (s *HydraAdminService) IntrospectToken(ctx context.Context, token string) (map[string]interface{}, error) {
endpoint := fmt.Sprintf("%s/admin/oauth2/introspect", strings.TrimRight(s.AdminURL, "/"))
data := url.Values{}
data.Set("token", token)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := s.HttpClient().Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("hydra admin: introspect failed status=%d body=%s", resp.StatusCode, string(body))
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return result, nil
}

View File

@@ -1,7 +1,6 @@
package service package service
import ( import (
"baron-sso-backend/internal/utils"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
@@ -10,6 +9,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"net/url" "net/url"
"os"
"time" "time"
) )
@@ -18,7 +18,6 @@ type KetoService interface {
CreateRelation(ctx context.Context, namespace, object, relation, subject string) error CreateRelation(ctx context.Context, namespace, object, relation, subject string) error
DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error DeleteRelation(ctx context.Context, namespace, object, relation, subject string) error
ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error) ListRelations(ctx context.Context, namespace, object, relation, subject string) ([]RelationTuple, error)
ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error)
} }
type ketoService struct { type ketoService struct {
@@ -28,8 +27,14 @@ type ketoService struct {
} }
func NewKetoService() KetoService { func NewKetoService() KetoService {
readURL := utils.GetEnv("KETO_READ_URL", "http://keto:4466") readURL := os.Getenv("KETO_READ_URL")
writeURL := utils.GetEnv("KETO_WRITE_URL", "http://keto:4467") if readURL == "" {
readURL = "http://keto:4466"
}
writeURL := os.Getenv("KETO_WRITE_URL")
if writeURL == "" {
writeURL = "http://keto:4467"
}
return &ketoService{ return &ketoService{
readURL: readURL, readURL: readURL,
@@ -187,40 +192,3 @@ func (s *ketoService) DeleteRelation(ctx context.Context, namespace, object, rel
slog.Info("Keto relation deleted", "namespace", namespace, "object", object, "relation", relation, "subject", subject) slog.Info("Keto relation deleted", "namespace", namespace, "object", object, "relation", relation, "subject", subject)
return nil return nil
} }
func (s *ketoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
u, _ := url.Parse(fmt.Sprintf("%s/relation-tuples", s.readURL))
q := u.Query()
q.Set("namespace", namespace)
q.Set("relation", relation)
q.Set("subject_id", subject)
u.RawQuery = q.Encode()
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("keto returned status %d: %s", resp.StatusCode, string(body))
}
var res relationTuplesResponse
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return nil, err
}
objects := make([]string, 0, len(res.RelationTuples))
seen := make(map[string]bool)
for _, rt := range res.RelationTuples {
if !seen[rt.Object] {
objects = append(objects, rt.Object)
seen[rt.Object] = true
}
}
return objects, nil
}

View File

@@ -1,7 +1,6 @@
package service package service
import ( import (
"baron-sso-backend/internal/utils"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
@@ -29,7 +28,7 @@ type KratosAdminService struct {
func NewKratosAdminService() *KratosAdminService { func NewKratosAdminService() *KratosAdminService {
return &KratosAdminService{ return &KratosAdminService{
AdminURL: utils.GetEnv("KRATOS_ADMIN_URL", "http://kratos:4434"), AdminURL: getenvKratos("KRATOS_ADMIN_URL", "http://kratos:4434"),
} }
} }
@@ -228,9 +227,8 @@ func (s *KratosAdminService) httpClient() *http.Client {
} }
func getenvKratos(key, fallback string) string { func getenvKratos(key, fallback string) string {
v := os.Getenv(key) if v := os.Getenv(key); v != "" {
if v == "" { return v
return fallback
} }
return strings.Trim(v, "\"") return fallback
} }

View File

@@ -2,7 +2,6 @@ package service
import ( import (
"baron-sso-backend/internal/domain" "baron-sso-backend/internal/domain"
"baron-sso-backend/internal/utils"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
@@ -28,9 +27,9 @@ type OryProvider struct {
func NewOryProvider() *OryProvider { func NewOryProvider() *OryProvider {
return &OryProvider{ return &OryProvider{
KratosAdminURL: utils.GetEnv("KRATOS_ADMIN_URL", "http://kratos:4434"), KratosAdminURL: getenv("KRATOS_ADMIN_URL", "http://kratos:4434"),
KratosPublicURL: utils.GetEnv("KRATOS_PUBLIC_URL", "http://kratos:4433"), KratosPublicURL: getenv("KRATOS_PUBLIC_URL", "http://kratos:4433"),
HydraAdminURL: utils.GetEnv("HYDRA_ADMIN_URL", "http://hydra:4445"), HydraAdminURL: getenv("HYDRA_ADMIN_URL", "http://hydra:4445"),
} }
} }
@@ -320,7 +319,6 @@ func (o *OryProvider) submitLoginCodeInit(loginID, returnTo string) (*domain.Lin
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
if resp.StatusCode >= 300 { if resp.StatusCode >= 300 {
slog.Warn("Ory link login init failed", "loginID", loginID, "flow_id", flowID, "status", resp.StatusCode, "body", string(respBody))
init, ok := parseKratosLinkLoginResponse(flowID, respBody) init, ok := parseKratosLinkLoginResponse(flowID, respBody)
if ok { if ok {
slog.Info("Ory link login initiated with non-2xx response", "loginID", loginID, "flow_id", flowID, "status", resp.StatusCode) slog.Info("Ory link login initiated with non-2xx response", "loginID", loginID, "flow_id", flowID, "status", resp.StatusCode)
@@ -729,12 +727,10 @@ func (o *OryProvider) UpdateUserPassword(loginID, newPassword string, r *http.Re
} }
func getenv(key, fallback string) string { func getenv(key, fallback string) string {
v := os.Getenv(key) if v := os.Getenv(key); v != "" {
if v == "" { return v
return fallback
} }
// Strip surrounding double quotes if present return fallback
return strings.Trim(v, "\"")
} }
// findIdentityID: Kratos Admin API에서 credentials_identifier로 검색 후 첫 번째 identity id 반환 // findIdentityID: Kratos Admin API에서 credentials_identifier로 검색 후 첫 번째 identity id 반환

View File

@@ -15,10 +15,6 @@ type RelyingPartyService interface {
ListByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.RelyingParty, error) ListByTenantIDs(ctx context.Context, tenantIDs []string) ([]domain.RelyingParty, error)
Update(ctx context.Context, clientID string, client domain.HydraClient) (*domain.RelyingParty, error) Update(ctx context.Context, clientID string, client domain.HydraClient) (*domain.RelyingParty, error)
Delete(ctx context.Context, clientID string) error Delete(ctx context.Context, clientID string) error
CheckPermission(ctx context.Context, userID, clientID, relation string) (bool, error)
AddOwner(ctx context.Context, clientID, subject string) error
RemoveOwner(ctx context.Context, clientID, subject string) error
ListOwners(ctx context.Context, clientID string) ([]string, error)
} }
type relyingPartyService struct { type relyingPartyService struct {
@@ -162,31 +158,6 @@ func (s *relyingPartyService) Delete(ctx context.Context, clientID string) error
return nil return nil
} }
func (s *relyingPartyService) CheckPermission(ctx context.Context, userID, clientID, relation string) (bool, error) {
return s.ketoService.CheckPermission(ctx, userID, "RelyingParty", clientID, relation)
}
func (s *relyingPartyService) AddOwner(ctx context.Context, clientID, subject string) error {
return s.ketoService.CreateRelation(ctx, "RelyingParty", clientID, "owners", subject)
}
func (s *relyingPartyService) RemoveOwner(ctx context.Context, clientID, subject string) error {
return s.ketoService.DeleteRelation(ctx, "RelyingParty", clientID, "owners", subject)
}
func (s *relyingPartyService) ListOwners(ctx context.Context, clientID string) ([]string, error) {
tuples, err := s.ketoService.ListRelations(ctx, "RelyingParty", clientID, "owners", "")
if err != nil {
return nil, err
}
subjects := make([]string, 0, len(tuples))
for _, t := range tuples {
subjects = append(subjects, t.SubjectID)
}
return subjects, nil
}
func (s *relyingPartyService) mapHydraToDomain(client *domain.HydraClient) *domain.RelyingParty { func (s *relyingPartyService) mapHydraToDomain(client *domain.HydraClient) *domain.RelyingParty {
if client == nil { if client == nil {
return nil return nil

View File

@@ -54,14 +54,6 @@ func (m *MockKetoService) ListRelations(ctx context.Context, namespace, object,
return args.Get(0).([]RelationTuple), args.Error(1) return args.Get(0).([]RelationTuple), args.Error(1)
} }
func (m *MockKetoService) ListObjects(ctx context.Context, namespace, relation, subject string) ([]string, error) {
args := m.Called(ctx, namespace, relation, subject)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]string), args.Error(1)
}
// --- Test Helpers --- // --- Test Helpers ---
type hydraRoundTripperFunc func(*http.Request) (*http.Response, error) type hydraRoundTripperFunc func(*http.Request) (*http.Response, error)

View File

@@ -18,12 +18,8 @@ type TenantService interface {
GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error)
GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error)
GetTenant(ctx context.Context, id string) (*domain.Tenant, error) GetTenant(ctx context.Context, id string) (*domain.Tenant, error)
ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error)
ApproveTenant(ctx context.Context, id string) error ApproveTenant(ctx context.Context, id string) error
SetKetoService(keto KetoService) // 추가 SetKetoService(keto KetoService) // 추가
AddTenantAdmin(ctx context.Context, tenantID, userID string) error
RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error
ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error)
} }
type tenantService struct { type tenantService struct {
@@ -43,60 +39,6 @@ func (s *tenantService) GetTenant(ctx context.Context, id string) (*domain.Tenan
return s.repo.FindByID(ctx, id) return s.repo.FindByID(ctx, id)
} }
func (s *tenantService) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
if s.keto == nil {
return nil, errors.New("keto service not initialized")
}
// 1. Get directly managed tenants
directTenantIDs, err := s.keto.ListObjects(ctx, "Tenant", "admins", userID)
if err != nil {
slog.Error("Failed to list directly managed tenants from Keto", "userID", userID, "error", err)
}
// 2. Get managed tenant groups
groupIDs, err := s.keto.ListObjects(ctx, "TenantGroup", "admins", userID)
if err != nil {
slog.Error("Failed to list managed tenant groups from Keto", "userID", userID, "error", err)
}
// 3. Get tenants belonging to those groups
var groupInheritedTenantIDs []string
for _, groupID := range groupIDs {
// In Keto, we defined: Tenant#parent_group@TenantGroup:GroupID#_
// To find tenants in a group, we look for relations where namespace=Tenant, relation=parent_group, subject=TenantGroup:GroupID#_
// Wait, my ListObjects lists objects given a subject.
// So subject="TenantGroup:"+groupID+"#_"
// Object is Tenant ID.
ts, err := s.keto.ListRelations(ctx, "Tenant", "", "parent_group", "TenantGroup:"+groupID)
if err == nil {
for _, t := range ts {
groupInheritedTenantIDs = append(groupInheritedTenantIDs, t.Object)
}
}
}
// Combine and deduplicate IDs
allIDsMap := make(map[string]bool)
for _, id := range directTenantIDs {
allIDsMap[id] = true
}
for _, id := range groupInheritedTenantIDs {
allIDsMap[id] = true
}
allIDs := make([]string, 0, len(allIDsMap))
for id := range allIDsMap {
allIDs = append(allIDs, id)
}
if len(allIDs) == 0 {
return []domain.Tenant{}, nil
}
return s.repo.FindByIDs(ctx, allIDs)
}
func (s *tenantService) RegisterTenant(ctx context.Context, name, slug, description string, domains []string) (*domain.Tenant, error) { func (s *tenantService) RegisterTenant(ctx context.Context, name, slug, description string, domains []string) (*domain.Tenant, error) {
// Validate Slug // Validate Slug
if ok, msg := utils.ValidateSlug(slug); !ok { if ok, msg := utils.ValidateSlug(slug); !ok {
@@ -211,35 +153,3 @@ func (s *tenantService) GetTenantByDomain(ctx context.Context, emailDomain strin
func (s *tenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) { func (s *tenantService) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
return s.repo.FindBySlug(ctx, slug) return s.repo.FindBySlug(ctx, slug)
} }
func (s *tenantService) AddTenantAdmin(ctx context.Context, tenantID, userID string) error {
if s.keto == nil {
return errors.New("keto service not initialized")
}
return s.keto.CreateRelation(ctx, "Tenant", tenantID, "admins", "User:"+userID)
}
func (s *tenantService) RemoveTenantAdmin(ctx context.Context, tenantID, userID string) error {
if s.keto == nil {
return errors.New("keto service not initialized")
}
return s.keto.DeleteRelation(ctx, "Tenant", tenantID, "admins", "User:"+userID)
}
func (s *tenantService) ListTenantAdmins(ctx context.Context, tenantID string) ([]string, error) {
if s.keto == nil {
return nil, errors.New("keto service not initialized")
}
tuples, err := s.keto.ListRelations(ctx, "Tenant", tenantID, "admins", "")
if err != nil {
return nil, err
}
userIDs := make([]string, 0, len(tuples))
for _, t := range tuples {
if len(t.SubjectID) > 5 && t.SubjectID[:5] == "User:" {
userIDs = append(userIDs, t.SubjectID[5:])
}
}
return userIDs, nil
}