1
0
forked from baron/baron-sso

테넌트 접근 제한 로직 보강

This commit is contained in:
2026-04-28 10:57:16 +09:00
parent 367368805a
commit 955128a25a
3 changed files with 375 additions and 48 deletions

View File

@@ -3944,6 +3944,70 @@ func (h *AuthHandler) GetMe(c *fiber.Ctx) error {
return c.JSON(profile)
}
func (h *AuthHandler) resolveProfileForSubject(ctx context.Context, subject string) (*domain.UserProfileResponse, error) {
subject = strings.TrimSpace(subject)
if subject == "" || h.KratosAdmin == nil {
return nil, fmt.Errorf("subject profile unavailable")
}
identity, err := h.KratosAdmin.GetIdentity(ctx, subject)
if err != nil {
return nil, err
}
if identity == nil {
return nil, fmt.Errorf("identity not found")
}
profile := h.mapKratosIdentityToProfile(identity.ID, identity.Traits)
if profile == nil {
return nil, fmt.Errorf("failed to map identity profile")
}
return h.hydrateResolvedProfile(ctx, profile), nil
}
func (h *AuthHandler) hydrateResolvedProfile(ctx context.Context, profile *domain.UserProfileResponse) *domain.UserProfileResponse {
if profile == nil {
return nil
}
profile.Role = domain.NormalizeRole(profile.Role)
if profile.Role == "" {
profile.Role = domain.RoleUser
}
if h.TenantService != nil {
if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" {
if tenant, err := h.TenantService.GetTenant(ctx, *profile.TenantID); err == nil {
profile.Tenant = tenant
}
}
if profile.Tenant == nil && profile.CompanyCode != "" {
if tenant, err := h.TenantService.GetTenantBySlug(ctx, profile.CompanyCode); err == nil && tenant != nil {
profile.Tenant = tenant
if profile.TenantID == nil || *profile.TenantID == "" {
profile.TenantID = &tenant.ID
}
}
}
}
if h.TenantService != nil {
if profile.Role == domain.RoleTenantAdmin {
manageable, err := h.TenantService.ListManageableTenants(ctx, profile.ID)
if err == nil {
profile.ManageableTenants = manageable
}
}
joined, err := h.TenantService.ListJoinedTenants(ctx, profile.ID)
if err == nil {
profile.JoinedTenants = joined
}
}
return profile
}
// GetEnrichedProfile - Exported wrapper for resolveCurrentProfile used by middlewares
func (h *AuthHandler) GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error) {
return h.resolveCurrentProfile(c)
@@ -5132,8 +5196,14 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
)
profile, err := h.resolveCurrentProfile(c)
if tenantErr := enforceClientTenantAccess(c, consentRequest.Client, profile, err); tenantErr != nil {
return tenantErr
if (err != nil || profile == nil) && consentRequest.Subject != "" {
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), consentRequest.Subject); fallbackErr == nil {
profile = fallbackProfile
err = nil
}
}
if enforceClientTenantAccess(c, h.TenantService, consentRequest.Client, profile, err) {
return nil
}
// [New] 로컬 DB에서 기존 동의 내역 확인 (강제 자동 승인 전략)
@@ -5342,8 +5412,14 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error {
consentRequest.RequestedScope = mergeRequestedScopesWithClientRequirements(consentRequest.Client, consentRequest.RequestedScope)
profile, err := h.resolveCurrentProfile(c)
if tenantErr := enforceClientTenantAccess(c, consentRequest.Client, profile, err); tenantErr != nil {
return tenantErr
if (err != nil || profile == nil) && consentRequest.Subject != "" {
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), consentRequest.Subject); fallbackErr == nil {
profile = fallbackProfile
err = nil
}
}
if enforceClientTenantAccess(c, h.TenantService, consentRequest.Client, profile, err) {
return nil
}
// 3. Hydra에 승인 요청
@@ -5484,9 +5560,15 @@ func (h *AuthHandler) AcceptOidcLoginRequest(c *fiber.Ctx) error {
}
profile, err := h.resolveCurrentProfile(c)
if (err != nil || profile == nil) && loginReq != nil && strings.TrimSpace(loginReq.Subject) != "" {
if fallbackProfile, fallbackErr := h.resolveProfileForSubject(c.Context(), loginReq.Subject); fallbackErr == nil {
profile = fallbackProfile
err = nil
}
}
if loginReq != nil {
if tenantErr := enforceClientTenantAccess(c, loginReq.Client, profile, err); tenantErr != nil {
return tenantErr
if enforceClientTenantAccess(c, h.TenantService, loginReq.Client, profile, err) {
return nil
}
}
@@ -5631,37 +5713,7 @@ func (h *AuthHandler) resolveCurrentProfile(c *fiber.Ctx) (*domain.UserProfileRe
delete(profile.Metadata, "_used_identifier") // Cleanup
}
// Fetch Tenant Metadata if missing
if h.TenantService != nil {
if profile.Tenant == nil && profile.TenantID != nil && *profile.TenantID != "" {
if tenant, err := h.TenantService.GetTenant(c.Context(), *profile.TenantID); err == nil {
profile.Tenant = tenant
}
}
if profile.Tenant == nil && profile.CompanyCode != "" {
if tenant, err := h.TenantService.GetTenantBySlug(c.Context(), profile.CompanyCode); err == nil && tenant != nil {
profile.Tenant = tenant
if profile.TenantID == nil || *profile.TenantID == "" {
profile.TenantID = &tenant.ID
}
}
}
}
// [New] Fetch manageable and joined tenants
if h.TenantService != nil {
if profile.Role == domain.RoleTenantAdmin {
manageable, err := h.TenantService.ListManageableTenants(c.Context(), profile.ID)
if err == nil {
profile.ManageableTenants = manageable
}
}
joined, err := h.TenantService.ListJoinedTenants(c.Context(), profile.ID)
if err == nil {
profile.JoinedTenants = joined
}
}
profile = h.hydrateResolvedProfile(c.Context(), profile)
// 4. Save to Redis Cache (Short TTL)
// IMPORTANT: In dev mode, if role was overridden, we should NOT cache it under the token key

View File

@@ -2,6 +2,8 @@ package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/response"
"baron-sso-backend/internal/service"
"encoding/json"
"errors"
"sort"
@@ -133,8 +135,32 @@ func clientTenantAccessAllowed(profile *domain.UserProfileResponse, client domai
return false
}
func tenantNotAllowedError(c *fiber.Ctx) error {
return errorJSONCode(c, fiber.StatusForbidden, "tenant_not_allowed", "허용되지 않은 테넌트입니다.")
type tenantAccessDeniedDetails struct {
Account tenantAccessDeniedAccount `json:"account"`
CurrentTenant tenantAccessDeniedTenant `json:"current_tenant"`
AffiliatedTenants []tenantAccessDeniedTenant `json:"affiliated_tenants,omitempty"`
AllowedTenants []tenantAccessDeniedTenant `json:"allowed_tenants,omitempty"`
}
type tenantAccessDeniedAccount struct {
Email string `json:"email,omitempty"`
}
type tenantAccessDeniedTenant struct {
ID string `json:"id,omitempty"`
Slug string `json:"slug,omitempty"`
Name string `json:"name,omitempty"`
Identifier string `json:"identifier,omitempty"`
}
func tenantNotAllowedError(c *fiber.Ctx, details tenantAccessDeniedDetails) error {
return response.ErrorWithDetails(
c,
fiber.StatusForbidden,
"tenant_not_allowed",
"허용되지 않은 테넌트입니다.",
details,
)
}
func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client domain.HydraClient) bool {
@@ -144,17 +170,162 @@ func isClientTenantAccessAllowed(profile *domain.UserProfileResponse, client dom
return clientTenantAccessAllowed(profile, client)
}
func enforceClientTenantAccess(c *fiber.Ctx, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) error {
func enforceClientTenantAccess(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse, resolveErr error) bool {
if !clientTenantAccessRestricted(client.Metadata) {
return nil
return false
}
details := buildTenantAccessDeniedDetails(c, tenantSvc, client, profile)
if resolveErr != nil || profile == nil {
return tenantNotAllowedError(c)
_ = tenantNotAllowedError(c, details)
return true
}
if !clientTenantAccessAllowed(profile, client) {
return tenantNotAllowedError(c)
_ = tenantNotAllowedError(c, details)
return true
}
return nil
return false
}
func buildTenantAccessDeniedDetails(c *fiber.Ctx, tenantSvc service.TenantService, client domain.HydraClient, profile *domain.UserProfileResponse) tenantAccessDeniedDetails {
details := tenantAccessDeniedDetails{
Account: tenantAccessDeniedAccount{Email: strings.TrimSpace(profileEmail(profile))},
CurrentTenant: resolveCurrentTenantDetails(c, tenantSvc, profile),
AffiliatedTenants: resolveAffiliatedTenantDetails(c, tenantSvc, profile),
}
for _, identifier := range clientAllowedTenants(client.Metadata) {
details.AllowedTenants = append(details.AllowedTenants, resolveAllowedTenantDetails(c, tenantSvc, identifier))
}
return details
}
func resolveAffiliatedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) []tenantAccessDeniedTenant {
if profile == nil {
return nil
}
seen := make(map[string]struct{})
out := make([]tenantAccessDeniedTenant, 0, len(profile.JoinedTenants)+1)
appendTenant := func(tenant tenantAccessDeniedTenant) {
key := strings.ToLower(firstNonEmptyString(tenant.ID, tenant.Slug, tenant.Identifier, tenant.Name))
if key == "" {
return
}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
out = append(out, tenant)
}
appendTenant(resolveCurrentTenantDetails(c, tenantSvc, profile))
for _, joined := range profile.JoinedTenants {
appendTenant(tenantAccessDeniedTenant{
ID: strings.TrimSpace(joined.ID),
Slug: strings.TrimSpace(joined.Slug),
Name: strings.TrimSpace(joined.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(joined.Slug), strings.TrimSpace(joined.ID)),
})
}
return out
}
func resolveCurrentTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, profile *domain.UserProfileResponse) tenantAccessDeniedTenant {
if profile == nil {
return tenantAccessDeniedTenant{}
}
if profile.Tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(profile.Tenant.ID),
Slug: strings.TrimSpace(profile.Tenant.Slug),
Name: strings.TrimSpace(profile.Tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(profile.Tenant.Slug), strings.TrimSpace(profile.Tenant.ID)),
}
}
if tenantSvc != nil {
if profile.TenantID != nil && strings.TrimSpace(*profile.TenantID) != "" {
if tenant, err := tenantSvc.GetTenant(c.Context(), strings.TrimSpace(*profile.TenantID)); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
}
}
}
if strings.TrimSpace(profile.CompanyCode) != "" {
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), strings.TrimSpace(profile.CompanyCode)); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID)),
}
}
}
}
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(pointerValue(profile.TenantID)),
Slug: strings.TrimSpace(profile.CompanyCode),
Identifier: firstNonEmptyString(strings.TrimSpace(profile.CompanyCode), strings.TrimSpace(pointerValue(profile.TenantID))),
}
}
func resolveAllowedTenantDetails(c *fiber.Ctx, tenantSvc service.TenantService, identifier string) tenantAccessDeniedTenant {
identifier = strings.TrimSpace(identifier)
if identifier == "" {
return tenantAccessDeniedTenant{}
}
if tenantSvc != nil {
if tenant, err := tenantSvc.GetTenant(c.Context(), identifier); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
}
}
if tenant, err := tenantSvc.GetTenantBySlug(c.Context(), identifier); err == nil && tenant != nil {
return tenantAccessDeniedTenant{
ID: strings.TrimSpace(tenant.ID),
Slug: strings.TrimSpace(tenant.Slug),
Name: strings.TrimSpace(tenant.Name),
Identifier: firstNonEmptyString(strings.TrimSpace(tenant.Slug), strings.TrimSpace(tenant.ID), identifier),
}
}
}
return tenantAccessDeniedTenant{Identifier: identifier}
}
func profileEmail(profile *domain.UserProfileResponse) string {
if profile == nil {
return ""
}
return profile.Email
}
func pointerValue(value *string) string {
if value == nil {
return ""
}
return *value
}
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
type clientStructuredScope struct {

View File

@@ -12,6 +12,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestCreateClient_NormalizesTenantAccessMetadata(t *testing.T) {
@@ -190,6 +191,18 @@ func TestGetConsentRequest_DeniesTenantAccess(t *testing.T) {
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.NotEmpty(t, account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.NotEmpty(t, currentTenant["identifier"])
}
func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *testing.T) {
@@ -232,6 +245,37 @@ func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *t
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: func() service.KratosAdminService {
mockKratos := new(MockKratosAdminService)
mockKratos.On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]interface{}{
"email": "user@test.com",
"tenant_id": "tenant-a",
"companyCode": "tenant-a",
},
}, nil).Once()
return mockKratos
}(),
TenantService: func() service.TenantService {
tenantSvc := new(MockTenantService)
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
ID: "tenant-a",
Slug: "tenant-a",
Name: "Tenant A",
}, nil).Twice()
tenantSvc.On("ListJoinedTenants", mock.Anything, "user-123").Return([]domain.Tenant{
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
}, nil).Once()
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError).Once()
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
ID: "tenant-b-id",
Slug: "tenant-b",
Name: "Tenant B",
}, nil).Once()
return tenantSvc
}(),
ConsentRepo: &mockConsentRepo{
consents: []domain.ClientConsent{
{
@@ -253,6 +297,22 @@ func TestGetConsentRequest_DeniesRestrictedClientWhenProfileResolutionFails(t *t
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assert.False(t, acceptCalled)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "user@test.com", account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant A", currentTenant["name"])
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, affiliatedTenants, 2)
}
func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
@@ -260,9 +320,15 @@ func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
app.Get("/deny", func(c *fiber.Ctx) error {
tenantID := "tenant-a"
profile := &domain.UserProfileResponse{
ID: "user-123",
Role: domain.RoleUser,
TenantID: &tenantID,
ID: "user-123",
Role: domain.RoleUser,
Email: "user@test.com",
TenantID: &tenantID,
CompanyCode: "tenant-a",
JoinedTenants: []domain.Tenant{
{ID: "tenant-a", Slug: "tenant-a", Name: "Tenant A"},
{ID: "tenant-c", Slug: "tenant-c", Name: "Tenant C"},
},
}
client := domain.HydraClient{
ClientID: "client-tenant",
@@ -271,11 +337,49 @@ func TestAcceptOidcLoginRequest_DeniesTenantAccess(t *testing.T) {
"allowed_tenants": []string{"tenant-b"},
},
}
return enforceClientTenantAccess(c, client, profile, nil)
tenantSvc := new(MockTenantService)
tenantSvc.On("GetTenant", mock.Anything, "tenant-a").Return(&domain.Tenant{
ID: "tenant-a",
Slug: "tenant-a",
Name: "Tenant A",
}, nil).Twice()
tenantSvc.On("GetTenant", mock.Anything, "tenant-b").Return(nil, assert.AnError).Once()
tenantSvc.On("GetTenantBySlug", mock.Anything, "tenant-b").Return(&domain.Tenant{
ID: "tenant-b-id",
Slug: "tenant-b",
Name: "Tenant B",
}, nil).Once()
enforceClientTenantAccess(c, tenantSvc, client, profile, nil)
return nil
})
req := httptest.NewRequest(http.MethodGet, "/deny", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
var body map[string]any
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "tenant_not_allowed", body["code"])
details, ok := body["details"].(map[string]any)
assert.True(t, ok)
account, ok := details["account"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "user@test.com", account["email"])
currentTenant, ok := details["current_tenant"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant A", currentTenant["name"])
affiliatedTenants, ok := details["affiliated_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, affiliatedTenants, 2)
allowedTenants, ok := details["allowed_tenants"].([]any)
assert.True(t, ok)
assert.Len(t, allowedTenants, 1)
allowedTenant, ok := allowedTenants[0].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "Tenant B", allowedTenant["name"])
}