forked from baron/baron-sso
테넌트 접근 제한 로직 보강
This commit is contained in:
@@ -3944,6 +3944,70 @@ func (h *AuthHandler) GetMe(c *fiber.Ctx) error {
|
||||
return c.JSON(profile)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) resolveProfileForSubject(ctx context.Context, subject string) (*domain.UserProfileResponse, error) {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" || h.KratosAdmin == nil {
|
||||
return nil, fmt.Errorf("subject profile unavailable")
|
||||
}
|
||||
|
||||
identity, err := h.KratosAdmin.GetIdentity(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if identity == nil {
|
||||
return nil, fmt.Errorf("identity not found")
|
||||
}
|
||||
|
||||
profile := h.mapKratosIdentityToProfile(identity.ID, identity.Traits)
|
||||
if profile == nil {
|
||||
return nil, fmt.Errorf("failed to map identity profile")
|
||||
}
|
||||
return h.hydrateResolvedProfile(ctx, profile), nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) hydrateResolvedProfile(ctx context.Context, profile *domain.UserProfileResponse) *domain.UserProfileResponse {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
profile.Role = domain.NormalizeRole(profile.Role)
|
||||
if profile.Role == "" {
|
||||
profile.Role = domain.RoleUser
|
||||
}
|
||||
|
||||
if h.TenantService != nil {
|
||||
if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" {
|
||||
if tenant, err := h.TenantService.GetTenant(ctx, *profile.TenantID); err == nil {
|
||||
profile.Tenant = tenant
|
||||
}
|
||||
}
|
||||
if profile.Tenant == nil && profile.CompanyCode != "" {
|
||||
if tenant, err := h.TenantService.GetTenantBySlug(ctx, profile.CompanyCode); err == nil && tenant != nil {
|
||||
profile.Tenant = tenant
|
||||
if profile.TenantID == nil || *profile.TenantID == "" {
|
||||
profile.TenantID = &tenant.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.TenantService != nil {
|
||||
if profile.Role == domain.RoleTenantAdmin {
|
||||
manageable, err := h.TenantService.ListManageableTenants(ctx, profile.ID)
|
||||
if err == nil {
|
||||
profile.ManageableTenants = manageable
|
||||
}
|
||||
}
|
||||
|
||||
joined, err := h.TenantService.ListJoinedTenants(ctx, profile.ID)
|
||||
if err == nil {
|
||||
profile.JoinedTenants = joined
|
||||
}
|
||||
}
|
||||
|
||||
return profile
|
||||
}
|
||||
|
||||
// GetEnrichedProfile - Exported wrapper for resolveCurrentProfile used by middlewares
|
||||
func (h *AuthHandler) GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) {
|
||||
return h.resolveCurrentProfile(c)
|
||||
@@ -5132,8 +5196,14 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
|
||||
)
|
||||
|
||||
profile, err := h.resolveCurrentProfile(c)
|
||||
if tenantErr := enforceClientTenantAccess(c, consentRequest.Client, profile, err); tenantErr != nil {
|
||||
return tenantErr
|
||||
if (err != nil || profile == nil) && consentRequest.Subject != "" {
|
||||
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), consentRequest.Subject); fallbackErr == nil {
|
||||
profile = fallbackProfile
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if enforceClientTenantAccess(c, h.TenantService, consentRequest.Client, profile, err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// [New] 로컬 DB에서 기존 동의 내역 확인 (강제 자동 승인 전략)
|
||||
@@ -5342,8 +5412,14 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error {
|
||||
consentRequest.RequestedScope = mergeRequestedScopesWithClientRequirements(consentRequest.Client, consentRequest.RequestedScope)
|
||||
|
||||
profile, err := h.resolveCurrentProfile(c)
|
||||
if tenantErr := enforceClientTenantAccess(c, consentRequest.Client, profile, err); tenantErr != nil {
|
||||
return tenantErr
|
||||
if (err != nil || profile == nil) && consentRequest.Subject != "" {
|
||||
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), consentRequest.Subject); fallbackErr == nil {
|
||||
profile = fallbackProfile
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if enforceClientTenantAccess(c, h.TenantService, consentRequest.Client, profile, err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 3. Hydra에 승인 요청
|
||||
@@ -5484,9 +5560,15 @@ func (h *AuthHandler) AcceptOidcLoginRequest(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
profile, err := h.resolveCurrentProfile(c)
|
||||
if (err != nil || profile == nil) && loginReq != nil && strings.TrimSpace(loginReq.Subject) != "" {
|
||||
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), loginReq.Subject); fallbackErr == nil {
|
||||
profile = fallbackProfile
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if loginReq != nil {
|
||||
if tenantErr := enforceClientTenantAccess(c, loginReq.Client, profile, err); tenantErr != nil {
|
||||
return tenantErr
|
||||
if enforceClientTenantAccess(c, h.TenantService, loginReq.Client, profile, err) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5631,37 +5713,7 @@ func (h *AuthHandler) resolveCurrentProfile(c *fiber.Ctx) (*domain.UserProfileRe
|
||||
delete(profile.Metadata, "_used_identifier") // Cleanup
|
||||
}
|
||||
|
||||
// Fetch Tenant Metadata if missing
|
||||
if h.TenantService != nil {
|
||||
if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" {
|
||||
if tenant, err := h.TenantService.GetTenant(c.Context(), *profile.TenantID); err == nil {
|
||||
profile.Tenant = tenant
|
||||
}
|
||||
}
|
||||
if profile.Tenant == nil && profile.CompanyCode != "" {
|
||||
if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), profile.CompanyCode); err == nil && tenant != nil {
|
||||
profile.Tenant = tenant
|
||||
if profile.TenantID == nil || *profile.TenantID == "" {
|
||||
profile.TenantID = &tenant.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [New] Fetch manageable and joined tenants
|
||||
if h.TenantService != nil {
|
||||
if profile.Role == domain.RoleTenantAdmin {
|
||||
manageable, err := h.TenantService.ListManageableTenants(c.Context(), profile.ID)
|
||||
if err == nil {
|
||||
profile.ManageableTenants = manageable
|
||||
}
|
||||
}
|
||||
|
||||
joined, err := h.TenantService.ListJoinedTenants(c.Context(), profile.ID)
|
||||
if err == nil {
|
||||
profile.JoinedTenants = joined
|
||||
}
|
||||
}
|
||||
profile = h.hydrateResolvedProfile(c.Context(), profile)
|
||||
|
||||
// 4. Save to Redis Cache (Short TTL)
|
||||
// IMPORTANT: In dev mode, if role was overridden, we should NOT cache it under the token key
|
||||
|
||||
@@ -2,6 +2,8 @@ package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/response"
|
||||
"baron-sso-backend/internal/service"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sort"
|
||||
@@ -133,8 +135,32 @@ func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domai
|
||||
return false
|
||||
}
|
||||
|
||||
func tenantNotAllowedError(c *fiber.Ctx) error {
|
||||
return errorJSONCode(c, fiber.StatusForbidden, "tenant_not_allowed", "허용되지 않은 테넌트입니다.")
|
||||
type tenantAccessDeniedDetails struct {
|
||||
Account tenantAccessDeniedAccount `json:"account"`
|
||||
CurrentTenant tenantAccessDeniedTenant `json:"current_tenant"`
|
||||
AffiliatedTenants []tenantAccessDeniedTenant `json:"affiliated_tenants,omitempty"`
|
||||
AllowedTenants []tenantAccessDeniedTenant `json:"allowed_tenants,omitempty"`
|
||||
}
|
||||
|
||||
type tenantAccessDeniedAccount struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
type tenantAccessDeniedTenant struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Slug string `json:"slug,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Identifier string `json:"identifier,omitempty"`
|
||||
}
|
||||
|
||||
func tenantNotAllowedError(c *fiber.Ctx, details tenantAccessDeniedDetails) error {
|
||||
return response.ErrorWithDetails(
|
||||
c,
|
||||
fiber.StatusForbidden,
|
||||
"tenant_not_allowed",
|
||||
"허용되지 않은 테넌트입니다.",
|
||||
details,
|
||||
)
|
||||
}
|
||||
|
||||
func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
||||
@@ -144,17 +170,162 @@ func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client dom
|
||||
return clientTenantAccessAllowed(profile, client)
|
||||
}
|
||||
|
||||
func enforceClientTenantAccess(c *fiber.Ctx, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) error {
|
||||
func enforceClientTenantAccess(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) bool {
|
||||
if !clientTenantAccessRestricted(client.Metadata) {
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
details := buildTenantAccessDeniedDetails(c, tenantSvc, client, profile)
|
||||
if resolveErr != nil || profile == nil {
|
||||
return tenantNotAllowedError(c)
|
||||
_ = tenantNotAllowedError(c, details)
|
||||
return true
|
||||
}
|
||||
if !clientTenantAccessAllowed(profile, client) {
|
||||
return tenantNotAllowedError(c)
|
||||
_ = tenantNotAllowedError(c, details)
|
||||
return true
|
||||
}
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
|
||||
func buildTenantAccessDeniedDetails(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse) tenantAccessDeniedDetails {
|
||||
details := tenantAccessDeniedDetails{
|
||||
Account: tenantAccessDeniedAccount{Email: strings.TrimSpace(profileEmail(profile))},
|
||||
CurrentTenant: resolveCurrentTenantDetails(c, tenantSvc, profile),
|
||||
AffiliatedTenants: resolveAffiliatedTenantDetails(c, tenantSvc, profile),
|
||||
}
|
||||
|
||||
for _, identifier := range clientAllowedTenants(client.Metadata) {
|
||||
details.AllowedTenants = append(details.AllowedTenants, resolveAllowedTenantDetails(c, tenantSvc, identifier))
|
||||
}
|
||||
|
||||
return details
|
||||
}
|
||||
|
||||
func resolveAffiliatedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) []tenantAccessDeniedTenant {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]tenantAccessDeniedTenant, 0, len(profile.JoinedTenants)+1)
|
||||
appendTenant := func(tenant tenantAccessDeniedTenant) {
|
||||
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Identifier, tenant.Name))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, tenant)
|
||||
}
|
||||
|
||||
appendTenant(resolveCurrentTenantDetails(c, tenantSvc, profile))
|
||||
|
||||
for _, joined := range profile.JoinedTenants {
|
||||
appendTenant(tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(joined.ID),
|
||||
Slug: strings.TrimSpace(joined.Slug),
|
||||
Name: strings.TrimSpace(joined.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(joined.Slug), strings.TrimSpace(joined.ID)),
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveCurrentTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) tenantAccessDeniedTenant {
|
||||
if profile == nil {
|
||||
return tenantAccessDeniedTenant{}
|
||||
}
|
||||
|
||||
if profile.Tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(profile.Tenant.ID),
|
||||
Slug: strings.TrimSpace(profile.Tenant.Slug),
|
||||
Name: strings.TrimSpace(profile.Tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(profile.Tenant.Slug), strings.TrimSpace(profile.Tenant.ID)),
|
||||
}
|
||||
}
|
||||
|
||||
if tenantSvc != nil {
|
||||
if profile.TenantID != nil && strings.TrimSpace(*profile.TenantID) != "" {
|
||||
if tenant, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(*profile.TenantID)); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(profile.CompanyCode) != "" {
|
||||
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), strings.TrimSpace(profile.CompanyCode)); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(pointerValue(profile.TenantID)),
|
||||
Slug: strings.TrimSpace(profile.CompanyCode),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(profile.CompanyCode), strings.TrimSpace(pointerValue(profile.TenantID))),
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAllowedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, identifier string) tenantAccessDeniedTenant {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return tenantAccessDeniedTenant{}
|
||||
}
|
||||
|
||||
if tenantSvc != nil {
|
||||
if tenant, err := tenantSvc.GetTenant(c.Context(), identifier); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
|
||||
}
|
||||
}
|
||||
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), identifier); err == nil && tenant != nil {
|
||||
return tenantAccessDeniedTenant{
|
||||
ID: strings.TrimSpace(tenant.ID),
|
||||
Slug: strings.TrimSpace(tenant.Slug),
|
||||
Name: strings.TrimSpace(tenant.Name),
|
||||
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tenantAccessDeniedTenant{Identifier: identifier}
|
||||
}
|
||||
|
||||
func profileEmail(profile *domain.UserProfileResponse) string {
|
||||
if profile == nil {
|
||||
return ""
|
||||
}
|
||||
return profile.Email
|
||||
}
|
||||
|
||||
func pointerValue(value *string) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type clientStructuredScope struct {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) {
|
||||
@@ -190,6 +191,18 @@ func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) {
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, account["email"])
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, currentTenant["identifier"])
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) {
|
||||
@@ -232,6 +245,37 @@ func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *t
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: func() service.KratosAdminService {
|
||||
mockKratos := new(MockKratosAdminService)
|
||||
mockKratos.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]interface{}{
|
||||
"email": "user@test.com",
|
||||
"tenant_id": "tenant-a",
|
||||
"companyCode": "tenant-a",
|
||||
},
|
||||
}, nil).Once()
|
||||
return mockKratos
|
||||
}(),
|
||||
TenantService: func() service.TenantService {
|
||||
tenantSvc := new(MockTenantService)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
||||
ID: "tenant-a",
|
||||
Slug: "tenant-a",
|
||||
Name: "Tenant A",
|
||||
}, nil).Twice()
|
||||
tenantSvc.On("ListJoinedTenants", mock.Anything, "user-123").Return([]domain.Tenant{
|
||||
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
|
||||
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
|
||||
}, nil).Once()
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError).Once()
|
||||
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
|
||||
ID: "tenant-b-id",
|
||||
Slug: "tenant-b",
|
||||
Name: "Tenant B",
|
||||
}, nil).Once()
|
||||
return tenantSvc
|
||||
}(),
|
||||
ConsentRepo: &mockConsentRepo{
|
||||
consents: []domain.ClientConsent{
|
||||
{
|
||||
@@ -253,6 +297,22 @@ func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *t
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
assert.False(t, acceptCalled)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "user@test.com", account["email"])
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant A", currentTenant["name"])
|
||||
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, affiliatedTenants, 2)
|
||||
}
|
||||
|
||||
func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
||||
@@ -260,9 +320,15 @@ func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
||||
app.Get("/deny", func(c *fiber.Ctx) error {
|
||||
tenantID := "tenant-a"
|
||||
profile := &domain.UserProfileResponse{
|
||||
ID: "user-123",
|
||||
Role: domain.RoleUser,
|
||||
TenantID: &tenantID,
|
||||
ID: "user-123",
|
||||
Role: domain.RoleUser,
|
||||
Email: "user@test.com",
|
||||
TenantID: &tenantID,
|
||||
CompanyCode: "tenant-a",
|
||||
JoinedTenants: []domain.Tenant{
|
||||
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
|
||||
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
|
||||
},
|
||||
}
|
||||
client := domain.HydraClient{
|
||||
ClientID: "client-tenant",
|
||||
@@ -271,11 +337,49 @@ func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
||||
"allowed_tenants": []string{"tenant-b"},
|
||||
},
|
||||
}
|
||||
return enforceClientTenantAccess(c, client, profile, nil)
|
||||
tenantSvc := new(MockTenantService)
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
||||
ID: "tenant-a",
|
||||
Slug: "tenant-a",
|
||||
Name: "Tenant A",
|
||||
}, nil).Twice()
|
||||
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError).Once()
|
||||
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
|
||||
ID: "tenant-b-id",
|
||||
Slug: "tenant-b",
|
||||
Name: "Tenant B",
|
||||
}, nil).Once()
|
||||
enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/deny", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
var body map[string]any
|
||||
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||
|
||||
details, ok := body["details"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
|
||||
account, ok := details["account"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "user@test.com", account["email"])
|
||||
|
||||
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant A", currentTenant["name"])
|
||||
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, affiliatedTenants, 2)
|
||||
|
||||
allowedTenants, ok := details["allowed_tenants"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, allowedTenants, 1)
|
||||
allowedTenant, ok := allowedTenants[0].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "Tenant B", allowedTenant["name"])
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user