1
0
forked from baron/baron-sso

feat(backend): implement dynamic multi-tenancy routing and CORS

This commit is contained in:
2026-03-03 15:27:05 +09:00
parent a6e7f1253c
commit 5423f920b7
6 changed files with 277 additions and 34 deletions

View File

@@ -123,7 +123,7 @@ func TestRequireKetoPermission_Success(t *testing.T) {
profile := &domain.UserProfileResponse{ID: "user1", Role: "user"}
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(profile, nil)
mockKeto.On("CheckPermission", mock.Anything, "user1", "tenants", "tenant1", "read").Return(true, nil)
mockKeto.On("CheckPermission", mock.Anything, "User:user1", "tenants", "tenant1", "read").Return(true, nil)
app.Get("/tenants/:id", RequireKetoPermission(config, "tenants", "read"), func(c *fiber.Ctx) error {
return c.SendString("ok")

View File

@@ -0,0 +1,69 @@
package middleware
import (
"baron-sso-backend/internal/service"
"log/slog"
"net/url"
"os"
"strings"
"github.com/gofiber/fiber/v2"
)
// TenantContextConfig defines the configuration for Tenant context middleware
type TenantContextConfig struct {
TenantService service.TenantService
}
// TenantContextMiddleware identifies the tenant based on the Host header (subdomain or custom domain)
func TenantContextMiddleware(config TenantContextConfig) fiber.Handler {
userfrontURL := os.Getenv("USERFRONT_URL")
baseDomain := ""
if userfrontURL != "" {
if u, err := url.Parse(userfrontURL); err == nil {
baseDomain = u.Hostname()
}
}
return func(c *fiber.Ctx) error {
hostname := c.Hostname()
if hostname == "" {
hostname = string(c.Request().Header.Host())
}
if hostname == "" {
return c.Next()
}
// 1. If it's the exact base domain or localhost, no specific tenant context from host
if hostname == baseDomain || hostname == "localhost" || hostname == "127.0.0.1" {
return c.Next()
}
// 2. Try to find by registered custom domain
tenant, err := config.TenantService.GetTenantByDomain(c.Context(), hostname)
if err == nil && tenant != nil {
slog.Debug("Tenant identified by custom domain", "hostname", hostname, "tenantID", tenant.ID)
c.Locals("tenant_id", tenant.ID)
c.Locals("tenant_slug", tenant.Slug)
return c.Next()
}
// 3. Try to find by subdomain (slug.baseDomain)
if baseDomain != "" && strings.HasSuffix(hostname, "."+baseDomain) {
slug := strings.TrimSuffix(hostname, "."+baseDomain)
// Handle cases like "www.sso.hmac.kr" if baseDomain is "sso.hmac.kr"
if slug != "" && slug != "www" {
tenant, err := config.TenantService.GetTenantBySlug(c.Context(), slug)
if err == nil && tenant != nil {
slog.Debug("Tenant identified by subdomain slug", "slug", slug, "tenantID", tenant.ID)
c.Locals("tenant_id", tenant.ID)
c.Locals("tenant_slug", tenant.Slug)
return c.Next()
}
}
}
return c.Next()
}
}

View File

@@ -0,0 +1,103 @@
package middleware
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"log/slog"
"net/http/httptest"
"os"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockTenantServiceForMiddleware struct {
mock.Mock
}
func (m *MockTenantServiceForMiddleware) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForMiddleware) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForMiddleware) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
args := m.Called(mock.Anything, emailDomain)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForMiddleware) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
args := m.Called(mock.Anything, slug)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForMiddleware) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForMiddleware) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
return nil, 0, nil
}
func (m *MockTenantServiceForMiddleware) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
return nil, nil
}
func (m *MockTenantServiceForMiddleware) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
return false, nil
}
func (m *MockTenantServiceForMiddleware) ApproveTenant(ctx context.Context, id string) error {
return nil
}
func (m *MockTenantServiceForMiddleware) SetKetoService(keto service.KetoService) {}
func TestTenantContextMiddleware(t *testing.T) {
os.Setenv("USERFRONT_URL", "https://sso.hmac.kr")
defer os.Unsetenv("USERFRONT_URL")
mockSvc := new(MockTenantServiceForMiddleware)
app := fiber.New()
app.Use(TenantContextMiddleware(TenantContextConfig{TenantService: mockSvc}))
app.Get("/test", func(c *fiber.Ctx) error {
return c.JSON(fiber.Map{
"tenant_id": c.Locals("tenant_id"),
"tenant_slug": c.Locals("tenant_slug"),
})
})
t.Run("Base Domain - No Tenant Context", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://sso.hmac.kr/test", nil)
req.Host = "sso.hmac.kr"
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
mockSvc.AssertNotCalled(t, "GetTenantByDomain", mock.Anything, "sso.hmac.kr")
})
t.Run("Subdomain - Identify by Slug", func(t *testing.T) {
mockSvc.On("GetTenantByDomain", mock.Anything, "tenant1.sso.hmac.kr").Return(nil, nil).Once()
mockSvc.On("GetTenantBySlug", mock.Anything, "tenant1").Return(&domain.Tenant{ID: "t1", Slug: "tenant1"}, nil).Once()
req := httptest.NewRequest("GET", "/test", nil)
req.Host = "tenant1.sso.hmac.kr"
req.Header.Set("Host", "tenant1.sso.hmac.kr")
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
mockSvc.AssertExpectations(t)
})
t.Run("Custom Domain - Identify by Domain", func(t *testing.T) {
mockSvc.On("GetTenantByDomain", mock.Anything, "company.com").Return(&domain.Tenant{ID: "t2", Slug: "company"}, nil).Once()
req := httptest.NewRequest("GET", "/test", nil)
req.Host = "company.com"
req.Header.Set("Host", "company.com")
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
mockSvc.AssertExpectations(t)
})
}