forked from baron/baron-sso
테넌트 등록 방식을 결정
This commit is contained in:
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/logger"
|
||||
"baron-sso-backend/internal/repository"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"bytes"
|
||||
@@ -78,6 +79,8 @@ type AuthHandler struct {
|
||||
IdpProvider domain.IdentityProvider
|
||||
AuditRepo domain.AuditRepository
|
||||
Hydra *service.HydraAdminService
|
||||
TenantService service.TenantService
|
||||
UserRepo repository.UserRepository
|
||||
}
|
||||
|
||||
type signupState struct {
|
||||
@@ -135,7 +138,7 @@ func checkPollInterval(redis *service.RedisService, key string, interval time.Du
|
||||
return false, int(interval.Seconds())
|
||||
}
|
||||
|
||||
func NewAuthHandler(redisService *service.RedisService, idpProvider domain.IdentityProvider, auditRepo domain.AuditRepository) *AuthHandler {
|
||||
func NewAuthHandler(redisService *service.RedisService, idpProvider domain.IdentityProvider, auditRepo domain.AuditRepository, tenantService service.TenantService, userRepo repository.UserRepository) *AuthHandler {
|
||||
projectID := os.Getenv("DESCOPE_PROJECT_ID")
|
||||
managementKey := os.Getenv("DESCOPE_MANAGEMENT_KEY")
|
||||
|
||||
@@ -161,6 +164,8 @@ func NewAuthHandler(redisService *service.RedisService, idpProvider domain.Ident
|
||||
IdpProvider: idpProvider,
|
||||
AuditRepo: auditRepo,
|
||||
Hydra: service.NewHydraAdminService(),
|
||||
TenantService: tenantService,
|
||||
UserRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,6 +362,9 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
|
||||
if req.Email == "" || req.Password == "" || req.Name == "" || req.Phone == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Missing required fields"})
|
||||
}
|
||||
if !strings.Contains(req.Email, "@") || !strings.Contains(req.Email, ".") {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid email format"})
|
||||
}
|
||||
if !req.TermsAccepted {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Terms must be accepted"})
|
||||
}
|
||||
@@ -385,6 +393,20 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Identity provider unavailable"})
|
||||
}
|
||||
|
||||
// [New] Auto-Assign Tenant by Domain
|
||||
companyCode := req.CompanyCode
|
||||
if companyCode == "" {
|
||||
parts := strings.Split(req.Email, "@")
|
||||
if len(parts) == 2 {
|
||||
domainName := parts[1]
|
||||
tenant, err := h.TenantService.GetTenantByDomain(c.Context(), domainName)
|
||||
if err == nil && tenant != nil {
|
||||
slog.Info("[Signup] Auto-assigning tenant", "email", req.Email, "tenant", tenant.Slug)
|
||||
companyCode = tenant.Slug
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize Phone (E.164 형태로 보관)
|
||||
normalizedPhone := strings.ReplaceAll(req.Phone, "-", "")
|
||||
normalizedPhone = strings.ReplaceAll(normalizedPhone, " ", "")
|
||||
@@ -398,7 +420,7 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
|
||||
attributes := map[string]interface{}{
|
||||
"department": req.Department,
|
||||
"affiliationType": req.AffiliationType,
|
||||
"companyCode": req.CompanyCode,
|
||||
"companyCode": companyCode,
|
||||
// grade는 기존 스키마 필수 키이므로 기본값을 설정
|
||||
"grade": "member",
|
||||
}
|
||||
@@ -427,6 +449,31 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
|
||||
|
||||
slog.Info("[Signup] New user registered", "email", req.Email, "type", req.AffiliationType, "provider", h.IdpProvider.Name(), "subject", providerID)
|
||||
|
||||
// [New] Local DB Sync
|
||||
localUser := &domain.User{
|
||||
ID: providerID, // Match IDP Subject
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
Phone: normalizedPhone,
|
||||
AffiliationType: req.AffiliationType,
|
||||
CompanyCode: companyCode,
|
||||
Department: req.Department,
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
// Link TenantID if possible
|
||||
if companyCode != "" {
|
||||
if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), companyCode); err == nil && tenant != nil {
|
||||
localUser.TenantID = &tenant.ID
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.UserRepo.Create(c.Context(), localUser); err != nil {
|
||||
slog.Error("[Signup] Failed to sync user to local DB", "email", req.Email, "error", err)
|
||||
// We don't fail the whole signup if local sync fails
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"success": true,
|
||||
"message": "User registered successfully",
|
||||
@@ -2057,6 +2104,13 @@ func (h *AuthHandler) GetMe(c *fiber.Ctx) error {
|
||||
AffiliationType: affType,
|
||||
CompanyCode: compCode,
|
||||
}
|
||||
|
||||
if compCode != "" {
|
||||
if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), compCode); err == nil && tenant != nil {
|
||||
resp.Tenant = tenant
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -11,21 +12,23 @@ import (
|
||||
)
|
||||
|
||||
type TenantHandler struct {
|
||||
DB *gorm.DB
|
||||
DB *gorm.DB
|
||||
Service service.TenantService
|
||||
}
|
||||
|
||||
func NewTenantHandler(db *gorm.DB) *TenantHandler {
|
||||
return &TenantHandler{DB: db}
|
||||
func NewTenantHandler(db *gorm.DB, svc service.TenantService) *TenantHandler {
|
||||
return &TenantHandler{DB: db, Service: svc}
|
||||
}
|
||||
|
||||
type tenantSummary struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
Domains []string `json:"domains,omitempty"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type tenantListResponse struct {
|
||||
@@ -55,7 +58,7 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
var tenants []domain.Tenant
|
||||
if err := h.DB.Order("created_at desc").Limit(limit).Offset(offset).Find(&tenants).Error; err != nil {
|
||||
if err := h.DB.Order("created_at desc").Limit(limit).Offset(offset).Preload("Domains").Find(&tenants).Error; err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
@@ -78,7 +81,7 @@ func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
var tenant domain.Tenant
|
||||
if err := h.DB.First(&tenant, "id = ?", tenantID).Error; err != nil {
|
||||
if err := h.DB.Preload("Domains").First(&tenant, "id = ?", tenantID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "tenant not found"})
|
||||
}
|
||||
@@ -94,10 +97,11 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
Domains []string `json:"domains"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
|
||||
@@ -121,25 +125,16 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
|
||||
status = "active"
|
||||
}
|
||||
|
||||
var exists domain.Tenant
|
||||
if err := h.DB.Unscoped().Where("slug = ?", slug).First(&exists).Error; err == nil {
|
||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "slug already exists"})
|
||||
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Use Service
|
||||
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, req.Description, req.Domains)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "already exists") {
|
||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
tenant := domain.Tenant{
|
||||
Name: name,
|
||||
Slug: slug,
|
||||
Description: strings.TrimSpace(req.Description),
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if err := h.DB.Create(&tenant).Error; err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(mapTenantSummary(tenant))
|
||||
return c.Status(fiber.StatusCreated).JSON(mapTenantSummary(*tenant))
|
||||
}
|
||||
|
||||
func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
@@ -161,10 +156,11 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name *string `json:"name"`
|
||||
Slug *string `json:"slug"`
|
||||
Description *string `json:"description"`
|
||||
Status *string `json:"status"`
|
||||
Name *string `json:"name"`
|
||||
Slug *string `json:"slug"`
|
||||
Description *string `json:"description"`
|
||||
Status *string `json:"status"`
|
||||
Domains []string `json:"domains"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request body"})
|
||||
@@ -207,6 +203,32 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Update domains if provided
|
||||
if req.Domains != nil {
|
||||
// Simple approach: Delete existing and recreate
|
||||
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to clear old domains"})
|
||||
}
|
||||
for _, d := range req.Domains {
|
||||
if strings.TrimSpace(d) == "" {
|
||||
continue
|
||||
}
|
||||
td := domain.TenantDomain{
|
||||
TenantID: tenant.ID,
|
||||
Domain: strings.TrimSpace(d),
|
||||
Verified: true,
|
||||
}
|
||||
if err := h.DB.Create(&td).Error; err != nil {
|
||||
// Log and continue or return error?
|
||||
// For now return error to be safe.
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to add domain: " + d})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Refetch to get updated relations
|
||||
h.DB.Preload("Domains").First(&tenant, "id = ?", tenant.ID)
|
||||
|
||||
return c.JSON(mapTenantSummary(tenant))
|
||||
}
|
||||
|
||||
@@ -228,12 +250,18 @@ func (h *TenantHandler) DeleteTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
func mapTenantSummary(t domain.Tenant) tenantSummary {
|
||||
domains := make([]string, 0, len(t.Domains))
|
||||
for _, d := range t.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
|
||||
return tenantSummary{
|
||||
ID: t.ID,
|
||||
Name: t.Name,
|
||||
Slug: t.Slug,
|
||||
Description: t.Description,
|
||||
Status: t.Status,
|
||||
Domains: domains,
|
||||
CreatedAt: t.CreatedAt.Format(time.RFC3339),
|
||||
UpdatedAt: t.UpdatedAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -11,29 +13,32 @@ import (
|
||||
)
|
||||
|
||||
type UserHandler struct {
|
||||
KratosAdmin *service.KratosAdminService
|
||||
OryProvider *service.OryProvider
|
||||
KratosAdmin *service.KratosAdminService
|
||||
OryProvider *service.OryProvider
|
||||
TenantService service.TenantService
|
||||
}
|
||||
|
||||
func NewUserHandler(kratosAdmin *service.KratosAdminService, oryProvider *service.OryProvider) *UserHandler {
|
||||
func NewUserHandler(kratosAdmin *service.KratosAdminService, oryProvider *service.OryProvider, tenantService service.TenantService) *UserHandler {
|
||||
return &UserHandler{
|
||||
KratosAdmin: kratosAdmin,
|
||||
OryProvider: oryProvider,
|
||||
KratosAdmin: kratosAdmin,
|
||||
OryProvider: oryProvider,
|
||||
TenantService: tenantService,
|
||||
}
|
||||
}
|
||||
|
||||
type userSummary struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Phone string `json:"phone"`
|
||||
Role string `json:"role"`
|
||||
Status string `json:"status"`
|
||||
CompanyCode string `json:"companyCode"`
|
||||
Department string `json:"department"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
InitialPassword string `json:"initialPassword,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Phone string `json:"phone"`
|
||||
Role string `json:"role"`
|
||||
Status string `json:"status"`
|
||||
CompanyCode string `json:"companyCode"`
|
||||
Tenant *domain.Tenant `json:"tenant,omitempty"`
|
||||
Department string `json:"department"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
InitialPassword string `json:"initialPassword,omitempty"`
|
||||
}
|
||||
|
||||
type userListResponse struct {
|
||||
@@ -89,7 +94,8 @@ func (h *UserHandler) ListUsers(c *fiber.Ctx) error {
|
||||
|
||||
items := make([]userSummary, 0, end-offset)
|
||||
for _, identity := range filtered[offset:end] {
|
||||
items = append(items, mapIdentitySummary(identity))
|
||||
summary := h.mapIdentitySummary(c.Context(), identity)
|
||||
items = append(items, summary)
|
||||
}
|
||||
|
||||
return c.JSON(userListResponse{Items: items, Limit: limit, Offset: offset, Total: total})
|
||||
@@ -113,7 +119,7 @@ func (h *UserHandler) GetUser(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "user not found"})
|
||||
}
|
||||
|
||||
return c.JSON(mapIdentitySummary(*identity))
|
||||
return c.JSON(h.mapIdentitySummary(c.Context(), *identity))
|
||||
}
|
||||
|
||||
func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
|
||||
@@ -138,6 +144,9 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
|
||||
if email == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "email is required"})
|
||||
}
|
||||
if !strings.Contains(email, "@") || !strings.Contains(email, ".") {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid email format"})
|
||||
}
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "name is required"})
|
||||
@@ -205,7 +214,7 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{"id": identityID, "initialPassword": generatedPassword})
|
||||
}
|
||||
|
||||
response := mapIdentitySummary(*identity)
|
||||
response := h.mapIdentitySummary(c.Context(), *identity)
|
||||
if generatedPassword != "" {
|
||||
response.InitialPassword = generatedPassword
|
||||
}
|
||||
@@ -279,7 +288,7 @@ func (h *UserHandler) UpdateUser(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(mapIdentitySummary(*updated))
|
||||
return c.JSON(h.mapIdentitySummary(c.Context(), *updated))
|
||||
}
|
||||
|
||||
func (h *UserHandler) DeleteUser(c *fiber.Ctx) error {
|
||||
@@ -299,24 +308,35 @@ func (h *UserHandler) DeleteUser(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
|
||||
func mapIdentitySummary(identity service.KratosIdentity) userSummary {
|
||||
func (h *UserHandler) mapIdentitySummary(ctx context.Context, identity service.KratosIdentity) userSummary {
|
||||
traits := identity.Traits
|
||||
role := extractTraitString(traits, "grade")
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
return userSummary{
|
||||
|
||||
compCode := extractTraitString(traits, "companyCode")
|
||||
slog.Debug("Mapping identity", "email", extractTraitString(traits, "email"), "compCode", compCode)
|
||||
summary := userSummary{
|
||||
ID: identity.ID,
|
||||
Email: extractTraitString(traits, "email"),
|
||||
Name: extractTraitString(traits, "name"),
|
||||
Phone: extractTraitString(traits, "phone_number"),
|
||||
Role: role,
|
||||
Status: normalizeStatus(identity.State),
|
||||
CompanyCode: extractTraitString(traits, "companyCode"),
|
||||
CompanyCode: compCode,
|
||||
Department: extractTraitString(traits, "department"),
|
||||
CreatedAt: formatTime(identity.CreatedAt),
|
||||
UpdatedAt: formatTime(identity.UpdatedAt),
|
||||
}
|
||||
|
||||
if compCode != "" && h.TenantService != nil {
|
||||
if tenant, err := h.TenantService.GetTenantBySlug(ctx, compCode); err == nil && tenant != nil {
|
||||
summary.Tenant = tenant
|
||||
}
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
func extractTraitString(traits map[string]interface{}, key string) string {
|
||||
|
||||
Reference in New Issue
Block a user