forked from baron/baron-sso
70 lines
2.0 KiB
Go
70 lines
2.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"baron-sso-backend/internal/service"
|
|
"log/slog"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
)
|
|
|
|
// TenantContextConfig defines the configuration for Tenant context middleware
|
|
type TenantContextConfig struct {
|
|
TenantService service.TenantService
|
|
}
|
|
|
|
// TenantContextMiddleware identifies the tenant based on the Host header (subdomain or custom domain)
|
|
func TenantContextMiddleware(config TenantContextConfig) fiber.Handler {
|
|
userfrontURL := os.Getenv("USERFRONT_URL")
|
|
baseDomain := ""
|
|
if userfrontURL != "" {
|
|
if u, err := url.Parse(userfrontURL); err == nil {
|
|
baseDomain = u.Hostname()
|
|
}
|
|
}
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
hostname := c.Hostname()
|
|
if hostname == "" {
|
|
hostname = string(c.Request().Header.Host())
|
|
}
|
|
|
|
if hostname == "" {
|
|
return c.Next()
|
|
}
|
|
|
|
// 1. If it's the exact base domain or localhost, no specific tenant context from host
|
|
if hostname == baseDomain || hostname == "localhost" || hostname == "127.0.0.1" {
|
|
return c.Next()
|
|
}
|
|
|
|
// 2. Try to find by registered custom domain
|
|
tenant, err := config.TenantService.GetTenantByDomain(c.Context(), hostname)
|
|
if err == nil && tenant != nil {
|
|
slog.Debug("Tenant identified by custom domain", "hostname", hostname, "tenantID", tenant.ID)
|
|
c.Locals("tenant_id", tenant.ID)
|
|
c.Locals("tenant_slug", tenant.Slug)
|
|
return c.Next()
|
|
}
|
|
|
|
// 3. Try to find by subdomain (slug.baseDomain)
|
|
if baseDomain != "" && strings.HasSuffix(hostname, "."+baseDomain) {
|
|
slug := strings.TrimSuffix(hostname, "."+baseDomain)
|
|
// Handle cases like "www.sso.hmac.kr" if baseDomain is "sso.hmac.kr"
|
|
if slug != "" && slug != "www" {
|
|
tenant, err := config.TenantService.GetTenantBySlug(c.Context(), slug)
|
|
if err == nil && tenant != nil {
|
|
slog.Debug("Tenant identified by subdomain slug", "slug", slug, "tenantID", tenant.ID)
|
|
c.Locals("tenant_id", tenant.ID)
|
|
c.Locals("tenant_slug", tenant.Slug)
|
|
return c.Next()
|
|
}
|
|
}
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|