forked from baron/baron-sso
Implement tenant import and RP auto login policies
This commit is contained in:
@@ -85,20 +85,21 @@ const (
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
SmsService domain.SmsService
|
||||
EmailService domain.EmailService
|
||||
RedisService domain.RedisRepository
|
||||
HeadlessJWKS *service.HeadlessJWKSCacheService
|
||||
KratosAdmin service.KratosAdminService
|
||||
IdpProvider domain.IdentityProvider
|
||||
AuditRepo domain.AuditRepository
|
||||
OathkeeperRepo domain.OathkeeperLogRepository
|
||||
Hydra *service.HydraAdminService
|
||||
TenantService service.TenantService
|
||||
KetoService service.KetoService
|
||||
KetoOutboxRepo repository.KetoOutboxRepository
|
||||
UserRepo repository.UserRepository
|
||||
ConsentRepo repository.ClientConsentRepository
|
||||
SmsService domain.SmsService
|
||||
EmailService domain.EmailService
|
||||
RedisService domain.RedisRepository
|
||||
HeadlessJWKS *service.HeadlessJWKSCacheService
|
||||
KratosAdmin service.KratosAdminService
|
||||
IdpProvider domain.IdentityProvider
|
||||
AuditRepo domain.AuditRepository
|
||||
OathkeeperRepo domain.OathkeeperLogRepository
|
||||
Hydra *service.HydraAdminService
|
||||
TenantService service.TenantService
|
||||
KetoService service.KetoService
|
||||
KetoOutboxRepo repository.KetoOutboxRepository
|
||||
UserRepo repository.UserRepository
|
||||
ConsentRepo repository.ClientConsentRepository
|
||||
RPUserMetadataRepo repository.RPUserMetadataRepository
|
||||
}
|
||||
|
||||
type signupState struct {
|
||||
@@ -1157,6 +1158,120 @@ func withOidcSessionMetadata(claims map[string]any, sessionID string) map[string
|
||||
return claims
|
||||
}
|
||||
|
||||
func (h *AuthHandler) withRPProfileClaims(ctx context.Context, claims map[string]any, client domain.HydraClient, subject string) map[string]any {
|
||||
if claims == nil {
|
||||
claims = map[string]any{}
|
||||
}
|
||||
if h == nil || h.RPUserMetadataRepo == nil {
|
||||
return claims
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(client.ClientID)
|
||||
subject = strings.TrimSpace(subject)
|
||||
if clientID == "" || subject == "" {
|
||||
return claims
|
||||
}
|
||||
|
||||
claimKeys := extractClaimEnabledCustomUserSchemaKeys(client.Metadata)
|
||||
if len(claimKeys) == 0 {
|
||||
return claims
|
||||
}
|
||||
|
||||
row, err := h.RPUserMetadataRepo.Get(ctx, clientID, subject)
|
||||
if err != nil || row == nil || len(row.Metadata) == 0 {
|
||||
return claims
|
||||
}
|
||||
|
||||
fields := make(map[string]any)
|
||||
for _, key := range claimKeys {
|
||||
raw, ok := row.Metadata[key]
|
||||
if !ok || raw == nil {
|
||||
continue
|
||||
}
|
||||
if value, ok := raw.(string); ok {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
fields[key] = value
|
||||
continue
|
||||
}
|
||||
fields[key] = raw
|
||||
}
|
||||
if len(fields) == 0 {
|
||||
return claims
|
||||
}
|
||||
|
||||
profile := map[string]any{
|
||||
"client_id": clientID,
|
||||
"fields": fields,
|
||||
}
|
||||
if existing, ok := claims["rp_profiles"].([]any); ok {
|
||||
claims["rp_profiles"] = append(existing, profile)
|
||||
return claims
|
||||
}
|
||||
if existing, ok := claims["rp_profiles"].([]interface{}); ok {
|
||||
claims["rp_profiles"] = append(existing, profile)
|
||||
return claims
|
||||
}
|
||||
claims["rp_profiles"] = []any{profile}
|
||||
return claims
|
||||
}
|
||||
|
||||
func extractClaimEnabledCustomUserSchemaKeys(metadata map[string]interface{}) []string {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
rawSchema, ok := metadata["customUserSchema"]
|
||||
if !ok || rawSchema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var items []interface{}
|
||||
switch schema := rawSchema.(type) {
|
||||
case []interface{}:
|
||||
items = schema
|
||||
case []map[string]interface{}:
|
||||
items = make([]interface{}, 0, len(schema))
|
||||
for _, item := range schema {
|
||||
items = append(items, item)
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(items))
|
||||
seen := make(map[string]struct{})
|
||||
for _, item := range items {
|
||||
field, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
if typed, typedOK := item.(map[string]any); typedOK {
|
||||
field = typed
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
enabled, _ := field["claimEnabled"].(bool)
|
||||
if !enabled {
|
||||
enabled, _ = field["claim_enabled"].(bool)
|
||||
}
|
||||
if !enabled {
|
||||
continue
|
||||
}
|
||||
key, _ := field["key"].(string)
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func collectEmailList(traits map[string]any, primaryEmail string) []string {
|
||||
emails := make([]string, 0)
|
||||
seen := make(map[string]struct{})
|
||||
@@ -4792,6 +4907,8 @@ type linkedRpSummary struct {
|
||||
Logo string `json:"logo,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
InitURL string `json:"init_url,omitempty"`
|
||||
AutoLoginSupported bool `json:"auto_login_supported"`
|
||||
AutoLoginURL string `json:"auto_login_url,omitempty"`
|
||||
LastAuthenticatedAt string `json:"lastAuthenticatedAt,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
@@ -4872,19 +4989,23 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
|
||||
if len(scopes) == 0 && strings.TrimSpace(client.Scope) != "" {
|
||||
scopes = strings.Fields(client.Scope)
|
||||
}
|
||||
initURL := resolveLinkedRPInitURL(client.ClientID, scopes, client.RedirectURIs)
|
||||
autoLoginSupported := resolveLinkedRPAutoLoginSupported(client.ClientID, client.Metadata)
|
||||
autoLoginURL := resolveLinkedRPAutoLoginURL(client.ClientID, client.Metadata)
|
||||
initURL := resolveLinkedRPInitURL(client.ClientID, client.Metadata)
|
||||
|
||||
existing := records[clientID]
|
||||
if existing == nil {
|
||||
records[clientID] = &linkedRpRecord{
|
||||
linkedRpSummary: linkedRpSummary{
|
||||
ID: clientID,
|
||||
Name: name,
|
||||
Logo: extractHydraClientLogo(client.Metadata),
|
||||
URL: clientURL,
|
||||
InitURL: initURL,
|
||||
Status: "active", // Hydra 세션이 있으면 활성
|
||||
Scopes: scopes,
|
||||
ID: clientID,
|
||||
Name: name,
|
||||
Logo: extractHydraClientLogo(client.Metadata),
|
||||
URL: clientURL,
|
||||
InitURL: initURL,
|
||||
AutoLoginSupported: autoLoginSupported,
|
||||
AutoLoginURL: autoLoginURL,
|
||||
Status: "active", // Hydra 세션이 있으면 활성
|
||||
Scopes: scopes,
|
||||
},
|
||||
lastAuth: lastAuth,
|
||||
}
|
||||
@@ -4903,6 +5024,12 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
|
||||
if existing.InitURL == "" {
|
||||
existing.InitURL = initURL
|
||||
}
|
||||
if !existing.AutoLoginSupported {
|
||||
existing.AutoLoginSupported = autoLoginSupported
|
||||
}
|
||||
if existing.AutoLoginURL == "" {
|
||||
existing.AutoLoginURL = autoLoginURL
|
||||
}
|
||||
existing.Scopes = mergeScopes(existing.Scopes, scopes)
|
||||
if lastAuth.After(existing.lastAuth) {
|
||||
existing.lastAuth = lastAuth
|
||||
@@ -4943,11 +5070,13 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
|
||||
)
|
||||
}
|
||||
if record.InitURL == "" {
|
||||
record.InitURL = resolveLinkedRPInitURL(
|
||||
client.ClientID,
|
||||
record.Scopes,
|
||||
client.RedirectURIs,
|
||||
)
|
||||
record.InitURL = resolveLinkedRPInitURL(client.ClientID, client.Metadata)
|
||||
}
|
||||
if !record.AutoLoginSupported {
|
||||
record.AutoLoginSupported = resolveLinkedRPAutoLoginSupported(client.ClientID, client.Metadata)
|
||||
}
|
||||
if record.AutoLoginURL == "" {
|
||||
record.AutoLoginURL = resolveLinkedRPAutoLoginURL(client.ClientID, client.Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4999,21 +5128,21 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
|
||||
client.ClientURI,
|
||||
client.RedirectURIs,
|
||||
)
|
||||
initURL := resolveLinkedRPInitURL(
|
||||
client.ClientID,
|
||||
dc.GrantedScopes,
|
||||
client.RedirectURIs,
|
||||
)
|
||||
autoLoginSupported := resolveLinkedRPAutoLoginSupported(client.ClientID, client.Metadata)
|
||||
autoLoginURL := resolveLinkedRPAutoLoginURL(client.ClientID, client.Metadata)
|
||||
initURL := resolveLinkedRPInitURL(client.ClientID, client.Metadata)
|
||||
|
||||
records[dc.ClientID] = &linkedRpRecord{
|
||||
linkedRpSummary: linkedRpSummary{
|
||||
ID: dc.ClientID,
|
||||
Name: name,
|
||||
Logo: extractHydraClientLogo(client.Metadata),
|
||||
URL: clientURL,
|
||||
InitURL: initURL,
|
||||
Status: status,
|
||||
Scopes: dc.GrantedScopes,
|
||||
ID: dc.ClientID,
|
||||
Name: name,
|
||||
Logo: extractHydraClientLogo(client.Metadata),
|
||||
URL: clientURL,
|
||||
InitURL: initURL,
|
||||
AutoLoginSupported: autoLoginSupported,
|
||||
AutoLoginURL: autoLoginURL,
|
||||
Status: status,
|
||||
Scopes: dc.GrantedScopes,
|
||||
},
|
||||
lastAuth: dc.UpdatedAt,
|
||||
}
|
||||
@@ -5087,11 +5216,9 @@ func (h *AuthHandler) ListLinkedRps(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
record.URL = clientURL
|
||||
record.InitURL = resolveLinkedRPInitURL(
|
||||
client.ClientID,
|
||||
scopes,
|
||||
client.RedirectURIs,
|
||||
)
|
||||
record.InitURL = resolveLinkedRPInitURL(client.ClientID, client.Metadata)
|
||||
record.AutoLoginSupported = resolveLinkedRPAutoLoginSupported(client.ClientID, client.Metadata)
|
||||
record.AutoLoginURL = resolveLinkedRPAutoLoginURL(client.ClientID, client.Metadata)
|
||||
} else {
|
||||
// Hydra 정보 없음 (삭제됨 등) -> Audit 정보나 ID로 대체
|
||||
if record.Name == "" {
|
||||
@@ -5239,6 +5366,7 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
|
||||
buildOidcClaimsFromTraits(identity.Traits, consentRequest.RequestedScope, tenantID),
|
||||
currentSessionID,
|
||||
)
|
||||
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
|
||||
acceptResp, err := h.Hydra.AcceptConsentRequest(c.Context(), challenge, consentRequest, sessionClaims)
|
||||
if err == nil {
|
||||
return c.JSON(acceptResp)
|
||||
@@ -5268,6 +5396,7 @@ func (h *AuthHandler) GetConsentRequest(c *fiber.Ctx) error {
|
||||
buildOidcClaimsFromTraits(identity.Traits, consentRequest.RequestedScope, tenantID),
|
||||
currentSessionID,
|
||||
)
|
||||
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
|
||||
|
||||
// [Debug] 실제 생성된 클레임 출력 (요청사항 확인용 - 자동 승인 시)
|
||||
appEnv := strings.ToLower(os.Getenv("APP_ENV"))
|
||||
@@ -5450,6 +5579,7 @@ func (h *AuthHandler) AcceptConsentRequest(c *fiber.Ctx) error {
|
||||
buildOidcClaimsFromTraits(identity.Traits, consentRequest.RequestedScope, tenantID),
|
||||
currentSessionID,
|
||||
)
|
||||
sessionClaims = h.withRPProfileClaims(c.Context(), sessionClaims, consentRequest.Client, consentRequest.Subject)
|
||||
|
||||
// [Debug] 실제 생성된 클레임 출력 (요청사항 확인용)
|
||||
appEnv := strings.ToLower(os.Getenv("APP_ENV"))
|
||||
@@ -7255,6 +7385,10 @@ func resolveLinkedRPURL(clientID string, clientURI string, redirectURIs []string
|
||||
if value := strings.TrimSpace(os.Getenv("DEVFRONT_URL")); value != "" {
|
||||
return value
|
||||
}
|
||||
case "orgfront":
|
||||
if value := strings.TrimSpace(os.Getenv("ORGFRONT_URL")); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
clientURL := strings.TrimSpace(clientURI)
|
||||
@@ -7271,10 +7405,22 @@ func resolveLinkedRPURL(clientID string, clientURI string, redirectURIs []string
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveLinkedRPInitURL(clientID string, scopes []string, redirectURIs []string) string {
|
||||
func resolveLinkedRPAutoLoginSupported(clientID string, metadata map[string]interface{}) bool {
|
||||
if readMetadataBoolValue(metadata, domain.MetadataAutoLoginSupported) {
|
||||
return true
|
||||
}
|
||||
switch strings.TrimSpace(clientID) {
|
||||
case "adminfront", "devfront", "orgfront":
|
||||
return resolveLinkedRPAutoLoginURL(clientID, nil) != ""
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func resolveLinkedRPAutoLoginURL(clientID string, metadata map[string]interface{}) string {
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
return ""
|
||||
if metadataURL := readMetadataStringValue(metadata, domain.MetadataAutoLoginURL); metadataURL != "" {
|
||||
return metadataURL
|
||||
}
|
||||
|
||||
switch clientID {
|
||||
@@ -7286,8 +7432,23 @@ func resolveLinkedRPInitURL(clientID string, scopes []string, redirectURIs []str
|
||||
if value := strings.TrimRight(strings.TrimSpace(os.Getenv("DEVFRONT_URL")), "/"); value != "" {
|
||||
return value + "/login?auto=1&returnTo=%2Fclients"
|
||||
}
|
||||
case "orgfront":
|
||||
if value := strings.TrimRight(strings.TrimSpace(os.Getenv("ORGFRONT_URL")), "/"); value != "" {
|
||||
return value + "/login?auto=1"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveLinkedRPInitURL(clientID string, metadata map[string]interface{}) string {
|
||||
if !resolveLinkedRPAutoLoginSupported(clientID, metadata) {
|
||||
return ""
|
||||
}
|
||||
return resolveLinkedRPAutoLoginURL(clientID, metadata)
|
||||
}
|
||||
|
||||
func buildHydraAuthorizationURL(clientID string, scopes []string, redirectURIs []string) string {
|
||||
hydraPublicURL := strings.TrimRight(os.Getenv("HYDRA_PUBLIC_URL"), "/")
|
||||
if hydraPublicURL == "" {
|
||||
userfrontURL := strings.TrimRight(os.Getenv("USERFRONT_URL"), "/")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
@@ -178,6 +179,103 @@ func TestAcceptConsentRequest_DynamicClaims(t *testing.T) {
|
||||
assert.Equal(t, "Architect", capturedClaims["position"])
|
||||
}
|
||||
|
||||
func TestAcceptConsentRequest_IncludesRPProfileClaims(t *testing.T) {
|
||||
var capturedClaims map[string]interface{}
|
||||
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
|
||||
"challenge": "challenge-rp-profile",
|
||||
"requested_scope": []string{"openid", "profile"},
|
||||
"subject": "user-123",
|
||||
"client": map[string]interface{}{
|
||||
"client_id": "client-app",
|
||||
"metadata": map[string]interface{}{
|
||||
"customUserSchema": []map[string]interface{}{
|
||||
{
|
||||
"key": "approvalLevel",
|
||||
"label": "승인 등급",
|
||||
"type": "text",
|
||||
"claimEnabled": true,
|
||||
},
|
||||
{
|
||||
"key": "internalMemo",
|
||||
"label": "내부 메모",
|
||||
"type": "text",
|
||||
"claimEnabled": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/oauth2/auth/requests/consent/accept" && r.URL.Query().Get("consent_challenge") == "challenge-rp-profile" {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var acceptReq map[string]interface{}
|
||||
json.Unmarshal(body, &acceptReq)
|
||||
if session, ok := acceptReq["session"].(map[string]interface{}); ok {
|
||||
capturedClaims = session["id_token"].(map[string]interface{})
|
||||
}
|
||||
|
||||
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
|
||||
"redirect_to": "http://rp/cb",
|
||||
}), nil
|
||||
}
|
||||
return httpResponse(r, http.StatusNotFound, "not found"), nil
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
h := &AuthHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: client,
|
||||
},
|
||||
KratosAdmin: new(MockKratosAdminService),
|
||||
}
|
||||
h.KratosAdmin.(*MockKratosAdminService).On("GetIdentity", mock.Anything, "user-123").Return(&service.KratosIdentity{
|
||||
ID: "user-123",
|
||||
Traits: map[string]interface{}{
|
||||
"email": "user@test.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
}, nil)
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
repo.On("Get", mock.Anything, "client-app", "user-123").Return(&domain.RPUserMetadata{
|
||||
ClientID: "client-app",
|
||||
UserID: "user-123",
|
||||
Metadata: domain.JSONMap{
|
||||
"approvalLevel": "A",
|
||||
"internalMemo": "관리자 전용",
|
||||
},
|
||||
}, nil).Once()
|
||||
h.RPUserMetadataRepo = repo
|
||||
|
||||
app := fiber.New()
|
||||
app.Post("/api/v1/auth/consent/accept", h.AcceptConsentRequest)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]interface{}{
|
||||
"consent_challenge": "challenge-rp-profile",
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/consent/accept", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, capturedClaims)
|
||||
rpProfiles, ok := capturedClaims["rp_profiles"].([]interface{})
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, rpProfiles, 1)
|
||||
profile := rpProfiles[0].(map[string]interface{})
|
||||
assert.Equal(t, "client-app", profile["client_id"])
|
||||
fields := profile["fields"].(map[string]interface{})
|
||||
assert.Equal(t, "A", fields["approvalLevel"])
|
||||
assert.NotContains(t, fields, "internalMemo")
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetConsentRequest_Skip_DynamicClaims(t *testing.T) {
|
||||
var capturedClaims map[string]interface{}
|
||||
|
||||
|
||||
@@ -55,6 +55,21 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
"handled_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
"client": map[string]interface{}{
|
||||
"client_id": "orgfront",
|
||||
"client_name": "OrgFront",
|
||||
"metadata": map[string]interface{}{
|
||||
"auto_login_supported": true,
|
||||
"auto_login_url": "http://localhost:5175/login?auto=1",
|
||||
},
|
||||
"redirect_uris": []string{
|
||||
"http://localhost:5175/auth/callback",
|
||||
},
|
||||
},
|
||||
"grant_scope": []string{"openid", "profile"},
|
||||
"handled_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
if r.URL.Path == "/admin/clients/client-audit" {
|
||||
@@ -129,16 +144,18 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
|
||||
|
||||
var res struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Scopes []string `json:"scopes"`
|
||||
InitURL string `json:"init_url"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Scopes []string `json:"scopes"`
|
||||
InitURL string `json:"init_url"`
|
||||
AutoLoginSupported bool `json:"auto_login_supported"`
|
||||
AutoLoginURL string `json:"auto_login_url"`
|
||||
} `json:"items"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&res)
|
||||
|
||||
assert.Equal(t, 3, len(res.Items))
|
||||
assert.Equal(t, 4, len(res.Items))
|
||||
|
||||
statusMap := make(map[string]string)
|
||||
for _, item := range res.Items {
|
||||
@@ -146,6 +163,7 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Equal(t, "active", statusMap["devfront"])
|
||||
assert.Equal(t, "active", statusMap["orgfront"])
|
||||
assert.Equal(t, "inactive", statusMap["client-consent"])
|
||||
assert.Equal(t, "inactive", statusMap["client-audit"])
|
||||
|
||||
@@ -164,6 +182,23 @@ func TestListLinkedRps_PriorityAndAggregation(t *testing.T) {
|
||||
assert.Equal(t, "/login", parsedInitURL.Path)
|
||||
assert.Equal(t, "1", parsedInitURL.Query().Get("auto"))
|
||||
assert.Equal(t, "/clients", parsedInitURL.Query().Get("returnTo"))
|
||||
|
||||
var orgfrontItem struct {
|
||||
InitURL string
|
||||
AutoLoginSupported bool
|
||||
AutoLoginURL string
|
||||
}
|
||||
for _, item := range res.Items {
|
||||
if item.ID == "orgfront" {
|
||||
orgfrontItem.InitURL = item.InitURL
|
||||
orgfrontItem.AutoLoginSupported = item.AutoLoginSupported
|
||||
orgfrontItem.AutoLoginURL = item.AutoLoginURL
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, orgfrontItem.AutoLoginSupported)
|
||||
assert.Equal(t, "http://localhost:5175/login?auto=1", orgfrontItem.AutoLoginURL)
|
||||
assert.Equal(t, orgfrontItem.AutoLoginURL, orgfrontItem.InitURL)
|
||||
}
|
||||
|
||||
func TestListLinkedRps_EnrichesLogoFromHydraClientWhenConsentSessionOmitsMetadata(t *testing.T) {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -24,19 +25,20 @@ import (
|
||||
)
|
||||
|
||||
type DevHandler struct {
|
||||
Hydra *service.HydraAdminService
|
||||
Redis domain.RedisRepository
|
||||
HeadlessJWKS *service.HeadlessJWKSCacheService
|
||||
SecretRepo domain.ClientSecretRepository
|
||||
AuditRepo domain.AuditRepository
|
||||
KratosAdmin service.KratosAdminService
|
||||
ConsentRepo repository.ClientConsentRepository
|
||||
Keto service.KetoService
|
||||
KetoOutbox repository.KetoOutboxRepository
|
||||
RPSvc service.RelyingPartyService
|
||||
TenantSvc service.TenantService
|
||||
DeveloperSvc *service.DeveloperService
|
||||
Auth interface {
|
||||
Hydra *service.HydraAdminService
|
||||
Redis domain.RedisRepository
|
||||
HeadlessJWKS *service.HeadlessJWKSCacheService
|
||||
SecretRepo domain.ClientSecretRepository
|
||||
AuditRepo domain.AuditRepository
|
||||
KratosAdmin service.KratosAdminService
|
||||
ConsentRepo repository.ClientConsentRepository
|
||||
Keto service.KetoService
|
||||
KetoOutbox repository.KetoOutboxRepository
|
||||
RPSvc service.RelyingPartyService
|
||||
TenantSvc service.TenantService
|
||||
DeveloperSvc *service.DeveloperService
|
||||
RPUserMetadataRepo repository.RPUserMetadataRepository
|
||||
Auth interface {
|
||||
GetEnrichedProfile(c *fiber.Ctx) (*domain.UserProfileResponse, error)
|
||||
}
|
||||
}
|
||||
@@ -1377,6 +1379,86 @@ func (h *DevHandler) publicHeadlessJWKSCacheState(clientID string) (*domain.Head
|
||||
return h.HeadlessJWKS.PublicState(clientID)
|
||||
}
|
||||
|
||||
func (h *DevHandler) GetRPUserMetadata(c *fiber.Ctx) error {
|
||||
clientID := strings.TrimSpace(c.Params("id"))
|
||||
userID := strings.TrimSpace(c.Params("userId"))
|
||||
if clientID == "" || userID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "client id and user id are required")
|
||||
}
|
||||
if h.RPUserMetadataRepo == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "rp user metadata repository unavailable")
|
||||
}
|
||||
|
||||
profile := h.getCurrentProfile(c)
|
||||
if profile == nil {
|
||||
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: authentication required")
|
||||
}
|
||||
|
||||
summary, err := h.loadClientSummary(c.Context(), clientID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusNotFound, "client not found")
|
||||
}
|
||||
if !h.canViewClientByPermit(c, profile, summary) {
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient permission to view client metadata")
|
||||
}
|
||||
|
||||
metadata, err := h.RPUserMetadataRepo.Get(c.Context(), clientID, userID)
|
||||
if err != nil {
|
||||
return c.JSON(fiber.Map{
|
||||
"clientId": clientID,
|
||||
"userId": userID,
|
||||
"metadata": domain.JSONMap{},
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(metadata)
|
||||
}
|
||||
|
||||
func (h *DevHandler) UpsertRPUserMetadata(c *fiber.Ctx) error {
|
||||
clientID := strings.TrimSpace(c.Params("id"))
|
||||
userID := strings.TrimSpace(c.Params("userId"))
|
||||
if clientID == "" || userID == "" {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "client id and user id are required")
|
||||
}
|
||||
if h.RPUserMetadataRepo == nil {
|
||||
return errorJSON(c, fiber.StatusServiceUnavailable, "rp user metadata repository unavailable")
|
||||
}
|
||||
|
||||
profile := h.getCurrentProfile(c)
|
||||
if profile == nil {
|
||||
return errorJSON(c, fiber.StatusUnauthorized, "unauthorized: authentication required")
|
||||
}
|
||||
|
||||
summary, err := h.loadClientSummary(c.Context(), clientID)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusNotFound, "client not found")
|
||||
}
|
||||
if !h.canManageClientRelations(c, profile, summary) {
|
||||
return errorJSON(c, fiber.StatusForbidden, "forbidden: insufficient permission to update client metadata")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
}
|
||||
if req.Metadata == nil {
|
||||
req.Metadata = map[string]any{}
|
||||
}
|
||||
|
||||
row := &domain.RPUserMetadata{
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Metadata: domain.JSONMap(req.Metadata),
|
||||
}
|
||||
if err := h.RPUserMetadataRepo.Upsert(c.Context(), row); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(row)
|
||||
}
|
||||
|
||||
func (h *DevHandler) syncHeadlessJWKSCache(ctx context.Context, client domain.HydraClient, reason string) {
|
||||
if h.HeadlessJWKS == nil {
|
||||
h.HeadlessJWKS = service.NewHeadlessJWKSCacheService(h.Redis, nil)
|
||||
@@ -1574,6 +1656,10 @@ func (h *DevHandler) CreateClient(c *fiber.Ctx) error {
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
metadata, err = normalizeClientAutoLoginMetadata(metadata)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
tokenAuthMethod := strings.TrimSpace(valueOr(req.TokenEndpointAuthMethod, ""))
|
||||
if tokenAuthMethod == "" {
|
||||
@@ -1766,6 +1852,10 @@ func (h *DevHandler) UpdateClient(c *fiber.Ctx) error {
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
metadata, err = normalizeClientAutoLoginMetadata(metadata)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
resolvedClientType := currentSummary.Type
|
||||
if clientType != "" {
|
||||
resolvedClientType = clientType
|
||||
@@ -2575,6 +2665,30 @@ func readMetadataBoolValue(metadata map[string]interface{}, key string) bool {
|
||||
return value
|
||||
}
|
||||
|
||||
func normalizeClientAutoLoginMetadata(metadata map[string]interface{}) (map[string]interface{}, error) {
|
||||
if metadata == nil {
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
supported := readMetadataBoolValue(metadata, domain.MetadataAutoLoginSupported)
|
||||
rawURL := strings.TrimSpace(readMetadataStringValue(metadata, domain.MetadataAutoLoginURL))
|
||||
metadata[domain.MetadataAutoLoginSupported] = supported
|
||||
if !supported {
|
||||
delete(metadata, domain.MetadataAutoLoginURL)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
if rawURL == "" {
|
||||
return nil, errors.New("auto_login_url is required when auto_login_supported is true")
|
||||
}
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" || (parsed.Scheme != "https" && parsed.Scheme != "http") {
|
||||
return nil, errors.New("auto_login_url must be an http or https URL")
|
||||
}
|
||||
metadata[domain.MetadataAutoLoginURL] = rawURL
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func normalizeHeadlessClientConfig(
|
||||
clientType string,
|
||||
tokenAuthMethod string,
|
||||
|
||||
94
backend/internal/handler/dev_handler_rp_metadata_test.go
Normal file
94
backend/internal/handler/dev_handler_rp_metadata_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"baron-sso-backend/internal/service"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type devMockRPUserMetadataRepo struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *devMockRPUserMetadataRepo) Get(ctx context.Context, clientID, userID string) (*domain.RPUserMetadata, error) {
|
||||
args := m.Called(ctx, clientID, userID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.RPUserMetadata), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *devMockRPUserMetadataRepo) Upsert(ctx context.Context, metadata *domain.RPUserMetadata) error {
|
||||
args := m.Called(ctx, metadata)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestDevHandler_RPUserMetadataRoundTrip(t *testing.T) {
|
||||
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path == "/clients/client-1" {
|
||||
return httpJSONAny(r, http.StatusOK, map[string]interface{}{
|
||||
"client_id": "client-1",
|
||||
"client_name": "Client One",
|
||||
"metadata": map[string]interface{}{
|
||||
"tenant_id": "tenant-1",
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
return httpJSONAny(r, http.StatusNotFound, nil), nil
|
||||
})
|
||||
|
||||
repo := new(devMockRPUserMetadataRepo)
|
||||
repo.On("Upsert", mock.Anything, mock.MatchedBy(func(row *domain.RPUserMetadata) bool {
|
||||
return row.ClientID == "client-1" &&
|
||||
row.UserID == "user-1" &&
|
||||
row.Metadata["approvalLevel"] == "A"
|
||||
})).Return(nil).Once()
|
||||
repo.On("Get", mock.Anything, "client-1", "user-1").Return(&domain.RPUserMetadata{
|
||||
ClientID: "client-1",
|
||||
UserID: "user-1",
|
||||
Metadata: domain.JSONMap{"approvalLevel": "A"},
|
||||
}, nil).Once()
|
||||
|
||||
h := &DevHandler{
|
||||
Hydra: &service.HydraAdminService{
|
||||
AdminURL: "http://hydra.test",
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
},
|
||||
RPUserMetadataRepo: repo,
|
||||
}
|
||||
app := fiber.New()
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{ID: "admin", Role: domain.RoleSuperAdmin})
|
||||
return c.Next()
|
||||
})
|
||||
app.Put("/api/v1/dev/clients/:id/users/:userId/metadata", h.UpsertRPUserMetadata)
|
||||
app.Get("/api/v1/dev/clients/:id/users/:userId/metadata", h.GetRPUserMetadata)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"metadata": map[string]any{"approvalLevel": "A"},
|
||||
})
|
||||
putReq := httptest.NewRequest(http.MethodPut, "/api/v1/dev/clients/client-1/users/user-1/metadata", bytes.NewReader(body))
|
||||
putReq.Header.Set("Content-Type", "application/json")
|
||||
putResp, _ := app.Test(putReq, -1)
|
||||
assert.Equal(t, http.StatusOK, putResp.StatusCode)
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/api/v1/dev/clients/client-1/users/user-1/metadata", nil)
|
||||
getResp, _ := app.Test(getReq, -1)
|
||||
assert.Equal(t, http.StatusOK, getResp.StatusCode)
|
||||
|
||||
var got map[string]any
|
||||
assert.NoError(t, json.NewDecoder(getResp.Body).Decode(&got))
|
||||
assert.Equal(t, "client-1", got["clientId"])
|
||||
assert.Equal(t, "user-1", got["userId"])
|
||||
assert.Equal(t, "A", got["metadata"].(map[string]any)["approvalLevel"])
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
@@ -1300,6 +1300,36 @@ func TestCreateClient_DefaultsSkipConsentToTrue(t *testing.T) {
|
||||
assert.True(t, *captured.SkipConsent)
|
||||
}
|
||||
|
||||
func TestNormalizeClientAutoLoginMetadata(t *testing.T) {
|
||||
t.Run("keeps supported flag and URL", func(t *testing.T) {
|
||||
metadata, err := normalizeClientAutoLoginMetadata(map[string]interface{}{
|
||||
"auto_login_supported": true,
|
||||
"auto_login_url": "https://rp.example.com/login?auto=1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, metadata["auto_login_supported"])
|
||||
assert.Equal(t, "https://rp.example.com/login?auto=1", metadata["auto_login_url"])
|
||||
})
|
||||
|
||||
t.Run("requires URL when supported", func(t *testing.T) {
|
||||
_, err := normalizeClientAutoLoginMetadata(map[string]interface{}{
|
||||
"auto_login_supported": true,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("removes URL when unsupported", func(t *testing.T) {
|
||||
metadata, err := normalizeClientAutoLoginMetadata(map[string]interface{}{
|
||||
"auto_login_supported": false,
|
||||
"auto_login_url": "https://rp.example.com/login?auto=1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, metadata["auto_login_supported"])
|
||||
_, exists := metadata["auto_login_url"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateClient_AllowsExplicitSkipConsentFalse(t *testing.T) {
|
||||
var captured domain.HydraClient
|
||||
|
||||
|
||||
244
backend/internal/handler/hanmac_email_policy.go
Normal file
244
backend/internal/handler/hanmac_email_policy.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"baron-sso-backend/internal/domain"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const hanmacFamilyTenantSlug = "hanmac-family"
|
||||
|
||||
type hanmacEmailScope struct {
|
||||
TenantIDs map[string]bool
|
||||
Slugs map[string]bool
|
||||
IDList []string
|
||||
SlugList []string
|
||||
}
|
||||
|
||||
type hanmacEmailEvaluation struct {
|
||||
Email string
|
||||
OriginalEmail string
|
||||
SuggestedEmail string
|
||||
Status string
|
||||
Warnings []string
|
||||
Message string
|
||||
Blocking bool
|
||||
LocalPart string
|
||||
}
|
||||
|
||||
func (h *UserHandler) evaluateHanmacImportEmail(ctx context.Context, item bulkUserItem, scope *hanmacEmailScope, usedLocalParts map[string]bool) hanmacEmailEvaluation {
|
||||
originalEmail := strings.TrimSpace(item.Email)
|
||||
name := strings.TrimSpace(item.Name)
|
||||
evaluation := hanmacEmailEvaluation{
|
||||
Email: originalEmail,
|
||||
OriginalEmail: originalEmail,
|
||||
Status: "valid",
|
||||
}
|
||||
|
||||
localPart, domainPart, err := domain.SplitEmailDomain(originalEmail)
|
||||
if err != nil {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "invalid email format"
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
|
||||
base, needsReview, _ := domain.BuildKoreanNameEmailBase(name)
|
||||
if needsReview {
|
||||
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
|
||||
evaluation.Status = "needsReview"
|
||||
}
|
||||
|
||||
if localPart == "" {
|
||||
if base == "" {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "이름으로 이메일 ID를 제안할 수 없습니다."
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
nextLocalPart := nextAvailableHanmacLocalPart(base, usedLocalParts)
|
||||
evaluation.Email = nextLocalPart + "@" + domainPart
|
||||
evaluation.SuggestedEmail = evaluation.Email
|
||||
evaluation.LocalPart = nextLocalPart
|
||||
evaluation.Status = "suggested"
|
||||
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "suggested")
|
||||
return evaluation
|
||||
}
|
||||
|
||||
evaluation.LocalPart = localPart
|
||||
if usedLocalParts[localPart] {
|
||||
evaluation.Status = "blockingError"
|
||||
evaluation.Message = "한맥가족 내에서 이미 사용 중인 이메일 ID입니다."
|
||||
evaluation.Blocking = true
|
||||
return evaluation
|
||||
}
|
||||
|
||||
if base != "" && !domain.MatchesSuggestedNameRule(localPart, base) {
|
||||
evaluation.Status = "ruleMismatch"
|
||||
evaluation.Warnings = appendUniqueString(evaluation.Warnings, "ruleMismatch")
|
||||
}
|
||||
|
||||
if evaluation.Status == "needsReview" && len(evaluation.Warnings) == 0 {
|
||||
evaluation.Warnings = append(evaluation.Warnings, "needsReview")
|
||||
}
|
||||
_ = scope
|
||||
return evaluation
|
||||
}
|
||||
|
||||
func (h *UserHandler) ensureHanmacCreateEmailAllowed(ctx context.Context, email string, tenantSlug string, tenantID string) error {
|
||||
scope, err := h.resolveHanmacEmailScope(ctx)
|
||||
if err != nil || scope == nil || !scope.ContainsTenant(tenantID, tenantSlug) {
|
||||
return nil
|
||||
}
|
||||
|
||||
localPart, err := domain.ExtractNormalizedEmailLocalPart(email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usedLocalParts, err := h.loadHanmacLocalParts(ctx, scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usedLocalParts[localPart] {
|
||||
return fmt.Errorf("한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *UserHandler) resolveHanmacEmailScope(ctx context.Context) (*hanmacEmailScope, error) {
|
||||
if h.TenantService == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tenants, _, err := h.TenantService.ListTenants(ctx, 10000, 0, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rootID string
|
||||
for _, tenant := range tenants {
|
||||
if strings.EqualFold(strings.TrimSpace(tenant.Slug), hanmacFamilyTenantSlug) {
|
||||
rootID = tenant.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
if rootID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tenantByID := make(map[string]domain.Tenant, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
tenantByID[tenant.ID] = tenant
|
||||
}
|
||||
|
||||
scope := &hanmacEmailScope{
|
||||
TenantIDs: make(map[string]bool),
|
||||
Slugs: make(map[string]bool),
|
||||
}
|
||||
for _, tenant := range tenants {
|
||||
if isTenantDescendantOf(tenant, rootID, tenantByID) {
|
||||
scope.TenantIDs[tenant.ID] = true
|
||||
scope.Slugs[strings.ToLower(strings.TrimSpace(tenant.Slug))] = true
|
||||
scope.IDList = append(scope.IDList, tenant.ID)
|
||||
scope.SlugList = append(scope.SlugList, tenant.Slug)
|
||||
}
|
||||
}
|
||||
return scope, nil
|
||||
}
|
||||
|
||||
func (h *UserHandler) loadHanmacLocalParts(ctx context.Context, scope *hanmacEmailScope) (map[string]bool, error) {
|
||||
used := make(map[string]bool)
|
||||
if h.UserRepo == nil || scope == nil {
|
||||
return used, nil
|
||||
}
|
||||
|
||||
if len(scope.IDList) > 0 {
|
||||
users, err := h.UserRepo.FindByTenantIDs(ctx, scope.IDList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addUserEmailLocalParts(used, users)
|
||||
}
|
||||
|
||||
if len(scope.SlugList) > 0 {
|
||||
users, err := h.UserRepo.FindByCompanyCodes(ctx, scope.SlugList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addUserEmailLocalParts(used, users)
|
||||
}
|
||||
|
||||
return used, nil
|
||||
}
|
||||
|
||||
func (s *hanmacEmailScope) ContainsTenant(tenantID string, slug string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if tenantID != "" && s.TenantIDs[tenantID] {
|
||||
return true
|
||||
}
|
||||
return s.Slugs[strings.ToLower(strings.TrimSpace(slug))]
|
||||
}
|
||||
|
||||
func isTenantDescendantOf(tenant domain.Tenant, rootID string, tenantByID map[string]domain.Tenant) bool {
|
||||
if tenant.ID == rootID {
|
||||
return true
|
||||
}
|
||||
visited := make(map[string]bool)
|
||||
parentID := ""
|
||||
if tenant.ParentID != nil {
|
||||
parentID = *tenant.ParentID
|
||||
}
|
||||
for parentID != "" {
|
||||
if parentID == rootID {
|
||||
return true
|
||||
}
|
||||
if visited[parentID] {
|
||||
return false
|
||||
}
|
||||
visited[parentID] = true
|
||||
parent, ok := tenantByID[parentID]
|
||||
if !ok || parent.ParentID == nil {
|
||||
return false
|
||||
}
|
||||
parentID = *parent.ParentID
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addUserEmailLocalParts(target map[string]bool, users []domain.User) {
|
||||
for _, user := range users {
|
||||
localPart, err := domain.ExtractNormalizedEmailLocalPart(user.Email)
|
||||
if err == nil && localPart != "" {
|
||||
target[localPart] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func nextAvailableHanmacLocalPart(base string, usedLocalParts map[string]bool) string {
|
||||
base = strings.ToLower(strings.TrimSpace(base))
|
||||
if base == "" {
|
||||
return ""
|
||||
}
|
||||
if !usedLocalParts[base] {
|
||||
return base
|
||||
}
|
||||
for index := 1; ; index++ {
|
||||
candidate := fmt.Sprintf("%s%d", base, index)
|
||||
if !usedLocalParts[candidate] {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendUniqueString(values []string, value string) []string {
|
||||
for _, existing := range values {
|
||||
if existing == value {
|
||||
return values
|
||||
}
|
||||
}
|
||||
return append(values, value)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"baron-sso-backend/internal/service"
|
||||
"baron-sso-backend/internal/utils"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -68,14 +69,22 @@ type tenantImportResult struct {
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
type tenantDomainConflict struct {
|
||||
Domain string `json:"domain"`
|
||||
TenantID string `json:"tenantId"`
|
||||
TenantName string `json:"tenantName"`
|
||||
TenantSlug string `json:"tenantSlug"`
|
||||
}
|
||||
|
||||
type tenantCSVRecord struct {
|
||||
TenantID string
|
||||
Name string
|
||||
Type string
|
||||
ParentTenantID *string
|
||||
Slug string
|
||||
Memo string
|
||||
Domains []string
|
||||
TenantID string
|
||||
Name string
|
||||
Type string
|
||||
ParentTenantID *string
|
||||
ParentTenantSlug string
|
||||
Slug string
|
||||
Memo string
|
||||
Domains []string
|
||||
}
|
||||
|
||||
func (h *TenantHandler) RegisterTenantPublic(c *fiber.Ctx) error {
|
||||
@@ -258,13 +267,24 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := csv.NewWriter(&buf)
|
||||
if err := writer.Write([]string{"tenant_id", "name", "type", "parent_tenant_id", "slug", "memo", "email_domain"}); err != nil {
|
||||
includeIDs := includeCSVIds(c)
|
||||
if includeIDs {
|
||||
if err := writer.Write([]string{"tenant_id", "name", "type", "parent_tenant_id", "parent_tenant_slug", "slug", "memo", "email_domain"}); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
} else if err := writer.Write([]string{"name", "type", "parent_tenant_slug", "slug", "memo", "email_domain"}); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
slugByID := make(map[string]string, len(tenants))
|
||||
for _, tenant := range tenants {
|
||||
slugByID[tenant.ID] = tenant.Slug
|
||||
}
|
||||
for _, tenant := range tenants {
|
||||
parentID := ""
|
||||
parentSlug := ""
|
||||
if tenant.ParentID != nil {
|
||||
parentID = *tenant.ParentID
|
||||
parentSlug = slugByID[parentID]
|
||||
}
|
||||
domains := make([]string, 0, len(tenant.Domains))
|
||||
for _, domainName := range tenant.Domains {
|
||||
@@ -273,15 +293,27 @@ func (h *TenantHandler) ExportTenantsCSV(c *fiber.Ctx) error {
|
||||
domains = append(domains, domainName)
|
||||
}
|
||||
}
|
||||
if err := writer.Write([]string{
|
||||
tenant.ID,
|
||||
row := []string{
|
||||
tenant.Name,
|
||||
tenant.Type,
|
||||
parentID,
|
||||
parentSlug,
|
||||
tenant.Slug,
|
||||
tenant.Description,
|
||||
strings.Join(domains, ";"),
|
||||
}); err != nil {
|
||||
}
|
||||
if includeIDs {
|
||||
row = []string{
|
||||
tenant.ID,
|
||||
tenant.Name,
|
||||
tenant.Type,
|
||||
parentID,
|
||||
parentSlug,
|
||||
tenant.Slug,
|
||||
tenant.Description,
|
||||
strings.Join(domains, ";"),
|
||||
}
|
||||
}
|
||||
if err := writer.Write(row); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
}
|
||||
@@ -305,33 +337,60 @@ func (h *TenantHandler) ImportTenantsCSV(c *fiber.Ctx) error {
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
records = orderTenantCSVRecordsByParentSlug(records)
|
||||
|
||||
creatorID := ""
|
||||
if profile, ok := c.Locals("user_profile").(*domain.UserProfileResponse); ok && profile != nil {
|
||||
creatorID = profile.ID
|
||||
}
|
||||
|
||||
tenantIDBySlug := make(map[string]string)
|
||||
if h.Service != nil {
|
||||
if tenants, _, err := h.Service.ListTenants(c.Context(), 10000, 0, ""); err == nil {
|
||||
for _, tenant := range tenants {
|
||||
tenantIDBySlug[strings.ToLower(tenant.Slug)] = tenant.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := tenantImportResult{Errors: make([]string, 0)}
|
||||
for i, record := range records {
|
||||
rowNumber := i + 2
|
||||
if record.ParentTenantID == nil && record.ParentTenantSlug != "" {
|
||||
parentID := tenantIDBySlug[strings.ToLower(record.ParentTenantSlug)]
|
||||
if parentID == "" {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("row %d: parent tenant slug not found: %s", rowNumber, record.ParentTenantSlug))
|
||||
continue
|
||||
}
|
||||
record.ParentTenantID = &parentID
|
||||
}
|
||||
if record.TenantID != "" || (h.DB != nil && record.Slug != "") {
|
||||
updated, err := h.upsertTenantCSVRecord(c, record)
|
||||
tenant, updated, err := h.upsertTenantCSVRecord(c, record)
|
||||
if err != nil {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("row %d: %s", rowNumber, err.Error()))
|
||||
continue
|
||||
}
|
||||
if updated {
|
||||
tenantIDBySlug[strings.ToLower(record.Slug)] = tenant.ID
|
||||
result.Updated++
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.createTenantCSVRecord(c, record, creatorID); err != nil {
|
||||
tenant, err := h.createTenantCSVRecord(c, record, creatorID)
|
||||
if err != nil {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("row %d: %s", rowNumber, err.Error()))
|
||||
continue
|
||||
}
|
||||
if tenant == nil {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("row %d: tenant creation returned empty result", rowNumber))
|
||||
continue
|
||||
}
|
||||
tenantIDBySlug[strings.ToLower(record.Slug)] = tenant.ID
|
||||
result.Created++
|
||||
}
|
||||
|
||||
@@ -414,13 +473,14 @@ func parseTenantCSVRecords(r io.Reader) ([]tenantCSVRecord, error) {
|
||||
}
|
||||
|
||||
records = append(records, tenantCSVRecord{
|
||||
TenantID: tenantCSVValue(row, header, "tenant_id"),
|
||||
Name: name,
|
||||
Type: tenantType,
|
||||
ParentTenantID: parentID,
|
||||
Slug: slug,
|
||||
Memo: tenantCSVValue(row, header, "memo"),
|
||||
Domains: splitTenantCSVDomains(tenantCSVValue(row, header, "email_domain")),
|
||||
TenantID: tenantCSVValue(row, header, "tenant_id"),
|
||||
Name: name,
|
||||
Type: tenantType,
|
||||
ParentTenantID: parentID,
|
||||
ParentTenantSlug: tenantCSVValue(row, header, "parent_tenant_slug"),
|
||||
Slug: slug,
|
||||
Memo: tenantCSVValue(row, header, "memo"),
|
||||
Domains: splitTenantCSVDomains(tenantCSVValue(row, header, "email_domain")),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -430,23 +490,25 @@ func parseTenantCSVRecords(r io.Reader) ([]tenantCSVRecord, error) {
|
||||
func tenantCSVHeaderIndex(header []string) map[string]int {
|
||||
index := make(map[string]int, len(header))
|
||||
aliases := map[string]string{
|
||||
"id": "tenant_id",
|
||||
"tenantid": "tenant_id",
|
||||
"tenant_id": "tenant_id",
|
||||
"name": "name",
|
||||
"type": "type",
|
||||
"parentid": "parent_tenant_id",
|
||||
"parent_id": "parent_tenant_id",
|
||||
"parenttenantid": "parent_tenant_id",
|
||||
"parent_tenant_id": "parent_tenant_id",
|
||||
"slug": "slug",
|
||||
"memo": "memo",
|
||||
"description": "memo",
|
||||
"email-domain": "email_domain",
|
||||
"emaildomain": "email_domain",
|
||||
"email_domain": "email_domain",
|
||||
"domain": "email_domain",
|
||||
"domains": "email_domain",
|
||||
"id": "tenant_id",
|
||||
"tenantid": "tenant_id",
|
||||
"tenant_id": "tenant_id",
|
||||
"name": "name",
|
||||
"type": "type",
|
||||
"parentid": "parent_tenant_id",
|
||||
"parent_id": "parent_tenant_id",
|
||||
"parenttenantid": "parent_tenant_id",
|
||||
"parent_tenant_id": "parent_tenant_id",
|
||||
"parenttenantslug": "parent_tenant_slug",
|
||||
"parent_tenant_slug": "parent_tenant_slug",
|
||||
"slug": "slug",
|
||||
"memo": "memo",
|
||||
"description": "memo",
|
||||
"email-domain": "email_domain",
|
||||
"emaildomain": "email_domain",
|
||||
"email_domain": "email_domain",
|
||||
"domain": "email_domain",
|
||||
"domains": "email_domain",
|
||||
}
|
||||
for i, column := range header {
|
||||
key := strings.ToLower(strings.TrimSpace(column))
|
||||
@@ -475,6 +537,40 @@ func tenantCSVRowIsEmpty(row []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func includeCSVIds(c *fiber.Ctx) bool {
|
||||
value := strings.ToLower(strings.TrimSpace(c.Query("includeIds")))
|
||||
return value == "true" || value == "1" || value == "yes"
|
||||
}
|
||||
|
||||
func orderTenantCSVRecordsByParentSlug(records []tenantCSVRecord) []tenantCSVRecord {
|
||||
bySlug := make(map[string]tenantCSVRecord, len(records))
|
||||
for _, record := range records {
|
||||
bySlug[strings.ToLower(record.Slug)] = record
|
||||
}
|
||||
|
||||
ordered := make([]tenantCSVRecord, 0, len(records))
|
||||
visited := make(map[string]bool, len(records))
|
||||
var visit func(record tenantCSVRecord)
|
||||
visit = func(record tenantCSVRecord) {
|
||||
key := strings.ToLower(record.Slug)
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
if record.ParentTenantSlug != "" {
|
||||
if parent, ok := bySlug[strings.ToLower(record.ParentTenantSlug)]; ok {
|
||||
visit(parent)
|
||||
}
|
||||
}
|
||||
visited[key] = true
|
||||
ordered = append(ordered, record)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
visit(record)
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
|
||||
func splitTenantCSVDomains(value string) []string {
|
||||
value = strings.ReplaceAll(value, "\n", ";")
|
||||
value = strings.ReplaceAll(value, ",", ";")
|
||||
@@ -492,12 +588,203 @@ func splitTenantCSVDomains(value string) []string {
|
||||
return domains
|
||||
}
|
||||
|
||||
func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord) (bool, error) {
|
||||
func normalizeTenantDomainInputs(values []string) []string {
|
||||
seen := make(map[string]bool, len(values))
|
||||
domains := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
for _, part := range strings.FieldsFunc(value, func(r rune) bool {
|
||||
return r == ',' || r == ';' || r == '\n' || r == '\r' || r == '\t' || r == ' '
|
||||
}) {
|
||||
domainName := strings.ToLower(strings.TrimSpace(part))
|
||||
if domainName == "" || seen[domainName] {
|
||||
continue
|
||||
}
|
||||
seen[domainName] = true
|
||||
domains = append(domains, domainName)
|
||||
}
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
func normalizeTenantConfig(config map[string]any) (domain.JSONMap, error) {
|
||||
normalized := make(domain.JSONMap, len(config))
|
||||
for key, value := range config {
|
||||
if key == "userSchema" {
|
||||
fields, err := normalizeTenantUserSchema(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized[key] = fields
|
||||
continue
|
||||
}
|
||||
normalized[key] = value
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func normalizeTenantUserSchema(value any) ([]any, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rawFields, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("userSchema must be an array")
|
||||
}
|
||||
|
||||
fields := make([]any, 0, len(rawFields))
|
||||
for _, raw := range rawFields {
|
||||
field, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("userSchema fields must be objects")
|
||||
}
|
||||
|
||||
normalized := make(map[string]any, len(field))
|
||||
for key, value := range field {
|
||||
if key == "maxLength" {
|
||||
continue
|
||||
}
|
||||
normalized[key] = value
|
||||
}
|
||||
|
||||
isLoginID, _ := normalized["isLoginId"].(bool)
|
||||
if isLoginID {
|
||||
fieldType, _ := normalized["type"].(string)
|
||||
if fieldType != "" && fieldType != "text" {
|
||||
return nil, fmt.Errorf("login ID fields must be text")
|
||||
}
|
||||
normalized["type"] = "text"
|
||||
normalized["indexed"] = true
|
||||
} else if indexed, ok := normalized["indexed"].(bool); !ok || !indexed {
|
||||
normalized["indexed"] = false
|
||||
}
|
||||
|
||||
fields = append(fields, normalized)
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func normalizeTenantDomainForceSet(values []string) map[string]bool {
|
||||
domains := normalizeTenantDomainInputs(values)
|
||||
force := make(map[string]bool, len(domains))
|
||||
for _, domainName := range domains {
|
||||
force[domainName] = true
|
||||
}
|
||||
return force
|
||||
}
|
||||
|
||||
func tenantDomainConflictJSON(c *fiber.Ctx, conflicts []tenantDomainConflict) error {
|
||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{
|
||||
"code": "tenant_domain_conflict",
|
||||
"error": "domain is already assigned to another tenant",
|
||||
"conflicts": conflicts,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *TenantHandler) findTenantDomainConflicts(ctx context.Context, tenantID string, domains []string, forceDomains []string) ([]tenantDomainConflict, error) {
|
||||
if h.DB == nil || h.DB.Config == nil || len(domains) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
force := normalizeTenantDomainForceSet(forceDomains)
|
||||
var rows []domain.TenantDomain
|
||||
query := h.DB.WithContext(ctx).Where("domain IN ?", domains)
|
||||
if tenantID != "" {
|
||||
query = query.Where("tenant_id <> ?", tenantID)
|
||||
}
|
||||
if err := query.Find(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conflicts := make([]tenantDomainConflict, 0, len(rows))
|
||||
tenantIDs := make([]string, 0, len(rows))
|
||||
seenTenantIDs := make(map[string]bool, len(rows))
|
||||
for _, row := range rows {
|
||||
if force[row.Domain] {
|
||||
continue
|
||||
}
|
||||
if !seenTenantIDs[row.TenantID] {
|
||||
seenTenantIDs[row.TenantID] = true
|
||||
tenantIDs = append(tenantIDs, row.TenantID)
|
||||
}
|
||||
}
|
||||
|
||||
tenantsByID := make(map[string]domain.Tenant, len(tenantIDs))
|
||||
if len(tenantIDs) > 0 {
|
||||
var tenants []domain.Tenant
|
||||
if err := h.DB.WithContext(ctx).Where("id IN ?", tenantIDs).Find(&tenants).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, tenant := range tenants {
|
||||
tenantsByID[tenant.ID] = tenant
|
||||
}
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
if force[row.Domain] {
|
||||
continue
|
||||
}
|
||||
conflict := tenantDomainConflict{
|
||||
Domain: row.Domain,
|
||||
TenantID: row.TenantID,
|
||||
}
|
||||
if tenant, ok := tenantsByID[row.TenantID]; ok {
|
||||
conflict.TenantName = tenant.Name
|
||||
conflict.TenantSlug = tenant.Slug
|
||||
}
|
||||
conflicts = append(conflicts, conflict)
|
||||
}
|
||||
|
||||
return conflicts, nil
|
||||
}
|
||||
|
||||
func (h *TenantHandler) replaceTenantDomains(ctx context.Context, tenantID string, domains []string, forceDomains []string) error {
|
||||
if h.DB == nil {
|
||||
return errors.New("database not available")
|
||||
}
|
||||
if h.DB.Config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
deleteQuery := h.DB.WithContext(ctx).Where("tenant_id = ?", tenantID)
|
||||
if len(domains) > 0 {
|
||||
deleteQuery = deleteQuery.Where("domain NOT IN ?", domains)
|
||||
}
|
||||
if err := deleteQuery.Delete(&domain.TenantDomain{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to clear old domains: %w", err)
|
||||
}
|
||||
|
||||
for _, domainName := range domains {
|
||||
var existing domain.TenantDomain
|
||||
err := h.DB.WithContext(ctx).Unscoped().
|
||||
Where("tenant_id = ? AND domain = ?", tenantID, domainName).
|
||||
First(&existing).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if err := repository.NewTenantRepository(h.DB).AddDomain(ctx, tenantID, domainName, true); err != nil {
|
||||
return fmt.Errorf("failed to add domain: %s", domainName)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.DB.WithContext(ctx).Unscoped().Model(&existing).Updates(map[string]any{
|
||||
"verified": true,
|
||||
"deleted_at": nil,
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("failed to add domain: %s", domainName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord) (*domain.Tenant, bool, error) {
|
||||
if h.DB == nil {
|
||||
if record.TenantID != "" {
|
||||
return false, errors.New("database not available for tenant update")
|
||||
return nil, false, errors.New("database not available for tenant update")
|
||||
}
|
||||
return false, nil
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
var tenant domain.Tenant
|
||||
@@ -510,10 +797,10 @@ func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
tenant.Name = record.Name
|
||||
@@ -526,29 +813,29 @@ func (h *TenantHandler) upsertTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
|
||||
}
|
||||
|
||||
if err := h.DB.Save(&tenant).Error; err != nil {
|
||||
return false, err
|
||||
return nil, false, err
|
||||
}
|
||||
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
|
||||
return false, err
|
||||
return nil, false, err
|
||||
}
|
||||
repo := repository.NewTenantRepository(h.DB)
|
||||
for _, domainName := range record.Domains {
|
||||
if err := repo.AddDomain(c.Context(), tenant.ID, domainName, true); err != nil {
|
||||
return false, err
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return &tenant, true, nil
|
||||
}
|
||||
|
||||
func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord, creatorID string) error {
|
||||
func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVRecord, creatorID string) (*domain.Tenant, error) {
|
||||
if h.DB != nil && record.TenantID != "" {
|
||||
var exists int64
|
||||
if err := h.DB.Unscoped().Model(&domain.Tenant{}).Where("slug = ?", record.Slug).Count(&exists).Error; err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if exists > 0 {
|
||||
return errors.New("tenant slug already exists")
|
||||
return nil, errors.New("tenant slug already exists")
|
||||
}
|
||||
|
||||
tenant := domain.Tenant{
|
||||
@@ -561,7 +848,7 @@ func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
|
||||
Status: domain.TenantStatusActive,
|
||||
}
|
||||
if err := h.DB.Create(&tenant).Error; err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if h.KetoOutbox != nil {
|
||||
_ = h.KetoOutbox.Create(c.Context(), &domain.KetoOutbox{
|
||||
@@ -595,14 +882,14 @@ func (h *TenantHandler) createTenantCSVRecord(c *fiber.Ctx, record tenantCSVReco
|
||||
repo := repository.NewTenantRepository(h.DB)
|
||||
for _, domainName := range record.Domains {
|
||||
if err := repo.AddDomain(c.Context(), tenant.ID, domainName, true); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return &tenant, nil
|
||||
}
|
||||
|
||||
_, err := h.Service.RegisterTenant(c.Context(), record.Name, record.Slug, record.Type, record.Memo, record.Domains, record.ParentTenantID, creatorID)
|
||||
return err
|
||||
tenant, err := h.Service.RegisterTenant(c.Context(), record.Name, record.Slug, record.Type, record.Memo, record.Domains, record.ParentTenantID, creatorID)
|
||||
return tenant, err
|
||||
}
|
||||
|
||||
func (h *TenantHandler) GetTenant(c *fiber.Ctx) error {
|
||||
@@ -646,14 +933,15 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
Domains []string `json:"domains"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Config map[string]any `json:"config"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
Domains []string `json:"domains"`
|
||||
ForceDomains []string `json:"forceDomainConflicts"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
@@ -701,7 +989,16 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
|
||||
creatorID = profile.ID
|
||||
}
|
||||
|
||||
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, req.Domains, parentID, creatorID)
|
||||
normalizedDomains := normalizeTenantDomainInputs(req.Domains)
|
||||
conflicts, err := h.findTenantDomainConflicts(c.Context(), "", normalizedDomains, req.ForceDomains)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if len(conflicts) > 0 {
|
||||
return tenantDomainConflictJSON(c, conflicts)
|
||||
}
|
||||
|
||||
tenant, err := h.Service.RegisterTenant(c.Context(), name, slug, tenantType, req.Description, nil, parentID, creatorID)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "already exists") {
|
||||
return errorJSON(c, fiber.StatusConflict, err.Error())
|
||||
@@ -713,10 +1010,20 @@ func (h *TenantHandler) CreateTenant(c *fiber.Ctx) error {
|
||||
summary.MemberCount = 0
|
||||
|
||||
if req.Config != nil {
|
||||
tenant.Config = req.Config
|
||||
config, err := normalizeTenantConfig(req.Config)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
tenant.Config = config
|
||||
h.DB.Save(tenant)
|
||||
summary.Config = tenant.Config
|
||||
}
|
||||
if err := h.replaceTenantDomains(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if len(normalizedDomains) > 0 {
|
||||
summary.Domains = normalizedDomains
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(summary)
|
||||
}
|
||||
@@ -740,14 +1047,15 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name *string `json:"name"`
|
||||
Type *string `json:"type"`
|
||||
Slug *string `json:"slug"`
|
||||
Description *string `json:"description"`
|
||||
Status *string `json:"status"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Domains []string `json:"domains"`
|
||||
Config map[string]any `json:"config"`
|
||||
Name *string `json:"name"`
|
||||
Type *string `json:"type"`
|
||||
Slug *string `json:"slug"`
|
||||
Description *string `json:"description"`
|
||||
Status *string `json:"status"`
|
||||
ParentID *string `json:"parentId"`
|
||||
Domains []string `json:"domains"`
|
||||
ForceDomains []string `json:"forceDomainConflicts"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, "invalid request body")
|
||||
@@ -835,7 +1143,11 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
if req.Config != nil {
|
||||
tenant.Config = req.Config
|
||||
config, err := normalizeTenantConfig(req.Config)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
tenant.Config = config
|
||||
}
|
||||
|
||||
if err := h.DB.Save(&tenant).Error; err != nil {
|
||||
@@ -844,18 +1156,16 @@ func (h *TenantHandler) UpdateTenant(c *fiber.Ctx) error {
|
||||
|
||||
// Update domains if provided
|
||||
if req.Domains != nil {
|
||||
// Simple approach: Delete existing and recreate
|
||||
if err := h.DB.Delete(&domain.TenantDomain{}, "tenant_id = ?", tenant.ID).Error; err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to clear old domains")
|
||||
normalizedDomains := normalizeTenantDomainInputs(req.Domains)
|
||||
conflicts, err := h.findTenantDomainConflicts(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains)
|
||||
if err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
for _, d := range req.Domains {
|
||||
if strings.TrimSpace(d) == "" {
|
||||
continue
|
||||
}
|
||||
// Use repository for consistency
|
||||
if err := repository.NewTenantRepository(h.DB).AddDomain(c.Context(), tenant.ID, strings.TrimSpace(d), true); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, "failed to add domain: "+d)
|
||||
}
|
||||
if len(conflicts) > 0 {
|
||||
return tenantDomainConflictJSON(c, conflicts)
|
||||
}
|
||||
if err := h.replaceTenantDomains(c.Context(), tenant.ID, normalizedDomains, req.ForceDomains); err != nil {
|
||||
return errorJSON(c, fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -189,7 +189,7 @@ func TestTenantHandler_CreateTenant(t *testing.T) {
|
||||
}
|
||||
body, _ := json.Marshal(input)
|
||||
|
||||
mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", domain.TenantTypeCompany, "", []string{"test.com"}, (*string)(nil), "").
|
||||
mockSvc.On("RegisterTenant", mock.Anything, "Test Tenant", "test-tenant", domain.TenantTypeCompany, "", []string(nil), (*string)(nil), "").
|
||||
Return(&domain.Tenant{ID: "t1", Name: "Test Tenant", Slug: "test-tenant"}, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/tenants", bytes.NewReader(body))
|
||||
@@ -278,15 +278,55 @@ func TestTenantHandler_ExportTenantsCSV(t *testing.T) {
|
||||
|
||||
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(1), nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/tenants/export", nil)
|
||||
req := httptest.NewRequest("GET", "/tenants/export?includeIds=true", nil)
|
||||
resp, _ := app.Test(req)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Contains(t, resp.Header.Get("Content-Disposition"), "tenants.csv")
|
||||
assert.Equal(t, "text/csv", strings.Split(resp.Header.Get("Content-Type"), ";")[0])
|
||||
assert.Contains(t, string(body), "tenant_id,name,type,parent_tenant_id,slug,memo,email_domain")
|
||||
assert.Contains(t, string(body), "t1,Tenant A,COMPANY,parent-1,tenant-a,Primary tenant,tenant-a.example.com;login.tenant-a.example.com")
|
||||
assert.Contains(t, string(body), "tenant_id,name,type,parent_tenant_id,parent_tenant_slug,slug,memo,email_domain")
|
||||
assert.Contains(t, string(body), "t1,Tenant A,COMPANY,parent-1,,tenant-a,Primary tenant,tenant-a.example.com;login.tenant-a.example.com")
|
||||
}
|
||||
|
||||
func TestTenantHandler_ExportTenantsCSV_OmitsIDsAndUsesParentSlug(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockSvc := new(MockTenantService)
|
||||
h := &TenantHandler{Service: mockSvc}
|
||||
|
||||
app.Get("/tenants/export", h.ExportTenantsCSV)
|
||||
|
||||
parentID := "parent-1"
|
||||
tenants := []domain.Tenant{
|
||||
{
|
||||
ID: parentID,
|
||||
Name: "Parent Tenant",
|
||||
Type: domain.TenantTypeCompanyGroup,
|
||||
Slug: "parent-tenant",
|
||||
},
|
||||
{
|
||||
ID: "child-1",
|
||||
Name: "Child Tenant",
|
||||
Type: domain.TenantTypeUserGroup,
|
||||
ParentID: &parentID,
|
||||
Slug: "child-tenant",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(2), nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/tenants/export?includeIds=false", nil)
|
||||
resp, _ := app.Test(req)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
text := string(body)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Contains(t, text, "name,type,parent_tenant_slug,slug,memo,email_domain")
|
||||
assert.Contains(t, text, "Child Tenant,USER_GROUP,parent-tenant,child-tenant,,")
|
||||
assert.NotContains(t, text, "tenant_id")
|
||||
assert.NotContains(t, text, "parent_tenant_id")
|
||||
assert.NotContains(t, text, "child-1")
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) {
|
||||
@@ -304,6 +344,7 @@ func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, writer.Close())
|
||||
|
||||
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return([]domain.Tenant{}, int64(0), nil).Once()
|
||||
mockSvc.On(
|
||||
"RegisterTenant",
|
||||
mock.Anything,
|
||||
@@ -331,6 +372,127 @@ func TestTenantHandler_ImportTenantsCSVCreatesTenant(t *testing.T) {
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTenantHandler_ImportTenantsCSVResolvesParentSlugToID(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockSvc := new(MockTenantService)
|
||||
h := &TenantHandler{Service: mockSvc}
|
||||
|
||||
app.Post("/tenants/import", h.ImportTenantsCSV)
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
part, err := writer.CreateFormFile("file", "tenants.csv")
|
||||
assert.NoError(t, err)
|
||||
_, err = part.Write([]byte("name,type,parent_tenant_slug,slug,memo,email_domain\nParent Tenant,COMPANY,,parent-slug,,\nChild Tenant,USER_GROUP,parent-slug,child-slug,,\n"))
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, writer.Close())
|
||||
|
||||
parentID := "parent-id"
|
||||
mockSvc.On("ListTenants", mock.Anything, 10000, 0, "").Return([]domain.Tenant{}, int64(0), nil).Once()
|
||||
mockSvc.On(
|
||||
"RegisterTenant",
|
||||
mock.Anything,
|
||||
"Parent Tenant",
|
||||
"parent-slug",
|
||||
domain.TenantTypeCompany,
|
||||
"",
|
||||
[]string{},
|
||||
(*string)(nil),
|
||||
"",
|
||||
).Return(&domain.Tenant{ID: parentID, Name: "Parent Tenant", Slug: "parent-slug"}, nil).Once()
|
||||
mockSvc.On(
|
||||
"RegisterTenant",
|
||||
mock.Anything,
|
||||
"Child Tenant",
|
||||
"child-slug",
|
||||
domain.TenantTypeUserGroup,
|
||||
"",
|
||||
[]string{},
|
||||
mock.MatchedBy(func(got *string) bool {
|
||||
return got != nil && *got == parentID
|
||||
}),
|
||||
"",
|
||||
).Return(&domain.Tenant{ID: "child-id", Name: "Child Tenant", Slug: "child-slug"}, nil).Once()
|
||||
|
||||
req := httptest.NewRequest("POST", "/tenants/import", &body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
resp, _ := app.Test(req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
var got map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&got)
|
||||
assert.Equal(t, float64(2), got["created"])
|
||||
assert.Equal(t, float64(0), got["failed"])
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTenantCSVAllowedDomainsRoundTrip(t *testing.T) {
|
||||
records, err := parseTenantCSVRecords(strings.NewReader(
|
||||
"name,type,parent_tenant_slug,slug,memo,email_domain\n" +
|
||||
"Hanmac,COMPANY,,hanmac,,\"samaneng.com, hanmaceng.co.kr;login.hmac.kr\"\n",
|
||||
))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 1)
|
||||
assert.Equal(t, []string{"samaneng.com", "hanmaceng.co.kr", "login.hmac.kr"}, records[0].Domains)
|
||||
}
|
||||
|
||||
func TestNormalizeTenantDomainInputsSplitsCommaAndWhitespace(t *testing.T) {
|
||||
got := normalizeTenantDomainInputs([]string{
|
||||
"samaneng.com, hanmaceng.co.kr",
|
||||
" LOGIN.HMAC.KR\nportal.hmac.kr ",
|
||||
"samaneng.com",
|
||||
})
|
||||
|
||||
assert.Equal(t, []string{
|
||||
"samaneng.com",
|
||||
"hanmaceng.co.kr",
|
||||
"login.hmac.kr",
|
||||
"portal.hmac.kr",
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestNormalizeTenantConfigForcesIndexedForLoginIDFields(t *testing.T) {
|
||||
config, err := normalizeTenantConfig(map[string]any{
|
||||
"userSchema": []any{
|
||||
map[string]any{
|
||||
"key": "emp_no",
|
||||
"label": "사번",
|
||||
"type": "text",
|
||||
"indexed": false,
|
||||
"isLoginId": true,
|
||||
"maxLength": 20,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
fields, ok := config["userSchema"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, fields, 1)
|
||||
|
||||
field, ok := fields[0].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, true, field["indexed"])
|
||||
assert.Equal(t, true, field["isLoginId"])
|
||||
assert.NotContains(t, field, "maxLength")
|
||||
}
|
||||
|
||||
func TestNormalizeTenantConfigRejectsNonTextLoginIDFields(t *testing.T) {
|
||||
_, err := normalizeTenantConfig(map[string]any{
|
||||
"userSchema": []any{
|
||||
map[string]any{
|
||||
"key": "emp_no",
|
||||
"type": "number",
|
||||
"isLoginId": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "login ID fields must be text")
|
||||
}
|
||||
|
||||
func TestTenantHandler_ApproveTenant(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockSvc := new(MockTenantService)
|
||||
|
||||
@@ -425,6 +425,15 @@ func (h *UserHandler) CreateUser(c *fiber.Ctx) error {
|
||||
attributes["tenant_id"] = tenantID
|
||||
}
|
||||
|
||||
if h.UserRepo != nil {
|
||||
if err := h.ensureHanmacCreateEmailAllowed(c.Context(), email, req.CompanyCode, tenantID); err != nil {
|
||||
if strings.Contains(err.Error(), "한맥가족") {
|
||||
return errorJSON(c, fiber.StatusConflict, err.Error())
|
||||
}
|
||||
return errorJSON(c, fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Merge custom metadata into attributes
|
||||
for k, v := range req.Metadata {
|
||||
// Don't overwrite core fields
|
||||
@@ -534,10 +543,14 @@ type bulkUserItem struct {
|
||||
}
|
||||
|
||||
type bulkUserResult struct {
|
||||
Email string `json:"email"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
Email string `json:"email"`
|
||||
OriginalEmail string `json:"originalEmail,omitempty"`
|
||||
SuggestedEmail string `json:"suggestedEmail,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
}
|
||||
|
||||
func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
|
||||
@@ -565,6 +578,9 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
|
||||
|
||||
requester, _ := c.Locals("user_profile").(*domain.UserProfileResponse)
|
||||
results := make([]bulkUserResult, 0, len(req.Users))
|
||||
var hanmacScope *hanmacEmailScope
|
||||
var hanmacLocalParts map[string]bool
|
||||
hanmacScopeLoaded := false
|
||||
|
||||
// Pre-fetch tenant data to avoid redundant DB calls
|
||||
type tenantCacheItem struct {
|
||||
@@ -638,6 +654,53 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
if h.UserRepo != nil && !hanmacScopeLoaded {
|
||||
hanmacScopeLoaded = true
|
||||
var err error
|
||||
hanmacScope, err = h.resolveHanmacEmailScope(c.Context())
|
||||
if err != nil {
|
||||
results = append(results, bulkUserResult{Email: email, Success: false, Message: "failed to resolve Hanmac family tenant scope"})
|
||||
continue
|
||||
}
|
||||
if hanmacScope != nil {
|
||||
hanmacLocalParts, err = h.loadHanmacLocalParts(c.Context(), hanmacScope)
|
||||
if err != nil {
|
||||
results = append(results, bulkUserResult{Email: email, Success: false, Message: "failed to validate Hanmac family email policy"})
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userEmail := email
|
||||
var emailEvaluation hanmacEmailEvaluation
|
||||
if h.UserRepo != nil && hanmacScope != nil && hanmacScope.ContainsTenant(tItem.ID, tenantSlug) {
|
||||
emailEvaluation = h.evaluateHanmacImportEmail(c.Context(), item, hanmacScope, hanmacLocalParts)
|
||||
if emailEvaluation.Blocking {
|
||||
results = append(results, bulkUserResult{
|
||||
Email: emailEvaluation.Email,
|
||||
OriginalEmail: emailEvaluation.OriginalEmail,
|
||||
Status: emailEvaluation.Status,
|
||||
Warnings: emailEvaluation.Warnings,
|
||||
Success: false,
|
||||
Message: emailEvaluation.Message,
|
||||
})
|
||||
continue
|
||||
}
|
||||
userEmail = emailEvaluation.Email
|
||||
if emailEvaluation.LocalPart != "" {
|
||||
hanmacLocalParts[emailEvaluation.LocalPart] = true
|
||||
}
|
||||
} else {
|
||||
if _, _, err := domain.SplitEmailDomain(email); err != nil {
|
||||
results = append(results, bulkUserResult{Email: email, Success: false, Status: "blockingError", Message: "invalid email format"})
|
||||
continue
|
||||
}
|
||||
if localPart, err := domain.ExtractNormalizedEmailLocalPart(email); err != nil || localPart == "" {
|
||||
results = append(results, bulkUserResult{Email: email, Success: false, Status: "blockingError", Message: "invalid email format"})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
password, _ := utils.GeneratePasswordWithPolicy(policy)
|
||||
role := item.Role
|
||||
if role == "" {
|
||||
@@ -665,7 +728,6 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
userEmail := email
|
||||
userPhone := normalizePhoneNumber(item.Phone)
|
||||
|
||||
// Validate all collected LoginIDs
|
||||
@@ -673,7 +735,7 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
|
||||
valid := true
|
||||
for _, lid := range collectedIDs {
|
||||
if err := domain.ValidateLoginID(lid, userEmail, userPhone); err != nil {
|
||||
results = append(results, bulkUserResult{Email: email, Success: false, Message: "Invalid LoginID (" + lid + "): " + err.Error()})
|
||||
results = append(results, bulkUserResult{Email: userEmail, OriginalEmail: emailEvaluation.OriginalEmail, SuggestedEmail: emailEvaluation.SuggestedEmail, Status: emailEvaluation.Status, Warnings: emailEvaluation.Warnings, Success: false, Message: "Invalid LoginID (" + lid + "): " + err.Error()})
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
@@ -692,14 +754,14 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
|
||||
if err != nil {
|
||||
// 만약 이미 존재하는 사용자라면 로컬 DB 및 Keto 관계만 업데이트(Sync)를 시도
|
||||
if strings.Contains(err.Error(), "409") || strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "exists already") {
|
||||
identityID, err = h.KratosAdmin.FindIdentityIDByIdentifier(c.Context(), email)
|
||||
identityID, err = h.KratosAdmin.FindIdentityIDByIdentifier(c.Context(), userEmail)
|
||||
if err != nil || identityID == "" {
|
||||
results = append(results, bulkUserResult{Email: email, Success: false, Message: "이미 다른 사용자가 해당 식별자(이메일/사번 등)를 사용 중입니다."})
|
||||
results = append(results, bulkUserResult{Email: userEmail, OriginalEmail: emailEvaluation.OriginalEmail, SuggestedEmail: emailEvaluation.SuggestedEmail, Status: "blockingError", Warnings: emailEvaluation.Warnings, Success: false, Message: "이미 다른 사용자가 해당 식별자(이메일/사번 등)를 사용 중입니다."})
|
||||
continue
|
||||
}
|
||||
slog.Info("BulkCreate: User already exists, syncing local DB and Keto", "email", email, "identityID", identityID)
|
||||
} else {
|
||||
results = append(results, bulkUserResult{Email: email, Success: false, Message: err.Error()})
|
||||
results = append(results, bulkUserResult{Email: userEmail, OriginalEmail: emailEvaluation.OriginalEmail, SuggestedEmail: emailEvaluation.SuggestedEmail, Status: emailEvaluation.Status, Warnings: emailEvaluation.Warnings, Success: false, Message: err.Error()})
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -709,7 +771,7 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
|
||||
if h.UserRepo != nil {
|
||||
localUser := &domain.User{
|
||||
ID: identityID,
|
||||
Email: email,
|
||||
Email: userEmail,
|
||||
Name: name,
|
||||
Phone: normalizePhoneNumber(item.Phone),
|
||||
Role: role,
|
||||
@@ -776,7 +838,15 @@ func (h *UserHandler) BulkCreateUsers(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, bulkUserResult{Email: email, Success: true, UserID: identityID})
|
||||
results = append(results, bulkUserResult{
|
||||
Email: userEmail,
|
||||
OriginalEmail: emailEvaluation.OriginalEmail,
|
||||
SuggestedEmail: emailEvaluation.SuggestedEmail,
|
||||
Status: emailEvaluation.Status,
|
||||
Warnings: emailEvaluation.Warnings,
|
||||
Success: true,
|
||||
UserID: identityID,
|
||||
})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
@@ -870,12 +940,19 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
|
||||
defer writer.Flush()
|
||||
|
||||
// Header row
|
||||
header := []string{"ID", "Email", "Name", "Phone", "Status", "Tenant", "Position", "JobTitle", "CreatedAt"}
|
||||
includeIDs := includeCSVIds(c)
|
||||
header := []string{"Email", "Name", "Phone", "Status", "tenant_slug", "Position", "JobTitle", "CreatedAt"}
|
||||
if includeIDs {
|
||||
header = []string{"user_id", "Email", "Name", "Phone", "Status", "tenant_id", "tenant_slug", "Position", "JobTitle", "CreatedAt"}
|
||||
}
|
||||
|
||||
// Collect all possible metadata keys for dynamic columns
|
||||
metaKeysMap := make(map[string]bool)
|
||||
for _, u := range filtered {
|
||||
for k := range u.Metadata {
|
||||
if !includeIDs && csvMetadataKeyIsID(k) {
|
||||
continue
|
||||
}
|
||||
metaKeysMap[k] = true
|
||||
}
|
||||
}
|
||||
@@ -891,8 +968,11 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
|
||||
|
||||
// Data rows
|
||||
for _, u := range filtered {
|
||||
tenantID := ""
|
||||
if u.TenantID != nil {
|
||||
tenantID = *u.TenantID
|
||||
}
|
||||
row := []string{
|
||||
u.ID,
|
||||
u.Email,
|
||||
u.Name,
|
||||
u.Phone,
|
||||
@@ -902,6 +982,20 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
|
||||
u.JobTitle,
|
||||
u.CreatedAt.Format(time.RFC3339),
|
||||
}
|
||||
if includeIDs {
|
||||
row = []string{
|
||||
u.ID,
|
||||
u.Email,
|
||||
u.Name,
|
||||
u.Phone,
|
||||
u.Status,
|
||||
tenantID,
|
||||
u.CompanyCode,
|
||||
u.Position,
|
||||
u.JobTitle,
|
||||
u.CreatedAt.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
// Append metadata values in order
|
||||
for _, k := range metaKeys {
|
||||
val := ""
|
||||
@@ -918,6 +1012,11 @@ func (h *UserHandler) ExportUsersCSV(c *fiber.Ctx) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func csvMetadataKeyIsID(key string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(key))
|
||||
return normalized == "id" || normalized == "user_id" || normalized == "tenant_id" || normalized == "tenantid"
|
||||
}
|
||||
|
||||
func (h *UserHandler) BulkUpdateUsers(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
UserIDs []string `json:"userIds"`
|
||||
|
||||
@@ -126,6 +126,14 @@ func (m *MockTenantServiceForUser) ListManageableTenants(ctx context.Context, us
|
||||
return args.Get(0).([]domain.Tenant), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForUser) ListTenants(ctx context.Context, limit, offset int, parentID string) ([]domain.Tenant, int64, error) {
|
||||
args := m.Called(ctx, limit, offset, parentID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Get(1).(int64), args.Error(2)
|
||||
}
|
||||
return args.Get(0).([]domain.Tenant), args.Get(1).(int64), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockTenantServiceForUser) ProvisionTenantByDomain(ctx context.Context, domainName string) (*domain.Tenant, error) {
|
||||
args := m.Called(ctx, domainName)
|
||||
if args.Get(0) == nil {
|
||||
@@ -167,20 +175,66 @@ func TestUserHandler_ExportUsersCSV_UsesTenantSlugAliasAndOmitsRole(t *testing.T
|
||||
},
|
||||
}, int64(1), nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/users/export?tenantSlug=test-tenant", nil)
|
||||
req := httptest.NewRequest("GET", "/users/export?tenantSlug=test-tenant&includeIds=true", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
body := strings.TrimPrefix(string(bodyBytes), "\ufeff")
|
||||
assert.Contains(t, body, "ID,Email,Name,Phone,Status,Tenant,Position,JobTitle,CreatedAt")
|
||||
assert.Contains(t, body, "u-1,user@test.com,Test User,010-1111-2222,active,test-tenant")
|
||||
assert.Contains(t, body, "user_id,Email,Name,Phone,Status,tenant_id,tenant_slug,Position,JobTitle,CreatedAt")
|
||||
assert.Contains(t, body, "u-1,user@test.com,Test User,010-1111-2222,active,,test-tenant")
|
||||
assert.NotContains(t, body, "Role")
|
||||
assert.NotContains(t, body, "Department")
|
||||
mockRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserHandler_ExportUsersCSV_OmitsIDsAndUsesTenantSlug(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockRepo := new(MockUserRepoForHandler)
|
||||
h := &UserHandler{UserRepo: mockRepo}
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals("user_profile", &domain.UserProfileResponse{
|
||||
Role: domain.RoleSuperAdmin,
|
||||
})
|
||||
return c.Next()
|
||||
})
|
||||
app.Get("/users/export", h.ExportUsersCSV)
|
||||
|
||||
createdAt := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
|
||||
tenantID := "tenant-uuid"
|
||||
mockRepo.On("List", mock.Anything, 0, 10000, "", "").
|
||||
Return([]domain.User{
|
||||
{
|
||||
ID: "user-uuid",
|
||||
Email: "user@test.com",
|
||||
Name: "Test User",
|
||||
Phone: "010-1111-2222",
|
||||
Status: "active",
|
||||
CompanyCode: "test-tenant",
|
||||
TenantID: &tenantID,
|
||||
Position: "책임",
|
||||
JobTitle: "플랫폼 운영",
|
||||
CreatedAt: createdAt,
|
||||
},
|
||||
}, int64(1), nil).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/users/export?includeIds=false", nil)
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
body := strings.TrimPrefix(string(bodyBytes), "\ufeff")
|
||||
assert.Contains(t, body, "Email,Name,Phone,Status,tenant_slug,Position,JobTitle,CreatedAt")
|
||||
assert.Contains(t, body, "user@test.com,Test User,010-1111-2222,active,test-tenant")
|
||||
assert.NotContains(t, body, "user-uuid")
|
||||
assert.NotContains(t, body, "tenant-uuid")
|
||||
assert.NotContains(t, body, "ID,")
|
||||
mockRepo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserHandler_BulkCreateUsers(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockKratos := new(MockKratosAdmin)
|
||||
@@ -355,6 +409,170 @@ func TestUserHandler_BulkCreateUsers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserHandler_BulkCreateUsers_HanmacEmailPolicy(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockKratos := new(MockKratosAdmin)
|
||||
mockOry := new(MockOryProvider)
|
||||
mockTenant := new(MockTenantServiceForUser)
|
||||
mockRepo := new(MockUserRepoForHandler)
|
||||
|
||||
h := &UserHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
OryProvider: mockOry,
|
||||
TenantService: mockTenant,
|
||||
UserRepo: mockRepo,
|
||||
}
|
||||
|
||||
app.Post("/users/bulk", h.BulkCreateUsers)
|
||||
|
||||
rootID := "hanmac-family-id"
|
||||
companyID := "hanmac-id"
|
||||
tenants := []domain.Tenant{
|
||||
{ID: rootID, Slug: "hanmac-family", Name: "한맥가족"},
|
||||
{ID: companyID, Slug: "hanmac", Name: "한맥기술", ParentID: &rootID},
|
||||
{ID: "external-id", Slug: "external", Name: "외부사"},
|
||||
}
|
||||
|
||||
t.Run("domain only email receives suggested final email with next suffix", func(t *testing.T) {
|
||||
mockTenant.On("GetTenantBySlug", mock.Anything, "hanmac").Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
}, nil).Once()
|
||||
mockTenant.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
}, nil).Maybe()
|
||||
mockTenant.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil).Once()
|
||||
mockRepo.On("FindByTenantIDs", mock.Anything, []string{rootID, companyID}).Return([]domain.User{
|
||||
{Email: "cyhan@hanmaceng.co.kr", CompanyCode: "hanmac", TenantID: &companyID},
|
||||
{Email: "cyhan1@samaneng.com", CompanyCode: "hanmac", TenantID: &companyID},
|
||||
}, nil).Once()
|
||||
mockRepo.On("FindByCompanyCodes", mock.Anything, []string{"hanmac-family", "hanmac"}).Return([]domain.User{}, nil).Once()
|
||||
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
|
||||
mockOry.On("CreateUser", mock.MatchedBy(func(user *domain.BrokerUser) bool {
|
||||
return user.Email == "cyhan2@hanmaceng.co.kr"
|
||||
}), mock.Anything).Return("u-hanmac", nil).Once()
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"users": []map[string]interface{}{
|
||||
{
|
||||
"email": "@hanmaceng.co.kr",
|
||||
"name": "한치영",
|
||||
"tenantSlug": "hanmac",
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest("POST", "/users/bulk", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
results := result["results"].([]interface{})
|
||||
row := results[0].(map[string]interface{})
|
||||
assert.True(t, row["success"].(bool))
|
||||
assert.Equal(t, "cyhan2@hanmaceng.co.kr", row["email"])
|
||||
assert.Equal(t, "@hanmaceng.co.kr", row["originalEmail"])
|
||||
assert.Contains(t, row["warnings"].([]interface{}), "suggested")
|
||||
})
|
||||
|
||||
t.Run("full email duplicate local part is blocking error", func(t *testing.T) {
|
||||
mockTenant.On("GetTenantBySlug", mock.Anything, "hanmac").Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
}, nil).Once()
|
||||
mockTenant.On("GetTenant", mock.Anything, companyID).Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
}, nil).Maybe()
|
||||
mockTenant.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil).Once()
|
||||
mockRepo.On("FindByTenantIDs", mock.Anything, []string{rootID, companyID}).Return([]domain.User{
|
||||
{Email: "han@hanmaceng.co.kr", CompanyCode: "hanmac", TenantID: &companyID},
|
||||
}, nil).Once()
|
||||
mockRepo.On("FindByCompanyCodes", mock.Anything, []string{"hanmac-family", "hanmac"}).Return([]domain.User{}, nil).Once()
|
||||
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"users": []map[string]interface{}{
|
||||
{
|
||||
"email": "han@samaneng.com",
|
||||
"name": "한치영",
|
||||
"tenantSlug": "hanmac",
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest("POST", "/users/bulk", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
results := result["results"].([]interface{})
|
||||
row := results[0].(map[string]interface{})
|
||||
assert.False(t, row["success"].(bool))
|
||||
assert.Equal(t, "blockingError", row["status"])
|
||||
assert.Contains(t, row["message"].(string), "한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserHandler_CreateUser_HanmacEmailPolicyBlocksDuplicateLocalPart(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockKratos := new(MockKratosAdmin)
|
||||
mockOry := new(MockOryProvider)
|
||||
mockTenant := new(MockTenantServiceForUser)
|
||||
mockRepo := new(MockUserRepoForHandler)
|
||||
|
||||
h := &UserHandler{
|
||||
KratosAdmin: mockKratos,
|
||||
OryProvider: mockOry,
|
||||
TenantService: mockTenant,
|
||||
UserRepo: mockRepo,
|
||||
}
|
||||
|
||||
app.Post("/users", h.CreateUser)
|
||||
|
||||
rootID := "hanmac-family-id"
|
||||
companyID := "hanmac-id"
|
||||
tenants := []domain.Tenant{
|
||||
{ID: rootID, Slug: "hanmac-family", Name: "한맥가족"},
|
||||
{ID: companyID, Slug: "hanmac", Name: "한맥기술", ParentID: &rootID},
|
||||
}
|
||||
|
||||
mockOry.On("GetPasswordPolicy").Return(&domain.PasswordPolicy{MinLength: 8}, nil).Once()
|
||||
mockTenant.On("GetTenantBySlug", mock.Anything, "hanmac").Return(&domain.Tenant{
|
||||
ID: companyID,
|
||||
Slug: "hanmac",
|
||||
}, nil).Once()
|
||||
mockTenant.On("ListTenants", mock.Anything, 10000, 0, "").Return(tenants, int64(len(tenants)), nil).Once()
|
||||
mockRepo.On("FindByTenantIDs", mock.Anything, []string{rootID, companyID}).Return([]domain.User{
|
||||
{Email: "han@hanmaceng.co.kr", CompanyCode: "hanmac", TenantID: &companyID},
|
||||
}, nil).Once()
|
||||
mockRepo.On("FindByCompanyCodes", mock.Anything, []string{"hanmac-family", "hanmac"}).Return([]domain.User{}, nil).Once()
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"email": "han@samaneng.com",
|
||||
"name": "한치영",
|
||||
"companyCode": "hanmac",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest("POST", "/users", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, http.StatusConflict, resp.StatusCode)
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.Contains(t, result["error"].(string), "한맥가족 내에서 이미 사용 중인 이메일 ID입니다.")
|
||||
mockOry.AssertNotCalled(t, "CreateUser")
|
||||
}
|
||||
|
||||
func TestUserHandler_BulkUpdateUsers(t *testing.T) {
|
||||
app := fiber.New()
|
||||
mockKratos := new(MockKratosAdmin)
|
||||
|
||||
Reference in New Issue
Block a user