forked from baron/baron-sso
feat(backend): implement dynamic multi-tenancy routing and CORS
This commit is contained in:
@@ -345,12 +345,41 @@ func main() {
|
||||
app.Use(middleware.ErrorCodeEnricher())
|
||||
|
||||
allowedOrigins := getEnv("CORS_ALLOWED_ORIGINS", "http://localhost:5000")
|
||||
allowCredentials := allowedOrigins != "*"
|
||||
userfrontURL := getEnv("USERFRONT_URL", "http://sso.hmac.kr")
|
||||
baseDomain := ""
|
||||
if u, err := url.Parse(userfrontURL); err == nil {
|
||||
baseDomain = u.Hostname()
|
||||
}
|
||||
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: allowedOrigins,
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
|
||||
AllowOriginsFunc: func(origin string) bool {
|
||||
// 1. Check static allowed list
|
||||
for _, allowed := range strings.Split(allowedOrigins, ",") {
|
||||
if origin == strings.TrimSpace(allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Parse origin URL
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
hostname := u.Hostname()
|
||||
|
||||
// 2. Check subdomains of base domain
|
||||
if baseDomain != "" && (hostname == baseDomain || strings.HasSuffix(hostname, "."+baseDomain)) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 3. Check registered tenant domains
|
||||
// Use context.Background() as we don't have request context here easily
|
||||
allowed, _ := tenantService.IsDomainAllowed(context.Background(), hostname)
|
||||
return allowed
|
||||
},
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization, X-Test-Role, X-Mock-Role, X-Tenant-ID",
|
||||
AllowMethods: "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS",
|
||||
AllowCredentials: allowCredentials,
|
||||
AllowCredentials: true,
|
||||
}))
|
||||
|
||||
// Ensure COOKIE_SECRET is exactly 32 bytes for AES-256
|
||||
@@ -482,6 +511,11 @@ func main() {
|
||||
// Public Tenant Registration
|
||||
api.Post("/tenants/registration", tenantHandler.RegisterTenantPublic)
|
||||
|
||||
// Tenant Context Middleware (identifies tenant from Host header)
|
||||
api.Use(middleware.TenantContextMiddleware(middleware.TenantContextConfig{
|
||||
TenantService: tenantService,
|
||||
}))
|
||||
|
||||
// Auth Proxy Routes
|
||||
auth := api.Group("/auth")
|
||||
auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink)
|
||||
@@ -490,21 +524,13 @@ func main() {
|
||||
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
|
||||
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
|
||||
auth.Post("/password/login", authHandler.PasswordLogin)
|
||||
auth.Get("/tenant-info", authHandler.GetTenantInfo)
|
||||
auth.Get("/consent", authHandler.GetConsentRequest)
|
||||
auth.Post("/consent/accept", authHandler.AcceptConsentRequest)
|
||||
auth.Post("/consent/reject", authHandler.RejectConsentRequest)
|
||||
|
||||
auth.Post("/oidc/login/accept", authHandler.AcceptOidcLoginRequest)
|
||||
|
||||
auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink)
|
||||
auth.Post("/enchanted-link/poll", authHandler.PollEnchantedLink)
|
||||
auth.Post("/magic-link/verify", authHandler.VerifyMagicLink)
|
||||
auth.Post("/login/code/verify", authHandler.VerifyLoginCode)
|
||||
auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode)
|
||||
auth.Post("/password/login", authHandler.PasswordLogin)
|
||||
auth.Get("/consent", authHandler.GetConsentRequest)
|
||||
auth.Post("/consent/accept", authHandler.AcceptConsentRequest)
|
||||
|
||||
auth.Post("/password/reset/initiate", authHandler.InitiatePasswordReset)
|
||||
// [Changed] Use Interstitial Page for GET to prevent Scanner consumption
|
||||
auth.Get("/password/reset/verify", authHandler.VerifyPasswordResetPage)
|
||||
|
||||
@@ -538,6 +538,55 @@ func (h *AuthHandler) getBearerToken(c *fiber.Ctx) string {
|
||||
return parts[1]
|
||||
}
|
||||
|
||||
func (h *AuthHandler) resolveUserfrontURL(c *fiber.Ctx) string {
|
||||
// 1. Try to use the Host header from the request
|
||||
host := c.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = c.Hostname()
|
||||
}
|
||||
|
||||
// 2. Determine scheme
|
||||
scheme := "https"
|
||||
if os.Getenv("APP_ENV") == "dev" || os.Getenv("APP_ENV") == "" || c.Protocol() == "http" {
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
// 3. Fallback to env if host is not available or is localhost (and not in dev)
|
||||
envURL := os.Getenv("USERFRONT_URL")
|
||||
if envURL == "" {
|
||||
envURL = "http://sso.hmac.kr"
|
||||
}
|
||||
|
||||
if host == "" || (host == "localhost" && os.Getenv("APP_ENV") != "dev") {
|
||||
return strings.TrimRight(envURL, "/")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s", scheme, host)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) GetTenantInfo(c *fiber.Ctx) error {
|
||||
tenantID, _ := c.Locals("tenant_id").(string)
|
||||
if tenantID == "" {
|
||||
return c.JSON(fiber.Map{
|
||||
"isCentral": true,
|
||||
})
|
||||
}
|
||||
|
||||
tenant, err := h.TenantService.GetTenant(c.Context(), tenantID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusNotFound, "Tenant not found")
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"isCentral": false,
|
||||
"id": tenant.ID,
|
||||
"name": tenant.Name,
|
||||
"slug": tenant.Slug,
|
||||
"description": tenant.Description,
|
||||
"type": tenant.Type,
|
||||
})
|
||||
}
|
||||
|
||||
// normalizePhoneForLoginID는 전화번호를 IDP 조회에 적합한 형태(E.164)로 정규화합니다.
|
||||
func normalizePhoneForLoginID(phone string) string {
|
||||
normalized := strings.ReplaceAll(phone, "-", "")
|
||||
@@ -920,10 +969,7 @@ func (h *AuthHandler) InitEnchantedLink(c *fiber.Ctx) error {
|
||||
return errorJSON(c, fiber.StatusNotFound, "User not registered")
|
||||
}
|
||||
|
||||
userfrontURL := os.Getenv("USERFRONT_URL")
|
||||
if userfrontURL == "" {
|
||||
userfrontURL = "http://sso.hmac.kr"
|
||||
}
|
||||
userfrontURL := h.resolveUserfrontURL(c)
|
||||
if req.URI != "" {
|
||||
userfrontURL = req.URI
|
||||
}
|
||||
@@ -1692,14 +1738,7 @@ func (h *AuthHandler) InitiatePasswordReset(c *fiber.Ctx) error {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "Authentication service not configured")
|
||||
}
|
||||
|
||||
userfrontURL := os.Getenv("USERFRONT_URL")
|
||||
if userfrontURL == "" {
|
||||
ale.Status = fiber.StatusInternalServerError
|
||||
ale.LatencyMs = time.Since(startTime)
|
||||
ale.ProviderError = "USERFRONT_URL is not set"
|
||||
ale.Log(slog.LevelError, "USERFRONT_URL is not set")
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "USERFRONT_URL environment variable is not set")
|
||||
}
|
||||
userfrontURL := h.resolveUserfrontURL(c)
|
||||
// [Changed] Point to Backend API for verification (which then redirects to Frontend)
|
||||
redirectURL := fmt.Sprintf("%s/api/v1/auth/password/reset/verify", userfrontURL)
|
||||
ale.RedirectTo = redirectURL
|
||||
@@ -1863,10 +1902,7 @@ func (h *AuthHandler) ProcessPasswordResetToken(c *fiber.Ctx) error {
|
||||
ale.LoginIDs["loginId"] = loginID
|
||||
ale.LoginIDs["loginId_normalized"] = loginID
|
||||
|
||||
userfrontURL := strings.TrimRight(os.Getenv("USERFRONT_URL"), "/")
|
||||
if userfrontURL == "" {
|
||||
userfrontURL = "https://sso.hmac.kr"
|
||||
}
|
||||
userfrontURL := h.resolveUserfrontURL(c)
|
||||
redirectBase, parseErr := url.Parse(userfrontURL + "/reset-password")
|
||||
if parseErr != nil {
|
||||
ale.Status = fiber.StatusInternalServerError
|
||||
@@ -2000,10 +2036,7 @@ func (h *AuthHandler) InitQRLogin(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// QR 코드 페이로드를 실제 접속 가능한 URL로 변경합니다.
|
||||
userfrontURL := os.Getenv("USERFRONT_URL")
|
||||
if userfrontURL == "" {
|
||||
userfrontURL = "https://sso.hmac.kr"
|
||||
}
|
||||
userfrontURL := h.resolveUserfrontURL(c)
|
||||
qrPayload := fmt.Sprintf("%s/ql/%s", strings.TrimRight(userfrontURL, "/"), qrRef)
|
||||
|
||||
slog.Info("[QR] Init", "pendingRef", pendingRef, "qrRef", qrRef, "url", qrPayload)
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestRequireKetoPermission_Success(t *testing.T) {
|
||||
|
||||
profile := &domain.UserProfileResponse{ID: "user1", Role: "user"}
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(profile, nil)
|
||||
mockKeto.On("CheckPermission", mock.Anything, "user1", "tenants", "tenant1", "read").Return(true, nil)
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user1", "tenants", "tenant1", "read").Return(true, nil)
|
||||
|
||||
app.Get("/tenants/:id", RequireKetoPermission(config, "tenants", "read"), func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
|
||||
69
backend/internal/middleware/tenant_middleware.go
Normal file
69
backend/internal/middleware/tenant_middleware.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// TenantContextConfig defines the configuration for Tenant context middleware
|
||||
type TenantContextConfig struct {
|
||||
TenantService service.TenantService
|
||||
}
|
||||
|
||||
// TenantContextMiddleware identifies the tenant based on the Host header (subdomain or custom domain)
|
||||
func TenantContextMiddleware(config TenantContextConfig) fiber.Handler {
|
||||
userfrontURL := os.Getenv("USERFRONT_URL")
|
||||
baseDomain := ""
|
||||
if userfrontURL != "" {
|
||||
if u, err := url.Parse(userfrontURL); err == nil {
|
||||
baseDomain = u.Hostname()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
hostname := c.Hostname()
|
||||
if hostname == "" {
|
||||
hostname = string(c.Request().Header.Host())
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 1. If it's the exact base domain or localhost, no specific tenant context from host
|
||||
if hostname == baseDomain || hostname == "localhost" || hostname == "127.0.0.1" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 2. Try to find by registered custom domain
|
||||
tenant, err := config.TenantService.GetTenantByDomain(c.Context(), hostname)
|
||||
if err == nil && tenant != nil {
|
||||
slog.Debug("Tenant identified by custom domain", "hostname", hostname, "tenantID", tenant.ID)
|
||||
c.Locals("tenant_id", tenant.ID)
|
||||
c.Locals("tenant_slug", tenant.Slug)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 3. Try to find by subdomain (slug.baseDomain)
|
||||
if baseDomain != "" && strings.HasSuffix(hostname, "."+baseDomain) {
|
||||
slug := strings.TrimSuffix(hostname, "."+baseDomain)
|
||||
// Handle cases like "www.sso.hmac.kr" if baseDomain is "sso.hmac.kr"
|
||||
if slug != "" && slug != "www" {
|
||||
tenant, err := config.TenantService.GetTenantBySlug(c.Context(), slug)
|
||||
if err == nil && tenant != nil {
|
||||
slog.Debug("Tenant identified by subdomain slug", "slug", slug, "tenantID", tenant.ID)
|
||||
c.Locals("tenant_id", tenant.ID)
|
||||
c.Locals("tenant_slug", tenant.Slug)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
103
backend/internal/middleware/tenant_middleware_test.go
Normal file
103
backend/internal/middleware/tenant_middleware_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockTenantServiceForMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
|
||||
args := m.Called(mock.Anything, emailDomain)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
args := m.Called(mock.Anything, slug)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) ApproveTenant(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) SetKetoService(keto service.KetoService) {}
|
||||
|
||||
func TestTenantContextMiddleware(t *testing.T) {
|
||||
os.Setenv("USERFRONT_URL", "https://sso.hmac.kr")
|
||||
defer os.Unsetenv("USERFRONT_URL")
|
||||
|
||||
mockSvc := new(MockTenantServiceForMiddleware)
|
||||
app := fiber.New()
|
||||
app.Use(TenantContextMiddleware(TenantContextConfig{TenantService: mockSvc}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"tenant_id": c.Locals("tenant_id"),
|
||||
"tenant_slug": c.Locals("tenant_slug"),
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Base Domain - No Tenant Context", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://sso.hmac.kr/test", nil)
|
||||
req.Host = "sso.hmac.kr"
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertNotCalled(t, "GetTenantByDomain", mock.Anything, "sso.hmac.kr")
|
||||
})
|
||||
|
||||
t.Run("Subdomain - Identify by Slug", func(t *testing.T) {
|
||||
mockSvc.On("GetTenantByDomain", mock.Anything, "tenant1.sso.hmac.kr").Return(nil, nil).Once()
|
||||
mockSvc.On("GetTenantBySlug", mock.Anything, "tenant1").Return(&domain.Tenant{ID: "t1", Slug: "tenant1"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Host = "tenant1.sso.hmac.kr"
|
||||
req.Header.Set("Host", "tenant1.sso.hmac.kr")
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Custom Domain - Identify by Domain", func(t *testing.T) {
|
||||
mockSvc.On("GetTenantByDomain", mock.Anything, "company.com").Return(&domain.Tenant{ID: "t2", Slug: "company"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Host = "company.com"
|
||||
req.Header.Set("Host", "company.com")
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
@@ -20,6 +20,7 @@ type TenantService interface {
|
||||
GetTenant(ctx context.Context, id string) (*domain.Tenant, error)
|
||||
ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error)
|
||||
ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error)
|
||||
IsDomainAllowed(ctx context.Context, domainName string) (bool, error)
|
||||
ApproveTenant(ctx context.Context, id string) error
|
||||
SetKetoService(keto KetoService) // 추가
|
||||
}
|
||||
@@ -282,3 +283,14 @@ func (s *tenantService) ListTenants(ctx context.Context, limit, offset int, pare
|
||||
// Let the repository handle the query and pagination
|
||||
return s.repo.List(ctx, limit, offset, parentID)
|
||||
}
|
||||
|
||||
func (s *tenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
tenant, err := s.repo.FindByDomain(ctx, domainName)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return tenant != nil && tenant.Status == domain.TenantStatusActive, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user