diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index d1fc8581..428445df 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -3944,6 +3944,70 @@ func (h *AuthHandler) GetMe(c *fiber.Ctx) error { return c.JSON(profile) } +func (h *AuthHandler) resolveProfileForSubject(ctx context.Context, subject string) (*domain.UserProfileResponse, error) { + subject = strings.TrimSpace(subject) + if subject == "" || h.KratosAdmin == nil { + return nil, fmt.Errorf("subject profile unavailable") + } + + identity, err := h.KratosAdmin.GetIdentity(ctx, subject) + if err != nil { + return nil, err + } + if identity == nil { + return nil, fmt.Errorf("identity not found") + } + + profile := h.mapKratosIdentityToProfile(identity.ID, identity.Traits) + if profile == nil { + return nil, fmt.Errorf("failed to map identity profile") + } + return h.hydrateResolvedProfile(ctx, profile), nil +} + +func (h *AuthHandler) hydrateResolvedProfile(ctx context.Context, profile *domain.UserProfileResponse) *domain.UserProfileResponse { + if profile == nil { + return nil + } + + profile.Role = domain.NormalizeRole(profile.Role) + if profile.Role == "" { + profile.Role = domain.RoleUser + } + + if h.TenantService != nil { + if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" { + if tenant, err := h.TenantService.GetTenant(ctx, *profile.TenantID); err == nil { + profile.Tenant = tenant + } + } + if profile.Tenant == nil && profile.CompanyCode != "" { + if tenant, err := h.TenantService.GetTenantBySlug(ctx, profile.CompanyCode); err == nil && tenant != nil { + profile.Tenant = tenant + if profile.TenantID == nil || *profile.TenantID == "" { + profile.TenantID = &tenant.ID + } + } + } + } + + if h.TenantService != nil { + if profile.Role == domain.RoleTenantAdmin { + manageable, err := h.TenantService.ListManageableTenants(ctx, profile.ID) + if err == nil { + profile.ManageableTenants = manageable + } + } + + joined, err := h.TenantService.ListJoinedTenants(ctx, profile.ID) + if err == nil { + profile.JoinedTenants = joined + } + } + + return profile +} + // GetEnrichedProfile - Exported wrapper for resolveCurrentProfile used by middlewares func (h *AuthHandler) GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) { return h.resolveCurrentProfile(c) @@ -5132,8 +5196,14 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error { ) profile, err := h.resolveCurrentProfile(c) - if tenantErr := enforceClientTenantAccess(c, consentRequest.Client, profile, err); tenantErr != nil { - return tenantErr + if (err != nil || profile == nil) && consentRequest.Subject != "" { + if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), consentRequest.Subject); fallbackErr == nil { + profile = fallbackProfile + err = nil + } + } + if enforceClientTenantAccess(c, h.TenantService, consentRequest.Client, profile, err) { + return nil } // [New] 로컬 DB에서 기존 동의 내역 확인 (강제 자동 승인 전략) @@ -5342,8 +5412,14 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error { consentRequest.RequestedScope = mergeRequestedScopesWithClientRequirements(consentRequest.Client, consentRequest.RequestedScope) profile, err := h.resolveCurrentProfile(c) - if tenantErr := enforceClientTenantAccess(c, consentRequest.Client, profile, err); tenantErr != nil { - return tenantErr + if (err != nil || profile == nil) && consentRequest.Subject != "" { + if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), consentRequest.Subject); fallbackErr == nil { + profile = fallbackProfile + err = nil + } + } + if enforceClientTenantAccess(c, h.TenantService, consentRequest.Client, profile, err) { + return nil } // 3. Hydra에 승인 요청 @@ -5484,9 +5560,15 @@ func (h *AuthHandler) AcceptOidcLoginRequest(c *fiber.Ctx) error { } profile, err := h.resolveCurrentProfile(c) + if (err != nil || profile == nil) && loginReq != nil && strings.TrimSpace(loginReq.Subject) != "" { + if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), loginReq.Subject); fallbackErr == nil { + profile = fallbackProfile + err = nil + } + } if loginReq != nil { - if tenantErr := enforceClientTenantAccess(c, loginReq.Client, profile, err); tenantErr != nil { - return tenantErr + if enforceClientTenantAccess(c, h.TenantService, loginReq.Client, profile, err) { + return nil } } @@ -5631,37 +5713,7 @@ func (h *AuthHandler) resolveCurrentProfile(c *fiber.Ctx) (*domain.UserProfileRe delete(profile.Metadata, "_used_identifier") // Cleanup } - // Fetch Tenant Metadata if missing - if h.TenantService != nil { - if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" { - if tenant, err := h.TenantService.GetTenant(c.Context(), *profile.TenantID); err == nil { - profile.Tenant = tenant - } - } - if profile.Tenant == nil && profile.CompanyCode != "" { - if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), profile.CompanyCode); err == nil && tenant != nil { - profile.Tenant = tenant - if profile.TenantID == nil || *profile.TenantID == "" { - profile.TenantID = &tenant.ID - } - } - } - } - - // [New] Fetch manageable and joined tenants - if h.TenantService != nil { - if profile.Role == domain.RoleTenantAdmin { - manageable, err := h.TenantService.ListManageableTenants(c.Context(), profile.ID) - if err == nil { - profile.ManageableTenants = manageable - } - } - - joined, err := h.TenantService.ListJoinedTenants(c.Context(), profile.ID) - if err == nil { - profile.JoinedTenants = joined - } - } + profile = h.hydrateResolvedProfile(c.Context(), profile) // 4. Save to Redis Cache (Short TTL) // IMPORTANT: In dev mode, if role was overridden, we should NOT cache it under the token key diff --git a/backend/internal/handler/client_tenant_access.go b/backend/internal/handler/client_tenant_access.go index c2037260..097be621 100644 --- a/backend/internal/handler/client_tenant_access.go +++ b/backend/internal/handler/client_tenant_access.go @@ -2,6 +2,8 @@ package handler import ( "baron-sso-backend/internal/domain" + "baron-sso-backend/internal/response" + "baron-sso-backend/internal/service" "encoding/json" "errors" "sort" @@ -133,8 +135,32 @@ func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domai return false } -func tenantNotAllowedError(c *fiber.Ctx) error { - return errorJSONCode(c, fiber.StatusForbidden, "tenant_not_allowed", "허용되지 않은 테넌트입니다.") +type tenantAccessDeniedDetails struct { + Account tenantAccessDeniedAccount `json:"account"` + CurrentTenant tenantAccessDeniedTenant `json:"current_tenant"` + AffiliatedTenants []tenantAccessDeniedTenant `json:"affiliated_tenants,omitempty"` + AllowedTenants []tenantAccessDeniedTenant `json:"allowed_tenants,omitempty"` +} + +type tenantAccessDeniedAccount struct { + Email string `json:"email,omitempty"` +} + +type tenantAccessDeniedTenant struct { + ID string `json:"id,omitempty"` + Slug string `json:"slug,omitempty"` + Name string `json:"name,omitempty"` + Identifier string `json:"identifier,omitempty"` +} + +func tenantNotAllowedError(c *fiber.Ctx, details tenantAccessDeniedDetails) error { + return response.ErrorWithDetails( + c, + fiber.StatusForbidden, + "tenant_not_allowed", + "허용되지 않은 테넌트입니다.", + details, + ) } func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool { @@ -144,17 +170,162 @@ func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client dom return clientTenantAccessAllowed(profile, client) } -func enforceClientTenantAccess(c *fiber.Ctx, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) error { +func enforceClientTenantAccess(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) bool { if !clientTenantAccessRestricted(client.Metadata) { - return nil + return false } + details := buildTenantAccessDeniedDetails(c, tenantSvc, client, profile) if resolveErr != nil || profile == nil { - return tenantNotAllowedError(c) + _ = tenantNotAllowedError(c, details) + return true } if !clientTenantAccessAllowed(profile, client) { - return tenantNotAllowedError(c) + _ = tenantNotAllowedError(c, details) + return true } - return nil + return false +} + +func buildTenantAccessDeniedDetails(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse) tenantAccessDeniedDetails { + details := tenantAccessDeniedDetails{ + Account: tenantAccessDeniedAccount{Email: strings.TrimSpace(profileEmail(profile))}, + CurrentTenant: resolveCurrentTenantDetails(c, tenantSvc, profile), + AffiliatedTenants: resolveAffiliatedTenantDetails(c, tenantSvc, profile), + } + + for _, identifier := range clientAllowedTenants(client.Metadata) { + details.AllowedTenants = append(details.AllowedTenants, resolveAllowedTenantDetails(c, tenantSvc, identifier)) + } + + return details +} + +func resolveAffiliatedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) []tenantAccessDeniedTenant { + if profile == nil { + return nil + } + + seen := make(map[string]struct{}) + out := make([]tenantAccessDeniedTenant, 0, len(profile.JoinedTenants)+1) + appendTenant := func(tenant tenantAccessDeniedTenant) { + key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Identifier, tenant.Name)) + if key == "" { + return + } + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + out = append(out, tenant) + } + + appendTenant(resolveCurrentTenantDetails(c, tenantSvc, profile)) + + for _, joined := range profile.JoinedTenants { + appendTenant(tenantAccessDeniedTenant{ + ID: strings.TrimSpace(joined.ID), + Slug: strings.TrimSpace(joined.Slug), + Name: strings.TrimSpace(joined.Name), + Identifier: firstNonEmptyString(strings.TrimSpace(joined.Slug), strings.TrimSpace(joined.ID)), + }) + } + + return out +} + +func resolveCurrentTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) tenantAccessDeniedTenant { + if profile == nil { + return tenantAccessDeniedTenant{} + } + + if profile.Tenant != nil { + return tenantAccessDeniedTenant{ + ID: strings.TrimSpace(profile.Tenant.ID), + Slug: strings.TrimSpace(profile.Tenant.Slug), + Name: strings.TrimSpace(profile.Tenant.Name), + Identifier: firstNonEmptyString(strings.TrimSpace(profile.Tenant.Slug), strings.TrimSpace(profile.Tenant.ID)), + } + } + + if tenantSvc != nil { + if profile.TenantID != nil && strings.TrimSpace(*profile.TenantID) != "" { + if tenant, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(*profile.TenantID)); err == nil && tenant != nil { + return tenantAccessDeniedTenant{ + ID: strings.TrimSpace(tenant.ID), + Slug: strings.TrimSpace(tenant.Slug), + Name: strings.TrimSpace(tenant.Name), + Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)), + } + } + } + if strings.TrimSpace(profile.CompanyCode) != "" { + if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), strings.TrimSpace(profile.CompanyCode)); err == nil && tenant != nil { + return tenantAccessDeniedTenant{ + ID: strings.TrimSpace(tenant.ID), + Slug: strings.TrimSpace(tenant.Slug), + Name: strings.TrimSpace(tenant.Name), + Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)), + } + } + } + } + + return tenantAccessDeniedTenant{ + ID: strings.TrimSpace(pointerValue(profile.TenantID)), + Slug: strings.TrimSpace(profile.CompanyCode), + Identifier: firstNonEmptyString(strings.TrimSpace(profile.CompanyCode), strings.TrimSpace(pointerValue(profile.TenantID))), + } +} + +func resolveAllowedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, identifier string) tenantAccessDeniedTenant { + identifier = strings.TrimSpace(identifier) + if identifier == "" { + return tenantAccessDeniedTenant{} + } + + if tenantSvc != nil { + if tenant, err := tenantSvc.GetTenant(c.Context(), identifier); err == nil && tenant != nil { + return tenantAccessDeniedTenant{ + ID: strings.TrimSpace(tenant.ID), + Slug: strings.TrimSpace(tenant.Slug), + Name: strings.TrimSpace(tenant.Name), + Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier), + } + } + if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), identifier); err == nil && tenant != nil { + return tenantAccessDeniedTenant{ + ID: strings.TrimSpace(tenant.ID), + Slug: strings.TrimSpace(tenant.Slug), + Name: strings.TrimSpace(tenant.Name), + Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier), + } + } + } + + return tenantAccessDeniedTenant{Identifier: identifier} +} + +func profileEmail(profile *domain.UserProfileResponse) string { + if profile == nil { + return "" + } + return profile.Email +} + +func pointerValue(value *string) string { + if value == nil { + return "" + } + return *value +} + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" } type clientStructuredScope struct { diff --git a/backend/internal/handler/client_tenant_access_test.go b/backend/internal/handler/client_tenant_access_test.go index 7d47f3fd..a27af74e 100644 --- a/backend/internal/handler/client_tenant_access_test.go +++ b/backend/internal/handler/client_tenant_access_test.go @@ -12,6 +12,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) { @@ -190,6 +191,18 @@ func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) { resp, err := app.Test(req) assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + var body map[string]any + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, "tenant_not_allowed", body["code"]) + details, ok := body["details"].(map[string]any) + assert.True(t, ok) + account, ok := details["account"].(map[string]any) + assert.True(t, ok) + assert.NotEmpty(t, account["email"]) + currentTenant, ok := details["current_tenant"].(map[string]any) + assert.True(t, ok) + assert.NotEmpty(t, currentTenant["identifier"]) } func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) { @@ -232,6 +245,37 @@ func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *t AdminURL: "http://hydra.test", HTTPClient: client, }, + KratosAdmin: func() service.KratosAdminService { + mockKratos := new(MockKratosAdminService) + mockKratos.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{ + ID: "user-123", + Traits: map[string]interface{}{ + "email": "user@test.com", + "tenant_id": "tenant-a", + "companyCode": "tenant-a", + }, + }, nil).Once() + return mockKratos + }(), + TenantService: func() service.TenantService { + tenantSvc := new(MockTenantService) + tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{ + ID: "tenant-a", + Slug: "tenant-a", + Name: "Tenant A", + }, nil).Twice() + tenantSvc.On("ListJoinedTenants", mock.Anything, "user-123").Return([]domain.Tenant{ + {ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"}, + {ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"}, + }, nil).Once() + tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError).Once() + tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{ + ID: "tenant-b-id", + Slug: "tenant-b", + Name: "Tenant B", + }, nil).Once() + return tenantSvc + }(), ConsentRepo: &mockConsentRepo{ consents: []domain.ClientConsent{ { @@ -253,6 +297,22 @@ func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *t assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, resp.StatusCode) assert.False(t, acceptCalled) + + var body map[string]any + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, "tenant_not_allowed", body["code"]) + + details, ok := body["details"].(map[string]any) + assert.True(t, ok) + account, ok := details["account"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "user@test.com", account["email"]) + currentTenant, ok := details["current_tenant"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "Tenant A", currentTenant["name"]) + affiliatedTenants, ok := details["affiliated_tenants"].([]any) + assert.True(t, ok) + assert.Len(t, affiliatedTenants, 2) } func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) { @@ -260,9 +320,15 @@ func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) { app.Get("/deny", func(c *fiber.Ctx) error { tenantID := "tenant-a" profile := &domain.UserProfileResponse{ - ID: "user-123", - Role: domain.RoleUser, - TenantID: &tenantID, + ID: "user-123", + Role: domain.RoleUser, + Email: "user@test.com", + TenantID: &tenantID, + CompanyCode: "tenant-a", + JoinedTenants: []domain.Tenant{ + {ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"}, + {ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"}, + }, } client := domain.HydraClient{ ClientID: "client-tenant", @@ -271,11 +337,49 @@ func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) { "allowed_tenants": []string{"tenant-b"}, }, } - return enforceClientTenantAccess(c, client, profile, nil) + tenantSvc := new(MockTenantService) + tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{ + ID: "tenant-a", + Slug: "tenant-a", + Name: "Tenant A", + }, nil).Twice() + tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError).Once() + tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{ + ID: "tenant-b-id", + Slug: "tenant-b", + Name: "Tenant B", + }, nil).Once() + enforceClientTenantAccess(c, tenantSvc, client, profile, nil) + return nil }) req := httptest.NewRequest(http.MethodGet, "/deny", nil) resp, err := app.Test(req) assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + var body map[string]any + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, "tenant_not_allowed", body["code"]) + + details, ok := body["details"].(map[string]any) + assert.True(t, ok) + + account, ok := details["account"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "user@test.com", account["email"]) + + currentTenant, ok := details["current_tenant"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "Tenant A", currentTenant["name"]) + affiliatedTenants, ok := details["affiliated_tenants"].([]any) + assert.True(t, ok) + assert.Len(t, affiliatedTenants, 2) + + allowedTenants, ok := details["allowed_tenants"].([]any) + assert.True(t, ok) + assert.Len(t, allowedTenants, 1) + allowedTenant, ok := allowedTenants[0].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "Tenant B", allowedTenant["name"]) }