forked from baron/baron-sso
feat(backend): implement dynamic multi-tenancy routing and CORS
This commit is contained in:
@@ -123,7 +123,7 @@ func TestRequireKetoPermission_Success(t *testing.T) {
|
||||
|
||||
profile := &domain.UserProfileResponse{ID: "user1", Role: "user"}
|
||||
mockAuth.On("GetEnrichedProfile", mock.Anything).Return(profile, nil)
|
||||
mockKeto.On("CheckPermission", mock.Anything, "user1", "tenants", "tenant1", "read").Return(true, nil)
|
||||
mockKeto.On("CheckPermission", mock.Anything, "User:user1", "tenants", "tenant1", "read").Return(true, nil)
|
||||
|
||||
app.Get("/tenants/:id", RequireKetoPermission(config, "tenants", "read"), func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
|
||||
69
backend/internal/middleware/tenant_middleware.go
Normal file
69
backend/internal/middleware/tenant_middleware.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// TenantContextConfig defines the configuration for Tenant context middleware
|
||||
type TenantContextConfig struct {
|
||||
TenantService service.TenantService
|
||||
}
|
||||
|
||||
// TenantContextMiddleware identifies the tenant based on the Host header (subdomain or custom domain)
|
||||
func TenantContextMiddleware(config TenantContextConfig) fiber.Handler {
|
||||
userfrontURL := os.Getenv("USERFRONT_URL")
|
||||
baseDomain := ""
|
||||
if userfrontURL != "" {
|
||||
if u, err := url.Parse(userfrontURL); err == nil {
|
||||
baseDomain = u.Hostname()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
hostname := c.Hostname()
|
||||
if hostname == "" {
|
||||
hostname = string(c.Request().Header.Host())
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 1. If it's the exact base domain or localhost, no specific tenant context from host
|
||||
if hostname == baseDomain || hostname == "localhost" || hostname == "127.0.0.1" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 2. Try to find by registered custom domain
|
||||
tenant, err := config.TenantService.GetTenantByDomain(c.Context(), hostname)
|
||||
if err == nil && tenant != nil {
|
||||
slog.Debug("Tenant identified by custom domain", "hostname", hostname, "tenantID", tenant.ID)
|
||||
c.Locals("tenant_id", tenant.ID)
|
||||
c.Locals("tenant_slug", tenant.Slug)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// 3. Try to find by subdomain (slug.baseDomain)
|
||||
if baseDomain != "" && strings.HasSuffix(hostname, "."+baseDomain) {
|
||||
slug := strings.TrimSuffix(hostname, "."+baseDomain)
|
||||
// Handle cases like "www.sso.hmac.kr" if baseDomain is "sso.hmac.kr"
|
||||
if slug != "" && slug != "www" {
|
||||
tenant, err := config.TenantService.GetTenantBySlug(c.Context(), slug)
|
||||
if err == nil && tenant != nil {
|
||||
slog.Debug("Tenant identified by subdomain slug", "slug", slug, "tenantID", tenant.ID)
|
||||
c.Locals("tenant_id", tenant.ID)
|
||||
c.Locals("tenant_slug", tenant.Slug)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
103
backend/internal/middleware/tenant_middleware_test.go
Normal file
103
backend/internal/middleware/tenant_middleware_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockTenantServiceForMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForMiddleware) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) {
|
||||
args := m.Called(mock.Anything, emailDomain)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) {
|
||||
args := m.Called(mock.Anything, slug)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.Tenant), args.Error(1)
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) ApproveTenant(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *MockTenantServiceForMiddleware) SetKetoService(keto service.KetoService) {}
|
||||
|
||||
func TestTenantContextMiddleware(t *testing.T) {
|
||||
os.Setenv("USERFRONT_URL", "https://sso.hmac.kr")
|
||||
defer os.Unsetenv("USERFRONT_URL")
|
||||
|
||||
mockSvc := new(MockTenantServiceForMiddleware)
|
||||
app := fiber.New()
|
||||
app.Use(TenantContextMiddleware(TenantContextConfig{TenantService: mockSvc}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"tenant_id": c.Locals("tenant_id"),
|
||||
"tenant_slug": c.Locals("tenant_slug"),
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Base Domain - No Tenant Context", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://sso.hmac.kr/test", nil)
|
||||
req.Host = "sso.hmac.kr"
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertNotCalled(t, "GetTenantByDomain", mock.Anything, "sso.hmac.kr")
|
||||
})
|
||||
|
||||
t.Run("Subdomain - Identify by Slug", func(t *testing.T) {
|
||||
mockSvc.On("GetTenantByDomain", mock.Anything, "tenant1.sso.hmac.kr").Return(nil, nil).Once()
|
||||
mockSvc.On("GetTenantBySlug", mock.Anything, "tenant1").Return(&domain.Tenant{ID: "t1", Slug: "tenant1"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Host = "tenant1.sso.hmac.kr"
|
||||
req.Header.Set("Host", "tenant1.sso.hmac.kr")
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Custom Domain - Identify by Domain", func(t *testing.T) {
|
||||
mockSvc.On("GetTenantByDomain", mock.Anything, "company.com").Return(&domain.Tenant{ID: "t2", Slug: "company"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Host = "company.com"
|
||||
req.Header.Set("Host", "company.com")
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user