diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index fb2b361b..ed869a0b 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -345,12 +345,41 @@ func main() { app.Use(middleware.ErrorCodeEnricher()) allowedOrigins := getEnv("CORS_ALLOWED_ORIGINS", "http://localhost:5000") - allowCredentials := allowedOrigins != "*" + userfrontURL := getEnv("USERFRONT_URL", "http://sso.hmac.kr") + baseDomain := "" + if u, err := url.Parse(userfrontURL); err == nil { + baseDomain = u.Hostname() + } + app.Use(cors.New(cors.Config{ - AllowOrigins: allowedOrigins, - AllowHeaders: "Origin, Content-Type, Accept, Authorization", + AllowOriginsFunc: func(origin string) bool { + // 1. Check static allowed list + for _, allowed := range strings.Split(allowedOrigins, ",") { + if origin == strings.TrimSpace(allowed) { + return true + } + } + + // Parse origin URL + u, err := url.Parse(origin) + if err != nil { + return false + } + hostname := u.Hostname() + + // 2. Check subdomains of base domain + if baseDomain != "" && (hostname == baseDomain || strings.HasSuffix(hostname, "."+baseDomain)) { + return true + } + + // 3. Check registered tenant domains + // Use context.Background() as we don't have request context here easily + allowed, _ := tenantService.IsDomainAllowed(context.Background(), hostname) + return allowed + }, + AllowHeaders: "Origin, Content-Type, Accept, Authorization, X-Test-Role, X-Mock-Role, X-Tenant-ID", AllowMethods: "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS", - AllowCredentials: allowCredentials, + AllowCredentials: true, })) // Ensure COOKIE_SECRET is exactly 32 bytes for AES-256 @@ -482,6 +511,11 @@ func main() { // Public Tenant Registration api.Post("/tenants/registration", tenantHandler.RegisterTenantPublic) + // Tenant Context Middleware (identifies tenant from Host header) + api.Use(middleware.TenantContextMiddleware(middleware.TenantContextConfig{ + TenantService: tenantService, + })) + // Auth Proxy Routes auth := api.Group("/auth") auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink) @@ -490,21 +524,13 @@ func main() { auth.Post("/login/code/verify", authHandler.VerifyLoginCode) auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode) auth.Post("/password/login", authHandler.PasswordLogin) + auth.Get("/tenant-info", authHandler.GetTenantInfo) auth.Get("/consent", authHandler.GetConsentRequest) auth.Post("/consent/accept", authHandler.AcceptConsentRequest) auth.Post("/consent/reject", authHandler.RejectConsentRequest) auth.Post("/oidc/login/accept", authHandler.AcceptOidcLoginRequest) - auth.Post("/enchanted-link/init", authHandler.InitEnchantedLink) - auth.Post("/enchanted-link/poll", authHandler.PollEnchantedLink) - auth.Post("/magic-link/verify", authHandler.VerifyMagicLink) - auth.Post("/login/code/verify", authHandler.VerifyLoginCode) - auth.Post("/login/code/verify-short", authHandler.VerifyLoginShortCode) - auth.Post("/password/login", authHandler.PasswordLogin) - auth.Get("/consent", authHandler.GetConsentRequest) - auth.Post("/consent/accept", authHandler.AcceptConsentRequest) - auth.Post("/password/reset/initiate", authHandler.InitiatePasswordReset) // [Changed] Use Interstitial Page for GET to prevent Scanner consumption auth.Get("/password/reset/verify", authHandler.VerifyPasswordResetPage) diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 5e5b058f..d0d63baa 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -538,6 +538,55 @@ func (h *AuthHandler) getBearerToken(c *fiber.Ctx) string { return parts[1] } +func (h *AuthHandler) resolveUserfrontURL(c *fiber.Ctx) string { + // 1. Try to use the Host header from the request + host := c.Get("X-Forwarded-Host") + if host == "" { + host = c.Hostname() + } + + // 2. Determine scheme + scheme := "https" + if os.Getenv("APP_ENV") == "dev" || os.Getenv("APP_ENV") == "" || c.Protocol() == "http" { + scheme = "http" + } + + // 3. Fallback to env if host is not available or is localhost (and not in dev) + envURL := os.Getenv("USERFRONT_URL") + if envURL == "" { + envURL = "http://sso.hmac.kr" + } + + if host == "" || (host == "localhost" && os.Getenv("APP_ENV") != "dev") { + return strings.TrimRight(envURL, "/") + } + + return fmt.Sprintf("%s://%s", scheme, host) +} + +func (h *AuthHandler) GetTenantInfo(c *fiber.Ctx) error { + tenantID, _ := c.Locals("tenant_id").(string) + if tenantID == "" { + return c.JSON(fiber.Map{ + "isCentral": true, + }) + } + + tenant, err := h.TenantService.GetTenant(c.Context(), tenantID) + if err != nil { + return errorJSON(c, fiber.StatusNotFound, "Tenant not found") + } + + return c.JSON(fiber.Map{ + "isCentral": false, + "id": tenant.ID, + "name": tenant.Name, + "slug": tenant.Slug, + "description": tenant.Description, + "type": tenant.Type, + }) +} + // normalizePhoneForLoginID는 전화번호를 IDP 조회에 적합한 형태(E.164)로 정규화합니다. func normalizePhoneForLoginID(phone string) string { normalized := strings.ReplaceAll(phone, "-", "") @@ -920,10 +969,7 @@ func (h *AuthHandler) InitEnchantedLink(c *fiber.Ctx) error { return errorJSON(c, fiber.StatusNotFound, "User not registered") } - userfrontURL := os.Getenv("USERFRONT_URL") - if userfrontURL == "" { - userfrontURL = "http://sso.hmac.kr" - } + userfrontURL := h.resolveUserfrontURL(c) if req.URI != "" { userfrontURL = req.URI } @@ -1692,14 +1738,7 @@ func (h *AuthHandler) InitiatePasswordReset(c *fiber.Ctx) error { return errorJSON(c, fiber.StatusInternalServerError, "Authentication service not configured") } - userfrontURL := os.Getenv("USERFRONT_URL") - if userfrontURL == "" { - ale.Status = fiber.StatusInternalServerError - ale.LatencyMs = time.Since(startTime) - ale.ProviderError = "USERFRONT_URL is not set" - ale.Log(slog.LevelError, "USERFRONT_URL is not set") - return errorJSON(c, fiber.StatusInternalServerError, "USERFRONT_URL environment variable is not set") - } + userfrontURL := h.resolveUserfrontURL(c) // [Changed] Point to Backend API for verification (which then redirects to Frontend) redirectURL := fmt.Sprintf("%s/api/v1/auth/password/reset/verify", userfrontURL) ale.RedirectTo = redirectURL @@ -1863,10 +1902,7 @@ func (h *AuthHandler) ProcessPasswordResetToken(c *fiber.Ctx) error { ale.LoginIDs["loginId"] = loginID ale.LoginIDs["loginId_normalized"] = loginID - userfrontURL := strings.TrimRight(os.Getenv("USERFRONT_URL"), "/") - if userfrontURL == "" { - userfrontURL = "https://sso.hmac.kr" - } + userfrontURL := h.resolveUserfrontURL(c) redirectBase, parseErr := url.Parse(userfrontURL + "/reset-password") if parseErr != nil { ale.Status = fiber.StatusInternalServerError @@ -2000,10 +2036,7 @@ func (h *AuthHandler) InitQRLogin(c *fiber.Ctx) error { } // QR 코드 페이로드를 실제 접속 가능한 URL로 변경합니다. - userfrontURL := os.Getenv("USERFRONT_URL") - if userfrontURL == "" { - userfrontURL = "https://sso.hmac.kr" - } + userfrontURL := h.resolveUserfrontURL(c) qrPayload := fmt.Sprintf("%s/ql/%s", strings.TrimRight(userfrontURL, "/"), qrRef) slog.Info("[QR] Init", "pendingRef", pendingRef, "qrRef", qrRef, "url", qrPayload) diff --git a/backend/internal/middleware/rbac_test.go b/backend/internal/middleware/rbac_test.go index db9ff925..54bd4b9d 100644 --- a/backend/internal/middleware/rbac_test.go +++ b/backend/internal/middleware/rbac_test.go @@ -123,7 +123,7 @@ func TestRequireKetoPermission_Success(t *testing.T) { profile := &domain.UserProfileResponse{ID: "user1", Role: "user"} mockAuth.On("GetEnrichedProfile", mock.Anything).Return(profile, nil) - mockKeto.On("CheckPermission", mock.Anything, "user1", "tenants", "tenant1", "read").Return(true, nil) + mockKeto.On("CheckPermission", mock.Anything, "User:user1", "tenants", "tenant1", "read").Return(true, nil) app.Get("/tenants/:id", RequireKetoPermission(config, "tenants", "read"), func(c *fiber.Ctx) error { return c.SendString("ok") diff --git a/backend/internal/middleware/tenant_middleware.go b/backend/internal/middleware/tenant_middleware.go new file mode 100644 index 00000000..80eb346f --- /dev/null +++ b/backend/internal/middleware/tenant_middleware.go @@ -0,0 +1,69 @@ +package middleware + +import ( + "baron-sso-backend/internal/service" + "log/slog" + "net/url" + "os" + "strings" + + "github.com/gofiber/fiber/v2" +) + +// TenantContextConfig defines the configuration for Tenant context middleware +type TenantContextConfig struct { + TenantService service.TenantService +} + +// TenantContextMiddleware identifies the tenant based on the Host header (subdomain or custom domain) +func TenantContextMiddleware(config TenantContextConfig) fiber.Handler { + userfrontURL := os.Getenv("USERFRONT_URL") + baseDomain := "" + if userfrontURL != "" { + if u, err := url.Parse(userfrontURL); err == nil { + baseDomain = u.Hostname() + } + } + + return func(c *fiber.Ctx) error { + hostname := c.Hostname() + if hostname == "" { + hostname = string(c.Request().Header.Host()) + } + + if hostname == "" { + return c.Next() + } + + // 1. If it's the exact base domain or localhost, no specific tenant context from host + if hostname == baseDomain || hostname == "localhost" || hostname == "127.0.0.1" { + return c.Next() + } + + // 2. Try to find by registered custom domain + tenant, err := config.TenantService.GetTenantByDomain(c.Context(), hostname) + if err == nil && tenant != nil { + slog.Debug("Tenant identified by custom domain", "hostname", hostname, "tenantID", tenant.ID) + c.Locals("tenant_id", tenant.ID) + c.Locals("tenant_slug", tenant.Slug) + return c.Next() + } + + // 3. Try to find by subdomain (slug.baseDomain) + if baseDomain != "" && strings.HasSuffix(hostname, "."+baseDomain) { + slug := strings.TrimSuffix(hostname, "."+baseDomain) + // Handle cases like "www.sso.hmac.kr" if baseDomain is "sso.hmac.kr" + if slug != "" && slug != "www" { + tenant, err := config.TenantService.GetTenantBySlug(c.Context(), slug) + if err == nil && tenant != nil { + slog.Debug("Tenant identified by subdomain slug", "slug", slug, "tenantID", tenant.ID) + c.Locals("tenant_id", tenant.ID) + c.Locals("tenant_slug", tenant.Slug) + return c.Next() + } + } + } + + return c.Next() + } +} diff --git a/backend/internal/middleware/tenant_middleware_test.go b/backend/internal/middleware/tenant_middleware_test.go new file mode 100644 index 00000000..5d8f21a2 --- /dev/null +++ b/backend/internal/middleware/tenant_middleware_test.go @@ -0,0 +1,103 @@ +package middleware + +import ( + "baron-sso-backend/internal/domain" + "baron-sso-backend/internal/service" + "context" + "log/slog" + "net/http/httptest" + "os" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockTenantServiceForMiddleware struct { + mock.Mock +} + +func (m *MockTenantServiceForMiddleware) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) { + return nil, nil +} +func (m *MockTenantServiceForMiddleware) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) { + return nil, nil +} +func (m *MockTenantServiceForMiddleware) GetTenantByDomain(ctx context.Context, emailDomain string) (*domain.Tenant, error) { + args := m.Called(mock.Anything, emailDomain) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Tenant), args.Error(1) +} +func (m *MockTenantServiceForMiddleware) GetTenantBySlug(ctx context.Context, slug string) (*domain.Tenant, error) { + args := m.Called(mock.Anything, slug) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Tenant), args.Error(1) +} +func (m *MockTenantServiceForMiddleware) GetTenant(ctx context.Context, id string) (*domain.Tenant, error) { + return nil, nil +} +func (m *MockTenantServiceForMiddleware) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) { + return nil, 0, nil +} +func (m *MockTenantServiceForMiddleware) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) { + return nil, nil +} +func (m *MockTenantServiceForMiddleware) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) { + return false, nil +} +func (m *MockTenantServiceForMiddleware) ApproveTenant(ctx context.Context, id string) error { + return nil +} +func (m *MockTenantServiceForMiddleware) SetKetoService(keto service.KetoService) {} + +func TestTenantContextMiddleware(t *testing.T) { + os.Setenv("USERFRONT_URL", "https://sso.hmac.kr") + defer os.Unsetenv("USERFRONT_URL") + + mockSvc := new(MockTenantServiceForMiddleware) + app := fiber.New() + app.Use(TenantContextMiddleware(TenantContextConfig{TenantService: mockSvc})) + + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "tenant_id": c.Locals("tenant_id"), + "tenant_slug": c.Locals("tenant_slug"), + }) + }) + + t.Run("Base Domain - No Tenant Context", func(t *testing.T) { + req := httptest.NewRequest("GET", "http://sso.hmac.kr/test", nil) + req.Host = "sso.hmac.kr" + resp, _ := app.Test(req) + assert.Equal(t, 200, resp.StatusCode) + mockSvc.AssertNotCalled(t, "GetTenantByDomain", mock.Anything, "sso.hmac.kr") + }) + + t.Run("Subdomain - Identify by Slug", func(t *testing.T) { + mockSvc.On("GetTenantByDomain", mock.Anything, "tenant1.sso.hmac.kr").Return(nil, nil).Once() + mockSvc.On("GetTenantBySlug", mock.Anything, "tenant1").Return(&domain.Tenant{ID: "t1", Slug: "tenant1"}, nil).Once() + + req := httptest.NewRequest("GET", "/test", nil) + req.Host = "tenant1.sso.hmac.kr" + req.Header.Set("Host", "tenant1.sso.hmac.kr") + resp, _ := app.Test(req) + assert.Equal(t, 200, resp.StatusCode) + mockSvc.AssertExpectations(t) + }) + + t.Run("Custom Domain - Identify by Domain", func(t *testing.T) { + mockSvc.On("GetTenantByDomain", mock.Anything, "company.com").Return(&domain.Tenant{ID: "t2", Slug: "company"}, nil).Once() + + req := httptest.NewRequest("GET", "/test", nil) + req.Host = "company.com" + req.Header.Set("Host", "company.com") + resp, _ := app.Test(req) + assert.Equal(t, 200, resp.StatusCode) + mockSvc.AssertExpectations(t) + }) +} diff --git a/backend/internal/service/tenant_service.go b/backend/internal/service/tenant_service.go index 553e9554..ff104c91 100644 --- a/backend/internal/service/tenant_service.go +++ b/backend/internal/service/tenant_service.go @@ -20,6 +20,7 @@ type TenantService interface { GetTenant(ctx context.Context, id string) (*domain.Tenant, error) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) ListManageableTenants(ctx context.Context, userID string) ([]domain.Tenant, error) + IsDomainAllowed(ctx context.Context, domainName string) (bool, error) ApproveTenant(ctx context.Context, id string) error SetKetoService(keto KetoService) // 추가 } @@ -282,3 +283,14 @@ func (s *tenantService) ListTenants(ctx context.Context, limit, offset int, pare // Let the repository handle the query and pagination return s.repo.List(ctx, limit, offset, parentID) } + +func (s *tenantService) IsDomainAllowed(ctx context.Context, domainName string) (bool, error) { + tenant, err := s.repo.FindByDomain(ctx, domainName) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err + } + return tenant != nil && tenant.Status == domain.TenantStatusActive, nil +}