diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 03d87b7f..f8c294cc 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -5130,6 +5130,13 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error { "scopes", consentRequest.RequestedScope, ) + profile, err := h.resolveCurrentProfile(c) + if err == nil && profile != nil { + if !isClientTenantAccessAllowed(profile, consentRequest.Client) { + return tenantNotAllowedError(c) + } + } + // [New] 로컬 DB에서 기존 동의 내역 확인 (강제 자동 승인 전략) // Hydra가 skip을 주지 않더라도, 우리 DB에 이미 기록이 있다면 승인 처리함 if !consentRequest.Skip && h.ConsentRepo != nil && consentRequest.Subject != "" { @@ -5333,6 +5340,13 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error { consentRequest.RequestedScope = filteredScopes } + profile, err := h.resolveCurrentProfile(c) + if err == nil && profile != nil { + if !isClientTenantAccessAllowed(profile, consentRequest.Client) { + return tenantNotAllowedError(c) + } + } + // 3. Hydra에 승인 요청 if consentRequest.Subject == "" { return fiber.NewError(fiber.StatusInternalServerError, "Consent subject missing") @@ -5470,6 +5484,15 @@ func (h *AuthHandler) AcceptOidcLoginRequest(c *fiber.Ctx) error { } } + profile, err := h.resolveCurrentProfile(c) + if loginReq != nil { + if err == nil && profile != nil { + if !isClientTenantAccessAllowed(profile, loginReq.Client) { + return tenantNotAllowedError(c) + } + } + } + subject, err := h.resolveConsentSubject(c) if err != nil || subject == "" { return fiber.NewError(fiber.StatusUnauthorized, "Authentication required") @@ -5520,6 +5543,10 @@ func (h *AuthHandler) AcceptOidcLoginRequest(c *fiber.Ctx) error { } func (h *AuthHandler) resolveCurrentProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) { + if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok && profile != nil { + return profile, nil + } + appEnv := strings.ToLower(os.Getenv("APP_ENV")) isDev := appEnv == "dev" || appEnv == "development" || appEnv == "" @@ -5608,16 +5635,18 @@ func (h *AuthHandler) resolveCurrentProfile(c *fiber.Ctx) (*domain.UserProfileRe } // Fetch Tenant Metadata if missing - if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" { - if tenant, err := h.TenantService.GetTenant(c.Context(), *profile.TenantID); err == nil { - profile.Tenant = tenant + if h.TenantService != nil { + if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" { + if tenant, err := h.TenantService.GetTenant(c.Context(), *profile.TenantID); err == nil { + profile.Tenant = tenant + } } - } - if profile.Tenant == nil && profile.CompanyCode != "" { - if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), profile.CompanyCode); err == nil && tenant != nil { - profile.Tenant = tenant - if profile.TenantID == nil || *profile.TenantID == "" { - profile.TenantID = &tenant.ID + if profile.Tenant == nil && profile.CompanyCode != "" { + if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), profile.CompanyCode); err == nil && tenant != nil { + profile.Tenant = tenant + if profile.TenantID == nil || *profile.TenantID == "" { + profile.TenantID = &tenant.ID + } } } } diff --git a/backend/internal/handler/client_tenant_access.go b/backend/internal/handler/client_tenant_access.go new file mode 100644 index 00000000..d58e2d36 --- /dev/null +++ b/backend/internal/handler/client_tenant_access.go @@ -0,0 +1,144 @@ +package handler + +import ( + "baron-sso-backend/internal/domain" + "errors" + "sort" + "strings" + + "github.com/gofiber/fiber/v2" +) + +const ( + clientTenantAccessRestrictedKey = "tenant_access_restricted" + clientAllowedTenantsKey = "allowed_tenants" +) + +func normalizeClientTenantAccessMetadata(metadata map[string]interface{}) (map[string]interface{}, error) { + if metadata == nil { + metadata = map[string]interface{}{} + } + + restricted := readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey) + allowedTenants := normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey]) + ownerTenantID := normalizeMetadataString(metadata["tenant_id"]) + + if len(allowedTenants) > 0 { + restricted = true + } + + if !restricted { + delete(metadata, clientAllowedTenantsKey) + metadata[clientTenantAccessRestrictedKey] = false + return metadata, nil + } + + if ownerTenantID != "" { + allowedTenants = append(allowedTenants, ownerTenantID) + } + allowedTenants = uniqueSortedStrings(allowedTenants) + if len(allowedTenants) == 0 { + return nil, errors.New("allowed_tenants is required when tenant_access_restricted is enabled") + } + + metadata[clientTenantAccessRestrictedKey] = true + metadata[clientAllowedTenantsKey] = allowedTenants + return metadata, nil +} + +func clientTenantAccessRestricted(metadata map[string]interface{}) bool { + if metadata == nil { + return false + } + if readMetadataBoolValue(metadata, clientTenantAccessRestrictedKey) { + return true + } + return len(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])) > 0 +} + +func clientAllowedTenants(metadata map[string]interface{}) []string { + if metadata == nil { + return nil + } + if !clientTenantAccessRestricted(metadata) { + return nil + } + return uniqueSortedStrings(normalizeMetadataStringSlice(metadata[clientAllowedTenantsKey])) +} + +func normalizeMetadataStringSlice(raw any) []string { + switch value := raw.(type) { + case []string: + return uniqueSortedStrings(value) + case []any: + items := make([]string, 0, len(value)) + for _, item := range value { + if s, ok := item.(string); ok { + items = append(items, s) + } + } + return uniqueSortedStrings(items) + default: + return nil + } +} + +func normalizeMetadataString(raw any) string { + s, ok := raw.(string) + if !ok { + return "" + } + return strings.TrimSpace(s) +} + +func uniqueSortedStrings(values []string) []string { + if len(values) == 0 { + return nil + } + seen := make(map[string]struct{}, len(values)) + out := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + out = append(out, trimmed) + } + sort.Strings(out) + return out +} + +func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool { + if !clientTenantAccessRestricted(client.Metadata) { + return true + } + allowed := clientAllowedTenants(client.Metadata) + if len(allowed) == 0 { + return false + } + keys := manageableTenantKeysFromProfile(profile) + if len(keys) == 0 { + return false + } + for _, tenantID := range allowed { + if _, ok := keys[strings.ToLower(strings.TrimSpace(tenantID))]; ok { + return true + } + } + return false +} + +func tenantNotAllowedError(c *fiber.Ctx) error { + return errorJSONCode(c, fiber.StatusForbidden, "tenant_not_allowed", "허용되지 않은 테넌트입니다.") +} + +func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool { + if profile == nil { + return false + } + return clientTenantAccessAllowed(profile, client) +} diff --git a/backend/internal/handler/client_tenant_access_test.go b/backend/internal/handler/client_tenant_access_test.go new file mode 100644 index 00000000..51138845 --- /dev/null +++ b/backend/internal/handler/client_tenant_access_test.go @@ -0,0 +1,263 @@ +package handler + +import ( + "baron-sso-backend/internal/domain" + "baron-sso-backend/internal/service" + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" +) + +func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) { + var captured domain.HydraClient + ownerTenantID := "tenant-owner" + + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.Method == http.MethodPost && r.URL.Path == "/clients" { + body, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &captured)) + return httpJSONAny(r, http.StatusCreated, map[string]any{ + "client_id": captured.ClientID, + "client_name": captured.ClientName, + "redirect_uris": captured.RedirectURIs, + "grant_types": captured.GrantTypes, + "response_types": captured.ResponseTypes, + "scope": captured.Scope, + "token_endpoint_auth_method": captured.TokenEndpointAuthMethod, + "metadata": captured.Metadata, + }), nil + } + return httpJSONAny(r, http.StatusNotFound, nil), nil + }) + + h := &DevHandler{ + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: transport}, + }, + Keto: new(devMockKetoService), + } + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals("user_profile", &domain.UserProfileResponse{ + ID: "user-1", + Role: domain.RoleSuperAdmin, + TenantID: &ownerTenantID, + }) + return c.Next() + }) + app.Post("/api/v1/dev/clients", h.CreateClient) + + body, _ := json.Marshal(map[string]any{ + "id": "client-tenant", + "name": "Tenant Client", + "type": "pkce", + "redirectUris": []string{"https://rp.example.com/cb"}, + "metadata": map[string]any{ + "tenant_access_restricted": true, + "allowed_tenants": []string{"tenant-b", "tenant-a", "tenant-b"}, + }, + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := app.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode) + assert.True(t, clientTenantAccessRestricted(captured.Metadata)) + assert.Equal(t, []string{"tenant-a", "tenant-b", "tenant-owner"}, clientAllowedTenants(captured.Metadata)) +} + +func TestCreateClient_RejectsTenantAccessWithoutAllowedTenants(t *testing.T) { + hydraCalled := false + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.Method == http.MethodPost && r.URL.Path == "/clients" { + hydraCalled = true + } + return httpJSONAny(r, http.StatusNotFound, nil), nil + }) + + h := &DevHandler{ + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: &http.Client{Transport: transport}, + }, + Keto: new(devMockKetoService), + } + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals("user_profile", &domain.UserProfileResponse{ID: "user-1", Role: domain.RoleSuperAdmin}) + return c.Next() + }) + app.Post("/api/v1/dev/clients", h.CreateClient) + + body, _ := json.Marshal(map[string]any{ + "id": "client-tenant", + "name": "Tenant Client", + "type": "pkce", + "redirectUris": []string{"https://rp.example.com/cb"}, + "metadata": map[string]any{ + "tenant_access_restricted": true, + }, + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/dev/clients", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := app.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.False(t, hydraCalled) +} + +func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) { + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + switch { + case r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-tenant": + return httpJSONAny(r, http.StatusOK, map[string]any{ + "challenge": "challenge-tenant", + "requested_scope": []string{"openid", "profile"}, + "skip": false, + "subject": "user-123", + "client": map[string]any{ + "client_id": "client-tenant", + "metadata": map[string]any{ + "tenant_access_restricted": true, + "allowed_tenants": []string{"tenant-b"}, + }, + }, + }), nil + case r.URL.Host == "kratos.test" && r.URL.Path == "/sessions/whoami": + return httpJSONAny(r, http.StatusOK, map[string]any{ + "identity": map[string]any{ + "id": "user-123", + "traits": map[string]any{ + "email": "user@test.com", + "tenant_id": "tenant-a", + }, + }, + }), nil + default: + return httpJSONAny(r, http.StatusNotFound, nil), nil + } + }) + + client := &http.Client{Transport: transport} + origDefault := http.DefaultClient + http.DefaultClient = client + defer func() { http.DefaultClient = origDefault }() + + t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test") + + h := &AuthHandler{ + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: client, + }, + } + + app := fiber.New() + tenantID := "tenant-a" + app.Use(func(c *fiber.Ctx) error { + c.Locals("user_profile", &domain.UserProfileResponse{ + ID: "user-123", + Role: domain.RoleUser, + TenantID: &tenantID, + }) + return c.Next() + }) + app.Get("/api/v1/auth/consent", h.GetConsentRequest) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/consent?consent_challenge=challenge-tenant", nil) + req.Header.Set("Cookie", "ory_kratos_session=session-1") + + resp, err := app.Test(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + bodyBytes, _ := io.ReadAll(resp.Body) + var body map[string]any + assert.NoError(t, json.Unmarshal(bodyBytes, &body)) + assert.Equal(t, "tenant_not_allowed", body["code"]) +} + +func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) { + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + switch { + case r.URL.Path == "/oauth2/auth/requests/login" && r.URL.Query().Get("login_challenge") == "login-tenant": + return httpJSONAny(r, http.StatusOK, map[string]any{ + "challenge": "login-tenant", + "subject": "user-123", + "client": map[string]any{ + "client_id": "client-tenant", + "metadata": map[string]any{ + "tenant_access_restricted": true, + "allowed_tenants": []string{"tenant-b"}, + }, + }, + }), nil + case r.URL.Host == "kratos.test" && r.URL.Path == "/sessions/whoami": + return httpJSONAny(r, http.StatusOK, map[string]any{ + "identity": map[string]any{ + "id": "user-123", + "traits": map[string]any{ + "email": "user@test.com", + "tenant_id": "tenant-a", + }, + }, + }), nil + default: + return httpJSONAny(r, http.StatusNotFound, nil), nil + } + }) + + client := &http.Client{Transport: transport} + origDefault := http.DefaultClient + http.DefaultClient = client + defer func() { http.DefaultClient = origDefault }() + + t.Setenv("KRATOS_PUBLIC_URL", "http://kratos.test") + + h := &AuthHandler{ + Hydra: &service.HydraAdminService{ + AdminURL: "http://hydra.test", + HTTPClient: client, + }, + } + + app := fiber.New() + tenantID := "tenant-a" + app.Use(func(c *fiber.Ctx) error { + c.Locals("user_profile", &domain.UserProfileResponse{ + ID: "user-123", + Role: domain.RoleUser, + TenantID: &tenantID, + }) + return c.Next() + }) + app.Post("/api/v1/auth/oidc/login/accept", h.AcceptOidcLoginRequest) + + reqBody, _ := json.Marshal(map[string]any{ + "login_challenge": "login-tenant", + }) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oidc/login/accept", bytes.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Cookie", "ory_kratos_session=session-1") + + resp, err := app.Test(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + bodyBytes, _ := io.ReadAll(resp.Body) + var body map[string]any + assert.NoError(t, json.Unmarshal(bodyBytes, &body)) + assert.Equal(t, "tenant_not_allowed", body["code"]) +} diff --git a/backend/internal/handler/dev_handler.go b/backend/internal/handler/dev_handler.go index 1a0b8a95..cc18a86c 100644 --- a/backend/internal/handler/dev_handler.go +++ b/backend/internal/handler/dev_handler.go @@ -1528,6 +1528,11 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error { } metadata["status"] = status metadata["created_at"] = time.Now().Format(time.RFC3339) + var err error + metadata, err = normalizeClientTenantAccessMetadata(metadata) + if err != nil { + return errorJSON(c, fiber.StatusBadRequest, err.Error()) + } tokenAuthMethod := strings.TrimSpace(valueOr(req.TokenEndpointAuthMethod, "")) if tokenAuthMethod == "" { @@ -1716,6 +1721,10 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error { } metadata["status"] = status } + metadata, err = normalizeClientTenantAccessMetadata(metadata) + if err != nil { + return errorJSON(c, fiber.StatusBadRequest, err.Error()) + } resolvedClientType := currentSummary.Type if clientType != "" { resolvedClientType = clientType