forked from baron/baron-sso
테넌트 접근 제한 로직 보강
This commit is contained in:
@@ -3944,6 +3944,70 @@ func (h *AuthHandler) GetMe(c *fiber.Ctx) error {
|
|||||||
return c.JSON(profile)
|
return c.JSON(profile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) resolveProfileForSubject(ctx context.Context, subject string) (*domain.UserProfileResponse, error) {
|
||||||
|
subject = strings.TrimSpace(subject)
|
||||||
|
if subject == "" || h.KratosAdmin == nil {
|
||||||
|
return nil, fmt.Errorf("subject profile unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
identity, err := h.KratosAdmin.GetIdentity(ctx, subject)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if identity == nil {
|
||||||
|
return nil, fmt.Errorf("identity not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
profile := h.mapKratosIdentityToProfile(identity.ID, identity.Traits)
|
||||||
|
if profile == nil {
|
||||||
|
return nil, fmt.Errorf("failed to map identity profile")
|
||||||
|
}
|
||||||
|
return h.hydrateResolvedProfile(ctx, profile), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) hydrateResolvedProfile(ctx context.Context, profile *domain.UserProfileResponse) *domain.UserProfileResponse {
|
||||||
|
if profile == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
profile.Role = domain.NormalizeRole(profile.Role)
|
||||||
|
if profile.Role == "" {
|
||||||
|
profile.Role = domain.RoleUser
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.TenantService != nil {
|
||||||
|
if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" {
|
||||||
|
if tenant, err := h.TenantService.GetTenant(ctx, *profile.TenantID); err == nil {
|
||||||
|
profile.Tenant = tenant
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if profile.Tenant == nil && profile.CompanyCode != "" {
|
||||||
|
if tenant, err := h.TenantService.GetTenantBySlug(ctx, profile.CompanyCode); err == nil && tenant != nil {
|
||||||
|
profile.Tenant = tenant
|
||||||
|
if profile.TenantID == nil || *profile.TenantID == "" {
|
||||||
|
profile.TenantID = &tenant.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.TenantService != nil {
|
||||||
|
if profile.Role == domain.RoleTenantAdmin {
|
||||||
|
manageable, err := h.TenantService.ListManageableTenants(ctx, profile.ID)
|
||||||
|
if err == nil {
|
||||||
|
profile.ManageableTenants = manageable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
joined, err := h.TenantService.ListJoinedTenants(ctx, profile.ID)
|
||||||
|
if err == nil {
|
||||||
|
profile.JoinedTenants = joined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return profile
|
||||||
|
}
|
||||||
|
|
||||||
// GetEnrichedProfile - Exported wrapper for resolveCurrentProfile used by middlewares
|
// GetEnrichedProfile - Exported wrapper for resolveCurrentProfile used by middlewares
|
||||||
func (h *AuthHandler) GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) {
|
func (h *AuthHandler) GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) {
|
||||||
return h.resolveCurrentProfile(c)
|
return h.resolveCurrentProfile(c)
|
||||||
@@ -5132,8 +5196,14 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
|
|||||||
)
|
)
|
||||||
|
|
||||||
profile, err := h.resolveCurrentProfile(c)
|
profile, err := h.resolveCurrentProfile(c)
|
||||||
if tenantErr := enforceClientTenantAccess(c, consentRequest.Client, profile, err); tenantErr != nil {
|
if (err != nil || profile == nil) && consentRequest.Subject != "" {
|
||||||
return tenantErr
|
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), consentRequest.Subject); fallbackErr == nil {
|
||||||
|
profile = fallbackProfile
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if enforceClientTenantAccess(c, h.TenantService, consentRequest.Client, profile, err) {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// [New] 로컬 DB에서 기존 동의 내역 확인 (강제 자동 승인 전략)
|
// [New] 로컬 DB에서 기존 동의 내역 확인 (강제 자동 승인 전략)
|
||||||
@@ -5342,8 +5412,14 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error {
|
|||||||
consentRequest.RequestedScope = mergeRequestedScopesWithClientRequirements(consentRequest.Client, consentRequest.RequestedScope)
|
consentRequest.RequestedScope = mergeRequestedScopesWithClientRequirements(consentRequest.Client, consentRequest.RequestedScope)
|
||||||
|
|
||||||
profile, err := h.resolveCurrentProfile(c)
|
profile, err := h.resolveCurrentProfile(c)
|
||||||
if tenantErr := enforceClientTenantAccess(c, consentRequest.Client, profile, err); tenantErr != nil {
|
if (err != nil || profile == nil) && consentRequest.Subject != "" {
|
||||||
return tenantErr
|
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), consentRequest.Subject); fallbackErr == nil {
|
||||||
|
profile = fallbackProfile
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if enforceClientTenantAccess(c, h.TenantService, consentRequest.Client, profile, err) {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Hydra에 승인 요청
|
// 3. Hydra에 승인 요청
|
||||||
@@ -5484,9 +5560,15 @@ func (h *AuthHandler) AcceptOidcLoginRequest(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
profile, err := h.resolveCurrentProfile(c)
|
profile, err := h.resolveCurrentProfile(c)
|
||||||
|
if (err != nil || profile == nil) && loginReq != nil && strings.TrimSpace(loginReq.Subject) != "" {
|
||||||
|
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), loginReq.Subject); fallbackErr == nil {
|
||||||
|
profile = fallbackProfile
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
if loginReq != nil {
|
if loginReq != nil {
|
||||||
if tenantErr := enforceClientTenantAccess(c, loginReq.Client, profile, err); tenantErr != nil {
|
if enforceClientTenantAccess(c, h.TenantService, loginReq.Client, profile, err) {
|
||||||
return tenantErr
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5631,37 +5713,7 @@ func (h *AuthHandler) resolveCurrentProfile(c *fiber.Ctx) (*domain.UserProfileRe
|
|||||||
delete(profile.Metadata, "_used_identifier") // Cleanup
|
delete(profile.Metadata, "_used_identifier") // Cleanup
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch Tenant Metadata if missing
|
profile = h.hydrateResolvedProfile(c.Context(), profile)
|
||||||
if h.TenantService != nil {
|
|
||||||
if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" {
|
|
||||||
if tenant, err := h.TenantService.GetTenant(c.Context(), *profile.TenantID); err == nil {
|
|
||||||
profile.Tenant = tenant
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if profile.Tenant == nil && profile.CompanyCode != "" {
|
|
||||||
if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), profile.CompanyCode); err == nil && tenant != nil {
|
|
||||||
profile.Tenant = tenant
|
|
||||||
if profile.TenantID == nil || *profile.TenantID == "" {
|
|
||||||
profile.TenantID = &tenant.ID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// [New] Fetch manageable and joined tenants
|
|
||||||
if h.TenantService != nil {
|
|
||||||
if profile.Role == domain.RoleTenantAdmin {
|
|
||||||
manageable, err := h.TenantService.ListManageableTenants(c.Context(), profile.ID)
|
|
||||||
if err == nil {
|
|
||||||
profile.ManageableTenants = manageable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
joined, err := h.TenantService.ListJoinedTenants(c.Context(), profile.ID)
|
|
||||||
if err == nil {
|
|
||||||
profile.JoinedTenants = joined
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. Save to Redis Cache (Short TTL)
|
// 4. Save to Redis Cache (Short TTL)
|
||||||
// IMPORTANT: In dev mode, if role was overridden, we should NOT cache it under the token key
|
// IMPORTANT: In dev mode, if role was overridden, we should NOT cache it under the token key
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"baron-sso-backend/internal/domain"
|
"baron-sso-backend/internal/domain"
|
||||||
|
"baron-sso-backend/internal/response"
|
||||||
|
"baron-sso-backend/internal/service"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -133,8 +135,32 @@ func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domai
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func tenantNotAllowedError(c *fiber.Ctx) error {
|
type tenantAccessDeniedDetails struct {
|
||||||
return errorJSONCode(c, fiber.StatusForbidden, "tenant_not_allowed", "허용되지 않은 테넌트입니다.")
|
Account tenantAccessDeniedAccount `json:"account"`
|
||||||
|
CurrentTenant tenantAccessDeniedTenant `json:"current_tenant"`
|
||||||
|
AffiliatedTenants []tenantAccessDeniedTenant `json:"affiliated_tenants,omitempty"`
|
||||||
|
AllowedTenants []tenantAccessDeniedTenant `json:"allowed_tenants,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tenantAccessDeniedAccount struct {
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tenantAccessDeniedTenant struct {
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Slug string `json:"slug,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Identifier string `json:"identifier,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func tenantNotAllowedError(c *fiber.Ctx, details tenantAccessDeniedDetails) error {
|
||||||
|
return response.ErrorWithDetails(
|
||||||
|
c,
|
||||||
|
fiber.StatusForbidden,
|
||||||
|
"tenant_not_allowed",
|
||||||
|
"허용되지 않은 테넌트입니다.",
|
||||||
|
details,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
|
||||||
@@ -144,17 +170,162 @@ func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client dom
|
|||||||
return clientTenantAccessAllowed(profile, client)
|
return clientTenantAccessAllowed(profile, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
func enforceClientTenantAccess(c *fiber.Ctx, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) error {
|
func enforceClientTenantAccess(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) bool {
|
||||||
if !clientTenantAccessRestricted(client.Metadata) {
|
if !clientTenantAccessRestricted(client.Metadata) {
|
||||||
return nil
|
return false
|
||||||
}
|
}
|
||||||
|
details := buildTenantAccessDeniedDetails(c, tenantSvc, client, profile)
|
||||||
if resolveErr != nil || profile == nil {
|
if resolveErr != nil || profile == nil {
|
||||||
return tenantNotAllowedError(c)
|
_ = tenantNotAllowedError(c, details)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
if !clientTenantAccessAllowed(profile, client) {
|
if !clientTenantAccessAllowed(profile, client) {
|
||||||
return tenantNotAllowedError(c)
|
_ = tenantNotAllowedError(c, details)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
return nil
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTenantAccessDeniedDetails(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse) tenantAccessDeniedDetails {
|
||||||
|
details := tenantAccessDeniedDetails{
|
||||||
|
Account: tenantAccessDeniedAccount{Email: strings.TrimSpace(profileEmail(profile))},
|
||||||
|
CurrentTenant: resolveCurrentTenantDetails(c, tenantSvc, profile),
|
||||||
|
AffiliatedTenants: resolveAffiliatedTenantDetails(c, tenantSvc, profile),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, identifier := range clientAllowedTenants(client.Metadata) {
|
||||||
|
details.AllowedTenants = append(details.AllowedTenants, resolveAllowedTenantDetails(c, tenantSvc, identifier))
|
||||||
|
}
|
||||||
|
|
||||||
|
return details
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveAffiliatedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) []tenantAccessDeniedTenant {
|
||||||
|
if profile == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
out := make([]tenantAccessDeniedTenant, 0, len(profile.JoinedTenants)+1)
|
||||||
|
appendTenant := func(tenant tenantAccessDeniedTenant) {
|
||||||
|
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Identifier, tenant.Name))
|
||||||
|
if key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
out = append(out, tenant)
|
||||||
|
}
|
||||||
|
|
||||||
|
appendTenant(resolveCurrentTenantDetails(c, tenantSvc, profile))
|
||||||
|
|
||||||
|
for _, joined := range profile.JoinedTenants {
|
||||||
|
appendTenant(tenantAccessDeniedTenant{
|
||||||
|
ID: strings.TrimSpace(joined.ID),
|
||||||
|
Slug: strings.TrimSpace(joined.Slug),
|
||||||
|
Name: strings.TrimSpace(joined.Name),
|
||||||
|
Identifier: firstNonEmptyString(strings.TrimSpace(joined.Slug), strings.TrimSpace(joined.ID)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveCurrentTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) tenantAccessDeniedTenant {
|
||||||
|
if profile == nil {
|
||||||
|
return tenantAccessDeniedTenant{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if profile.Tenant != nil {
|
||||||
|
return tenantAccessDeniedTenant{
|
||||||
|
ID: strings.TrimSpace(profile.Tenant.ID),
|
||||||
|
Slug: strings.TrimSpace(profile.Tenant.Slug),
|
||||||
|
Name: strings.TrimSpace(profile.Tenant.Name),
|
||||||
|
Identifier: firstNonEmptyString(strings.TrimSpace(profile.Tenant.Slug), strings.TrimSpace(profile.Tenant.ID)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tenantSvc != nil {
|
||||||
|
if profile.TenantID != nil && strings.TrimSpace(*profile.TenantID) != "" {
|
||||||
|
if tenant, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(*profile.TenantID)); err == nil && tenant != nil {
|
||||||
|
return tenantAccessDeniedTenant{
|
||||||
|
ID: strings.TrimSpace(tenant.ID),
|
||||||
|
Slug: strings.TrimSpace(tenant.Slug),
|
||||||
|
Name: strings.TrimSpace(tenant.Name),
|
||||||
|
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(profile.CompanyCode) != "" {
|
||||||
|
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), strings.TrimSpace(profile.CompanyCode)); err == nil && tenant != nil {
|
||||||
|
return tenantAccessDeniedTenant{
|
||||||
|
ID: strings.TrimSpace(tenant.ID),
|
||||||
|
Slug: strings.TrimSpace(tenant.Slug),
|
||||||
|
Name: strings.TrimSpace(tenant.Name),
|
||||||
|
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tenantAccessDeniedTenant{
|
||||||
|
ID: strings.TrimSpace(pointerValue(profile.TenantID)),
|
||||||
|
Slug: strings.TrimSpace(profile.CompanyCode),
|
||||||
|
Identifier: firstNonEmptyString(strings.TrimSpace(profile.CompanyCode), strings.TrimSpace(pointerValue(profile.TenantID))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveAllowedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, identifier string) tenantAccessDeniedTenant {
|
||||||
|
identifier = strings.TrimSpace(identifier)
|
||||||
|
if identifier == "" {
|
||||||
|
return tenantAccessDeniedTenant{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tenantSvc != nil {
|
||||||
|
if tenant, err := tenantSvc.GetTenant(c.Context(), identifier); err == nil && tenant != nil {
|
||||||
|
return tenantAccessDeniedTenant{
|
||||||
|
ID: strings.TrimSpace(tenant.ID),
|
||||||
|
Slug: strings.TrimSpace(tenant.Slug),
|
||||||
|
Name: strings.TrimSpace(tenant.Name),
|
||||||
|
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), identifier); err == nil && tenant != nil {
|
||||||
|
return tenantAccessDeniedTenant{
|
||||||
|
ID: strings.TrimSpace(tenant.ID),
|
||||||
|
Slug: strings.TrimSpace(tenant.Slug),
|
||||||
|
Name: strings.TrimSpace(tenant.Name),
|
||||||
|
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tenantAccessDeniedTenant{Identifier: identifier}
|
||||||
|
}
|
||||||
|
|
||||||
|
func profileEmail(profile *domain.UserProfileResponse) string {
|
||||||
|
if profile == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return profile.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
func pointerValue(value *string) string {
|
||||||
|
if value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *value
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmptyString(values ...string) string {
|
||||||
|
for _, value := range values {
|
||||||
|
if strings.TrimSpace(value) != "" {
|
||||||
|
return strings.TrimSpace(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientStructuredScope struct {
|
type clientStructuredScope struct {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) {
|
func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) {
|
||||||
@@ -190,6 +191,18 @@ func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) {
|
|||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||||
|
details, ok := body["details"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
account, ok := details["account"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.NotEmpty(t, account["email"])
|
||||||
|
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.NotEmpty(t, currentTenant["identifier"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) {
|
func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) {
|
||||||
@@ -232,6 +245,37 @@ func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *t
|
|||||||
AdminURL: "http://hydra.test",
|
AdminURL: "http://hydra.test",
|
||||||
HTTPClient: client,
|
HTTPClient: client,
|
||||||
},
|
},
|
||||||
|
KratosAdmin: func() service.KratosAdminService {
|
||||||
|
mockKratos := new(MockKratosAdminService)
|
||||||
|
mockKratos.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||||
|
ID: "user-123",
|
||||||
|
Traits: map[string]interface{}{
|
||||||
|
"email": "user@test.com",
|
||||||
|
"tenant_id": "tenant-a",
|
||||||
|
"companyCode": "tenant-a",
|
||||||
|
},
|
||||||
|
}, nil).Once()
|
||||||
|
return mockKratos
|
||||||
|
}(),
|
||||||
|
TenantService: func() service.TenantService {
|
||||||
|
tenantSvc := new(MockTenantService)
|
||||||
|
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
||||||
|
ID: "tenant-a",
|
||||||
|
Slug: "tenant-a",
|
||||||
|
Name: "Tenant A",
|
||||||
|
}, nil).Twice()
|
||||||
|
tenantSvc.On("ListJoinedTenants", mock.Anything, "user-123").Return([]domain.Tenant{
|
||||||
|
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
|
||||||
|
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
|
||||||
|
}, nil).Once()
|
||||||
|
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError).Once()
|
||||||
|
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
|
||||||
|
ID: "tenant-b-id",
|
||||||
|
Slug: "tenant-b",
|
||||||
|
Name: "Tenant B",
|
||||||
|
}, nil).Once()
|
||||||
|
return tenantSvc
|
||||||
|
}(),
|
||||||
ConsentRepo: &mockConsentRepo{
|
ConsentRepo: &mockConsentRepo{
|
||||||
consents: []domain.ClientConsent{
|
consents: []domain.ClientConsent{
|
||||||
{
|
{
|
||||||
@@ -253,6 +297,22 @@ func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *t
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||||
assert.False(t, acceptCalled)
|
assert.False(t, acceptCalled)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||||
|
|
||||||
|
details, ok := body["details"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
account, ok := details["account"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "user@test.com", account["email"])
|
||||||
|
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "Tenant A", currentTenant["name"])
|
||||||
|
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Len(t, affiliatedTenants, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
||||||
@@ -260,9 +320,15 @@ func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
|||||||
app.Get("/deny", func(c *fiber.Ctx) error {
|
app.Get("/deny", func(c *fiber.Ctx) error {
|
||||||
tenantID := "tenant-a"
|
tenantID := "tenant-a"
|
||||||
profile := &domain.UserProfileResponse{
|
profile := &domain.UserProfileResponse{
|
||||||
ID: "user-123",
|
ID: "user-123",
|
||||||
Role: domain.RoleUser,
|
Role: domain.RoleUser,
|
||||||
TenantID: &tenantID,
|
Email: "user@test.com",
|
||||||
|
TenantID: &tenantID,
|
||||||
|
CompanyCode: "tenant-a",
|
||||||
|
JoinedTenants: []domain.Tenant{
|
||||||
|
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
|
||||||
|
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
client := domain.HydraClient{
|
client := domain.HydraClient{
|
||||||
ClientID: "client-tenant",
|
ClientID: "client-tenant",
|
||||||
@@ -271,11 +337,49 @@ func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
|
|||||||
"allowed_tenants": []string{"tenant-b"},
|
"allowed_tenants": []string{"tenant-b"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return enforceClientTenantAccess(c, client, profile, nil)
|
tenantSvc := new(MockTenantService)
|
||||||
|
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
|
||||||
|
ID: "tenant-a",
|
||||||
|
Slug: "tenant-a",
|
||||||
|
Name: "Tenant A",
|
||||||
|
}, nil).Twice()
|
||||||
|
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError).Once()
|
||||||
|
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
|
||||||
|
ID: "tenant-b-id",
|
||||||
|
Slug: "tenant-b",
|
||||||
|
Name: "Tenant B",
|
||||||
|
}, nil).Once()
|
||||||
|
enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/deny", nil)
|
req := httptest.NewRequest(http.MethodGet, "/deny", nil)
|
||||||
resp, err := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
assert.Equal(t, "tenant_not_allowed", body["code"])
|
||||||
|
|
||||||
|
details, ok := body["details"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
account, ok := details["account"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "user@test.com", account["email"])
|
||||||
|
|
||||||
|
currentTenant, ok := details["current_tenant"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "Tenant A", currentTenant["name"])
|
||||||
|
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Len(t, affiliatedTenants, 2)
|
||||||
|
|
||||||
|
allowedTenants, ok := details["allowed_tenants"].([]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Len(t, allowedTenants, 1)
|
||||||
|
allowedTenant, ok := allowedTenants[0].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "Tenant B", allowedTenant["name"])
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user