1
0
forked from baron/baron-sso

조직도 M2M조회 추가, 자동로그인 보완

This commit is contained in:
2026-05-13 13:44:30 +09:00
parent 72288f1d39
commit 8c2b2f71ef
29 changed files with 2985 additions and 81 deletions

View File

@@ -704,7 +704,7 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
return errorJSON(c, fiber.StatusInternalServerError, "Identity provider unavailable")
}
// [New Policy] Enforce Explicit Tenant Assignment (No Auto-Provisioning)
// 소속이 비어 있는 일반 가입자는 PERSONAL tenant를 자동 생성해 대표소속을 보장합니다.
companyCode := ""
var tenantID *string
@@ -765,6 +765,14 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
if tenantID == nil && req.AffiliationType == "AFFILIATE" {
return errorJSON(c, fiber.StatusBadRequest, "We couldn't verify your organization affiliation. Please check your choice.")
}
if tenantID == nil && req.AffiliationType == "GENERAL" {
tenant, err := createPersonalTenantForUser(c.Context(), h.TenantService, req.Email)
if err != nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "failed to create personal tenant")
}
companyCode = tenant.Slug
tenantID = &tenant.ID
}
// Normalize Phone (E.164 형태로 보관)
normalizedPhone := strings.ReplaceAll(req.Phone, "-", "")
@@ -785,6 +793,9 @@ func (h *AuthHandler) Signup(c *fiber.Ctx) error {
"grade": "",
"role": domain.RoleUser,
}
if tenantID != nil {
attributes["tenant_id"] = *tenantID
}
// Sync all custom login IDs based on tenant schemas
loginIDRecords := syncCustomLoginIDs(c.Context(), h.TenantService, attributes, req.Metadata, "")
@@ -1100,6 +1111,10 @@ func buildOidcClaimsFromTraits(traits map[string]any, scopes []string, tenantID
if traits == nil {
return claims
}
if tenantID == "" {
tenantID = representativeTenantIDFromTraits(traits)
}
includeTenantDetails := tenantClaimScopeRequested(scopes)
scopeSet := map[string]struct{}{}
for _, scope := range scopes {
@@ -1200,53 +1215,38 @@ func buildOidcClaimsFromTraits(traits map[string]any, scopes []string, tenantID
// [New] Dynamic Claim Injection for Multi-tenancy
if tenantID != "" {
claims["tenant_id"] = tenantID
// Extract namespaced metadata if available
// The key in traits is expected to be the tenantID
if namespaced, ok := traits[tenantID].(map[string]any); ok {
for k, v := range namespaced {
claims[k] = v
}
} else if namespaced, ok := traits[tenantID].(map[string]interface{}); ok {
for k, v := range namespaced {
claims[k] = v
if includeTenantDetails {
// tenant 스코프가 있을 때만 대표소속 namespace metadata를 top-level claim으로 펼칩니다.
if namespaced, ok := traits[tenantID].(map[string]any); ok {
for k, v := range namespaced {
claims[k] = v
}
}
}
}
// [Update] Pass ALL tenants the user belongs to
allTenants := map[string]any{}
var joinedTenants []string
joinedTenants := joinedTenantIDsFromTraits(traits, tenantID)
// Heuristic: if a trait value is a map, it's treated as namespaced metadata for a tenant
for k, v := range traits {
if k == "metadata" {
continue
}
if m, ok := v.(map[string]any); ok {
allTenants[k] = m
joinedTenants = append(joinedTenants, k)
} else if m, ok := v.(map[string]interface{}); ok {
allTenants[k] = m
joinedTenants = append(joinedTenants, k)
}
}
// [Fix] Include primary tenant_id in joined_tenants if it's not already there
if primaryTenantID := getString("tenant_id"); primaryTenantID != "" {
found := false
for _, id := range joinedTenants {
if id == primaryTenantID {
found = true
break
}
}
if !found {
joinedTenants = append(joinedTenants, primaryTenantID)
}
}
if len(allTenants) > 0 || len(joinedTenants) > 0 {
claims["tenants"] = allTenants
if len(joinedTenants) > 0 {
claims["joined_tenants"] = joinedTenants
}
if includeTenantDetails && len(allTenants) > 0 {
claims["tenants"] = allTenants
}
return claims
}
@@ -1268,6 +1268,311 @@ func composeOIDCSessionClaims(client domain.HydraClient, traits map[string]any,
return withOidcSessionMetadata(claims, sessionID)
}
func (h *AuthHandler) withHanmacFamilyTenantClaims(ctx context.Context, claims map[string]any, traits map[string]any, scopes []string) map[string]any {
if claims == nil {
claims = map[string]any{}
}
if h == nil || h.TenantService == nil {
return claims
}
appointments := tenantClaimAppointmentsFromTraits(traits)
includeTenantDetails := tenantClaimScopeRequested(scopes)
tenants, hadTenantClaims := claims["tenants"].(map[string]any)
if !hadTenantClaims {
tenants = map[string]any{}
}
createdTenantClaims := map[string]bool{}
if tenantID := tenantClaimString(claims, "tenant_id"); tenantID != "" {
if _, exists := tenants[tenantID]; !exists {
tenants[tenantID] = map[string]any{}
createdTenantClaims[tenantID] = true
}
}
for _, tenantKey := range tenantClaimAppointmentPrimaryKeys(appointments) {
if _, exists := tenants[tenantKey]; !exists {
tenants[tenantKey] = map[string]any{}
createdTenantClaims[tenantKey] = true
}
}
if len(tenants) == 0 {
return claims
}
leadTenantIDs := make([]string, 0)
joinedTenantIDs := make([]string, 0)
for tenantKey, rawTenantClaim := range tenants {
tenantClaim, ok := rawTenantClaim.(map[string]any)
if !ok {
continue
}
tenant, ancestors, inHanmacFamily := h.resolveHanmacFamilyTenantClaimAncestry(ctx, tenantKey)
if !inHanmacFamily || tenant == nil {
if createdTenantClaims[tenantKey] {
delete(tenants, tenantKey)
}
continue
}
joinedTenantIDs = append(joinedTenantIDs, tenant.ID)
if !includeTenantDetails {
if createdTenantClaims[tenantKey] {
delete(tenants, tenantKey)
}
continue
}
tenantClaim["id"] = tenant.ID
tenantClaim["slug"] = tenant.Slug
tenantClaim["name"] = tenant.Name
tenantClaim["type"] = tenant.Type
tenantClaim["ancestors"] = ancestors
if len(ancestors) > 0 {
tenantClaim["parentTenantId"] = ancestors[0]["id"]
} else {
tenantClaim["parentTenantId"] = nil
}
delete(tenantClaim, "parentTenant")
if appointment := lookupTenantClaimAppointment(appointments, tenantKey, tenant); appointment != nil {
mergeTenantAppointmentClaim(tenantClaim, appointment)
}
if lead, ok := metadataBoolFromMap(tenantClaim, "lead", "isLead", "isOwner", "isManager"); ok {
tenantClaim["lead"] = lead
if lead {
leadTenantIDs = append(leadTenantIDs, tenant.ID)
}
}
if representative, ok := metadataBoolFromMap(tenantClaim, "representative", "isPrimary", "primary"); ok {
tenantClaim["representative"] = representative
tenantClaim["isPrimary"] = representative
}
tenants[tenantKey] = tenantClaim
}
if len(leadTenantIDs) > 0 {
claims["lead_tenants"] = uniqueSortedStrings(leadTenantIDs)
}
if len(joinedTenantIDs) > 0 {
claims["joined_tenants"] = mergeClaimStringList(claims["joined_tenants"], joinedTenantIDs)
}
if !includeTenantDetails {
if !hadTenantClaims {
delete(claims, "tenants")
}
delete(claims, "lead_tenants")
return claims
}
if len(tenants) > 0 {
claims["tenants"] = tenants
} else if !hadTenantClaims {
delete(claims, "tenants")
}
return claims
}
func tenantClaimScopeRequested(scopes []string) bool {
for _, scope := range scopes {
if strings.EqualFold(strings.TrimSpace(scope), "tenant") {
return true
}
}
return false
}
func mergeClaimStringList(raw any, values []string) []string {
merged := make([]string, 0, len(values))
switch current := raw.(type) {
case []string:
merged = append(merged, current...)
case []any:
for _, item := range current {
if s, ok := item.(string); ok {
merged = append(merged, s)
}
}
}
merged = append(merged, values...)
return uniqueSortedStrings(merged)
}
func tenantClaimAppointmentPrimaryKeys(appointments map[string]map[string]any) []string {
if len(appointments) == 0 {
return nil
}
seen := map[string]bool{}
keys := make([]string, 0, len(appointments))
for _, appointment := range appointments {
for _, key := range []string{"tenantId", "tenant_id", "tenantSlug", "tenant_slug"} {
value := tenantClaimString(appointment, key)
if value == "" || seen[value] {
continue
}
seen[value] = true
keys = append(keys, value)
break
}
}
sort.Strings(keys)
return keys
}
func tenantClaimAppointmentsFromTraits(traits map[string]any) map[string]map[string]any {
raw := rawAdditionalAppointments(traits)
if raw == nil {
return nil
}
items, ok := raw.([]any)
if !ok {
return nil
}
appointments := make(map[string]map[string]any)
for _, item := range items {
appointment, ok := item.(map[string]any)
if !ok {
continue
}
for _, key := range []string{"tenantId", "tenant_id", "tenantSlug", "tenant_slug"} {
if id := tenantClaimString(appointment, key); id != "" {
appointments[id] = appointment
}
}
}
return appointments
}
func rawAdditionalAppointments(traits map[string]any) any {
if traits == nil {
return nil
}
if raw, ok := traits["additionalAppointments"]; ok {
return raw
}
if metadata, ok := traits["metadata"].(map[string]any); ok {
return metadata["additionalAppointments"]
}
return nil
}
func lookupTenantClaimAppointment(appointments map[string]map[string]any, tenantKey string, tenant *domain.Tenant) map[string]any {
if len(appointments) == 0 {
return nil
}
for _, key := range []string{tenantKey, tenant.ID, tenant.Slug} {
if appointment, ok := appointments[key]; ok {
return appointment
}
}
return nil
}
func mergeTenantAppointmentClaim(tenantClaim map[string]any, appointment map[string]any) {
for _, key := range []string{"grade", "jobTitle", "job_title", "position"} {
if value := tenantClaimString(appointment, key); value != "" {
switch key {
case "job_title":
tenantClaim["jobTitle"] = value
default:
tenantClaim[key] = value
}
}
}
if lead, ok := metadataBoolFromMap(appointment, "lead", "isLead", "isOwner", "isManager"); ok {
tenantClaim["lead"] = lead
}
if representative, ok := metadataBoolFromMap(appointment, "representative", "isPrimary", "primary"); ok {
tenantClaim["representative"] = representative
tenantClaim["isPrimary"] = representative
}
}
func tenantClaimString(values map[string]any, key string) string {
raw, ok := values[key]
if !ok || raw == nil {
return ""
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value)
default:
return strings.TrimSpace(fmt.Sprint(value))
}
}
func (h *AuthHandler) resolveHanmacFamilyTenantClaimAncestry(ctx context.Context, identifier string) (*domain.Tenant, []map[string]any, bool) {
tenant, err := h.resolveTenantClaimTenant(ctx, identifier)
if err != nil || tenant == nil {
return nil, nil, false
}
if strings.EqualFold(tenant.Slug, hanmacFamilyTenantSlug) {
return tenant, []map[string]any{}, true
}
ancestors := make([]*domain.Tenant, 0)
visited := map[string]bool{tenant.ID: true}
current := tenant
for current.ParentID != nil && strings.TrimSpace(*current.ParentID) != "" {
parentID := strings.TrimSpace(*current.ParentID)
if visited[parentID] {
return tenant, tenantClaimAncestorSummaries(ancestors), false
}
visited[parentID] = true
parent, err := h.TenantService.GetTenant(ctx, parentID)
if err != nil || parent == nil {
return tenant, tenantClaimAncestorSummaries(ancestors), false
}
ancestors = append(ancestors, parent)
if strings.EqualFold(parent.Slug, hanmacFamilyTenantSlug) {
return tenant, tenantClaimAncestorSummaries(ancestors), true
}
current = parent
}
return tenant, tenantClaimAncestorSummaries(ancestors), false
}
func (h *AuthHandler) resolveTenantClaimTenant(ctx context.Context, identifier string) (*domain.Tenant, error) {
identifier = strings.TrimSpace(identifier)
if identifier == "" {
return nil, errors.New("tenant identifier is required")
}
if tenant, err := h.TenantService.GetTenant(ctx, identifier); err == nil && tenant != nil {
return tenant, nil
}
return h.TenantService.GetTenantBySlug(ctx, identifier)
}
func tenantClaimTenantSummary(tenant *domain.Tenant) map[string]any {
return map[string]any{
"id": tenant.ID,
"slug": tenant.Slug,
"name": tenant.Name,
"type": tenant.Type,
}
}
func tenantClaimAncestorSummaries(ancestors []*domain.Tenant) []map[string]any {
if len(ancestors) == 0 {
return []map[string]any{}
}
items := make([]map[string]any, 0, len(ancestors))
for i, ancestor := range ancestors {
item := tenantClaimTenantSummary(ancestor)
if i+1 < len(ancestors) {
item["parentTenantId"] = ancestors[i+1].ID
} else {
item["parentTenantId"] = nil
}
items = append(items, item)
}
return items
}
func applyConfiguredIDTokenClaims(baseClaims map[string]any, metadata map[string]interface{}) map[string]any {
if baseClaims == nil {
baseClaims = map[string]any{}
@@ -5535,19 +5840,14 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
identity, err := h.KratosAdmin.GetIdentity(c.Context(), consentRequest.Subject)
if err == nil && identity != nil {
currentSessionID := h.resolveCurrentSessionID(c)
var tenantID string
if consentRequest.Client.Metadata != nil {
if tid, ok := consentRequest.Client.Metadata["tenant_id"].(string); ok {
tenantID = tid
}
}
sessionClaims := composeOIDCSessionClaims(
consentRequest.Client,
identity.Traits,
consentRequest.RequestedScope,
tenantID,
representativeTenantIDFromTraits(identity.Traits),
currentSessionID,
)
sessionClaims = h.withHanmacFamilyTenantClaims(c.Context(), sessionClaims, identity.Traits, consentRequest.RequestedScope)
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
acceptResp, err := h.Hydra.AcceptConsentRequest(c.Context(), challenge, consentRequest, sessionClaims)
if err == nil {
@@ -5571,10 +5871,10 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
// 신원 정보를 가져오지 못하면 자동 승인을 진행할 수 없으므로 일반 흐름(UI 노출)으로 진행
} else {
currentSessionID := h.resolveCurrentSessionID(c)
var tenantID string
var clientTenantID string
if consentRequest.Client.Metadata != nil {
if tid, ok := consentRequest.Client.Metadata["tenant_id"].(string); ok {
tenantID = tid
clientTenantID = tid
}
}
@@ -5582,9 +5882,10 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
consentRequest.Client,
identity.Traits,
consentRequest.RequestedScope,
tenantID,
representativeTenantIDFromTraits(identity.Traits),
currentSessionID,
)
sessionClaims = h.withHanmacFamilyTenantClaims(c.Context(), sessionClaims, identity.Traits, consentRequest.RequestedScope)
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
// [Debug] 실제 생성된 클레임 출력 (요청사항 확인용 - 자동 승인 시)
@@ -5627,7 +5928,7 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
EventID: GenerateSecureToken(16),
Timestamp: time.Now(),
UserID: consentRequest.Subject,
TenantID: tenantID, // Uses the tenantID extracted earlier
TenantID: clientTenantID,
SessionID: currentSessionID,
EventType: "consent.granted",
Status: "success",
@@ -5761,11 +6062,10 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error {
c.Locals("login_id", loginID)
}
currentSessionID := h.resolveCurrentSessionID(c)
var tenantID string
var clientTenantID string
if consentRequest.Client.Metadata != nil {
if tid, ok := consentRequest.Client.Metadata["tenant_id"].(string); ok {
tenantID = tid
clientTenantID = tid
}
}
@@ -5773,9 +6073,10 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error {
consentRequest.Client,
identity.Traits,
consentRequest.RequestedScope,
tenantID,
representativeTenantIDFromTraits(identity.Traits),
currentSessionID,
)
sessionClaims = h.withHanmacFamilyTenantClaims(c.Context(), sessionClaims, identity.Traits, consentRequest.RequestedScope)
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
// [Debug] 실제 생성된 클레임 출력 (요청사항 확인용)
@@ -5821,7 +6122,7 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error {
EventID: GenerateSecureToken(16),
Timestamp: time.Now(),
UserID: consentRequest.Subject,
TenantID: tenantID, // [New] Add TenantID to AuditLog
TenantID: clientTenantID,
SessionID: currentSessionID,
EventType: "consent.granted",
Status: "success",

View File

@@ -9,6 +9,7 @@ import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -175,7 +176,11 @@ type AsyncMockTenantService struct {
}
func (m *AsyncMockTenantService) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
return nil, nil
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *AsyncMockTenantService) RequestRegistration(ctx context.Context, name, slug, description string, domainName string, adminEmail string) (*domain.Tenant, error) {
@@ -284,9 +289,20 @@ func TestSignup_AsyncDB_Isolation(t *testing.T) {
mockRedis.On("Delete", phoneKey).Return(nil)
// Tenant Mocks
validTenant := &domain.Tenant{ID: "t1", Slug: "example", Status: domain.TenantStatusActive}
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(validTenant, nil)
mockTenant.On("GetTenant", mock.Anything, "t1").Return(validTenant, nil)
personalTenant := &domain.Tenant{ID: "personal-t1", Slug: "personal-test", Type: domain.TenantTypePersonal, Status: domain.TenantStatusActive}
mockTenant.On("GetTenantByDomain", mock.Anything, "example.com").Return(nil, nil)
mockTenant.On(
"RegisterTenant",
mock.Anything,
"Personal - test@example.com",
mock.MatchedBy(func(slug string) bool { return strings.HasPrefix(slug, "personal-") }),
domain.TenantTypePersonal,
"Automatically provisioned personal tenant",
[]string(nil),
(*string)(nil),
"",
).Return(personalTenant, nil)
mockTenant.On("GetTenant", mock.Anything, "personal-t1").Return(personalTenant, nil)
// Kratos Mocks (Success)
mockIdp.On("CreateUser", mock.Anything, "Password123!").Return("new-user-uuid", nil)

View File

@@ -4,6 +4,7 @@ import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -35,10 +36,11 @@ func TestBuildOidcClaimsFromTraits_DynamicClaims(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "primary-tenant-999", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.NotNil(t, claims["tenants"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999") // Should contain primary
@@ -48,11 +50,11 @@ func TestBuildOidcClaimsFromTraits_DynamicClaims(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, scopes, "tenant-1")
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "tenant-1", claims["tenant_id"]) // Dynamic tenant injection overwrites top-level for this context
assert.Equal(t, "개발팀", claims["department"])
assert.Equal(t, "선임", claims["grade"])
assert.Equal(t, "tenant-1", claims["tenant_id"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.NotNil(t, claims["tenants"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
@@ -63,10 +65,10 @@ func TestBuildOidcClaimsFromTraits_DynamicClaims(t *testing.T) {
assert.Equal(t, "user@baron.com", claims["email"])
assert.Equal(t, "홍길동", claims["name"])
assert.Equal(t, "tenant-2", claims["tenant_id"])
assert.Equal(t, "재무팀", claims["department"])
assert.Equal(t, "팀장", claims["grade"])
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.NotNil(t, claims["tenants"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
@@ -78,10 +80,53 @@ func TestBuildOidcClaimsFromTraits_DynamicClaims(t *testing.T) {
assert.Nil(t, claims["department"])
assert.Nil(t, claims["grade"])
assert.NotNil(t, claims["tenants"])
assert.Nil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
t.Run("Tenant scope includes detailed tenant metadata", func(t *testing.T) {
claims := buildOidcClaimsFromTraits(traits, []string{"openid", "profile", "tenant"}, "tenant-1")
assert.Equal(t, "tenant-1", claims["tenant_id"])
assert.Equal(t, "개발팀", claims["department"])
assert.Equal(t, "선임", claims["grade"])
assert.NotNil(t, claims["tenants"])
assert.Contains(t, claims["joined_tenants"], "tenant-1")
assert.Contains(t, claims["joined_tenants"], "tenant-2")
assert.Contains(t, claims["joined_tenants"], "primary-tenant-999")
})
}
func TestRepresentativeTenantIDFromTraits(t *testing.T) {
t.Run("explicit tenant_id wins", func(t *testing.T) {
traits := map[string]any{
"tenant_id": "01970f0a-5c28-74d8-a73a-f6e9e9a7b210",
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", "isPrimary": true},
},
}
assert.Equal(t, "01970f0a-5c28-74d8-a73a-f6e9e9a7b210", representativeTenantIDFromTraits(traits))
})
t.Run("primary appointment wins when tenant_id is absent", func(t *testing.T) {
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630", "representative": true},
},
}
assert.Equal(t, "01970f0c-8c44-7069-9f20-7d28c0b8e630", representativeTenantIDFromTraits(traits))
})
t.Run("first appointment is fallback", func(t *testing.T) {
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{"tenantId": "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"},
map[string]any{"tenantId": "01970f0c-8c44-7069-9f20-7d28c0b8e630"},
},
}
assert.Equal(t, "01970f0b-3448-7bb8-bdc7-16b6a1d2e661", representativeTenantIDFromTraits(traits))
})
}
func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
@@ -92,7 +137,7 @@ func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-dynamic" {
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"challenge": "challenge-dynamic",
"requested_scope": []string{"openid", "profile"},
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-123",
"client": map[string]interface{}{
"client_id": "client-app",
@@ -162,7 +207,7 @@ func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
reqBody, _ := json.Marshal(map[string]interface{}{
"consent_challenge": "challenge-dynamic",
"grant_scope": []string{"openid", "profile"},
"grant_scope": []string{"openid", "profile", "tenant"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
@@ -179,6 +224,293 @@ func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
assert.Equal(t, "Architect", capturedClaims["position"])
}
func TestAcceptConsentRequest_UsesRepresentativeTenantIDInsteadOfClientTenantContext(t *testing.T) {
var capturedClaims map[string]interface{}
representativeTenantID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
rpContextTenantID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"challenge": "challenge-representative-tenant",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-representative",
"client": map[string]interface{}{
"client_id": "client-app",
"metadata": map[string]interface{}{
"tenant_id": rpContextTenantID,
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-representative-tenant" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]interface{}
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]interface{}); ok {
capturedClaims = session["id_token"].(map[string]interface{})
}
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-representative").Return(&service.KratosIdentity{
ID: "user-representative",
Traits: map[string]interface{}{
"email": "user@test.com",
"name": "Test User",
"additionalAppointments": []interface{}{
map[string]interface{}{"tenantId": representativeTenantID, "isPrimary": true},
map[string]interface{}{"tenantId": rpContextTenantID},
},
},
}, nil)
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]interface{}{
"consent_challenge": "challenge-representative-tenant",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
assert.Equal(t, representativeTenantID, capturedClaims["tenant_id"])
assert.Contains(t, capturedClaims["joined_tenants"], representativeTenantID)
assert.Contains(t, capturedClaims["joined_tenants"], rpContextTenantID)
assert.Nil(t, capturedClaims["tenants"])
}
func TestAcceptConsentRequest_IncludesHanmacFamilyTenantClaimDetails(t *testing.T) {
var capturedClaims map[string]interface{}
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"challenge": "challenge-hanmac-tenant-claim",
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-hanmac",
"client": map[string]interface{}{
"client_id": "hanmac-rp",
"metadata": map[string]interface{}{
"tenant_id": deptID,
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-hanmac-tenant-claim" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]interface{}
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]interface{}); ok {
capturedClaims = session["id_token"].(map[string]interface{})
}
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-hanmac").Return(&service.KratosIdentity{
ID: "user-hanmac",
Traits: map[string]interface{}{
"email": "hanmac-user@example.com",
"name": "한맥 사용자",
"additionalAppointments": []interface{}{
map[string]interface{}{
"tenantId": deptID,
"isPrimary": true,
"isOwner": true,
"grade": "책임",
"jobTitle": "기술기획",
"position": "팀장",
},
map[string]interface{}{
"tenantId": secondDeptID,
"isPrimary": false,
"isOwner": false,
"grade": "선임",
"jobTitle": "품질관리",
"position": "파트원",
},
},
},
}, nil)
mockTenantSvc := new(MockTenantService)
mockTenantSvc.On("ListJoinedTenants", mock.Anything, "user-hanmac").Return([]domain.Tenant{}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
ID: deptID,
Slug: "tech-planning",
Name: "기술기획팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
ID: secondDeptID,
Slug: "quality",
Name: "품질관리팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
Name: "한맥기술",
Type: domain.TenantTypeCompany,
ParentID: &rootID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
ID: rootID,
Slug: "hanmac-family",
Name: "한맥가족",
Type: domain.TenantTypeCompanyGroup,
}, nil)
h.TenantService = mockTenantSvc
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]interface{}{
"consent_challenge": "challenge-hanmac-tenant-claim",
"grant_scope": []string{"openid", "profile", "tenant"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
assert.Equal(t, []interface{}{deptID}, capturedClaims["lead_tenants"])
assert.ElementsMatch(t, []interface{}{deptID, secondDeptID}, capturedClaims["joined_tenants"])
tenants := capturedClaims["tenants"].(map[string]interface{})
dept := tenants[deptID].(map[string]interface{})
assert.Equal(t, true, dept["lead"])
assert.Equal(t, true, dept["representative"])
assert.Equal(t, "책임", dept["grade"])
assert.Equal(t, "기술기획", dept["jobTitle"])
assert.Equal(t, "팀장", dept["position"])
assert.Equal(t, companyID, dept["parentTenantId"])
assert.NotContains(t, dept, "parentTenant")
ancestors := dept["ancestors"].([]interface{})
assert.Len(t, ancestors, 2)
companyAncestor := ancestors[0].(map[string]interface{})
assert.Equal(t, companyID, companyAncestor["id"])
assert.Equal(t, "hanmac", companyAncestor["slug"])
assert.Equal(t, rootID, companyAncestor["parentTenantId"])
assert.NotContains(t, companyAncestor, "parentTenant")
rootAncestor := ancestors[1].(map[string]interface{})
assert.Equal(t, rootID, rootAncestor["id"])
assert.Equal(t, "hanmac-family", rootAncestor["slug"])
assert.Contains(t, rootAncestor, "parentTenantId")
assert.Nil(t, rootAncestor["parentTenantId"])
assert.NotContains(t, rootAncestor, "parentTenant")
secondDept := tenants[secondDeptID].(map[string]interface{})
assert.Equal(t, false, secondDept["lead"])
assert.Equal(t, false, secondDept["representative"])
assert.Equal(t, "선임", secondDept["grade"])
assert.Equal(t, "품질관리", secondDept["jobTitle"])
assert.Equal(t, "파트원", secondDept["position"])
assert.Equal(t, companyID, secondDept["parentTenantId"])
}
func TestWithHanmacFamilyTenantClaims_DefaultClaimsOnlyWithoutTenantScope(t *testing.T) {
deptID := "01970f0a-5c28-74d8-a73a-f6e9e9a7b210"
secondDeptID := "01970f0b-3448-7bb8-bdc7-16b6a1d2e661"
companyID := "01970f08-91da-7286-bd19-882fb98d1f2c"
rootID := "01970f07-4f01-7d9a-a71e-b53ad508f345"
mockTenantSvc := new(MockTenantService)
mockTenantSvc.On("GetTenant", mock.Anything, deptID).Return(&domain.Tenant{
ID: deptID,
Slug: "tech-planning",
Name: "기술기획팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, secondDeptID).Return(&domain.Tenant{
ID: secondDeptID,
Slug: "quality",
Name: "품질관리팀",
Type: domain.TenantTypeUserGroup,
ParentID: &companyID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
Name: "한맥기술",
Type: domain.TenantTypeCompany,
ParentID: &rootID,
}, nil)
mockTenantSvc.On("GetTenant", mock.Anything, rootID).Return(&domain.Tenant{
ID: rootID,
Slug: "hanmac-family",
Name: "한맥가족",
Type: domain.TenantTypeCompanyGroup,
}, nil)
h := &AuthHandler{TenantService: mockTenantSvc}
claims := map[string]any{"tenant_id": deptID}
traits := map[string]any{
"additionalAppointments": []any{
map[string]any{
"tenantId": deptID,
"isPrimary": true,
"isOwner": true,
"grade": "책임",
},
map[string]any{
"tenantId": secondDeptID,
"grade": "선임",
},
},
}
claims = h.withHanmacFamilyTenantClaims(context.Background(), claims, traits, []string{"openid", "profile"})
assert.Equal(t, deptID, claims["tenant_id"])
assert.ElementsMatch(t, []string{deptID, secondDeptID}, claims["joined_tenants"])
assert.NotContains(t, claims, "tenants")
assert.NotContains(t, claims, "lead_tenants")
}
func TestAcceptConsentRequest_IncludesRPProfileClaims(t *testing.T) {
var capturedClaims map[string]interface{}
@@ -186,7 +518,7 @@ func TestAcceptConsentRequest_IncludesRPProfileClaims(t *testing.T) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"challenge": "challenge-rp-profile",
"requested_scope": []string{"openid", "profile"},
"requested_scope": []string{"openid", "profile", "tenant"},
"subject": "user-123",
"client": map[string]interface{}{
"client_id": "client-app",
@@ -284,7 +616,7 @@ func TestGetConsentRequest_Skip_DynamicClaims(t *testing.T) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-skip-dynamic" {
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"challenge": "challenge-skip-dynamic",
"requested_scope": []string{"openid", "profile"},
"requested_scope": []string{"openid", "profile", "tenant"},
"skip": true,
"subject": "user-456",
"client": map[string]interface{}{

View File

@@ -0,0 +1,155 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"context"
"errors"
"fmt"
"strings"
"github.com/google/uuid"
)
func representativeTenantIDFromTraits(traits map[string]any) string {
if value := tenantClaimString(traits, "tenant_id"); value != "" {
return value
}
if value := tenantClaimString(traits, "primaryTenantId"); value != "" {
return value
}
if metadata, ok := traits["metadata"].(map[string]any); ok {
if value := tenantClaimString(metadata, "primaryTenantId"); value != "" {
return value
}
}
appointments := tenantAssignmentAppointmentsFromTraits(traits)
for _, appointment := range appointments {
if tenantAssignmentBool(appointment, "isPrimary", "primary", "representative", "isRepresentative") {
if value := tenantAssignmentTenantID(appointment); value != "" {
return value
}
}
}
for _, appointment := range appointments {
if value := tenantAssignmentTenantID(appointment); value != "" {
return value
}
}
for _, tenantID := range tenantNamespaceIDsFromTraits(traits) {
return tenantID
}
return ""
}
func joinedTenantIDsFromTraits(traits map[string]any, representativeTenantID string) []string {
values := make([]string, 0)
if representativeTenantID != "" {
values = append(values, representativeTenantID)
}
if value := tenantClaimString(traits, "tenant_id"); value != "" {
values = append(values, value)
}
for _, appointment := range tenantAssignmentAppointmentsFromTraits(traits) {
if value := tenantAssignmentTenantID(appointment); value != "" {
values = append(values, value)
}
}
values = append(values, tenantNamespaceIDsFromTraits(traits)...)
return uniqueSortedStrings(values)
}
func tenantAssignmentAppointmentsFromTraits(traits map[string]any) []map[string]any {
raw := rawAdditionalAppointments(traits)
switch values := raw.(type) {
case []any:
appointments := make([]map[string]any, 0, len(values))
for _, item := range values {
if appointment, ok := item.(map[string]any); ok {
appointments = append(appointments, appointment)
}
}
return appointments
case []map[string]any:
return values
default:
return nil
}
}
func tenantAssignmentTenantID(appointment map[string]any) string {
for _, key := range []string{"tenantId", "tenant_id"} {
if value := tenantClaimString(appointment, key); value != "" {
return value
}
}
return ""
}
func tenantAssignmentBool(values map[string]any, keys ...string) bool {
for _, key := range keys {
raw, ok := values[key]
if !ok || raw == nil {
continue
}
switch value := raw.(type) {
case bool:
if value {
return true
}
case string:
normalized := strings.ToLower(strings.TrimSpace(value))
if normalized == "true" || normalized == "1" || normalized == "yes" {
return true
}
}
}
return false
}
func tenantNamespaceIDsFromTraits(traits map[string]any) []string {
if traits == nil {
return nil
}
ids := make([]string, 0)
for key, value := range traits {
if key == "" || key == "metadata" {
continue
}
switch value.(type) {
case map[string]any:
ids = append(ids, key)
}
}
return uniqueSortedStrings(ids)
}
func createPersonalTenantForUser(ctx context.Context, tenantService service.TenantService, email string) (*domain.Tenant, error) {
if tenantService == nil {
return nil, errors.New("tenant service unavailable")
}
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
if normalizedEmail == "" {
normalizedEmail = "user"
}
slug := "personal-" + strings.ReplaceAll(uuid.NewString(), "-", "")
tenant, err := tenantService.RegisterTenant(
ctx,
fmt.Sprintf("Personal - %s", normalizedEmail),
slug,
domain.TenantTypePersonal,
"Automatically provisioned personal tenant",
nil,
nil,
"",
)
if err != nil {
return nil, err
}
if tenant == nil {
return nil, errors.New("personal tenant not created")
}
return tenant, nil
}

View File

@@ -114,6 +114,60 @@ type tenantCSVRecord struct {
OrgUnitType string
}
type orgContextTenant struct {
ID string `json:"id"`
Type string `json:"type"`
Name string `json:"name"`
Slug string `json:"slug"`
ParentID *string `json:"parentId"`
Status string `json:"status"`
Description string `json:"description"`
Domains []string `json:"domains,omitempty"`
MemberCount int64 `json:"memberCount"`
Visibility string `json:"visibility"`
OrgUnitType string `json:"orgUnitType,omitempty"`
Config domain.JSONMap `json:"config,omitempty"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}
type orgContextUser struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
Status string `json:"status"`
TenantIDs []string `json:"tenantIds"`
TenantSlugs []string `json:"tenantSlugs"`
Department string `json:"department,omitempty"`
Grade string `json:"grade,omitempty"`
Position string `json:"position,omitempty"`
JobTitle string `json:"jobTitle,omitempty"`
Metadata domain.JSONMap `json:"metadata,omitempty"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}
type orgContextTreeNode struct {
orgContextTenant
DirectUserIDs []string `json:"directUserIds"`
Children []orgContextTreeNode `json:"children"`
}
type orgContextScope struct {
TenantID string `json:"tenantId"`
TenantSlug string `json:"tenantSlug"`
}
type orgContextResponse struct {
SchemaVersion string `json:"schemaVersion"`
IssuedAt string `json:"issuedAt"`
Scope orgContextScope `json:"scope"`
Tree *orgContextTreeNode `json:"tree"`
Tenants []orgContextTenant `json:"tenants"`
Users []orgContextUser `json:"users"`
}
func (h *TenantHandler) RegisterTenantPublic(c *fiber.Ctx) error {
var req struct {
Name string `json:"name"`
@@ -271,10 +325,12 @@ func (h *TenantHandler) ListTenants(c *fiber.Ctx) error {
}
func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
tenants, _, err := h.Service.ListTenants(c.Context(), 10000, 0, "")
parentID := strings.TrimSpace(c.Query("parentId"))
allTenants, _, err := h.Service.ListTenants(c.Context(), 10000, 0, "")
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
tenants := filterTenantCSVDescendants(allTenants, parentID)
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
@@ -286,8 +342,8 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
} else if err := writer.Write([]string{"name", "type", "parent_tenant_slug", "slug", "memo", "email_domain", "visibility", "org_unit_type"}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
slugByID := make(map[string]string, len(tenants))
for _, tenant := range tenants {
slugByID := make(map[string]string, len(allTenants))
for _, tenant := range allTenants {
slugByID[tenant.ID] = tenant.Slug
}
for _, tenant := range tenants {
@@ -343,6 +399,41 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
return c.Send(buf.Bytes())
}
func filterTenantCSVDescendants(tenants []domain.Tenant, parentID string) []domain.Tenant {
parentID = strings.TrimSpace(parentID)
if parentID == "" {
return tenants
}
descendantIDs := map[string]bool{}
frontier := map[string]bool{parentID: true}
for len(frontier) > 0 {
next := map[string]bool{}
for _, tenant := range tenants {
if tenant.ParentID == nil {
continue
}
if !frontier[strings.TrimSpace(*tenant.ParentID)] {
continue
}
if descendantIDs[tenant.ID] {
continue
}
descendantIDs[tenant.ID] = true
next[tenant.ID] = true
}
frontier = next
}
filtered := make([]domain.Tenant, 0, len(descendantIDs))
for _, tenant := range tenants {
if descendantIDs[tenant.ID] {
filtered = append(filtered, tenant)
}
}
return filtered
}
func (h *TenantHandler) ImportTenantsCSV(c *fiber.Ctx) error {
reader, err := tenantCSVReaderFromRequest(c)
if err != nil {
@@ -1818,6 +1909,272 @@ func mapTenantSummary(t domain.Tenant) tenantSummary {
}
}
func (h *TenantHandler) GetOrgContext(c *fiber.Ctx) error {
if c.Locals("apiKeyName") == nil {
return errorJSON(c, fiber.StatusUnauthorized, "api key authentication is required")
}
if h.Service == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "tenant service is not configured")
}
allTenants, _, err := h.Service.ListTenants(c.Context(), 10000, 0, "")
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
rootSlug := strings.TrimSpace(c.Query("tenantSlug"))
if rootSlug == "" {
rootSlug = "hanmac-family"
}
root, ok := findOrgContextTenantBySlug(allTenants, rootSlug)
if !ok {
return errorJSON(c, fiber.StatusNotFound, "tenant slug not found")
}
scopedTenants := filterOrgContextSubtree(allTenants, root.ID)
contextTenants := make([]orgContextTenant, 0, len(scopedTenants))
tenantIDs := make([]string, 0, len(scopedTenants))
tenantSlugs := make([]string, 0, len(scopedTenants))
tenantByID := make(map[string]orgContextTenant, len(scopedTenants))
tenantBySlug := make(map[string]orgContextTenant, len(scopedTenants))
for _, tenant := range scopedTenants {
summary := mapOrgContextTenant(tenant)
contextTenants = append(contextTenants, summary)
tenantIDs = append(tenantIDs, tenant.ID)
tenantSlugs = append(tenantSlugs, tenant.Slug)
tenantByID[tenant.ID] = summary
tenantBySlug[strings.ToLower(tenant.Slug)] = summary
}
includeUsers := !strings.EqualFold(strings.TrimSpace(c.Query("includeUsers")), "false")
contextUsers := []orgContextUser{}
if includeUsers {
if h.UserRepo == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "user repository is not configured")
}
contextUsers, err = h.loadOrgContextUsers(c.Context(), tenantIDs, tenantSlugs, tenantByID, tenantBySlug)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
}
directUserIDsByTenantID := make(map[string][]string)
for _, user := range contextUsers {
for _, tenantID := range user.TenantIDs {
directUserIDsByTenantID[tenantID] = append(directUserIDsByTenantID[tenantID], user.ID)
}
}
tree := buildOrgContextTree(root.ID, scopedTenants, tenantByID, directUserIDsByTenantID)
return c.JSON(orgContextResponse{
SchemaVersion: "baron.org-context.v1",
IssuedAt: time.Now().UTC().Format(time.RFC3339),
Scope: orgContextScope{
TenantID: root.ID,
TenantSlug: root.Slug,
},
Tree: tree,
Tenants: contextTenants,
Users: contextUsers,
})
}
func (h *TenantHandler) loadOrgContextUsers(ctx context.Context, tenantIDs, tenantSlugs []string, tenantByID, tenantBySlug map[string]orgContextTenant) ([]orgContextUser, error) {
usersByID, err := h.UserRepo.FindByTenantIDs(ctx, tenantIDs)
if err != nil {
return nil, err
}
usersBySlug, err := h.UserRepo.FindByCompanyCodes(ctx, tenantSlugs)
if err != nil {
return nil, err
}
seen := make(map[string]bool)
contextUsers := make([]orgContextUser, 0, len(usersByID)+len(usersBySlug))
for _, user := range append(usersByID, usersBySlug...) {
if seen[user.ID] || user.Status != domain.UserStatusActive {
continue
}
mapped, ok := mapOrgContextUser(user, tenantByID, tenantBySlug)
if !ok {
continue
}
seen[user.ID] = true
contextUsers = append(contextUsers, mapped)
}
return contextUsers, nil
}
func findOrgContextTenantBySlug(tenants []domain.Tenant, slug string) (domain.Tenant, bool) {
normalized := strings.ToLower(strings.TrimSpace(slug))
for _, tenant := range tenants {
if strings.ToLower(tenant.Slug) == normalized && isOrgContextTenantType(tenant) {
return tenant, true
}
}
return domain.Tenant{}, false
}
func isOrgContextTenantType(tenant domain.Tenant) bool {
switch strings.ToUpper(tenant.Type) {
case domain.TenantTypeCompanyGroup, domain.TenantTypeCompany, domain.TenantTypeOrganization, domain.TenantTypeUserGroup:
return true
default:
return false
}
}
func filterOrgContextSubtree(tenants []domain.Tenant, rootID string) []domain.Tenant {
descendantIDs := map[string]bool{rootID: true}
frontier := map[string]bool{rootID: true}
for len(frontier) > 0 {
next := map[string]bool{}
for _, tenant := range tenants {
if tenant.ParentID == nil || !frontier[*tenant.ParentID] || descendantIDs[tenant.ID] {
continue
}
descendantIDs[tenant.ID] = true
next[tenant.ID] = true
}
frontier = next
}
excludedIDs := map[string]bool{}
for _, tenant := range tenants {
if descendantIDs[tenant.ID] && tenantVisibility(tenant.Config) == "private" {
excludedIDs[tenant.ID] = true
}
}
changed := true
for changed {
changed = false
for _, tenant := range tenants {
if tenant.ParentID == nil || !descendantIDs[tenant.ID] || excludedIDs[tenant.ID] {
continue
}
if excludedIDs[*tenant.ParentID] {
excludedIDs[tenant.ID] = true
changed = true
}
}
}
filtered := make([]domain.Tenant, 0, len(descendantIDs))
for _, tenant := range tenants {
if descendantIDs[tenant.ID] && !excludedIDs[tenant.ID] && isOrgContextTenantType(tenant) {
filtered = append(filtered, tenant)
}
}
return filtered
}
func mapOrgContextTenant(tenant domain.Tenant) orgContextTenant {
domains := make([]string, 0, len(tenant.Domains))
for _, domain := range tenant.Domains {
domains = append(domains, domain.Domain)
}
visibility, orgUnitType := tenantCSVOrgConfigValues(tenant.Config)
return orgContextTenant{
ID: tenant.ID,
Type: tenant.Type,
Name: tenant.Name,
Slug: tenant.Slug,
ParentID: tenant.ParentID,
Status: tenant.Status,
Description: tenant.Description,
Domains: domains,
Visibility: visibility,
OrgUnitType: orgUnitType,
Config: tenant.Config,
CreatedAt: tenant.CreatedAt.Format(time.RFC3339),
UpdatedAt: tenant.UpdatedAt.Format(time.RFC3339),
}
}
func mapOrgContextUser(user domain.User, tenantByID, tenantBySlug map[string]orgContextTenant) (orgContextUser, bool) {
matchedTenants := make([]orgContextTenant, 0, 2)
seenTenants := map[string]bool{}
addTenant := func(tenant orgContextTenant, ok bool) {
if !ok || seenTenants[tenant.ID] {
return
}
seenTenants[tenant.ID] = true
matchedTenants = append(matchedTenants, tenant)
}
if user.TenantID != nil {
addTenant(tenantByID[*user.TenantID], tenantByID[*user.TenantID].ID != "")
}
if user.Tenant != nil {
addTenant(tenantByID[user.Tenant.ID], tenantByID[user.Tenant.ID].ID != "")
addTenant(tenantBySlug[strings.ToLower(user.Tenant.Slug)], tenantBySlug[strings.ToLower(user.Tenant.Slug)].ID != "")
}
if user.CompanyCode != "" {
addTenant(tenantBySlug[strings.ToLower(strings.TrimSpace(user.CompanyCode))], tenantBySlug[strings.ToLower(strings.TrimSpace(user.CompanyCode))].ID != "")
}
for _, companyCode := range user.CompanyCodes {
addTenant(tenantBySlug[strings.ToLower(strings.TrimSpace(companyCode))], tenantBySlug[strings.ToLower(strings.TrimSpace(companyCode))].ID != "")
}
if len(matchedTenants) == 0 {
return orgContextUser{}, false
}
tenantIDs := make([]string, 0, len(matchedTenants))
tenantSlugs := make([]string, 0, len(matchedTenants))
for _, tenant := range matchedTenants {
tenantIDs = append(tenantIDs, tenant.ID)
tenantSlugs = append(tenantSlugs, tenant.Slug)
}
return orgContextUser{
ID: user.ID,
Email: user.Email,
Name: user.Name,
Role: user.Role,
Status: user.Status,
TenantIDs: tenantIDs,
TenantSlugs: tenantSlugs,
Department: user.Department,
Grade: user.Grade,
Position: user.Position,
JobTitle: user.JobTitle,
Metadata: user.Metadata,
CreatedAt: user.CreatedAt.Format(time.RFC3339),
UpdatedAt: user.UpdatedAt.Format(time.RFC3339),
}, true
}
func buildOrgContextTree(rootID string, tenants []domain.Tenant, tenantByID map[string]orgContextTenant, directUserIDsByTenantID map[string][]string) *orgContextTreeNode {
childrenByParentID := make(map[string][]domain.Tenant)
for _, tenant := range tenants {
if tenant.ParentID == nil {
continue
}
childrenByParentID[*tenant.ParentID] = append(childrenByParentID[*tenant.ParentID], tenant)
}
var build func(tenantID string) *orgContextTreeNode
build = func(tenantID string) *orgContextTreeNode {
tenant, ok := tenantByID[tenantID]
if !ok {
return nil
}
node := &orgContextTreeNode{
orgContextTenant: tenant,
DirectUserIDs: directUserIDsByTenantID[tenantID],
Children: []orgContextTreeNode{},
}
for _, child := range childrenByParentID[tenantID] {
childNode := build(child.ID)
if childNode != nil {
node.Children = append(node.Children, *childNode)
}
}
return node
}
return build(rootID)
}
func (h *TenantHandler) countTenantMembersFromProjection(ctx context.Context, tenants []domain.Tenant) (map[string]int64, error) {
counts := make(map[string]int64, len(tenants))
for _, tenant := range tenants {

View File

@@ -13,6 +13,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
@@ -213,6 +214,13 @@ func (m *MockUserProjectionRepoForHandler) MarkFailed(ctx context.Context, syncE
return args.Error(0)
}
func toJSONString(t *testing.T, value any) string {
t.Helper()
raw, err := json.Marshal(value)
require.NoError(t, err)
return string(raw)
}
func TestTenantHandler_CreateTenant(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
@@ -360,6 +368,121 @@ func TestTenantHandler_ListTenants(t *testing.T) {
}
}
func TestTenantHandler_GetOrgContextJSONDefaultsToHanmacFamilyForApiKey(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
mockUsers := new(MockUserRepoForHandler)
h := &TenantHandler{Service: mockSvc, UserRepo: mockUsers}
app.Use(func(c *fiber.Ctx) error {
c.Locals("apiKeyName", "orgfront-ssot-client")
return c.Next()
})
app.Get("/org-context", h.GetOrgContext)
now := time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC)
parent := func(id string) *string { return &id }
tenants := []domain.Tenant{
{ID: "root-other", Type: domain.TenantTypeCompanyGroup, Name: "다른그룹", Slug: "other-family", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: "group-hanmac-family", Type: domain.TenantTypeCompanyGroup, Name: "한맥가족", Slug: "hanmac-family", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: "company-hanmac", Type: domain.TenantTypeCompany, ParentID: parent("group-hanmac-family"), Name: "한맥기술", Slug: "hanmac", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: "dept-platform", Type: domain.TenantTypeUserGroup, ParentID: parent("company-hanmac"), Name: "플랫폼실", Slug: "platform", Status: domain.TenantStatusActive, Config: domain.JSONMap{"orgUnitType": "실"}, CreatedAt: now, UpdatedAt: now},
{ID: "team-sso", Type: domain.TenantTypeUserGroup, ParentID: parent("dept-platform"), Name: "SSO팀", Slug: "sso", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: "private-team", Type: domain.TenantTypeUserGroup, ParentID: parent("company-hanmac"), Name: "비공개", Slug: "private-team", Status: domain.TenantStatusActive, Config: domain.JSONMap{"visibility": "private"}, CreatedAt: now, UpdatedAt: now},
}
usersByTenantID := []domain.User{
{ID: "user-platform-lead", Email: "lead@example.com", Name: "플랫폼 리드", Status: domain.UserStatusActive, TenantID: parent("dept-platform"), CompanyCode: "platform", Grade: "책임", Position: "실장", CreatedAt: now, UpdatedAt: now},
}
usersBySlug := []domain.User{
{ID: "user-sso-member", Email: "member@example.com", Name: "SSO 구성원", Status: domain.UserStatusActive, CompanyCode: "sso", Grade: "선임", CreatedAt: now, UpdatedAt: now},
}
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil)
mockUsers.On("FindByTenantIDs", mock.Anything, []string{"group-hanmac-family", "company-hanmac", "dept-platform", "team-sso"}).Return(usersByTenantID, nil)
mockUsers.On("FindByCompanyCodes", mock.Anything, []string{"hanmac-family", "hanmac", "platform", "sso"}).Return(usersBySlug, nil)
req := httptest.NewRequest(http.MethodGet, "/org-context", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
require.Equal(t, "baron.org-context.v1", got["schemaVersion"])
scope := got["scope"].(map[string]any)
require.Equal(t, "group-hanmac-family", scope["tenantId"])
require.Equal(t, "hanmac-family", scope["tenantSlug"])
tenantsPayload := got["tenants"].([]any)
require.Len(t, tenantsPayload, 4)
require.Equal(t, "group-hanmac-family", tenantsPayload[0].(map[string]any)["id"])
require.Equal(t, "company-hanmac", tenantsPayload[1].(map[string]any)["id"])
require.Equal(t, "dept-platform", tenantsPayload[2].(map[string]any)["id"])
require.Equal(t, "team-sso", tenantsPayload[3].(map[string]any)["id"])
usersPayload := got["users"].([]any)
require.Len(t, usersPayload, 2)
require.Equal(t, "user-platform-lead", usersPayload[0].(map[string]any)["id"])
require.Equal(t, []any{"dept-platform"}, usersPayload[0].(map[string]any)["tenantIds"])
require.Equal(t, "user-sso-member", usersPayload[1].(map[string]any)["id"])
tree := got["tree"].(map[string]any)
require.Equal(t, "group-hanmac-family", tree["id"])
require.NotContains(t, toJSONString(t, got), "private-team")
require.NotContains(t, toJSONString(t, got), "root-other")
}
func TestTenantHandler_GetOrgContextJSONScopesByTenantSlug(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
mockUsers := new(MockUserRepoForHandler)
h := &TenantHandler{Service: mockSvc, UserRepo: mockUsers}
app.Use(func(c *fiber.Ctx) error {
c.Locals("apiKeyName", "orgfront-ssot-client")
return c.Next()
})
app.Get("/org-context", h.GetOrgContext)
now := time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC)
parent := func(id string) *string { return &id }
tenants := []domain.Tenant{
{ID: "group-hanmac-family", Type: domain.TenantTypeCompanyGroup, Name: "한맥가족", Slug: "hanmac-family", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: "company-hanmac", Type: domain.TenantTypeCompany, ParentID: parent("group-hanmac-family"), Name: "한맥기술", Slug: "hanmac", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: "dept-platform", Type: domain.TenantTypeUserGroup, ParentID: parent("company-hanmac"), Name: "플랫폼실", Slug: "platform", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
{ID: "company-other", Type: domain.TenantTypeCompany, ParentID: parent("group-hanmac-family"), Name: "다른회사", Slug: "other", Status: domain.TenantStatusActive, CreatedAt: now, UpdatedAt: now},
}
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil)
mockUsers.On("FindByTenantIDs", mock.Anything, []string{"company-hanmac", "dept-platform"}).Return([]domain.User{}, nil)
mockUsers.On("FindByCompanyCodes", mock.Anything, []string{"hanmac", "platform"}).Return([]domain.User{}, nil)
req := httptest.NewRequest(http.MethodGet, "/org-context?tenantSlug=hanmac", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
scope := got["scope"].(map[string]any)
require.Equal(t, "company-hanmac", scope["tenantId"])
require.Equal(t, "hanmac", scope["tenantSlug"])
require.Contains(t, toJSONString(t, got), "dept-platform")
require.NotContains(t, toJSONString(t, got), "company-other")
}
func TestTenantHandler_GetOrgContextJSONRequiresApiKey(t *testing.T) {
app := fiber.New()
h := &TenantHandler{}
app.Get("/org-context", h.GetOrgContext)
req := httptest.NewRequest(http.MethodGet, "/org-context", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
func TestTenantHandler_ListTenantsReturnsServiceUnavailableWhenProjectionStatusFails(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
@@ -518,6 +641,62 @@ func TestTenantHandler_ExportTenantsCSV_OmitsIDsAndUsesParentSlug(t *testing.T)
mockSvc.AssertExpectations(t)
}
func TestTenantHandler_ExportTenantsCSV_FiltersDescendantsByParentIDWithIDs(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
h := &TenantHandler{Service: mockSvc}
app.Get("/tenants/export", h.ExportTenantsCSV)
parentID := "11111111-2222-4333-8444-555555555555"
childID := "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"
grandchildID := "bbbbbbbb-cccc-4ddd-8eee-ffffffffffff"
unrelatedID := "cccccccc-dddd-4eee-8fff-111111111111"
tenants := []domain.Tenant{
{
ID: parentID,
Name: "Parent Org",
Type: domain.TenantTypeCompany,
Slug: "parent-org",
},
{
ID: childID,
Name: "Child Org",
Type: domain.TenantTypeOrganization,
ParentID: &parentID,
Slug: "child-org",
},
{
ID: grandchildID,
Name: "Leaf Team",
Type: domain.TenantTypeUserGroup,
ParentID: &childID,
Slug: "leaf-team",
},
{
ID: unrelatedID,
Name: "Unrelated Org",
Type: domain.TenantTypeOrganization,
Slug: "unrelated-org",
},
}
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil)
req := httptest.NewRequest("GET", "/tenants/export?includeIds=true&parentId="+parentID, nil)
resp, _ := app.Test(req)
body, _ := io.ReadAll(resp.Body)
text := string(body)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Contains(t, text, "tenant_id,name,type,parent_tenant_id,parent_tenant_slug,slug,memo,email_domain,visibility,org_unit_type")
assert.Contains(t, text, childID+",Child Org,ORGANIZATION,"+parentID+",parent-org,child-org,")
assert.Contains(t, text, grandchildID+",Leaf Team,USER_GROUP,"+childID+",child-org,leaf-team,")
assert.NotContains(t, text, unrelatedID)
assert.NotContains(t, text, "Parent Org")
mockSvc.AssertExpectations(t)
}
func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)

View File

@@ -570,9 +570,10 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
// [Resolve TenantID and Custom Login IDs before Kratos creation]
var tenantID string
requestedPrimaryTenantID := primaryTenantIDFromRequest(req.PrimaryTenantID, req.Metadata, req.AdditionalAppointments)
if req.CompanyCode == "" && h.TenantService != nil {
if primaryTenantID := primaryTenantIDFromRequest(req.PrimaryTenantID, req.Metadata, req.AdditionalAppointments); primaryTenantID != "" {
if tenant, err := h.TenantService.GetTenant(c.Context(), primaryTenantID); err == nil && tenant != nil {
if requestedPrimaryTenantID != "" {
if tenant, err := h.TenantService.GetTenant(c.Context(), requestedPrimaryTenantID); err == nil && tenant != nil {
tenantID = tenant.ID
req.CompanyCode = tenant.Slug
}
@@ -583,6 +584,17 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
tenantID = tenant.ID
}
}
if tenantID == "" {
if req.CompanyCode != "" || requestedPrimaryTenantID != "" {
return errorJSON(c, fiber.StatusBadRequest, "invalid tenant assignment")
}
tenant, err := createPersonalTenantForUser(c.Context(), h.TenantService, email)
if err != nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "failed to create personal tenant")
}
tenantID = tenant.ID
req.CompanyCode = tenant.Slug
}
// Collect and sync all custom login IDs based on tenant schemas
loginIDRecords := syncCustomLoginIDs(c.Context(), h.TenantService, attributes, req.Metadata, "")
@@ -857,6 +869,14 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
return tItem, true
}
createPersonalTenantItem := func(email string) (tenantCacheItem, error) {
tenant, err := createPersonalTenantForUser(c.Context(), h.TenantService, email)
if err != nil {
return tenantCacheItem{}, err
}
return cacheTenantItem(buildTenantCacheItem(tenant)), nil
}
for _, item := range req.Users {
email := strings.TrimSpace(item.Email)
name := strings.TrimSpace(item.Name)
@@ -898,8 +918,12 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
}
}
if tenantSlug == "" {
results = append(results, bulkUserResult{Email: email, Success: false, Message: "tenant assignment is required"})
continue
tItem, err = createPersonalTenantItem(email)
if err != nil {
results = append(results, bulkUserResult{Email: email, Success: false, Message: "failed to create personal tenant"})
continue
}
tenantSlug = tItem.Slug
}
}

View File

@@ -176,6 +176,14 @@ func (m *MockTenantServiceForUser) ProvisionTenantByDomain(ctx context.Context,
return args.Get(0).(*domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForUser) RegisterTenant(ctx context.Context, name, slug, tenantType, description string, domains []string, parentID *string, creatorID string) (*domain.Tenant, error) {
args := m.Called(ctx, name, slug, tenantType, description, domains, parentID, creatorID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.Tenant), args.Error(1)
}
// --- Tests ---
func TestUserHandler_ExportUsersCSV_UsesTenantSlugAliasAndOmitsRole(t *testing.T) {
@@ -1411,6 +1419,85 @@ func TestUserHandler_CreateUser_UsesAdditionalAppointmentAsPrimaryTenant(t *test
mockOry.AssertExpectations(t)
}
func TestUserHandler_CreateUser_AutoCreatesPersonalTenantWhenAssignmentMissing(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockOry := new(MockOryProvider)
mockTenant := new(MockTenantServiceForUser)
h := &UserHandler{
KratosAdmin: mockKratos,
OryProvider: mockOry,
TenantService: mockTenant,
}
app.Post("/users", h.CreateUser)
personalTenantID := "01970f0d-9666-7548-963d-2890351f03dd"
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil)
mockTenant.On(
"RegisterTenant",
mock.Anything,
"Personal - personal-user@example.com",
mock.MatchedBy(func(slug string) bool { return strings.HasPrefix(slug, "personal-") }),
domain.TenantTypePersonal,
"Automatically provisioned personal tenant",
[]string(nil),
(*string)(nil),
"",
).Return(&domain.Tenant{
ID: personalTenantID,
Slug: "personal-01970f0d96667548963d2890351f03dd",
Name: "Personal - personal-user@example.com",
Type: domain.TenantTypePersonal,
Status: domain.TenantStatusActive,
Config: domain.JSONMap{},
}, nil).Once()
mockTenant.On("GetTenant", mock.Anything, personalTenantID).Return(&domain.Tenant{
ID: personalTenantID,
Slug: "personal-01970f0d96667548963d2890351f03dd",
Name: "Personal - personal-user@example.com",
Type: domain.TenantTypePersonal,
Status: domain.TenantStatusActive,
Config: domain.JSONMap{},
}, nil).Once()
mockTenant.On("GetTenantBySlug", mock.Anything, "personal-01970f0d96667548963d2890351f03dd").Return(&domain.Tenant{
ID: personalTenantID,
Slug: "personal-01970f0d96667548963d2890351f03dd",
Name: "Personal - personal-user@example.com",
Type: domain.TenantTypePersonal,
Status: domain.TenantStatusActive,
Config: domain.JSONMap{},
}, nil).Once()
mockOry.On("CreateUser", mock.MatchedBy(func(user *domain.BrokerUser) bool {
return user.Email == "personal-user@example.com" &&
user.Attributes["tenant_id"] == personalTenantID &&
user.Attributes["companyCode"] == "personal-01970f0d96667548963d2890351f03dd"
}), mock.Anything).Return("u-personal", nil).Once()
mockKratos.On("GetIdentity", mock.Anything, "u-personal").Return(&service.KratosIdentity{
ID: "u-personal",
Traits: map[string]interface{}{
"email": "personal-user@example.com",
"name": "Personal User",
"companyCode": "personal-01970f0d96667548963d2890351f03dd",
"tenant_id": personalTenantID,
},
State: "active",
}, nil).Once()
payload := map[string]interface{}{
"email": "personal-user@example.com",
"name": "Personal User",
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest("POST", "/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
mockTenant.AssertExpectations(t)
mockOry.AssertExpectations(t)
mockKratos.AssertExpectations(t)
}
func TestUserHandler_MapToLocalUserKeepsRoleAndGradeSeparate(t *testing.T) {
handler := &UserHandler{}
identity := service.KratosIdentity{