1
0
forked from baron/baron-sso

feat(backend): implement dynamic multi-tenancy routing and CORS

This commit is contained in:
2026-03-03 15:27:05 +09:00
parent a6e7f1253c
commit 5423f920b7
6 changed files with 277 additions and 34 deletions

View File

@@ -345,12 +345,41 @@ func main() {
app.Use(middleware.ErrorCodeEnricher())
allowedOrigins := getEnv("CORS_ALLOWED_ORIGINS", "http://localhost:5000")
allowCredentials := allowedOrigins != "*"
userfrontURL := getEnv("USERFRONT_URL", "http://sso.hmac.kr")
baseDomain := ""
if u, err := url.Parse(userfrontURL); err == nil {
baseDomain = u.Hostname()
}
app.Use(cors.New(cors.Config{
AllowOrigins: allowedOrigins,
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
AllowOriginsFunc: func(origin string) bool {
// 1. Check static allowed list
for _, allowed := range strings.Split(allowedOrigins, ",") {
if origin == strings.TrimSpace(allowed) {
return true
}
}
// Parse origin URL
u, err := url.Parse(origin)
if err != nil {
return false
}
hostname := u.Hostname()
// 2. Check subdomains of base domain
if baseDomain != "" && (hostname == baseDomain || strings.HasSuffix(hostname, "."+baseDomain)) {
return true
}
// 3. Check registered tenant domains
// Use context.Background() as we don't have request context here easily
allowed, _ := tenantService.IsDomainAllowed(context.Background(), hostname)
return allowed
},
AllowHeaders: "Origin, Content-Type, Accept, Authorization, X-Test-Role, X-Mock-Role, X-Tenant-ID",
AllowMethods: "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS",
AllowCredentials: allowCredentials,
AllowCredentials: true,
}))
// Ensure COOKIE_SECRET is exactly 32 bytes for AES-256
@@ -482,6 +511,11 @@ func main() {
// Public Tenant Registration
api.Post("/tenants/registration", tenantHandler.RegisterTenantPublic)
// Tenant Context Middleware (identifies tenant from Host header)
api.Use(middleware.TenantContextMiddleware(middleware.TenantContextConfig{
TenantService: tenantService,
}))
// Auth Proxy Routes
auth := api.Group("/auth")
auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink)
@@ -490,21 +524,13 @@ func main() {
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
auth.Post("/password/login", authHandler.PasswordLogin)
auth.Get("/tenant-info", authHandler.GetTenantInfo)
auth.Get("/consent", authHandler.GetConsentRequest)
auth.Post("/consent/accept", authHandler.AcceptConsentRequest)
auth.Post("/consent/reject", authHandler.RejectConsentRequest)
auth.Post("/oidc/login/accept", authHandler.AcceptOidcLoginRequest)
auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink)
auth.Post("/enchanted-link/poll", authHandler.PollEnchantedLink)
auth.Post("/magic-link/verify", authHandler.VerifyMagicLink)
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
auth.Post("/password/login", authHandler.PasswordLogin)
auth.Get("/consent", authHandler.GetConsentRequest)
auth.Post("/consent/accept", authHandler.AcceptConsentRequest)
auth.Post("/password/reset/initiate", authHandler.InitiatePasswordReset)
// [Changed] Use Interstitial Page for GET to prevent Scanner consumption
auth.Get("/password/reset/verify", authHandler.VerifyPasswordResetPage)

View File

@@ -538,6 +538,55 @@ func (h *AuthHandler) getBearerToken(c *fiber.Ctx) string {
return parts[1]
}
func (h *AuthHandler) resolveUserfrontURL(c *fiber.Ctx) string {
// 1. Try to use the Host header from the request
host := c.Get("X-Forwarded-Host")
if host == "" {
host = c.Hostname()
}
// 2. Determine scheme
scheme := "https"
if os.Getenv("APP_ENV") == "dev" || os.Getenv("APP_ENV") == "" || c.Protocol() == "http" {
scheme = "http"
}
// 3. Fallback to env if host is not available or is localhost (and not in dev)
envURL := os.Getenv("USERFRONT_URL")
if envURL == "" {
envURL = "http://sso.hmac.kr"
}
if host == "" || (host == "localhost" && os.Getenv("APP_ENV") != "dev") {
return strings.TrimRight(envURL, "/")
}
return fmt.Sprintf("%s://%s", scheme, host)
}
func (h *AuthHandler) GetTenantInfo(c *fiber.Ctx) error {
tenantID, _ := c.Locals("tenant_id").(string)
if tenantID == "" {
return c.JSON(fiber.Map{
"isCentral": true,
})
}
tenant, err := h.TenantService.GetTenant(c.Context(), tenantID)
if err != nil {
return errorJSON(c, fiber.StatusNotFound, "Tenant not found")
}
return c.JSON(fiber.Map{
"isCentral": false,
"id": tenant.ID,
"name": tenant.Name,
"slug": tenant.Slug,
"description": tenant.Description,
"type": tenant.Type,
})
}
// normalizePhoneForLoginID는 전화번호를 IDP 조회에 적합한 형태(E.164)로 정규화합니다.
func normalizePhoneForLoginID(phone string) string {
normalized := strings.ReplaceAll(phone, "-", "")
@@ -920,10 +969,7 @@ func (h *AuthHandler) InitEnchantedLink(c *fiber.Ctx) error {
return errorJSON(c, fiber.StatusNotFound, "User not registered")
}
userfrontURL := os.Getenv("USERFRONT_URL")
if userfrontURL == "" {
userfrontURL = "http://sso.hmac.kr"
}
userfrontURL := h.resolveUserfrontURL(c)
if req.URI != "" {
userfrontURL = req.URI
}
@@ -1692,14 +1738,7 @@ func (h *AuthHandler) InitiatePasswordReset(c *fiber.Ctx) error {
return errorJSON(c, fiber.StatusInternalServerError, "Authentication service not configured")
}
userfrontURL := os.Getenv("USERFRONT_URL")
if userfrontURL == "" {
ale.Status = fiber.StatusInternalServerError
ale.LatencyMs = time.Since(startTime)
ale.ProviderError = "USERFRONT_URL is not set"
ale.Log(slog.LevelError, "USERFRONT_URL is not set")
return errorJSON(c, fiber.StatusInternalServerError, "USERFRONT_URL environment variable is not set")
}
userfrontURL := h.resolveUserfrontURL(c)
// [Changed] Point to Backend API for verification (which then redirects to Frontend)
redirectURL := fmt.Sprintf("%s/api/v1/auth/password/reset/verify", userfrontURL)
ale.RedirectTo = redirectURL
@@ -1863,10 +1902,7 @@ func (h *AuthHandler) ProcessPasswordResetToken(c *fiber.Ctx) error {
ale.LoginIDs["loginId"] = loginID
ale.LoginIDs["loginId_normalized"] = loginID
userfrontURL := strings.TrimRight(os.Getenv("USERFRONT_URL"), "/")
if userfrontURL == "" {
userfrontURL = "https://sso.hmac.kr"
}
userfrontURL := h.resolveUserfrontURL(c)
redirectBase, parseErr := url.Parse(userfrontURL + "/reset-password")
if parseErr != nil {
ale.Status = fiber.StatusInternalServerError
@@ -2000,10 +2036,7 @@ func (h *AuthHandler) InitQRLogin(c *fiber.Ctx) error {
}
// QR 코드 페이로드를 실제 접속 가능한 URL로 변경합니다.
userfrontURL := os.Getenv("USERFRONT_URL")
if userfrontURL == "" {
userfrontURL = "https://sso.hmac.kr"
}
userfrontURL := h.resolveUserfrontURL(c)
qrPayload := fmt.Sprintf("%s/ql/%s", strings.TrimRight(userfrontURL, "/"), qrRef)
slog.Info("[QR] Init", "pendingRef", pendingRef, "qrRef", qrRef, "url", qrPayload)

View File

@@ -123,7 +123,7 @@ func TestRequireKetoPermission_Success(t *testing.T) {
profile := &domain.UserProfileResponse{ID: "user1", Role: "user"}
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(profile, nil)
mockKeto.On("CheckPermission", mock.Anything, "user1", "tenants", "tenant1", "read").Return(true, nil)
mockKeto.On("CheckPermission", mock.Anything, "User:user1", "tenants", "tenant1", "read").Return(true, nil)
app.Get("/tenants/:id", RequireKetoPermission(config, "tenants", "read"), func(c *fiber.Ctx) error {
return c.SendString("ok")

View File

@@ -0,0 +1,69 @@
package middleware
import (
"baron-sso-backend/internal/service"
"log/slog"
"net/url"
"os"
"strings"
"github.com/gofiber/fiber/v2"
)
// TenantContextConfig defines the configuration for Tenant context middleware
type TenantContextConfig struct {
TenantService service.TenantService
}
// TenantContextMiddleware identifies the tenant based on the Host header (subdomain or custom domain)
func TenantContextMiddleware(config TenantContextConfig) fiber.Handler {
userfrontURL := os.Getenv("USERFRONT_URL")
baseDomain := ""
if userfrontURL != "" {
if u, err := url.Parse(userfrontURL); err == nil {
baseDomain = u.Hostname()
}
}
return func(c *fiber.Ctx) error {
hostname := c.Hostname()
if hostname == "" {
hostname = string(c.Request().Header.Host())
}
if hostname == "" {
return c.Next()
}
// 1. If it's the exact base domain or localhost, no specific tenant context from host
if hostname == baseDomain || hostname == "localhost" || hostname == "127.0.0.1" {
return c.Next()
}
// 2. Try to find by registered custom domain
tenant, err := config.TenantService.GetTenantByDomain(c.Context(), hostname)
if err == nil && tenant != nil {
slog.Debug("Tenant identified by custom domain", "hostname", hostname, "tenantID", tenant.ID)
c.Locals("tenant_id", tenant.ID)
c.Locals("tenant_slug", tenant.Slug)
return c.Next()
}
// 3. Try to find by subdomain (slug.baseDomain)
if baseDomain != "" && strings.HasSuffix(hostname, "."+baseDomain) {
slug := strings.TrimSuffix(hostname, "."+baseDomain)
// Handle cases like "www.sso.hmac.kr" if baseDomain is "sso.hmac.kr"
if slug != "" && slug != "www" {
tenant, err := config.TenantService.GetTenantBySlug(c.Context(), slug)
if err == nil && tenant != nil {
slog.Debug("Tenant identified by subdomain slug", "slug", slug, "tenantID", tenant.ID)
c.Locals("tenant_id", tenant.ID)
c.Locals("tenant_slug", tenant.Slug)
return c.Next()
}
}
}
return c.Next()
}
}

View File

@@ -0,0 +1,103 @@
package middleware
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"log/slog"
"net/http/httptest"
"os"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockTenantServiceForMiddleware struct {
mock.Mock
}
func (m *MockTenantServiceForMiddleware) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForMiddleware) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForMiddleware) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
args := m.Called(mock.Anything, emailDomain)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForMiddleware) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
args := m.Called(mock.Anything, slug)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForMiddleware) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForMiddleware) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
return nil, 0, nil
}
func (m *MockTenantServiceForMiddleware) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForMiddleware) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
return false, nil
}
func (m *MockTenantServiceForMiddleware) ApproveTenant(ctx context.Context, id string) error {
return nil
}
func (m *MockTenantServiceForMiddleware) SetKetoService(keto service.KetoService) {}
func TestTenantContextMiddleware(t *testing.T) {
os.Setenv("USERFRONT_URL", "https://sso.hmac.kr")
defer os.Unsetenv("USERFRONT_URL")
mockSvc := new(MockTenantServiceForMiddleware)
app := fiber.New()
app.Use(TenantContextMiddleware(TenantContextConfig{TenantService: mockSvc}))
app.Get("/test", func(c *fiber.Ctx) error {
return c.JSON(fiber.Map{
"tenant_id": c.Locals("tenant_id"),
"tenant_slug": c.Locals("tenant_slug"),
})
})
t.Run("Base Domain - No Tenant Context", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://sso.hmac.kr/test", nil)
req.Host = "sso.hmac.kr"
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
mockSvc.AssertNotCalled(t, "GetTenantByDomain", mock.Anything, "sso.hmac.kr")
})
t.Run("Subdomain - Identify by Slug", func(t *testing.T) {
mockSvc.On("GetTenantByDomain", mock.Anything, "tenant1.sso.hmac.kr").Return(nil, nil).Once()
mockSvc.On("GetTenantBySlug", mock.Anything, "tenant1").Return(&domain.Tenant{ID: "t1", Slug: "tenant1"}, nil).Once()
req := httptest.NewRequest("GET", "/test", nil)
req.Host = "tenant1.sso.hmac.kr"
req.Header.Set("Host", "tenant1.sso.hmac.kr")
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
mockSvc.AssertExpectations(t)
})
t.Run("Custom Domain - Identify by Domain", func(t *testing.T) {
mockSvc.On("GetTenantByDomain", mock.Anything, "company.com").Return(&domain.Tenant{ID: "t2", Slug: "company"}, nil).Once()
req := httptest.NewRequest("GET", "/test", nil)
req.Host = "company.com"
req.Header.Set("Host", "company.com")
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
mockSvc.AssertExpectations(t)
})
}

View File

@@ -20,6 +20,7 @@ type TenantService interface {
GetTenant(ctx context.Context, id string) (*domain.Tenant, error)
ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error)
ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error)
IsDomainAllowed(ctx context.Context, domainName string) (bool, error)
ApproveTenant(ctx context.Context, id string) error
SetKetoService(keto KetoService) // 추가
}
@@ -282,3 +283,14 @@ func (s *tenantService) ListTenants(ctx context.Context, limit, offset int, pare
// Let the repository handle the query and pagination
return s.repo.List(ctx, limit, offset, parentID)
}
func (s *tenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
tenant, err := s.repo.FindByDomain(ctx, domainName)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return false, err
}
return tenant != nil && tenant.Status == domain.TenantStatusActive, nil
}