1
0
forked from baron/baron-sso

Implement tenant import and RP auto login policies

This commit is contained in:
2026-04-30 15:45:34 +09:00
parent 24807eab0f
commit f7e4d43b16
76 changed files with 5307 additions and 441 deletions

View File

@@ -85,20 +85,21 @@ const (
)
type AuthHandler struct {
SmsService domain.SmsService
EmailService domain.EmailService
RedisService domain.RedisRepository
HeadlessJWKS *service.HeadlessJWKSCacheService
KratosAdmin service.KratosAdminService
IdpProvider domain.IdentityProvider
AuditRepo domain.AuditRepository
OathkeeperRepo domain.OathkeeperLogRepository
Hydra *service.HydraAdminService
TenantService service.TenantService
KetoService service.KetoService
KetoOutboxRepo repository.KetoOutboxRepository
UserRepo repository.UserRepository
ConsentRepo repository.ClientConsentRepository
SmsService domain.SmsService
EmailService domain.EmailService
RedisService domain.RedisRepository
HeadlessJWKS *service.HeadlessJWKSCacheService
KratosAdmin service.KratosAdminService
IdpProvider domain.IdentityProvider
AuditRepo domain.AuditRepository
OathkeeperRepo domain.OathkeeperLogRepository
Hydra *service.HydraAdminService
TenantService service.TenantService
KetoService service.KetoService
KetoOutboxRepo repository.KetoOutboxRepository
UserRepo repository.UserRepository
ConsentRepo repository.ClientConsentRepository
RPUserMetadataRepo repository.RPUserMetadataRepository
}
type signupState struct {
@@ -1157,6 +1158,120 @@ func withOidcSessionMetadata(claims map[string]any, sessionID string) map[string
return claims
}
func (h *AuthHandler) withRPProfileClaims(ctx context.Context, claims map[string]any, client domain.HydraClient, subject string) map[string]any {
if claims == nil {
claims = map[string]any{}
}
if h == nil || h.RPUserMetadataRepo == nil {
return claims
}
clientID := strings.TrimSpace(client.ClientID)
subject = strings.TrimSpace(subject)
if clientID == "" || subject == "" {
return claims
}
claimKeys := extractClaimEnabledCustomUserSchemaKeys(client.Metadata)
if len(claimKeys) == 0 {
return claims
}
row, err := h.RPUserMetadataRepo.Get(ctx, clientID, subject)
if err != nil || row == nil || len(row.Metadata) == 0 {
return claims
}
fields := make(map[string]any)
for _, key := range claimKeys {
raw, ok := row.Metadata[key]
if !ok || raw == nil {
continue
}
if value, ok := raw.(string); ok {
value = strings.TrimSpace(value)
if value == "" {
continue
}
fields[key] = value
continue
}
fields[key] = raw
}
if len(fields) == 0 {
return claims
}
profile := map[string]any{
"client_id": clientID,
"fields": fields,
}
if existing, ok := claims["rp_profiles"].([]any); ok {
claims["rp_profiles"] = append(existing, profile)
return claims
}
if existing, ok := claims["rp_profiles"].([]interface{}); ok {
claims["rp_profiles"] = append(existing, profile)
return claims
}
claims["rp_profiles"] = []any{profile}
return claims
}
func extractClaimEnabledCustomUserSchemaKeys(metadata map[string]interface{}) []string {
if metadata == nil {
return nil
}
rawSchema, ok := metadata["customUserSchema"]
if !ok || rawSchema == nil {
return nil
}
var items []interface{}
switch schema := rawSchema.(type) {
case []interface{}:
items = schema
case []map[string]interface{}:
items = make([]interface{}, 0, len(schema))
for _, item := range schema {
items = append(items, item)
}
default:
return nil
}
keys := make([]string, 0, len(items))
seen := make(map[string]struct{})
for _, item := range items {
field, ok := item.(map[string]interface{})
if !ok {
if typed, typedOK := item.(map[string]any); typedOK {
field = typed
} else {
continue
}
}
enabled, _ := field["claimEnabled"].(bool)
if !enabled {
enabled, _ = field["claim_enabled"].(bool)
}
if !enabled {
continue
}
key, _ := field["key"].(string)
key = strings.TrimSpace(key)
if key == "" {
continue
}
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
keys = append(keys, key)
}
return keys
}
func collectEmailList(traits map[string]any, primaryEmail string) []string {
emails := make([]string, 0)
seen := make(map[string]struct{})
@@ -4792,6 +4907,8 @@ type linkedRpSummary struct {
Logo string `json:"logo,omitempty"`
URL string `json:"url,omitempty"`
InitURL string `json:"init_url,omitempty"`
AutoLoginSupported bool `json:"auto_login_supported"`
AutoLoginURL string `json:"auto_login_url,omitempty"`
LastAuthenticatedAt string `json:"lastAuthenticatedAt,omitempty"`
Status string `json:"status"`
Scopes []string `json:"scopes,omitempty"`
@@ -4872,19 +4989,23 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
if len(scopes) == 0 && strings.TrimSpace(client.Scope) != "" {
scopes = strings.Fields(client.Scope)
}
initURL := resolveLinkedRPInitURL(client.ClientID, scopes, client.RedirectURIs)
autoLoginSupported := resolveLinkedRPAutoLoginSupported(client.ClientID, client.Metadata)
autoLoginURL := resolveLinkedRPAutoLoginURL(client.ClientID, client.Metadata)
initURL := resolveLinkedRPInitURL(client.ClientID, client.Metadata)
existing := records[clientID]
if existing == nil {
records[clientID] = &linkedRpRecord{
linkedRpSummary: linkedRpSummary{
ID: clientID,
Name: name,
Logo: extractHydraClientLogo(client.Metadata),
URL: clientURL,
InitURL: initURL,
Status: "active", // Hydra 세션이 있으면 활성
Scopes: scopes,
ID: clientID,
Name: name,
Logo: extractHydraClientLogo(client.Metadata),
URL: clientURL,
InitURL: initURL,
AutoLoginSupported: autoLoginSupported,
AutoLoginURL: autoLoginURL,
Status: "active", // Hydra 세션이 있으면 활성
Scopes: scopes,
},
lastAuth: lastAuth,
}
@@ -4903,6 +5024,12 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
if existing.InitURL == "" {
existing.InitURL = initURL
}
if !existing.AutoLoginSupported {
existing.AutoLoginSupported = autoLoginSupported
}
if existing.AutoLoginURL == "" {
existing.AutoLoginURL = autoLoginURL
}
existing.Scopes = mergeScopes(existing.Scopes, scopes)
if lastAuth.After(existing.lastAuth) {
existing.lastAuth = lastAuth
@@ -4943,11 +5070,13 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
)
}
if record.InitURL == "" {
record.InitURL = resolveLinkedRPInitURL(
client.ClientID,
record.Scopes,
client.RedirectURIs,
)
record.InitURL = resolveLinkedRPInitURL(client.ClientID, client.Metadata)
}
if !record.AutoLoginSupported {
record.AutoLoginSupported = resolveLinkedRPAutoLoginSupported(client.ClientID, client.Metadata)
}
if record.AutoLoginURL == "" {
record.AutoLoginURL = resolveLinkedRPAutoLoginURL(client.ClientID, client.Metadata)
}
}
@@ -4999,21 +5128,21 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
client.ClientURI,
client.RedirectURIs,
)
initURL := resolveLinkedRPInitURL(
client.ClientID,
dc.GrantedScopes,
client.RedirectURIs,
)
autoLoginSupported := resolveLinkedRPAutoLoginSupported(client.ClientID, client.Metadata)
autoLoginURL := resolveLinkedRPAutoLoginURL(client.ClientID, client.Metadata)
initURL := resolveLinkedRPInitURL(client.ClientID, client.Metadata)
records[dc.ClientID] = &linkedRpRecord{
linkedRpSummary: linkedRpSummary{
ID: dc.ClientID,
Name: name,
Logo: extractHydraClientLogo(client.Metadata),
URL: clientURL,
InitURL: initURL,
Status: status,
Scopes: dc.GrantedScopes,
ID: dc.ClientID,
Name: name,
Logo: extractHydraClientLogo(client.Metadata),
URL: clientURL,
InitURL: initURL,
AutoLoginSupported: autoLoginSupported,
AutoLoginURL: autoLoginURL,
Status: status,
Scopes: dc.GrantedScopes,
},
lastAuth: dc.UpdatedAt,
}
@@ -5087,11 +5216,9 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
}
}
record.URL = clientURL
record.InitURL = resolveLinkedRPInitURL(
client.ClientID,
scopes,
client.RedirectURIs,
)
record.InitURL = resolveLinkedRPInitURL(client.ClientID, client.Metadata)
record.AutoLoginSupported = resolveLinkedRPAutoLoginSupported(client.ClientID, client.Metadata)
record.AutoLoginURL = resolveLinkedRPAutoLoginURL(client.ClientID, client.Metadata)
} else {
// Hydra 정보 없음 (삭제됨 등) -> Audit 정보나 ID로 대체
if record.Name == "" {
@@ -5239,6 +5366,7 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
buildOidcClaimsFromTraits(identity.Traits, consentRequest.RequestedScope, tenantID),
currentSessionID,
)
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
acceptResp, err := h.Hydra.AcceptConsentRequest(c.Context(), challenge, consentRequest, sessionClaims)
if err == nil {
return c.JSON(acceptResp)
@@ -5268,6 +5396,7 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
buildOidcClaimsFromTraits(identity.Traits, consentRequest.RequestedScope, tenantID),
currentSessionID,
)
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
// [Debug] 실제 생성된 클레임 출력 (요청사항 확인용 - 자동 승인 시)
appEnv := strings.ToLower(os.Getenv("APP_ENV"))
@@ -5450,6 +5579,7 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error {
buildOidcClaimsFromTraits(identity.Traits, consentRequest.RequestedScope, tenantID),
currentSessionID,
)
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
// [Debug] 실제 생성된 클레임 출력 (요청사항 확인용)
appEnv := strings.ToLower(os.Getenv("APP_ENV"))
@@ -7255,6 +7385,10 @@ func resolveLinkedRPURL(clientID string, clientURI string, redirectURIs []string
if value := strings.TrimSpace(os.Getenv("DEVFRONT_URL")); value != "" {
return value
}
case "orgfront":
if value := strings.TrimSpace(os.Getenv("ORGFRONT_URL")); value != "" {
return value
}
}
clientURL := strings.TrimSpace(clientURI)
@@ -7271,10 +7405,22 @@ func resolveLinkedRPURL(clientID string, clientURI string, redirectURIs []string
return ""
}
func resolveLinkedRPInitURL(clientID string, scopes []string, redirectURIs []string) string {
func resolveLinkedRPAutoLoginSupported(clientID string, metadata map[string]interface{}) bool {
if readMetadataBoolValue(metadata, domain.MetadataAutoLoginSupported) {
return true
}
switch strings.TrimSpace(clientID) {
case "adminfront", "devfront", "orgfront":
return resolveLinkedRPAutoLoginURL(clientID, nil) != ""
default:
return false
}
}
func resolveLinkedRPAutoLoginURL(clientID string, metadata map[string]interface{}) string {
clientID = strings.TrimSpace(clientID)
if clientID == "" {
return ""
if metadataURL := readMetadataStringValue(metadata, domain.MetadataAutoLoginURL); metadataURL != "" {
return metadataURL
}
switch clientID {
@@ -7286,8 +7432,23 @@ func resolveLinkedRPInitURL(clientID string, scopes []string, redirectURIs []str
if value := strings.TrimRight(strings.TrimSpace(os.Getenv("DEVFRONT_URL")), "/"); value != "" {
return value + "/login?auto=1&returnTo=%2Fclients"
}
case "orgfront":
if value := strings.TrimRight(strings.TrimSpace(os.Getenv("ORGFRONT_URL")), "/"); value != "" {
return value + "/login?auto=1"
}
}
return ""
}
func resolveLinkedRPInitURL(clientID string, metadata map[string]interface{}) string {
if !resolveLinkedRPAutoLoginSupported(clientID, metadata) {
return ""
}
return resolveLinkedRPAutoLoginURL(clientID, metadata)
}
func buildHydraAuthorizationURL(clientID string, scopes []string, redirectURIs []string) string {
hydraPublicURL := strings.TrimRight(os.Getenv("HYDRA_PUBLIC_URL"), "/")
if hydraPublicURL == "" {
userfrontURL := strings.TrimRight(os.Getenv("USERFRONT_URL"), "/")

View File

@@ -1,6 +1,7 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"encoding/json"
@@ -178,6 +179,103 @@ func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
assert.Equal(t, "Architect", capturedClaims["position"])
}
func TestAcceptConsentRequest_IncludesRPProfileClaims(t *testing.T) {
var capturedClaims map[string]interface{}
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"challenge": "challenge-rp-profile",
"requested_scope": []string{"openid", "profile"},
"subject": "user-123",
"client": map[string]interface{}{
"client_id": "client-app",
"metadata": map[string]interface{}{
"customUserSchema": []map[string]interface{}{
{
"key": "approvalLevel",
"label": "승인 등급",
"type": "text",
"claimEnabled": true,
},
{
"key": "internalMemo",
"label": "내부 메모",
"type": "text",
"claimEnabled": false,
},
},
},
},
}), nil
}
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
body, _ := io.ReadAll(r.Body)
var acceptReq map[string]interface{}
json.Unmarshal(body, &acceptReq)
if session, ok := acceptReq["session"].(map[string]interface{}); ok {
capturedClaims = session["id_token"].(map[string]interface{})
}
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"redirect_to": "http://rp/cb",
}), nil
}
return httpResponse(r, http.StatusNotFound, "not found"), nil
})
client := &http.Client{Transport: transport}
h := &AuthHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: client,
},
KratosAdmin: new(MockKratosAdminService),
}
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
ID: "user-123",
Traits: map[string]interface{}{
"email": "user@test.com",
"name": "Test User",
},
}, nil)
repo := new(devMockRPUserMetadataRepo)
repo.On("Get", mock.Anything, "client-app", "user-123").Return(&domain.RPUserMetadata{
ClientID: "client-app",
UserID: "user-123",
Metadata: domain.JSONMap{
"approvalLevel": "A",
"internalMemo": "관리자 전용",
},
}, nil).Once()
h.RPUserMetadataRepo = repo
app := fiber.New()
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
reqBody, _ := json.Marshal(map[string]interface{}{
"consent_challenge": "challenge-rp-profile",
"grant_scope": []string{"openid", "profile"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotNil(t, capturedClaims)
rpProfiles, ok := capturedClaims["rp_profiles"].([]interface{})
assert.True(t, ok)
assert.Len(t, rpProfiles, 1)
profile := rpProfiles[0].(map[string]interface{})
assert.Equal(t, "client-app", profile["client_id"])
fields := profile["fields"].(map[string]interface{})
assert.Equal(t, "A", fields["approvalLevel"])
assert.NotContains(t, fields, "internalMemo")
repo.AssertExpectations(t)
}
func TestGetConsentRequest_Skip_DynamicClaims(t *testing.T) {
var capturedClaims map[string]interface{}

View File

@@ -55,6 +55,21 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
"grant_scope": []string{"openid", "profile"},
"handled_at": time.Now().Format(time.RFC3339),
},
{
"client": map[string]interface{}{
"client_id": "orgfront",
"client_name": "OrgFront",
"metadata": map[string]interface{}{
"auto_login_supported": true,
"auto_login_url": "http://localhost:5175/login?auto=1",
},
"redirect_uris": []string{
"http://localhost:5175/auth/callback",
},
},
"grant_scope": []string{"openid", "profile"},
"handled_at": time.Now().Format(time.RFC3339),
},
}), nil
}
if r.URL.Path == "/admin/clients/client-audit" {
@@ -129,16 +144,18 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
var res struct {
Items []struct {
ID string `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
Scopes []string `json:"scopes"`
InitURL string `json:"init_url"`
ID string `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
Scopes []string `json:"scopes"`
InitURL string `json:"init_url"`
AutoLoginSupported bool `json:"auto_login_supported"`
AutoLoginURL string `json:"auto_login_url"`
} `json:"items"`
}
json.NewDecoder(resp.Body).Decode(&res)
assert.Equal(t, 3, len(res.Items))
assert.Equal(t, 4, len(res.Items))
statusMap := make(map[string]string)
for _, item := range res.Items {
@@ -146,6 +163,7 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
}
assert.Equal(t, "active", statusMap["devfront"])
assert.Equal(t, "active", statusMap["orgfront"])
assert.Equal(t, "inactive", statusMap["client-consent"])
assert.Equal(t, "inactive", statusMap["client-audit"])
@@ -164,6 +182,23 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
assert.Equal(t, "/login", parsedInitURL.Path)
assert.Equal(t, "1", parsedInitURL.Query().Get("auto"))
assert.Equal(t, "/clients", parsedInitURL.Query().Get("returnTo"))
var orgfrontItem struct {
InitURL string
AutoLoginSupported bool
AutoLoginURL string
}
for _, item := range res.Items {
if item.ID == "orgfront" {
orgfrontItem.InitURL = item.InitURL
orgfrontItem.AutoLoginSupported = item.AutoLoginSupported
orgfrontItem.AutoLoginURL = item.AutoLoginURL
break
}
}
assert.True(t, orgfrontItem.AutoLoginSupported)
assert.Equal(t, "http://localhost:5175/login?auto=1", orgfrontItem.AutoLoginURL)
assert.Equal(t, orgfrontItem.AutoLoginURL, orgfrontItem.InitURL)
}
func TestListLinkedRps_EnrichesLogoFromHydraClientWhenConsentSessionOmitsMetadata(t *testing.T) {

View File

@@ -14,6 +14,7 @@ import (
"io"
"log/slog"
"net/http"
"net/url"
"os"
"strconv"
"strings"
@@ -24,19 +25,20 @@ import (
)
type DevHandler struct {
Hydra *service.HydraAdminService
Redis domain.RedisRepository
HeadlessJWKS *service.HeadlessJWKSCacheService
SecretRepo domain.ClientSecretRepository
AuditRepo domain.AuditRepository
KratosAdmin service.KratosAdminService
ConsentRepo repository.ClientConsentRepository
Keto service.KetoService
KetoOutbox repository.KetoOutboxRepository
RPSvc service.RelyingPartyService
TenantSvc service.TenantService
DeveloperSvc *service.DeveloperService
Auth interface {
Hydra *service.HydraAdminService
Redis domain.RedisRepository
HeadlessJWKS *service.HeadlessJWKSCacheService
SecretRepo domain.ClientSecretRepository
AuditRepo domain.AuditRepository
KratosAdmin service.KratosAdminService
ConsentRepo repository.ClientConsentRepository
Keto service.KetoService
KetoOutbox repository.KetoOutboxRepository
RPSvc service.RelyingPartyService
TenantSvc service.TenantService
DeveloperSvc *service.DeveloperService
RPUserMetadataRepo repository.RPUserMetadataRepository
Auth interface {
GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error)
}
}
@@ -1377,6 +1379,86 @@ func (h *DevHandler) publicHeadlessJWKSCacheState(clientID string) (*domain.Head
return h.HeadlessJWKS.PublicState(clientID)
}
func (h *DevHandler) GetRPUserMetadata(c *fiber.Ctx) error {
clientID := strings.TrimSpace(c.Params("id"))
userID := strings.TrimSpace(c.Params("userId"))
if clientID == "" || userID == "" {
return errorJSON(c, fiber.StatusBadRequest, "client id and user id are required")
}
if h.RPUserMetadataRepo == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "rp user metadata repository unavailable")
}
profile := h.getCurrentProfile(c)
if profile == nil {
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: authentication required")
}
summary, err := h.loadClientSummary(c.Context(), clientID)
if err != nil {
return errorJSON(c, fiber.StatusNotFound, "client not found")
}
if !h.canViewClientByPermit(c, profile, summary) {
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient permission to view client metadata")
}
metadata, err := h.RPUserMetadataRepo.Get(c.Context(), clientID, userID)
if err != nil {
return c.JSON(fiber.Map{
"clientId": clientID,
"userId": userID,
"metadata": domain.JSONMap{},
})
}
return c.JSON(metadata)
}
func (h *DevHandler) UpsertRPUserMetadata(c *fiber.Ctx) error {
clientID := strings.TrimSpace(c.Params("id"))
userID := strings.TrimSpace(c.Params("userId"))
if clientID == "" || userID == "" {
return errorJSON(c, fiber.StatusBadRequest, "client id and user id are required")
}
if h.RPUserMetadataRepo == nil {
return errorJSON(c, fiber.StatusServiceUnavailable, "rp user metadata repository unavailable")
}
profile := h.getCurrentProfile(c)
if profile == nil {
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: authentication required")
}
summary, err := h.loadClientSummary(c.Context(), clientID)
if err != nil {
return errorJSON(c, fiber.StatusNotFound, "client not found")
}
if !h.canManageClientRelations(c, profile, summary) {
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient permission to update client metadata")
}
var req struct {
Metadata map[string]any `json:"metadata"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
}
if req.Metadata == nil {
req.Metadata = map[string]any{}
}
row := &domain.RPUserMetadata{
ClientID: clientID,
UserID: userID,
Metadata: domain.JSONMap(req.Metadata),
}
if err := h.RPUserMetadataRepo.Upsert(c.Context(), row); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
return c.JSON(row)
}
func (h *DevHandler) syncHeadlessJWKSCache(ctx context.Context, client domain.HydraClient, reason string) {
if h.HeadlessJWKS == nil {
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.Redis, nil)
@@ -1574,6 +1656,10 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
metadata, err = normalizeClientAutoLoginMetadata(metadata)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
tokenAuthMethod := strings.TrimSpace(valueOr(req.TokenEndpointAuthMethod, ""))
if tokenAuthMethod == "" {
@@ -1766,6 +1852,10 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error {
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
metadata, err = normalizeClientAutoLoginMetadata(metadata)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
resolvedClientType := currentSummary.Type
if clientType != "" {
resolvedClientType = clientType
@@ -2575,6 +2665,30 @@ func readMetadataBoolValue(metadata map[string]interface{}, key string) bool {
return value
}
func normalizeClientAutoLoginMetadata(metadata map[string]interface{}) (map[string]interface{}, error) {
if metadata == nil {
return metadata, nil
}
supported := readMetadataBoolValue(metadata, domain.MetadataAutoLoginSupported)
rawURL := strings.TrimSpace(readMetadataStringValue(metadata, domain.MetadataAutoLoginURL))
metadata[domain.MetadataAutoLoginSupported] = supported
if !supported {
delete(metadata, domain.MetadataAutoLoginURL)
return metadata, nil
}
if rawURL == "" {
return nil, errors.New("auto_login_url is required when auto_login_supported is true")
}
parsed, err := url.Parse(rawURL)
if err != nil || parsed.Scheme == "" || parsed.Host == "" || (parsed.Scheme != "https" && parsed.Scheme != "http") {
return nil, errors.New("auto_login_url must be an http or https URL")
}
metadata[domain.MetadataAutoLoginURL] = rawURL
return metadata, nil
}
func normalizeHeadlessClientConfig(
clientType string,
tokenAuthMethod string,

View File

@@ -0,0 +1,94 @@
package handler
import (
"baron-sso-backend/internal/domain"
"baron-sso-backend/internal/service"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type devMockRPUserMetadataRepo struct {
mock.Mock
}
func (m *devMockRPUserMetadataRepo) Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error) {
args := m.Called(ctx, clientID, userID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*domain.RPUserMetadata), args.Error(1)
}
func (m *devMockRPUserMetadataRepo) Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error {
args := m.Called(ctx, metadata)
return args.Error(0)
}
func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path == "/clients/client-1" {
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
"client_id": "client-1",
"client_name": "Client One",
"metadata": map[string]interface{}{
"tenant_id": "tenant-1",
},
}), nil
}
return httpJSONAny(r, http.StatusNotFound, nil), nil
})
repo := new(devMockRPUserMetadataRepo)
repo.On("Upsert", mock.Anything, mock.MatchedBy(func(row *domain.RPUserMetadata) bool {
return row.ClientID == "client-1" &&
row.UserID == "user-1" &&
row.Metadata["approvalLevel"] == "A"
})).Return(nil).Once()
repo.On("Get", mock.Anything, "client-1", "user-1").Return(&domain.RPUserMetadata{
ClientID: "client-1",
UserID: "user-1",
Metadata: domain.JSONMap{"approvalLevel": "A"},
}, nil).Once()
h := &DevHandler{
Hydra: &service.HydraAdminService{
AdminURL: "http://hydra.test",
HTTPClient: &http.Client{Transport: transport},
},
RPUserMetadataRepo: repo,
}
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
return c.Next()
})
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
app.Get("/api/v1/dev/clients/:id/users/:userId/metadata", h.GetRPUserMetadata)
body, _ := json.Marshal(map[string]any{
"metadata": map[string]any{"approvalLevel": "A"},
})
putReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
putReq.Header.Set("Content-Type", "application/json")
putResp, _ := app.Test(putReq, -1)
assert.Equal(t, http.StatusOK, putResp.StatusCode)
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-1/users/user-1/metadata", nil)
getResp, _ := app.Test(getReq, -1)
assert.Equal(t, http.StatusOK, getResp.StatusCode)
var got map[string]any
assert.NoError(t, json.NewDecoder(getResp.Body).Decode(&got))
assert.Equal(t, "client-1", got["clientId"])
assert.Equal(t, "user-1", got["userId"])
assert.Equal(t, "A", got["metadata"].(map[string]any)["approvalLevel"])
repo.AssertExpectations(t)
}

View File

@@ -1300,6 +1300,36 @@ func TestCreateClient_DefaultsSkipConsentToTrue(t *testing.T) {
assert.True(t, *captured.SkipConsent)
}
func TestNormalizeClientAutoLoginMetadata(t *testing.T) {
t.Run("keeps supported flag and URL", func(t *testing.T) {
metadata, err := normalizeClientAutoLoginMetadata(map[string]interface{}{
"auto_login_supported": true,
"auto_login_url": "https://rp.example.com/login?auto=1",
})
assert.NoError(t, err)
assert.Equal(t, true, metadata["auto_login_supported"])
assert.Equal(t, "https://rp.example.com/login?auto=1", metadata["auto_login_url"])
})
t.Run("requires URL when supported", func(t *testing.T) {
_, err := normalizeClientAutoLoginMetadata(map[string]interface{}{
"auto_login_supported": true,
})
assert.Error(t, err)
})
t.Run("removes URL when unsupported", func(t *testing.T) {
metadata, err := normalizeClientAutoLoginMetadata(map[string]interface{}{
"auto_login_supported": false,
"auto_login_url": "https://rp.example.com/login?auto=1",
})
assert.NoError(t, err)
assert.Equal(t, false, metadata["auto_login_supported"])
_, exists := metadata["auto_login_url"]
assert.False(t, exists)
})
}
func TestCreateClient_AllowsExplicitSkipConsentFalse(t *testing.T) {
var captured domain.HydraClient

View File

@@ -0,0 +1,244 @@
package handler
import (
"baron-sso-backend/internal/domain"
"context"
"fmt"
"strings"
)
const hanmacFamilyTenantSlug = "hanmac-family"
type hanmacEmailScope struct {
TenantIDs map[string]bool
Slugs map[string]bool
IDList []string
SlugList []string
}
type hanmacEmailEvaluation struct {
Email string
OriginalEmail string
SuggestedEmail string
Status string
Warnings []string
Message string
Blocking bool
LocalPart string
}
func (h *UserHandler) evaluateHanmacImportEmail(ctx context.Context, item bulkUserItem, scope *hanmacEmailScope, usedLocalParts map[string]bool) hanmacEmailEvaluation {
originalEmail := strings.TrimSpace(item.Email)
name := strings.TrimSpace(item.Name)
evaluation := hanmacEmailEvaluation{
Email: originalEmail,
OriginalEmail: originalEmail,
Status: "valid",
}
localPart, domainPart, err := domain.SplitEmailDomain(originalEmail)
if err != nil {
evaluation.Status = "blockingError"
evaluation.Message = "invalid email format"
evaluation.Blocking = true
return evaluation
}
base, needsReview, _ := domain.BuildKoreanNameEmailBase(name)
if needsReview {
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
evaluation.Status = "needsReview"
}
if localPart == "" {
if base == "" {
evaluation.Status = "blockingError"
evaluation.Message = "이름으로 이메일 ID를 제안할 수 없습니다."
evaluation.Blocking = true
return evaluation
}
nextLocalPart := nextAvailableHanmacLocalPart(base, usedLocalParts)
evaluation.Email = nextLocalPart + "@" + domainPart
evaluation.SuggestedEmail = evaluation.Email
evaluation.LocalPart = nextLocalPart
evaluation.Status = "suggested"
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "suggested")
return evaluation
}
evaluation.LocalPart = localPart
if usedLocalParts[localPart] {
evaluation.Status = "blockingError"
evaluation.Message = "한맥가족 내에서 이미 사용 중인 이메일 ID입니다."
evaluation.Blocking = true
return evaluation
}
if base != "" && !domain.MatchesSuggestedNameRule(localPart, base) {
evaluation.Status = "ruleMismatch"
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "ruleMismatch")
}
if evaluation.Status == "needsReview" && len(evaluation.Warnings) == 0 {
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
}
_ = scope
return evaluation
}
func (h *UserHandler) ensureHanmacCreateEmailAllowed(ctx context.Context, email string, tenantSlug string, tenantID string) error {
scope, err := h.resolveHanmacEmailScope(ctx)
if err != nil || scope == nil || !scope.ContainsTenant(tenantID, tenantSlug) {
return nil
}
localPart, err := domain.ExtractNormalizedEmailLocalPart(email)
if err != nil {
return err
}
usedLocalParts, err := h.loadHanmacLocalParts(ctx, scope)
if err != nil {
return err
}
if usedLocalParts[localPart] {
return fmt.Errorf("한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
}
return nil
}
func (h *UserHandler) resolveHanmacEmailScope(ctx context.Context) (*hanmacEmailScope, error) {
if h.TenantService == nil {
return nil, nil
}
tenants, _, err := h.TenantService.ListTenants(ctx, 10000, 0, "")
if err != nil {
return nil, err
}
var rootID string
for _, tenant := range tenants {
if strings.EqualFold(strings.TrimSpace(tenant.Slug), hanmacFamilyTenantSlug) {
rootID = tenant.ID
break
}
}
if rootID == "" {
return nil, nil
}
tenantByID := make(map[string]domain.Tenant, len(tenants))
for _, tenant := range tenants {
tenantByID[tenant.ID] = tenant
}
scope := &hanmacEmailScope{
TenantIDs: make(map[string]bool),
Slugs: make(map[string]bool),
}
for _, tenant := range tenants {
if isTenantDescendantOf(tenant, rootID, tenantByID) {
scope.TenantIDs[tenant.ID] = true
scope.Slugs[strings.ToLower(strings.TrimSpace(tenant.Slug))] = true
scope.IDList = append(scope.IDList, tenant.ID)
scope.SlugList = append(scope.SlugList, tenant.Slug)
}
}
return scope, nil
}
func (h *UserHandler) loadHanmacLocalParts(ctx context.Context, scope *hanmacEmailScope) (map[string]bool, error) {
used := make(map[string]bool)
if h.UserRepo == nil || scope == nil {
return used, nil
}
if len(scope.IDList) > 0 {
users, err := h.UserRepo.FindByTenantIDs(ctx, scope.IDList)
if err != nil {
return nil, err
}
addUserEmailLocalParts(used, users)
}
if len(scope.SlugList) > 0 {
users, err := h.UserRepo.FindByCompanyCodes(ctx, scope.SlugList)
if err != nil {
return nil, err
}
addUserEmailLocalParts(used, users)
}
return used, nil
}
func (s *hanmacEmailScope) ContainsTenant(tenantID string, slug string) bool {
if s == nil {
return false
}
if tenantID != "" && s.TenantIDs[tenantID] {
return true
}
return s.Slugs[strings.ToLower(strings.TrimSpace(slug))]
}
func isTenantDescendantOf(tenant domain.Tenant, rootID string, tenantByID map[string]domain.Tenant) bool {
if tenant.ID == rootID {
return true
}
visited := make(map[string]bool)
parentID := ""
if tenant.ParentID != nil {
parentID = *tenant.ParentID
}
for parentID != "" {
if parentID == rootID {
return true
}
if visited[parentID] {
return false
}
visited[parentID] = true
parent, ok := tenantByID[parentID]
if !ok || parent.ParentID == nil {
return false
}
parentID = *parent.ParentID
}
return false
}
func addUserEmailLocalParts(target map[string]bool, users []domain.User) {
for _, user := range users {
localPart, err := domain.ExtractNormalizedEmailLocalPart(user.Email)
if err == nil && localPart != "" {
target[localPart] = true
}
}
}
func nextAvailableHanmacLocalPart(base string, usedLocalParts map[string]bool) string {
base = strings.ToLower(strings.TrimSpace(base))
if base == "" {
return ""
}
if !usedLocalParts[base] {
return base
}
for index := 1; ; index++ {
candidate := fmt.Sprintf("%s%d", base, index)
if !usedLocalParts[candidate] {
return candidate
}
}
}
func appendUniqueString(values []string, value string) []string {
for _, existing := range values {
if existing == value {
return values
}
}
return append(values, value)
}

View File

@@ -6,6 +6,7 @@ import (
"baron-sso-backend/internal/service"
"baron-sso-backend/internal/utils"
"bytes"
"context"
"encoding/csv"
"errors"
"fmt"
@@ -68,14 +69,22 @@ type tenantImportResult struct {
Errors []string `json:"errors"`
}
type tenantDomainConflict struct {
Domain string `json:"domain"`
TenantID string `json:"tenantId"`
TenantName string `json:"tenantName"`
TenantSlug string `json:"tenantSlug"`
}
type tenantCSVRecord struct {
TenantID string
Name string
Type string
ParentTenantID *string
Slug string
Memo string
Domains []string
TenantID string
Name string
Type string
ParentTenantID *string
ParentTenantSlug string
Slug string
Memo string
Domains []string
}
func (h *TenantHandler) RegisterTenantPublic(c *fiber.Ctx) error {
@@ -258,13 +267,24 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
if err := writer.Write([]string{"tenant_id", "name", "type", "parent_tenant_id", "slug", "memo", "email_domain"}); err != nil {
includeIDs := includeCSVIds(c)
if includeIDs {
if err := writer.Write([]string{"tenant_id", "name", "type", "parent_tenant_id", "parent_tenant_slug", "slug", "memo", "email_domain"}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
} else if err := writer.Write([]string{"name", "type", "parent_tenant_slug", "slug", "memo", "email_domain"}); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
slugByID := make(map[string]string, len(tenants))
for _, tenant := range tenants {
slugByID[tenant.ID] = tenant.Slug
}
for _, tenant := range tenants {
parentID := ""
parentSlug := ""
if tenant.ParentID != nil {
parentID = *tenant.ParentID
parentSlug = slugByID[parentID]
}
domains := make([]string, 0, len(tenant.Domains))
for _, domainName := range tenant.Domains {
@@ -273,15 +293,27 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
domains = append(domains, domainName)
}
}
if err := writer.Write([]string{
tenant.ID,
row := []string{
tenant.Name,
tenant.Type,
parentID,
parentSlug,
tenant.Slug,
tenant.Description,
strings.Join(domains, ";"),
}); err != nil {
}
if includeIDs {
row = []string{
tenant.ID,
tenant.Name,
tenant.Type,
parentID,
parentSlug,
tenant.Slug,
tenant.Description,
strings.Join(domains, ";"),
}
}
if err := writer.Write(row); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
}
@@ -305,33 +337,60 @@ func (h *TenantHandler) ImportTenantsCSV(c *fiber.Ctx) error {
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
records = orderTenantCSVRecordsByParentSlug(records)
creatorID := ""
if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok && profile != nil {
creatorID = profile.ID
}
tenantIDBySlug := make(map[string]string)
if h.Service != nil {
if tenants, _, err := h.Service.ListTenants(c.Context(), 10000, 0, ""); err == nil {
for _, tenant := range tenants {
tenantIDBySlug[strings.ToLower(tenant.Slug)] = tenant.ID
}
}
}
result := tenantImportResult{Errors: make([]string, 0)}
for i, record := range records {
rowNumber := i + 2
if record.ParentTenantID == nil && record.ParentTenantSlug != "" {
parentID := tenantIDBySlug[strings.ToLower(record.ParentTenantSlug)]
if parentID == "" {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("row %d: parent tenant slug not found: %s", rowNumber, record.ParentTenantSlug))
continue
}
record.ParentTenantID = &parentID
}
if record.TenantID != "" || (h.DB != nil && record.Slug != "") {
updated, err := h.upsertTenantCSVRecord(c, record)
tenant, updated, err := h.upsertTenantCSVRecord(c, record)
if err != nil {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("row %d: %s", rowNumber, err.Error()))
continue
}
if updated {
tenantIDBySlug[strings.ToLower(record.Slug)] = tenant.ID
result.Updated++
continue
}
}
if err := h.createTenantCSVRecord(c, record, creatorID); err != nil {
tenant, err := h.createTenantCSVRecord(c, record, creatorID)
if err != nil {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("row %d: %s", rowNumber, err.Error()))
continue
}
if tenant == nil {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("row %d: tenant creation returned empty result", rowNumber))
continue
}
tenantIDBySlug[strings.ToLower(record.Slug)] = tenant.ID
result.Created++
}
@@ -414,13 +473,14 @@ func parseTenantCSVRecords(r io.Reader) ([]tenantCSVRecord, error) {
}
records = append(records, tenantCSVRecord{
TenantID: tenantCSVValue(row, header, "tenant_id"),
Name: name,
Type: tenantType,
ParentTenantID: parentID,
Slug: slug,
Memo: tenantCSVValue(row, header, "memo"),
Domains: splitTenantCSVDomains(tenantCSVValue(row, header, "email_domain")),
TenantID: tenantCSVValue(row, header, "tenant_id"),
Name: name,
Type: tenantType,
ParentTenantID: parentID,
ParentTenantSlug: tenantCSVValue(row, header, "parent_tenant_slug"),
Slug: slug,
Memo: tenantCSVValue(row, header, "memo"),
Domains: splitTenantCSVDomains(tenantCSVValue(row, header, "email_domain")),
})
}
@@ -430,23 +490,25 @@ func parseTenantCSVRecords(r io.Reader) ([]tenantCSVRecord, error) {
func tenantCSVHeaderIndex(header []string) map[string]int {
index := make(map[string]int, len(header))
aliases := map[string]string{
"id": "tenant_id",
"tenantid": "tenant_id",
"tenant_id": "tenant_id",
"name": "name",
"type": "type",
"parentid": "parent_tenant_id",
"parent_id": "parent_tenant_id",
"parenttenantid": "parent_tenant_id",
"parent_tenant_id": "parent_tenant_id",
"slug": "slug",
"memo": "memo",
"description": "memo",
"email-domain": "email_domain",
"emaildomain": "email_domain",
"email_domain": "email_domain",
"domain": "email_domain",
"domains": "email_domain",
"id": "tenant_id",
"tenantid": "tenant_id",
"tenant_id": "tenant_id",
"name": "name",
"type": "type",
"parentid": "parent_tenant_id",
"parent_id": "parent_tenant_id",
"parenttenantid": "parent_tenant_id",
"parent_tenant_id": "parent_tenant_id",
"parenttenantslug": "parent_tenant_slug",
"parent_tenant_slug": "parent_tenant_slug",
"slug": "slug",
"memo": "memo",
"description": "memo",
"email-domain": "email_domain",
"emaildomain": "email_domain",
"email_domain": "email_domain",
"domain": "email_domain",
"domains": "email_domain",
}
for i, column := range header {
key := strings.ToLower(strings.TrimSpace(column))
@@ -475,6 +537,40 @@ func tenantCSVRowIsEmpty(row []string) bool {
return true
}
func includeCSVIds(c *fiber.Ctx) bool {
value := strings.ToLower(strings.TrimSpace(c.Query("includeIds")))
return value == "true" || value == "1" || value == "yes"
}
func orderTenantCSVRecordsByParentSlug(records []tenantCSVRecord) []tenantCSVRecord {
bySlug := make(map[string]tenantCSVRecord, len(records))
for _, record := range records {
bySlug[strings.ToLower(record.Slug)] = record
}
ordered := make([]tenantCSVRecord, 0, len(records))
visited := make(map[string]bool, len(records))
var visit func(record tenantCSVRecord)
visit = func(record tenantCSVRecord) {
key := strings.ToLower(record.Slug)
if visited[key] {
return
}
if record.ParentTenantSlug != "" {
if parent, ok := bySlug[strings.ToLower(record.ParentTenantSlug)]; ok {
visit(parent)
}
}
visited[key] = true
ordered = append(ordered, record)
}
for _, record := range records {
visit(record)
}
return ordered
}
func splitTenantCSVDomains(value string) []string {
value = strings.ReplaceAll(value, "\n", ";")
value = strings.ReplaceAll(value, ",", ";")
@@ -492,12 +588,203 @@ func splitTenantCSVDomains(value string) []string {
return domains
}
func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord) (bool, error) {
func normalizeTenantDomainInputs(values []string) []string {
seen := make(map[string]bool, len(values))
domains := make([]string, 0, len(values))
for _, value := range values {
for _, part := range strings.FieldsFunc(value, func(r rune) bool {
return r == ',' || r == ';' || r == '\n' || r == '\r' || r == '\t' || r == ' '
}) {
domainName := strings.ToLower(strings.TrimSpace(part))
if domainName == "" || seen[domainName] {
continue
}
seen[domainName] = true
domains = append(domains, domainName)
}
}
return domains
}
func normalizeTenantConfig(config map[string]any) (domain.JSONMap, error) {
normalized := make(domain.JSONMap, len(config))
for key, value := range config {
if key == "userSchema" {
fields, err := normalizeTenantUserSchema(value)
if err != nil {
return nil, err
}
normalized[key] = fields
continue
}
normalized[key] = value
}
return normalized, nil
}
func normalizeTenantUserSchema(value any) ([]any, error) {
if value == nil {
return nil, nil
}
rawFields, ok := value.([]any)
if !ok {
return nil, fmt.Errorf("userSchema must be an array")
}
fields := make([]any, 0, len(rawFields))
for _, raw := range rawFields {
field, ok := raw.(map[string]any)
if !ok {
return nil, fmt.Errorf("userSchema fields must be objects")
}
normalized := make(map[string]any, len(field))
for key, value := range field {
if key == "maxLength" {
continue
}
normalized[key] = value
}
isLoginID, _ := normalized["isLoginId"].(bool)
if isLoginID {
fieldType, _ := normalized["type"].(string)
if fieldType != "" && fieldType != "text" {
return nil, fmt.Errorf("login ID fields must be text")
}
normalized["type"] = "text"
normalized["indexed"] = true
} else if indexed, ok := normalized["indexed"].(bool); !ok || !indexed {
normalized["indexed"] = false
}
fields = append(fields, normalized)
}
return fields, nil
}
func normalizeTenantDomainForceSet(values []string) map[string]bool {
domains := normalizeTenantDomainInputs(values)
force := make(map[string]bool, len(domains))
for _, domainName := range domains {
force[domainName] = true
}
return force
}
func tenantDomainConflictJSON(c *fiber.Ctx, conflicts []tenantDomainConflict) error {
return c.Status(fiber.StatusConflict).JSON(fiber.Map{
"code": "tenant_domain_conflict",
"error": "domain is already assigned to another tenant",
"conflicts": conflicts,
})
}
func (h *TenantHandler) findTenantDomainConflicts(ctx context.Context, tenantID string, domains []string, forceDomains []string) ([]tenantDomainConflict, error) {
if h.DB == nil || h.DB.Config == nil || len(domains) == 0 {
return nil, nil
}
force := normalizeTenantDomainForceSet(forceDomains)
var rows []domain.TenantDomain
query := h.DB.WithContext(ctx).Where("domain IN ?", domains)
if tenantID != "" {
query = query.Where("tenant_id <> ?", tenantID)
}
if err := query.Find(&rows).Error; err != nil {
return nil, err
}
conflicts := make([]tenantDomainConflict, 0, len(rows))
tenantIDs := make([]string, 0, len(rows))
seenTenantIDs := make(map[string]bool, len(rows))
for _, row := range rows {
if force[row.Domain] {
continue
}
if !seenTenantIDs[row.TenantID] {
seenTenantIDs[row.TenantID] = true
tenantIDs = append(tenantIDs, row.TenantID)
}
}
tenantsByID := make(map[string]domain.Tenant, len(tenantIDs))
if len(tenantIDs) > 0 {
var tenants []domain.Tenant
if err := h.DB.WithContext(ctx).Where("id IN ?", tenantIDs).Find(&tenants).Error; err != nil {
return nil, err
}
for _, tenant := range tenants {
tenantsByID[tenant.ID] = tenant
}
}
for _, row := range rows {
if force[row.Domain] {
continue
}
conflict := tenantDomainConflict{
Domain: row.Domain,
TenantID: row.TenantID,
}
if tenant, ok := tenantsByID[row.TenantID]; ok {
conflict.TenantName = tenant.Name
conflict.TenantSlug = tenant.Slug
}
conflicts = append(conflicts, conflict)
}
return conflicts, nil
}
func (h *TenantHandler) replaceTenantDomains(ctx context.Context, tenantID string, domains []string, forceDomains []string) error {
if h.DB == nil {
return errors.New("database not available")
}
if h.DB.Config == nil {
return nil
}
deleteQuery := h.DB.WithContext(ctx).Where("tenant_id = ?", tenantID)
if len(domains) > 0 {
deleteQuery = deleteQuery.Where("domain NOT IN ?", domains)
}
if err := deleteQuery.Delete(&domain.TenantDomain{}).Error; err != nil {
return fmt.Errorf("failed to clear old domains: %w", err)
}
for _, domainName := range domains {
var existing domain.TenantDomain
err := h.DB.WithContext(ctx).Unscoped().
Where("tenant_id = ? AND domain = ?", tenantID, domainName).
First(&existing).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
if err := repository.NewTenantRepository(h.DB).AddDomain(ctx, tenantID, domainName, true); err != nil {
return fmt.Errorf("failed to add domain: %s", domainName)
}
continue
}
if err != nil {
return err
}
if err := h.DB.WithContext(ctx).Unscoped().Model(&existing).Updates(map[string]any{
"verified": true,
"deleted_at": nil,
}).Error; err != nil {
return fmt.Errorf("failed to add domain: %s", domainName)
}
}
return nil
}
func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord) (*domain.Tenant, bool, error) {
if h.DB == nil {
if record.TenantID != "" {
return false, errors.New("database not available for tenant update")
return nil, false, errors.New("database not available for tenant update")
}
return false, nil
return nil, false, nil
}
var tenant domain.Tenant
@@ -510,10 +797,10 @@ func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
return nil, false, nil
}
if err != nil {
return false, err
return nil, false, err
}
tenant.Name = record.Name
@@ -526,29 +813,29 @@ func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
}
if err := h.DB.Save(&tenant).Error; err != nil {
return false, err
return nil, false, err
}
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
return false, err
return nil, false, err
}
repo := repository.NewTenantRepository(h.DB)
for _, domainName := range record.Domains {
if err := repo.AddDomain(c.Context(), tenant.ID, domainName, true); err != nil {
return false, err
return nil, false, err
}
}
return true, nil
return &tenant, true, nil
}
func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord, creatorID string) error {
func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord, creatorID string) (*domain.Tenant, error) {
if h.DB != nil && record.TenantID != "" {
var exists int64
if err := h.DB.Unscoped().Model(&domain.Tenant{}).Where("slug = ?", record.Slug).Count(&exists).Error; err != nil {
return err
return nil, err
}
if exists > 0 {
return errors.New("tenant slug already exists")
return nil, errors.New("tenant slug already exists")
}
tenant := domain.Tenant{
@@ -561,7 +848,7 @@ func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
Status: domain.TenantStatusActive,
}
if err := h.DB.Create(&tenant).Error; err != nil {
return err
return nil, err
}
if h.KetoOutbox != nil {
_ = h.KetoOutbox.Create(c.Context(), &domain.KetoOutbox{
@@ -595,14 +882,14 @@ func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
repo := repository.NewTenantRepository(h.DB)
for _, domainName := range record.Domains {
if err := repo.AddDomain(c.Context(), tenant.ID, domainName, true); err != nil {
return err
return nil, err
}
}
return nil
return &tenant, nil
}
_, err := h.Service.RegisterTenant(c.Context(), record.Name, record.Slug, record.Type, record.Memo, record.Domains, record.ParentTenantID, creatorID)
return err
tenant, err := h.Service.RegisterTenant(c.Context(), record.Name, record.Slug, record.Type, record.Memo, record.Domains, record.ParentTenantID, creatorID)
return tenant, err
}
func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
@@ -646,14 +933,15 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
}
var req struct {
Name string `json:"name"`
Slug string `json:"slug"`
Type string `json:"type"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains"`
ParentID *string `json:"parentId"`
Config map[string]any `json:"config"`
Name string `json:"name"`
Slug string `json:"slug"`
Type string `json:"type"`
Description string `json:"description"`
Status string `json:"status"`
Domains []string `json:"domains"`
ForceDomains []string `json:"forceDomainConflicts"`
ParentID *string `json:"parentId"`
Config map[string]any `json:"config"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
@@ -701,7 +989,16 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
creatorID = profile.ID
}
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, req.Domains, parentID, creatorID)
normalizedDomains := normalizeTenantDomainInputs(req.Domains)
conflicts, err := h.findTenantDomainConflicts(c.Context(), "", normalizedDomains, req.ForceDomains)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
if len(conflicts) > 0 {
return tenantDomainConflictJSON(c, conflicts)
}
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, nil, parentID, creatorID)
if err != nil {
if strings.Contains(err.Error(), "already exists") {
return errorJSON(c, fiber.StatusConflict, err.Error())
@@ -713,10 +1010,20 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
summary.MemberCount = 0
if req.Config != nil {
tenant.Config = req.Config
config, err := normalizeTenantConfig(req.Config)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
tenant.Config = config
h.DB.Save(tenant)
summary.Config = tenant.Config
}
if err := h.replaceTenantDomains(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
if len(normalizedDomains) > 0 {
summary.Domains = normalizedDomains
}
return c.Status(fiber.StatusCreated).JSON(summary)
}
@@ -740,14 +1047,15 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
}
var req struct {
Name *string `json:"name"`
Type *string `json:"type"`
Slug *string `json:"slug"`
Description *string `json:"description"`
Status *string `json:"status"`
ParentID *string `json:"parentId"`
Domains []string `json:"domains"`
Config map[string]any `json:"config"`
Name *string `json:"name"`
Type *string `json:"type"`
Slug *string `json:"slug"`
Description *string `json:"description"`
Status *string `json:"status"`
ParentID *string `json:"parentId"`
Domains []string `json:"domains"`
ForceDomains []string `json:"forceDomainConflicts"`
Config map[string]any `json:"config"`
}
if err := c.BodyParser(&req); err != nil {
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
@@ -835,7 +1143,11 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
}
}
if req.Config != nil {
tenant.Config = req.Config
config, err := normalizeTenantConfig(req.Config)
if err != nil {
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
tenant.Config = config
}
if err := h.DB.Save(&tenant).Error; err != nil {
@@ -844,18 +1156,16 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
// Update domains if provided
if req.Domains != nil {
// Simple approach: Delete existing and recreate
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to clear old domains")
normalizedDomains := normalizeTenantDomainInputs(req.Domains)
conflicts, err := h.findTenantDomainConflicts(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains)
if err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
for _, d := range req.Domains {
if strings.TrimSpace(d) == "" {
continue
}
// Use repository for consistency
if err := repository.NewTenantRepository(h.DB).AddDomain(c.Context(), tenant.ID, strings.TrimSpace(d), true); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, "failed to add domain: "+d)
}
if len(conflicts) > 0 {
return tenantDomainConflictJSON(c, conflicts)
}
if err := h.replaceTenantDomains(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains); err != nil {
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
}
}

View File

@@ -189,7 +189,7 @@ func TestTenantHandler_CreateTenant(t *testing.T) {
}
body, _ := json.Marshal(input)
mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", domain.TenantTypeCompany, "", []string{"test.com"}, (*string)(nil), "").
mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", domain.TenantTypeCompany, "", []string(nil), (*string)(nil), "").
Return(&domain.Tenant{ID: "t1", Name: "Test Tenant", Slug: "test-tenant"}, nil)
req := httptest.NewRequest("POST", "/tenants", bytes.NewReader(body))
@@ -278,15 +278,55 @@ func TestTenantHandler_ExportTenantsCSV(t *testing.T) {
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(1), nil)
req := httptest.NewRequest("GET", "/tenants/export", nil)
req := httptest.NewRequest("GET", "/tenants/export?includeIds=true", nil)
resp, _ := app.Test(req)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Contains(t, resp.Header.Get("Content-Disposition"), "tenants.csv")
assert.Equal(t, "text/csv", strings.Split(resp.Header.Get("Content-Type"), ";")[0])
assert.Contains(t, string(body), "tenant_id,name,type,parent_tenant_id,slug,memo,email_domain")
assert.Contains(t, string(body), "t1,Tenant A,COMPANY,parent-1,tenant-a,Primary tenant,tenant-a.example.com;login.tenant-a.example.com")
assert.Contains(t, string(body), "tenant_id,name,type,parent_tenant_id,parent_tenant_slug,slug,memo,email_domain")
assert.Contains(t, string(body), "t1,Tenant A,COMPANY,parent-1,,tenant-a,Primary tenant,tenant-a.example.com;login.tenant-a.example.com")
}
func TestTenantHandler_ExportTenantsCSV_OmitsIDsAndUsesParentSlug(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
h := &TenantHandler{Service: mockSvc}
app.Get("/tenants/export", h.ExportTenantsCSV)
parentID := "parent-1"
tenants := []domain.Tenant{
{
ID: parentID,
Name: "Parent Tenant",
Type: domain.TenantTypeCompanyGroup,
Slug: "parent-tenant",
},
{
ID: "child-1",
Name: "Child Tenant",
Type: domain.TenantTypeUserGroup,
ParentID: &parentID,
Slug: "child-tenant",
},
}
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(2), nil)
req := httptest.NewRequest("GET", "/tenants/export?includeIds=false", nil)
resp, _ := app.Test(req)
body, _ := io.ReadAll(resp.Body)
text := string(body)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Contains(t, text, "name,type,parent_tenant_slug,slug,memo,email_domain")
assert.Contains(t, text, "Child Tenant,USER_GROUP,parent-tenant,child-tenant,,")
assert.NotContains(t, text, "tenant_id")
assert.NotContains(t, text, "parent_tenant_id")
assert.NotContains(t, text, "child-1")
mockSvc.AssertExpectations(t)
}
func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) {
@@ -304,6 +344,7 @@ func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) {
assert.NoError(t, err)
assert.NoError(t, writer.Close())
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return([]domain.Tenant{}, int64(0), nil).Once()
mockSvc.On(
"RegisterTenant",
mock.Anything,
@@ -331,6 +372,127 @@ func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) {
mockSvc.AssertExpectations(t)
}
func TestTenantHandler_ImportTenantsCSVResolvesParentSlugToID(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)
h := &TenantHandler{Service: mockSvc}
app.Post("/tenants/import", h.ImportTenantsCSV)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
part, err := writer.CreateFormFile("file", "tenants.csv")
assert.NoError(t, err)
_, err = part.Write([]byte("name,type,parent_tenant_slug,slug,memo,email_domain\nParent Tenant,COMPANY,,parent-slug,,\nChild Tenant,USER_GROUP,parent-slug,child-slug,,\n"))
assert.NoError(t, err)
assert.NoError(t, writer.Close())
parentID := "parent-id"
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return([]domain.Tenant{}, int64(0), nil).Once()
mockSvc.On(
"RegisterTenant",
mock.Anything,
"Parent Tenant",
"parent-slug",
domain.TenantTypeCompany,
"",
[]string{},
(*string)(nil),
"",
).Return(&domain.Tenant{ID: parentID, Name: "Parent Tenant", Slug: "parent-slug"}, nil).Once()
mockSvc.On(
"RegisterTenant",
mock.Anything,
"Child Tenant",
"child-slug",
domain.TenantTypeUserGroup,
"",
[]string{},
mock.MatchedBy(func(got *string) bool {
return got != nil && *got == parentID
}),
"",
).Return(&domain.Tenant{ID: "child-id", Name: "Child Tenant", Slug: "child-slug"}, nil).Once()
req := httptest.NewRequest("POST", "/tenants/import", &body)
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var got map[string]interface{}
json.NewDecoder(resp.Body).Decode(&got)
assert.Equal(t, float64(2), got["created"])
assert.Equal(t, float64(0), got["failed"])
mockSvc.AssertExpectations(t)
}
func TestTenantCSVAllowedDomainsRoundTrip(t *testing.T) {
records, err := parseTenantCSVRecords(strings.NewReader(
"name,type,parent_tenant_slug,slug,memo,email_domain\n" +
"Hanmac,COMPANY,,hanmac,,\"samaneng.com, hanmaceng.co.kr;login.hmac.kr\"\n",
))
assert.NoError(t, err)
assert.Len(t, records, 1)
assert.Equal(t, []string{"samaneng.com", "hanmaceng.co.kr", "login.hmac.kr"}, records[0].Domains)
}
func TestNormalizeTenantDomainInputsSplitsCommaAndWhitespace(t *testing.T) {
got := normalizeTenantDomainInputs([]string{
"samaneng.com, hanmaceng.co.kr",
" LOGIN.HMAC.KR\nportal.hmac.kr ",
"samaneng.com",
})
assert.Equal(t, []string{
"samaneng.com",
"hanmaceng.co.kr",
"login.hmac.kr",
"portal.hmac.kr",
}, got)
}
func TestNormalizeTenantConfigForcesIndexedForLoginIDFields(t *testing.T) {
config, err := normalizeTenantConfig(map[string]any{
"userSchema": []any{
map[string]any{
"key": "emp_no",
"label": "사번",
"type": "text",
"indexed": false,
"isLoginId": true,
"maxLength": 20,
},
},
})
assert.NoError(t, err)
fields, ok := config["userSchema"].([]any)
assert.True(t, ok)
assert.Len(t, fields, 1)
field, ok := fields[0].(map[string]any)
assert.True(t, ok)
assert.Equal(t, true, field["indexed"])
assert.Equal(t, true, field["isLoginId"])
assert.NotContains(t, field, "maxLength")
}
func TestNormalizeTenantConfigRejectsNonTextLoginIDFields(t *testing.T) {
_, err := normalizeTenantConfig(map[string]any{
"userSchema": []any{
map[string]any{
"key": "emp_no",
"type": "number",
"isLoginId": true,
},
},
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "login ID fields must be text")
}
func TestTenantHandler_ApproveTenant(t *testing.T) {
app := fiber.New()
mockSvc := new(MockTenantService)

View File

@@ -425,6 +425,15 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
attributes["tenant_id"] = tenantID
}
if h.UserRepo != nil {
if err := h.ensureHanmacCreateEmailAllowed(c.Context(), email, req.CompanyCode, tenantID); err != nil {
if strings.Contains(err.Error(), "한맥가족") {
return errorJSON(c, fiber.StatusConflict, err.Error())
}
return errorJSON(c, fiber.StatusBadRequest, err.Error())
}
}
// Merge custom metadata into attributes
for k, v := range req.Metadata {
// Don't overwrite core fields
@@ -534,10 +543,14 @@ type bulkUserItem struct {
}
type bulkUserResult struct {
Email string `json:"email"`
Success bool `json:"success"`
Message string `json:"message,omitempty"`
UserID string `json:"userId,omitempty"`
Email string `json:"email"`
OriginalEmail string `json:"originalEmail,omitempty"`
SuggestedEmail string `json:"suggestedEmail,omitempty"`
Status string `json:"status,omitempty"`
Warnings []string `json:"warnings,omitempty"`
Success bool `json:"success"`
Message string `json:"message,omitempty"`
UserID string `json:"userId,omitempty"`
}
func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
@@ -565,6 +578,9 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
requester, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
results := make([]bulkUserResult, 0, len(req.Users))
var hanmacScope *hanmacEmailScope
var hanmacLocalParts map[string]bool
hanmacScopeLoaded := false
// Pre-fetch tenant data to avoid redundant DB calls
type tenantCacheItem struct {
@@ -638,6 +654,53 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
}
}
if h.UserRepo != nil && !hanmacScopeLoaded {
hanmacScopeLoaded = true
var err error
hanmacScope, err = h.resolveHanmacEmailScope(c.Context())
if err != nil {
results = append(results, bulkUserResult{Email: email, Success: false, Message: "failed to resolve Hanmac family tenant scope"})
continue
}
if hanmacScope != nil {
hanmacLocalParts, err = h.loadHanmacLocalParts(c.Context(), hanmacScope)
if err != nil {
results = append(results, bulkUserResult{Email: email, Success: false, Message: "failed to validate Hanmac family email policy"})
continue
}
}
}
userEmail := email
var emailEvaluation hanmacEmailEvaluation
if h.UserRepo != nil && hanmacScope != nil && hanmacScope.ContainsTenant(tItem.ID, tenantSlug) {
emailEvaluation = h.evaluateHanmacImportEmail(c.Context(), item, hanmacScope, hanmacLocalParts)
if emailEvaluation.Blocking {
results = append(results, bulkUserResult{
Email: emailEvaluation.Email,
OriginalEmail: emailEvaluation.OriginalEmail,
Status: emailEvaluation.Status,
Warnings: emailEvaluation.Warnings,
Success: false,
Message: emailEvaluation.Message,
})
continue
}
userEmail = emailEvaluation.Email
if emailEvaluation.LocalPart != "" {
hanmacLocalParts[emailEvaluation.LocalPart] = true
}
} else {
if _, _, err := domain.SplitEmailDomain(email); err != nil {
results = append(results, bulkUserResult{Email: email, Success: false, Status: "blockingError", Message: "invalid email format"})
continue
}
if localPart, err := domain.ExtractNormalizedEmailLocalPart(email); err != nil || localPart == "" {
results = append(results, bulkUserResult{Email: email, Success: false, Status: "blockingError", Message: "invalid email format"})
continue
}
}
password, _ := utils.GeneratePasswordWithPolicy(policy)
role := item.Role
if role == "" {
@@ -665,7 +728,6 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
}
}
userEmail := email
userPhone := normalizePhoneNumber(item.Phone)
// Validate all collected LoginIDs
@@ -673,7 +735,7 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
valid := true
for _, lid := range collectedIDs {
if err := domain.ValidateLoginID(lid, userEmail, userPhone); err != nil {
results = append(results, bulkUserResult{Email: email, Success: false, Message: "Invalid LoginID (" + lid + "): " + err.Error()})
results = append(results, bulkUserResult{Email: userEmail, OriginalEmail: emailEvaluation.OriginalEmail, SuggestedEmail: emailEvaluation.SuggestedEmail, Status: emailEvaluation.Status, Warnings: emailEvaluation.Warnings, Success: false, Message: "Invalid LoginID (" + lid + "): " + err.Error()})
valid = false
break
}
@@ -692,14 +754,14 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
if err != nil {
// 만약 이미 존재하는 사용자라면 로컬 DB 및 Keto 관계만 업데이트(Sync)를 시도
if strings.Contains(err.Error(), "409") || strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "exists already") {
identityID, err = h.KratosAdmin.FindIdentityIDByIdentifier(c.Context(), email)
identityID, err = h.KratosAdmin.FindIdentityIDByIdentifier(c.Context(), userEmail)
if err != nil || identityID == "" {
results = append(results, bulkUserResult{Email: email, Success: false, Message: "이미 다른 사용자가 해당 식별자(이메일/사번 등)를 사용 중입니다."})
results = append(results, bulkUserResult{Email: userEmail, OriginalEmail: emailEvaluation.OriginalEmail, SuggestedEmail: emailEvaluation.SuggestedEmail, Status: "blockingError", Warnings: emailEvaluation.Warnings, Success: false, Message: "이미 다른 사용자가 해당 식별자(이메일/사번 등)를 사용 중입니다."})
continue
}
slog.Info("BulkCreate: User already exists, syncing local DB and Keto", "email", email, "identityID", identityID)
} else {
results = append(results, bulkUserResult{Email: email, Success: false, Message: err.Error()})
results = append(results, bulkUserResult{Email: userEmail, OriginalEmail: emailEvaluation.OriginalEmail, SuggestedEmail: emailEvaluation.SuggestedEmail, Status: emailEvaluation.Status, Warnings: emailEvaluation.Warnings, Success: false, Message: err.Error()})
continue
}
}
@@ -709,7 +771,7 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
if h.UserRepo != nil {
localUser := &domain.User{
ID: identityID,
Email: email,
Email: userEmail,
Name: name,
Phone: normalizePhoneNumber(item.Phone),
Role: role,
@@ -776,7 +838,15 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
}
}
results = append(results, bulkUserResult{Email: email, Success: true, UserID: identityID})
results = append(results, bulkUserResult{
Email: userEmail,
OriginalEmail: emailEvaluation.OriginalEmail,
SuggestedEmail: emailEvaluation.SuggestedEmail,
Status: emailEvaluation.Status,
Warnings: emailEvaluation.Warnings,
Success: true,
UserID: identityID,
})
}
return c.Status(fiber.StatusOK).JSON(fiber.Map{
@@ -870,12 +940,19 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
defer writer.Flush()
// Header row
header := []string{"ID", "Email", "Name", "Phone", "Status", "Tenant", "Position", "JobTitle", "CreatedAt"}
includeIDs := includeCSVIds(c)
header := []string{"Email", "Name", "Phone", "Status", "tenant_slug", "Position", "JobTitle", "CreatedAt"}
if includeIDs {
header = []string{"user_id", "Email", "Name", "Phone", "Status", "tenant_id", "tenant_slug", "Position", "JobTitle", "CreatedAt"}
}
// Collect all possible metadata keys for dynamic columns
metaKeysMap := make(map[string]bool)
for _, u := range filtered {
for k := range u.Metadata {
if !includeIDs && csvMetadataKeyIsID(k) {
continue
}
metaKeysMap[k] = true
}
}
@@ -891,8 +968,11 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
// Data rows
for _, u := range filtered {
tenantID := ""
if u.TenantID != nil {
tenantID = *u.TenantID
}
row := []string{
u.ID,
u.Email,
u.Name,
u.Phone,
@@ -902,6 +982,20 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
u.JobTitle,
u.CreatedAt.Format(time.RFC3339),
}
if includeIDs {
row = []string{
u.ID,
u.Email,
u.Name,
u.Phone,
u.Status,
tenantID,
u.CompanyCode,
u.Position,
u.JobTitle,
u.CreatedAt.Format(time.RFC3339),
}
}
// Append metadata values in order
for _, k := range metaKeys {
val := ""
@@ -918,6 +1012,11 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
return nil
}
func csvMetadataKeyIsID(key string) bool {
normalized := strings.ToLower(strings.TrimSpace(key))
return normalized == "id" || normalized == "user_id" || normalized == "tenant_id" || normalized == "tenantid"
}
func (h *UserHandler) BulkUpdateUsers(c *fiber.Ctx) error {
var req struct {
UserIDs []string `json:"userIds"`

View File

@@ -126,6 +126,14 @@ func (m *MockTenantServiceForUser) ListManageableTenants(ctx context.Context, us
return args.Get(0).([]domain.Tenant), args.Error(1)
}
func (m *MockTenantServiceForUser) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
args := m.Called(ctx, limit, offset, parentID)
if args.Get(0) == nil {
return nil, args.Get(1).(int64), args.Error(2)
}
return args.Get(0).([]domain.Tenant), args.Get(1).(int64), args.Error(2)
}
func (m *MockTenantServiceForUser) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
args := m.Called(ctx, domainName)
if args.Get(0) == nil {
@@ -167,20 +175,66 @@ func TestUserHandler_ExportUsersCSV_UsesTenantSlugAliasAndOmitsRole(t *testing.T
},
}, int64(1), nil).Once()
req := httptest.NewRequest("GET", "/users/export?tenantSlug=test-tenant", nil)
req := httptest.NewRequest("GET", "/users/export?tenantSlug=test-tenant&includeIds=true", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
bodyBytes, _ := io.ReadAll(resp.Body)
body := strings.TrimPrefix(string(bodyBytes), "\ufeff")
assert.Contains(t, body, "ID,Email,Name,Phone,Status,Tenant,Position,JobTitle,CreatedAt")
assert.Contains(t, body, "u-1,user@test.com,Test User,010-1111-2222,active,test-tenant")
assert.Contains(t, body, "user_id,Email,Name,Phone,Status,tenant_id,tenant_slug,Position,JobTitle,CreatedAt")
assert.Contains(t, body, "u-1,user@test.com,Test User,010-1111-2222,active,,test-tenant")
assert.NotContains(t, body, "Role")
assert.NotContains(t, body, "Department")
mockRepo.AssertExpectations(t)
}
func TestUserHandler_ExportUsersCSV_OmitsIDsAndUsesTenantSlug(t *testing.T) {
app := fiber.New()
mockRepo := new(MockUserRepoForHandler)
h := &UserHandler{UserRepo: mockRepo}
app.Use(func(c *fiber.Ctx) error {
c.Locals("user_profile", &domain.UserProfileResponse{
Role: domain.RoleSuperAdmin,
})
return c.Next()
})
app.Get("/users/export", h.ExportUsersCSV)
createdAt := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
tenantID := "tenant-uuid"
mockRepo.On("List", mock.Anything, 0, 10000, "", "").
Return([]domain.User{
{
ID: "user-uuid",
Email: "user@test.com",
Name: "Test User",
Phone: "010-1111-2222",
Status: "active",
CompanyCode: "test-tenant",
TenantID: &tenantID,
Position: "책임",
JobTitle: "플랫폼 운영",
CreatedAt: createdAt,
},
}, int64(1), nil).Once()
req := httptest.NewRequest("GET", "/users/export?includeIds=false", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
bodyBytes, _ := io.ReadAll(resp.Body)
body := strings.TrimPrefix(string(bodyBytes), "\ufeff")
assert.Contains(t, body, "Email,Name,Phone,Status,tenant_slug,Position,JobTitle,CreatedAt")
assert.Contains(t, body, "user@test.com,Test User,010-1111-2222,active,test-tenant")
assert.NotContains(t, body, "user-uuid")
assert.NotContains(t, body, "tenant-uuid")
assert.NotContains(t, body, "ID,")
mockRepo.AssertExpectations(t)
}
func TestUserHandler_BulkCreateUsers(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
@@ -355,6 +409,170 @@ func TestUserHandler_BulkCreateUsers(t *testing.T) {
})
}
func TestUserHandler_BulkCreateUsers_HanmacEmailPolicy(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockOry := new(MockOryProvider)
mockTenant := new(MockTenantServiceForUser)
mockRepo := new(MockUserRepoForHandler)
h := &UserHandler{
KratosAdmin: mockKratos,
OryProvider: mockOry,
TenantService: mockTenant,
UserRepo: mockRepo,
}
app.Post("/users/bulk", h.BulkCreateUsers)
rootID := "hanmac-family-id"
companyID := "hanmac-id"
tenants := []domain.Tenant{
{ID: rootID, Slug: "hanmac-family", Name: "한맥가족"},
{ID: companyID, Slug: "hanmac", Name: "한맥기술", ParentID: &rootID},
{ID: "external-id", Slug: "external", Name: "외부사"},
}
t.Run("domain only email receives suggested final email with next suffix", func(t *testing.T) {
mockTenant.On("GetTenantBySlug", mock.Anything, "hanmac").Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
}, nil).Once()
mockTenant.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
}, nil).Maybe()
mockTenant.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil).Once()
mockRepo.On("FindByTenantIDs", mock.Anything, []string{rootID, companyID}).Return([]domain.User{
{Email: "cyhan@hanmaceng.co.kr", CompanyCode: "hanmac", TenantID: &companyID},
{Email: "cyhan1@samaneng.com", CompanyCode: "hanmac", TenantID: &companyID},
}, nil).Once()
mockRepo.On("FindByCompanyCodes", mock.Anything, []string{"hanmac-family", "hanmac"}).Return([]domain.User{}, nil).Once()
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
mockOry.On("CreateUser", mock.MatchedBy(func(user *domain.BrokerUser) bool {
return user.Email == "cyhan2@hanmaceng.co.kr"
}), mock.Anything).Return("u-hanmac", nil).Once()
payload := map[string]interface{}{
"users": []map[string]interface{}{
{
"email": "@hanmaceng.co.kr",
"name": "한치영",
"tenantSlug": "hanmac",
},
},
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest("POST", "/users/bulk", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
results := result["results"].([]interface{})
row := results[0].(map[string]interface{})
assert.True(t, row["success"].(bool))
assert.Equal(t, "cyhan2@hanmaceng.co.kr", row["email"])
assert.Equal(t, "@hanmaceng.co.kr", row["originalEmail"])
assert.Contains(t, row["warnings"].([]interface{}), "suggested")
})
t.Run("full email duplicate local part is blocking error", func(t *testing.T) {
mockTenant.On("GetTenantBySlug", mock.Anything, "hanmac").Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
}, nil).Once()
mockTenant.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
}, nil).Maybe()
mockTenant.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil).Once()
mockRepo.On("FindByTenantIDs", mock.Anything, []string{rootID, companyID}).Return([]domain.User{
{Email: "han@hanmaceng.co.kr", CompanyCode: "hanmac", TenantID: &companyID},
}, nil).Once()
mockRepo.On("FindByCompanyCodes", mock.Anything, []string{"hanmac-family", "hanmac"}).Return([]domain.User{}, nil).Once()
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
payload := map[string]interface{}{
"users": []map[string]interface{}{
{
"email": "han@samaneng.com",
"name": "한치영",
"tenantSlug": "hanmac",
},
},
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest("POST", "/users/bulk", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
results := result["results"].([]interface{})
row := results[0].(map[string]interface{})
assert.False(t, row["success"].(bool))
assert.Equal(t, "blockingError", row["status"])
assert.Contains(t, row["message"].(string), "한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
})
}
func TestUserHandler_CreateUser_HanmacEmailPolicyBlocksDuplicateLocalPart(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)
mockOry := new(MockOryProvider)
mockTenant := new(MockTenantServiceForUser)
mockRepo := new(MockUserRepoForHandler)
h := &UserHandler{
KratosAdmin: mockKratos,
OryProvider: mockOry,
TenantService: mockTenant,
UserRepo: mockRepo,
}
app.Post("/users", h.CreateUser)
rootID := "hanmac-family-id"
companyID := "hanmac-id"
tenants := []domain.Tenant{
{ID: rootID, Slug: "hanmac-family", Name: "한맥가족"},
{ID: companyID, Slug: "hanmac", Name: "한맥기술", ParentID: &rootID},
}
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
mockTenant.On("GetTenantBySlug", mock.Anything, "hanmac").Return(&domain.Tenant{
ID: companyID,
Slug: "hanmac",
}, nil).Once()
mockTenant.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil).Once()
mockRepo.On("FindByTenantIDs", mock.Anything, []string{rootID, companyID}).Return([]domain.User{
{Email: "han@hanmaceng.co.kr", CompanyCode: "hanmac", TenantID: &companyID},
}, nil).Once()
mockRepo.On("FindByCompanyCodes", mock.Anything, []string{"hanmac-family", "hanmac"}).Return([]domain.User{}, nil).Once()
payload := map[string]interface{}{
"email": "han@samaneng.com",
"name": "한치영",
"companyCode": "hanmac",
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest("POST", "/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, _ := app.Test(req)
assert.Equal(t, http.StatusConflict, resp.StatusCode)
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.Contains(t, result["error"].(string), "한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
mockOry.AssertNotCalled(t, "CreateUser")
}
func TestUserHandler_BulkUpdateUsers(t *testing.T) {
app := fiber.New()
mockKratos := new(MockKratosAdmin)